74 lines
2.3 KiB
C
74 lines
2.3 KiB
C
/** @brief Trains the neural network on the MNIST dataset and tests it.
|
|
*
|
|
* A neural network of 784 input neurons, 80 hidden neurons, and 10 output neurons is created.
|
|
* The MNIST is trained on 10K images and tested on another unseen 10K images.
|
|
* The accuracy of the network is then printed.
|
|
* In its current form, the test takes just over 2 minutes to run. This
|
|
* obviously depends on the machine, but expect few minutes.
|
|
* The accuracy is around 95%!
|
|
*
|
|
* @author Saleh Bubshait
|
|
*/
|
|
|
|
#include "../../src/network.h"
|
|
#include "mnist_loader.h"
|
|
|
|
|
|
/* Experiments and Perforamnce:
|
|
* 1. 95% accuracy, 2 minutes training, parameters: 784, 150, 10, 0.01, 10 epochs
|
|
* 2. ~98% accuracy with 784, 80, 10, 0.2, 25 epochs
|
|
*/
|
|
|
|
int main(void) {
|
|
// 784 because pics are 28x28
|
|
int dimensions[] = {784, 80, 10};
|
|
Network *network = network_create(3, dimensions);
|
|
|
|
// load_mnist is a function from mnist_loader.h. This is an external library,
|
|
// used for the sake of this test.
|
|
load_mnist(train_image, train_label, test_image, test_label);
|
|
printf("STARTING TRAINING.\nNOTE: THIS MAY TAKE FEW MINUTES TO TRAIN. PLEASE WAIT!\n");
|
|
|
|
int epochs = 25;
|
|
Matrix *input = matrix_create(784, 1);
|
|
Matrix *target = matrix_create(10, 1);
|
|
|
|
for (int i = 0; i < epochs; i++) {
|
|
for (int j = 0; j < NUM_TEST; j++) {
|
|
matrix_initialise(input, train_image[j]);
|
|
matrix_fill(target, 0);
|
|
target->data[train_label[j]][0] = 1;
|
|
|
|
network_train(network, input, target, 0.2);
|
|
}
|
|
}
|
|
int correct = 0;
|
|
|
|
for (int i = 0; i < NUM_TEST; i++) {
|
|
matrix_initialise(input, test_image[i]);
|
|
|
|
network_predict(network, input);
|
|
int maxIndex = 0;
|
|
for (int j = 1; j < 10; j++) {
|
|
if (network_output(network)[j] > network_output(network)[maxIndex]) {
|
|
maxIndex = j;
|
|
}
|
|
}
|
|
if (maxIndex == test_label[i]) {
|
|
correct++;
|
|
}
|
|
}
|
|
putchar('\n');
|
|
printf("====================================\n");
|
|
printf("Accuracy: %lf \n", (double)correct / NUM_TEST);
|
|
printf("====================================\n");
|
|
|
|
// Save the trained network
|
|
network_store(network, "../examples/MNIST/data/trained_mnist");
|
|
|
|
matrix_free(input);
|
|
matrix_free(target);
|
|
network_free(network);
|
|
|
|
return 0;
|
|
} |