diff --git a/extension/.gitignore b/extension/.gitignore new file mode 100755 index 0000000..93edde8 --- /dev/null +++ b/extension/.gitignore @@ -0,0 +1,374 @@ +#This the the canonical lab .gitignore file + +#ignore build +build/ +output/ +out/ +.idea/ +my_network +#ignore temp files +*~ + +#ignore pdf files (just keep source files) +*.pdf + +#ignore junk files from latex output +*.out +*.log +*.aux +*.dvi +*.ps + +#ignore junk files from compiling C code +*.o +emulate +assemble + +#ignore junk files from compiling Haskell code +*.hi + +#ignore junk files from compiling Java code +*.class + +#ignore other junk files +*.backup +*.kate-swp +*.swp +*.snm +*.vrb +*.nav +*.toc + +# Prerequisites +*.d + +# Object files +*.o +*.ko +*.obj +*.elf + +# Linker output +*.ilk +*.map +*.exp + +# Precompiled Headers +*.gch +*.pch + +# Libraries +*.lib +*.a +*.la +*.lo + +# Shared objects (inc. Windows DLLs) +*.dll +*.so +*.so.* +*.dylib + +# Executables +*.exe +*.out +*.app +*.i*86 +*.x86_64 +*.hex + +# Debug files +*.dSYM/ +*.su +*.idb +*.pdb + +# Kernel Module Compile Results +*.mod* +*.cmd +.tmp_versions/ +modules.order +Module.symvers +Mkfile.old +dkms.conf + + +### OS Specific ### + + +## LINUX ## + +*~ + +# temporary files which can be created if a process still has a handle open of a deleted file +.fuse_hidden* + +# KDE directory preferences +.directory + +# Linux trash folder which might appear on any partition or disk +.Trash-* + +# .nfs files are created when an open file is removed but is still being accessed +.nfs* + +## MACOS ## +# General +.DS_Store +.AppleDouble +.LSOverride + +# Icon must end with two \r +Icon + + +# Thumbnails +._* + +# Files that might appear in the root of a volume +.DocumentRevisions-V100 +.fseventsd +.Spotlight-V100 +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent + +# Directories potentially created on remote AFP share +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk + + +## WINDOWS ## +# Windows thumbnail cache files +Thumbs.db +Thumbs.db:encryptable +ehthumbs.db +ehthumbs_vista.db + +# Dump file +*.stackdump + +# Folder config file +[Dd]esktop.ini + +# Recycle Bin used on file shares +$RECYCLE.BIN/ + +# Windows Installer files +*.cab +*.msi +*.msix +*.msm +*.msp + +# Windows shortcuts +*.lnk + + +### Text Editors ### + +## Visual Studio Code ## +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +!.vscode/*.code-snippets + +# Local History for Visual Studio Code +.history/ + +# Built Visual Studio Code Extensions +*.vsix + +## Eclipse ## +.metadata +bin/ +tmp/ +*.tmp +*.bak +*.swp +*~.nib +local.properties +.settings/ +.loadpath +.recommenders + +# External tool builders +.externalToolBuilders/ + +# Locally stored "Eclipse launch configurations" +*.launch + +# PyDev specific (Python IDE for Eclipse) +*.pydevproject + +# CDT-specific (C/C++ Development Tooling) +.cproject + +# CDT- autotools +.autotools + +# Java annotation processor (APT) +.factorypath + +# PDT-specific (PHP Development Tools) +.buildpath + +# sbteclipse plugin +.target + +# Tern plugin +.tern-project + +# TeXlipse plugin +.texlipse + +# STS (Spring Tool Suite) +.springBeans + +# Code Recommenders +.recommenders/ + +# Annotation Processing +.apt_generated/ +.apt_generated_test/ + +# Scala IDE specific (Scala & Java development for Eclipse) +.cache-main +.scala_dependencies +.worksheet + +# Uncomment this line if you wish to ignore the project description file. +# Typically, this file would be tracked if it contains build/dependency configurations: +#.project + +## JetBrains ## +.metadata +bin/ +tmp/ +*.tmp +*.bak +*.swp +*~.nib +local.properties +.settings/ +.loadpath +.recommenders + +# External tool builders +.externalToolBuilders/ + +# Locally stored "Eclipse launch configurations" +*.launch + +# PyDev specific (Python IDE for Eclipse) +*.pydevproject + +# CDT-specific (C/C++ Development Tooling) +.cproject + +# CDT- autotools +.autotools + +# Java annotation processor (APT) +.factorypath + +# PDT-specific (PHP Development Tools) +.buildpath + +# sbteclipse plugin +.target + +# Tern plugin +.tern-project + +# TeXlipse plugin +.texlipse + +# STS (Spring Tool Suite) +.springBeans + +# Code Recommenders +.recommenders/ + +# Annotation Processing +.apt_generated/ +.apt_generated_test/ + +# Scala IDE specific (Scala & Java development for Eclipse) +.cache-main +.scala_dependencies +.worksheet + +# Uncomment this line if you wish to ignore the project description file. +# Typically, this file would be tracked if it contains build/dependency configurations: +#.project + +## Vim ## +# Swap +[._]*.s[a-v][a-z] +!*.svg # comment out if you don't need vector files +[._]*.sw[a-p] +[._]s[a-rt-v][a-z] +[._]ss[a-gi-z] +[._]sw[a-p] + +# Session +Session.vim +Sessionx.vim + +# Temporary +.netrwhist +*~ +# Auto-generated tag files +tags +# Persistent undo +[._]*.un~ + +## Submlime Text ## +# OS metadata +.DS_Store +Thumbs.db + +# Node +node_modules +package-lock.json + +# TypeScript +*.tsbuildinfo + +# build cache +.rollup.cache + +# Build directories +dist +nuclide/dist-nuclide-* +dist-nuclide +commons-atom +commons-ui + +## XCode ## +## User settings +xcuserdata/ + +## Xcode 8 and earlier +*.xcscmblueprint +*.xccheckout + +## Note Pad ++ ## +# Notepad++ backups # +*.bak + +## LibreOffice ## +# LibreOffice locks +.~lock.*# + +## Cloud9 ## +# Cloud9 IDE - http://c9.io +.c9revisions +.c9 \ No newline at end of file diff --git a/extension/Makefile b/extension/Makefile new file mode 100644 index 0000000..723d07a --- /dev/null +++ b/extension/Makefile @@ -0,0 +1,52 @@ +# Compiler and flags +CC = gcc +CFLAGS = -Wall -Wextra -std=c99 -I$(UNITY_DIR)/src +### Add -g if you want to debug. It is removed to avoid OS-dependent build files. + +# Linker flags +LDLIBS = -lm + +# Directories +SRC_DIR = src +TEST_DIR = tests +UNITY_DIR = $(TEST_DIR)/unity +BUILD_DIR = build +EXAMPLES_DIR = examples +MNIST_DIR = $(EXAMPLES_DIR)/MNIST + +# Unity +UNITY_SRC = $(UNITY_DIR)/src/unity.c + +# Executable names +TEST_MATRIX = $(BUILD_DIR)/test_matrix +TEST_NETWORK = $(BUILD_DIR)/test_network +TEST_XOR = $(BUILD_DIR)/test_xor +TEST_MNIST = $(BUILD_DIR)/test_mnist +MNIST_TRAIN = $(BUILD_DIR)/mnist_train + +all: $(BUILD_DIR) $(TEST_MATRIX) $(TEST_NETWORK) $(TEST_XOR) $(TEST_MNIST) $(MNIST_TRAIN) + +$(BUILD_DIR): + mkdir -p $(BUILD_DIR) + +$(TEST_MATRIX): $(UNITY_SRC) $(SRC_DIR)/matrix.c $(TEST_DIR)/test_matrix.c + $(CC) $(CFLAGS) -o $@ $^ $(LDLIBS) + +$(TEST_NETWORK): $(UNITY_SRC) $(SRC_DIR)/network.c $(SRC_DIR)/matrix.c $(TEST_DIR)/test_network.c + $(CC) $(CFLAGS) -o $@ $^ $(LDLIBS) + +$(TEST_XOR): $(SRC_DIR)/network.c $(SRC_DIR)/matrix.c $(TEST_DIR)/xor.c + $(CC) $(CFLAGS) -o $@ $^ $(LDLIBS) + +$(TEST_MNIST): $(SRC_DIR)/network.c $(SRC_DIR)/matrix.c $(MNIST_DIR)/test_mnist.c + $(CC) $(CFLAGS) -o $@ $^ $(LDLIBS) + +$(MNIST_TRAIN): $(SRC_DIR)/network.c $(SRC_DIR)/matrix.c $(MNIST_DIR)/mnist_train.c + $(CC) $(CFLAGS) -o $@ $^ $(LDLIBS) + +test: $(TEST_MATRIX) $(TEST_NETWORK) $(TEST_XOR) $(TEST_MNIST) + +clean: + rm -rf $(BUILD_DIR) + +.PHONY: all test clean diff --git a/extension/README.md b/extension/README.md new file mode 100644 index 0000000..0f80184 --- /dev/null +++ b/extension/README.md @@ -0,0 +1,46 @@ +# Extension: Artificial Neural Networks in C + +## Folder Structure + +- src: Contains the source code for the project. Two modules: `matrix` and `network`. +- tests: Contains unit and integration tests for the project. + * `test_matrix.c`: Unit tests for the `matrix` module. + * `test_network.c`: Unit tests for the `network` module. + * `test_xor`: Integration tests for a full neural network to learn the XOR function. +- examples: An example of a non-trivial neural network that learns to recognize handwritten digits. The network is trained on the MNIST dataset, with about 100K training examples. +- client: Contains a simple interface in python to interact with the neural network. You can there draw a digit and the network will try to recognize it in real time! Give it a go! + +## Building the Project + +To build the project, run the following commands: + +```bash +make +``` + +## Running the Tests + +To run the tests, run the following commands: + +``` +./build/test_matrix +./build/test_network +./build/test_xor +./build/test_mnist +``` + +As it stands, the mnist has about 98% accuracy! +``` +==================================== +Accuracy: 0.976300 +==================================== +``` + +## Running the client + +To run the client, run the following commands: + +``` +cd client +python mnist_client.py +``` \ No newline at end of file diff --git a/extension/client/mnist_client.py b/extension/client/mnist_client.py new file mode 100644 index 0000000..9898e9d --- /dev/null +++ b/extension/client/mnist_client.py @@ -0,0 +1,92 @@ +import tkinter as tk +from PIL import Image, ImageDraw, ImageOps, ImageTk +import numpy as np +import matplotlib.pyplot as plt +import ctypes + +#### Basically im using ctypes to get the c functions to here. +mnist_lib = ctypes.CDLL('./mnist.so') +mnist_lib.load_mnist.argtypes = [] +mnist_lib.load_mnist.restype = None +mnist_lib.predict_mnist.argtypes = [ctypes.POINTER(ctypes.c_double)] +mnist_lib.predict_mnist.restype = ctypes.c_char_p + + +# Initialize the C library (load the trained network into memory) +mnist_lib.load_mnist() + +def predict(pixel_array): + # Convert to array of double pointers in c, call the c library. + c_array = (ctypes.c_double * 784)(*pixel_array) + prediction = mnist_lib.predict_mnist(c_array) + prediction_str = prediction.decode('utf-8') + return f"{prediction_str}" + +# this is triggered when anything is drawn (like on changes) +def paint(event): + global last_x, last_y + x, y = event.x, event.y + if last_x and last_y: + canvas.create_line(last_x, last_y, x, y, fill="black", width=10, capstyle=tk.ROUND, smooth=tk.TRUE, splinesteps=36) + draw.line([last_x, last_y, x, y], fill="black", width=10) + last_x, last_y = x, y + update_prediction() + +def reset(event): + global last_x, last_y + last_x, last_y = None, None + +# convrets to array of doubles. +def update_prediction(): + # gray_image = ImageOps.grayscale(image) + gray_image = image.convert('L') # I thought this might change something but doesn't seem to. + inverted_image = ImageOps.invert(gray_image) # I believe the dataset is inverted so 0 is white and 1 is black. + + # resizing and normalisation + resized_image = inverted_image.resize((28, 28), Image.LANCZOS) + pixel_array = np.array(resized_image).astype(np.float32) + pixel_array /= 255.0 + pixel_array = pixel_array.flatten() + + prediction = predict(pixel_array) + prediction_label.config(text=prediction) + +# clear the canvas +def clear_canvas(): + canvas.delete("all") + draw.rectangle([0, 0, 280, 280], fill="white") + prediction_label.config(text="") + +# shows the resized 28x28 image bellow the canvas. +def show_resized_image(): + global resized_image_label + # gray_image = ImageOps.grayscale(image) + # resized_image = gray_image.resize((28, 28), Image.Resampling.LANCZOS) + # resized_image_tk = ImageTk.PhotoImage(resized_image) + # resized_image_label.config(image=resized_image_tk) + # resized_image_label.image = resized_image_tk + +root = tk.Tk() +root.title("MNIST Digit Recognizer") + +canvas = tk.Canvas(root, width=280, height=280, bg='white') +canvas.grid(row=0, column=0, padx=10, pady=10) + +image = Image.new("RGB", (280, 280), 'white') +draw = ImageDraw.Draw(image) + +canvas.bind("", paint) +canvas.bind("", reset) + +prediction_label = tk.Label(root, text="", font=("Helvetica", 16)) +prediction_label.grid(row=0, column=1, padx=10) +# show_image_button = tk.Button(root, text="Show Resized Image", command=show_resized_image) # this is to be remvoed later +# show_image_button.grid(row=1, column=0, pady=10) +clear_button = tk.Button(root, text="Clear", command=clear_canvas) +clear_button.grid(row=1, column=0) +resized_image_label = tk.Label(root) +resized_image_label.grid(row=2, column=0, columnspan=2, padx=10, pady=10) +last_x, last_y = None, None + +# Run the Tkinter event loop +root.mainloop() diff --git a/extension/client/trained_mnist b/extension/client/trained_mnist new file mode 100644 index 0000000..13c1d93 Binary files /dev/null and b/extension/client/trained_mnist differ diff --git a/extension/examples/MNIST/data/t10k-images.idx3-ubyte b/extension/examples/MNIST/data/t10k-images.idx3-ubyte new file mode 100644 index 0000000..1170b2c Binary files /dev/null and b/extension/examples/MNIST/data/t10k-images.idx3-ubyte differ diff --git a/extension/examples/MNIST/data/t10k-labels.idx1-ubyte b/extension/examples/MNIST/data/t10k-labels.idx1-ubyte new file mode 100644 index 0000000..d1c3a97 Binary files /dev/null and b/extension/examples/MNIST/data/t10k-labels.idx1-ubyte differ diff --git a/extension/examples/MNIST/data/train-images.idx3-ubyte b/extension/examples/MNIST/data/train-images.idx3-ubyte new file mode 100644 index 0000000..bbce276 Binary files /dev/null and b/extension/examples/MNIST/data/train-images.idx3-ubyte differ diff --git a/extension/examples/MNIST/data/train-labels.idx1-ubyte b/extension/examples/MNIST/data/train-labels.idx1-ubyte new file mode 100644 index 0000000..d6b4c5d Binary files /dev/null and b/extension/examples/MNIST/data/train-labels.idx1-ubyte differ diff --git a/extension/examples/MNIST/data/trained_mnist b/extension/examples/MNIST/data/trained_mnist new file mode 100644 index 0000000..13c1d93 Binary files /dev/null and b/extension/examples/MNIST/data/trained_mnist differ diff --git a/extension/examples/MNIST/mnist_loader.h b/extension/examples/MNIST/mnist_loader.h new file mode 100644 index 0000000..c4ad139 --- /dev/null +++ b/extension/examples/MNIST/mnist_loader.h @@ -0,0 +1,196 @@ +/* +Takafumi Hoiruchi. 2018. +https://github.com/takafumihoriuchi/MNIST_for_C +*/ + +#include +#include +#include +#include +#include + +// set appropriate path for data +#define TRAIN_IMAGE "examples/MNIST/data/train-images.idx3-ubyte" +#define TRAIN_LABEL "examples/MNIST/data/train-labels.idx1-ubyte" +#define TEST_IMAGE "examples/MNIST/data/t10k-images.idx3-ubyte" +#define TEST_LABEL "examples/MNIST/data/t10k-labels.idx1-ubyte" + +#define SIZE 784 // 28*28 +#define NUM_TRAIN 60000 +#define NUM_TEST 10000 +#define LEN_INFO_IMAGE 4 +#define LEN_INFO_LABEL 2 + +#define MAX_IMAGESIZE 1280 +#define MAX_BRIGHTNESS 255 +#define MAX_FILENAME 256 +#define MAX_NUM_OF_IMAGES 1 + +unsigned char image[MAX_NUM_OF_IMAGES][MAX_IMAGESIZE][MAX_IMAGESIZE]; +int width[MAX_NUM_OF_IMAGES], height[MAX_NUM_OF_IMAGES]; + +int info_image[LEN_INFO_IMAGE]; +int info_label[LEN_INFO_LABEL]; + +unsigned char train_image_char[NUM_TRAIN][SIZE]; +unsigned char test_image_char[NUM_TEST][SIZE]; +unsigned char train_label_char[NUM_TRAIN][1]; +unsigned char test_label_char[NUM_TEST][1]; + +double train_image[NUM_TRAIN][SIZE]; +double test_image[NUM_TEST][SIZE]; +int train_label[NUM_TRAIN]; +int test_label[NUM_TEST]; + + +void FlipLong(unsigned char * ptr) +{ + register unsigned char val; + + // Swap 1st and 4th bytes + val = *(ptr); + *(ptr) = *(ptr+3); + *(ptr+3) = val; + + // Swap 2nd and 3rd bytes + ptr += 1; + val = *(ptr); + *(ptr) = *(ptr+1); + *(ptr+1) = val; +} + + +void read_mnist_char(char *file_path, int num_data, int len_info, int arr_n, unsigned char data_char[][arr_n], int info_arr[]) +{ + int i, j, k, fd; + unsigned char *ptr; + + if ((fd = open(file_path, O_RDONLY)) == -1) { + fprintf(stderr, "couldn't open image file"); + exit(-1); + } + + read(fd, info_arr, len_info * sizeof(int)); + + // read-in information about size of data + for (i=0; i