ARMv8/extension/examples/MNIST/mnist_train.c

74 lines
2.3 KiB
C

/** @brief Trains the neural network on the MNIST dataset and tests it.
*
* A neural network of 784 input neurons, 80 hidden neurons, and 10 output neurons is created.
* The MNIST is trained on 10K images and tested on another unseen 10K images.
* The accuracy of the network is then printed.
* In its current form, the test takes just over 2 minutes to run. This
* obviously depends on the machine, but expect few minutes.
* The accuracy is around 95%!
*
* @author Saleh Bubshait
*/
#include "../../src/network.h"
#include "mnist_loader.h"
/* Experiments and Perforamnce:
* 1. 95% accuracy, 2 minutes training, parameters: 784, 150, 10, 0.01, 10 epochs
* 2. ~98% accuracy with 784, 80, 10, 0.2, 25 epochs
*/
int main(void) {
// 784 because pics are 28x28
int dimensions[] = {784, 80, 10};
Network *network = network_create(3, dimensions);
// load_mnist is a function from mnist_loader.h. This is an external library,
// used for the sake of this test.
load_mnist(train_image, train_label, test_image, test_label);
printf("STARTING TRAINING.\nNOTE: THIS MAY TAKE FEW MINUTES TO TRAIN. PLEASE WAIT!\n");
int epochs = 25;
Matrix *input = matrix_create(784, 1);
Matrix *target = matrix_create(10, 1);
for (int i = 0; i < epochs; i++) {
for (int j = 0; j < NUM_TEST; j++) {
matrix_initialise(input, train_image[j]);
matrix_fill(target, 0);
target->data[train_label[j]][0] = 1;
network_train(network, input, target, 0.2);
}
}
int correct = 0;
for (int i = 0; i < NUM_TEST; i++) {
matrix_initialise(input, test_image[i]);
network_predict(network, input);
int maxIndex = 0;
for (int j = 1; j < 10; j++) {
if (network_output(network)[j] > network_output(network)[maxIndex]) {
maxIndex = j;
}
}
if (maxIndex == test_label[i]) {
correct++;
}
}
putchar('\n');
printf("====================================\n");
printf("Accuracy: %lf \n", (double)correct / NUM_TEST);
printf("====================================\n");
// Save the trained network
network_store(network, "../examples/MNIST/data/trained_mnist");
matrix_free(input);
matrix_free(target);
network_free(network);
return 0;
}