ARMv8/extension/examples/MNIST/test_mnist.c

51 lines
1.7 KiB
C

/** @brief Shows the output of the network after being trained on the MNIST dataset.
* NOTE! This does not actually train the network. It only shows the output of the network.
* The output is exactly the one from mnist_train.c. We have trained it so you can see the accuracy.
* The accuracy is 95%. Please feel free to train it yourself using mnist_train.c, it takes about 3 minutes
* although that of course depends on your machine.
*
* @author Saleh Bubshait
*/
#include "../../src/network.h"
#include "mnist_loader.h"
int main(void) {
load_mnist(train_image, train_label, test_image, test_label);
printf("LOADING TRAINED MNIST DATASET. PLEASE WAIT!\n");
// Load the trained neural network. This is the one trained in mnist_train.c
Network *network = network_load("examples/MNIST/data/trained_mnist");
Matrix *input = matrix_create(784, 1);
Matrix *target = matrix_create(10, 1);
int correct = 0;
for (int i = 0; i < NUM_TEST; i++) {
matrix_initialise(input, test_image[i]);
network_predict(network, input);
int maxIndex = 0;
for (int j = 1; j < 10; j++) {
if (network_output(network)[j] > network_output(network)[maxIndex]) {
maxIndex = j;
}
}
if (maxIndex == test_label[i]) {
correct++;
}
}
putchar('\n');
printf("====================================\n");
printf("Accuracy: %lf \n", (double)correct / NUM_TEST);
printf("====================================\n");
// Save the trained network
network_store(network, "trained_mnist");
matrix_free(input);
matrix_free(target);
network_free(network);
return 0;
}