/** @brief Shows the output of the network after being trained on the MNIST dataset. * NOTE! This does not actually train the network. It only shows the output of the network. * The output is exactly the one from mnist_train.c. We have trained it so you can see the accuracy. * The accuracy is 95%. Please feel free to train it yourself using mnist_train.c, it takes about 3 minutes * although that of course depends on your machine. * * @author Saleh Bubshait */ #include "../../src/network.h" #include "mnist_loader.h" int main(void) { load_mnist(train_image, train_label, test_image, test_label); printf("LOADING TRAINED MNIST DATASET. PLEASE WAIT!\n"); // Load the trained neural network. This is the one trained in mnist_train.c Network *network = network_load("examples/MNIST/data/trained_mnist"); Matrix *input = matrix_create(784, 1); Matrix *target = matrix_create(10, 1); int correct = 0; for (int i = 0; i < NUM_TEST; i++) { matrix_initialise(input, test_image[i]); network_predict(network, input); int maxIndex = 0; for (int j = 1; j < 10; j++) { if (network_output(network)[j] > network_output(network)[maxIndex]) { maxIndex = j; } } if (maxIndex == test_label[i]) { correct++; } } putchar('\n'); printf("====================================\n"); printf("Accuracy: %lf \n", (double)correct / NUM_TEST); printf("====================================\n"); // Save the trained network network_store(network, "trained_mnist"); matrix_free(input); matrix_free(target); network_free(network); return 0; }