/** @brief Trains the neural network on the MNIST dataset and tests it. * * A neural network of 784 input neurons, 80 hidden neurons, and 10 output neurons is created. * The MNIST is trained on 10K images and tested on another unseen 10K images. * The accuracy of the network is then printed. * In its current form, the test takes just over 2 minutes to run. This * obviously depends on the machine, but expect few minutes. * The accuracy is around 95%! * * @author Saleh Bubshait */ #include "../../src/network.h" #include "mnist_loader.h" /* Experiments and Perforamnce: * 1. 95% accuracy, 2 minutes training, parameters: 784, 150, 10, 0.01, 10 epochs * 2. ~98% accuracy with 784, 80, 10, 0.2, 25 epochs */ int main(void) { // 784 because pics are 28x28 int dimensions[] = {784, 80, 10}; Network *network = network_create(3, dimensions); // load_mnist is a function from mnist_loader.h. This is an external library, // used for the sake of this test. load_mnist(train_image, train_label, test_image, test_label); printf("STARTING TRAINING.\nNOTE: THIS MAY TAKE FEW MINUTES TO TRAIN. PLEASE WAIT!\n"); int epochs = 25; Matrix *input = matrix_create(784, 1); Matrix *target = matrix_create(10, 1); for (int i = 0; i < epochs; i++) { for (int j = 0; j < NUM_TEST; j++) { matrix_initialise(input, train_image[j]); matrix_fill(target, 0); target->data[train_label[j]][0] = 1; network_train(network, input, target, 0.2); } } int correct = 0; for (int i = 0; i < NUM_TEST; i++) { matrix_initialise(input, test_image[i]); network_predict(network, input); int maxIndex = 0; for (int j = 1; j < 10; j++) { if (network_output(network)[j] > network_output(network)[maxIndex]) { maxIndex = j; } } if (maxIndex == test_label[i]) { correct++; } } putchar('\n'); printf("====================================\n"); printf("Accuracy: %lf \n", (double)correct / NUM_TEST); printf("====================================\n"); // Save the trained network network_store(network, "../examples/MNIST/data/trained_mnist"); matrix_free(input); matrix_free(target); network_free(network); return 0; }