51 lines
1.7 KiB
C
51 lines
1.7 KiB
C
/** @brief Shows the output of the network after being trained on the MNIST dataset.
|
|
* NOTE! This does not actually train the network. It only shows the output of the network.
|
|
* The output is exactly the one from mnist_train.c. We have trained it so you can see the accuracy.
|
|
* The accuracy is 95%. Please feel free to train it yourself using mnist_train.c, it takes about 3 minutes
|
|
* although that of course depends on your machine.
|
|
*
|
|
* @author Saleh Bubshait
|
|
*/
|
|
|
|
#include "../../src/network.h"
|
|
#include "mnist_loader.h"
|
|
|
|
int main(void) {
|
|
|
|
load_mnist(train_image, train_label, test_image, test_label);
|
|
printf("LOADING TRAINED MNIST DATASET. PLEASE WAIT!\n");
|
|
// Load the trained neural network. This is the one trained in mnist_train.c
|
|
Network *network = network_load("examples/MNIST/data/trained_mnist");
|
|
|
|
Matrix *input = matrix_create(784, 1);
|
|
Matrix *target = matrix_create(10, 1);
|
|
int correct = 0;
|
|
|
|
for (int i = 0; i < NUM_TEST; i++) {
|
|
matrix_initialise(input, test_image[i]);
|
|
|
|
network_predict(network, input);
|
|
int maxIndex = 0;
|
|
for (int j = 1; j < 10; j++) {
|
|
if (network_output(network)[j] > network_output(network)[maxIndex]) {
|
|
maxIndex = j;
|
|
}
|
|
}
|
|
if (maxIndex == test_label[i]) {
|
|
correct++;
|
|
}
|
|
}
|
|
putchar('\n');
|
|
printf("====================================\n");
|
|
printf("Accuracy: %lf \n", (double)correct / NUM_TEST);
|
|
printf("====================================\n");
|
|
|
|
// Save the trained network
|
|
network_store(network, "trained_mnist");
|
|
|
|
matrix_free(input);
|
|
matrix_free(target);
|
|
network_free(network);
|
|
|
|
return 0;
|
|
} |