ARMv8/extension/client/mnist_client.py

93 lines
3.3 KiB
Python

import tkinter as tk
from PIL import Image, ImageDraw, ImageOps, ImageTk
import numpy as np
import matplotlib.pyplot as plt
import ctypes
#### Basically im using ctypes to get the c functions to here.
mnist_lib = ctypes.CDLL('./mnist.so')
mnist_lib.load_mnist.argtypes = []
mnist_lib.load_mnist.restype = None
mnist_lib.predict_mnist.argtypes = [ctypes.POINTER(ctypes.c_double)]
mnist_lib.predict_mnist.restype = ctypes.c_char_p
# Initialize the C library (load the trained network into memory)
mnist_lib.load_mnist()
def predict(pixel_array):
# Convert to array of double pointers in c, call the c library.
c_array = (ctypes.c_double * 784)(*pixel_array)
prediction = mnist_lib.predict_mnist(c_array)
prediction_str = prediction.decode('utf-8')
return f"{prediction_str}"
# this is triggered when anything is drawn (like on changes)
def paint(event):
global last_x, last_y
x, y = event.x, event.y
if last_x and last_y:
canvas.create_line(last_x, last_y, x, y, fill="black", width=10, capstyle=tk.ROUND, smooth=tk.TRUE, splinesteps=36)
draw.line([last_x, last_y, x, y], fill="black", width=10)
last_x, last_y = x, y
update_prediction()
def reset(event):
global last_x, last_y
last_x, last_y = None, None
# convrets to array of doubles.
def update_prediction():
# gray_image = ImageOps.grayscale(image)
gray_image = image.convert('L') # I thought this might change something but doesn't seem to.
inverted_image = ImageOps.invert(gray_image) # I believe the dataset is inverted so 0 is white and 1 is black.
# resizing and normalisation
resized_image = inverted_image.resize((28, 28), Image.LANCZOS)
pixel_array = np.array(resized_image).astype(np.float32)
pixel_array /= 255.0
pixel_array = pixel_array.flatten()
prediction = predict(pixel_array)
prediction_label.config(text=prediction)
# clear the canvas
def clear_canvas():
canvas.delete("all")
draw.rectangle([0, 0, 280, 280], fill="white")
prediction_label.config(text="")
# shows the resized 28x28 image bellow the canvas.
def show_resized_image():
global resized_image_label
# gray_image = ImageOps.grayscale(image)
# resized_image = gray_image.resize((28, 28), Image.Resampling.LANCZOS)
# resized_image_tk = ImageTk.PhotoImage(resized_image)
# resized_image_label.config(image=resized_image_tk)
# resized_image_label.image = resized_image_tk
root = tk.Tk()
root.title("MNIST Digit Recognizer")
canvas = tk.Canvas(root, width=280, height=280, bg='white')
canvas.grid(row=0, column=0, padx=10, pady=10)
image = Image.new("RGB", (280, 280), 'white')
draw = ImageDraw.Draw(image)
canvas.bind("<B1-Motion>", paint)
canvas.bind("<ButtonRelease-1>", reset)
prediction_label = tk.Label(root, text="", font=("Helvetica", 16))
prediction_label.grid(row=0, column=1, padx=10)
# show_image_button = tk.Button(root, text="Show Resized Image", command=show_resized_image) # this is to be remvoed later
# show_image_button.grid(row=1, column=0, pady=10)
clear_button = tk.Button(root, text="Clear", command=clear_canvas)
clear_button.grid(row=1, column=0)
resized_image_label = tk.Label(root)
resized_image_label.grid(row=2, column=0, columnspan=2, padx=10, pady=10)
last_x, last_y = None, None
# Run the Tkinter event loop
root.mainloop()