ARMv8/extension/tests/xor.c

54 lines
1.7 KiB
C

/** @file xor.c
* @brief A test program for the Neural Network library to learn the XOR function.
* The program trains the network on the XOR function and then tests its accuracy.
* The model is a very simple one with only one hidden layer of 3 neurons.
* The accuracy is then printed using the mean squared errror calculation.
*
* @author Saleh Bubshait
*/
#include "../src/network.h"
int main(void) {
int dimensions[] = {2, 3, 1};
Network *network = network_create(3, dimensions);
double inputs[4][2] = {{0, 0}, {0, 1}, {1, 0}, {1, 1}};
double targets[4][1] = {{0}, {1}, {1}, {0}};
int epochs = 15000;
for (int i = 0; i < epochs; i++) {
for (int j = 0; j < 4; j++) {
Matrix *input = matrix_create(2, 1);
matrix_initialise(input, (double[2][1]){inputs[j][0], inputs[j][1]});
Matrix *target = matrix_create(1, 1);
matrix_initialise(target, (double[1][1]){targets[j][0]});
network_train(network, input, target, 0.1);
matrix_free(input);
matrix_free(target);
}
}
double error = 0;
for (int i = 0; i < 4; i++) {
Matrix *input = matrix_create(2, 1);
matrix_initialise(input, (double[2][1]){inputs[i][0], inputs[i][1]});
network_predict(network, input);
printf("Input: %d %d, Output: %f\n", (int)inputs[i][0], (int)inputs[i][1], network_output(network)[0]);
error += (network_output(network)[0] - targets[i][0]) * (network_output(network)[0] - targets[i][0]);
matrix_free(input);
}
printf("====================================\n");
printf("Accuracy: %lf \n", 1 - error / 4); // Mean Squared Error
network_free(network);
return 0;
}