93 lines
3.3 KiB
Python
93 lines
3.3 KiB
Python
import tkinter as tk
|
|
from PIL import Image, ImageDraw, ImageOps, ImageTk
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
import ctypes
|
|
|
|
#### Basically im using ctypes to get the c functions to here.
|
|
mnist_lib = ctypes.CDLL('./mnist.so')
|
|
mnist_lib.load_mnist.argtypes = []
|
|
mnist_lib.load_mnist.restype = None
|
|
mnist_lib.predict_mnist.argtypes = [ctypes.POINTER(ctypes.c_double)]
|
|
mnist_lib.predict_mnist.restype = ctypes.c_char_p
|
|
|
|
|
|
# Initialize the C library (load the trained network into memory)
|
|
mnist_lib.load_mnist()
|
|
|
|
def predict(pixel_array):
|
|
# Convert to array of double pointers in c, call the c library.
|
|
c_array = (ctypes.c_double * 784)(*pixel_array)
|
|
prediction = mnist_lib.predict_mnist(c_array)
|
|
prediction_str = prediction.decode('utf-8')
|
|
return f"{prediction_str}"
|
|
|
|
# this is triggered when anything is drawn (like on changes)
|
|
def paint(event):
|
|
global last_x, last_y
|
|
x, y = event.x, event.y
|
|
if last_x and last_y:
|
|
canvas.create_line(last_x, last_y, x, y, fill="black", width=10, capstyle=tk.ROUND, smooth=tk.TRUE, splinesteps=36)
|
|
draw.line([last_x, last_y, x, y], fill="black", width=10)
|
|
last_x, last_y = x, y
|
|
update_prediction()
|
|
|
|
def reset(event):
|
|
global last_x, last_y
|
|
last_x, last_y = None, None
|
|
|
|
# convrets to array of doubles.
|
|
def update_prediction():
|
|
# gray_image = ImageOps.grayscale(image)
|
|
gray_image = image.convert('L') # I thought this might change something but doesn't seem to.
|
|
inverted_image = ImageOps.invert(gray_image) # I believe the dataset is inverted so 0 is white and 1 is black.
|
|
|
|
# resizing and normalisation
|
|
resized_image = inverted_image.resize((28, 28), Image.LANCZOS)
|
|
pixel_array = np.array(resized_image).astype(np.float32)
|
|
pixel_array /= 255.0
|
|
pixel_array = pixel_array.flatten()
|
|
|
|
prediction = predict(pixel_array)
|
|
prediction_label.config(text=prediction)
|
|
|
|
# clear the canvas
|
|
def clear_canvas():
|
|
canvas.delete("all")
|
|
draw.rectangle([0, 0, 280, 280], fill="white")
|
|
prediction_label.config(text="")
|
|
|
|
# shows the resized 28x28 image bellow the canvas.
|
|
def show_resized_image():
|
|
global resized_image_label
|
|
# gray_image = ImageOps.grayscale(image)
|
|
# resized_image = gray_image.resize((28, 28), Image.Resampling.LANCZOS)
|
|
# resized_image_tk = ImageTk.PhotoImage(resized_image)
|
|
# resized_image_label.config(image=resized_image_tk)
|
|
# resized_image_label.image = resized_image_tk
|
|
|
|
root = tk.Tk()
|
|
root.title("MNIST Digit Recognizer")
|
|
|
|
canvas = tk.Canvas(root, width=280, height=280, bg='white')
|
|
canvas.grid(row=0, column=0, padx=10, pady=10)
|
|
|
|
image = Image.new("RGB", (280, 280), 'white')
|
|
draw = ImageDraw.Draw(image)
|
|
|
|
canvas.bind("<B1-Motion>", paint)
|
|
canvas.bind("<ButtonRelease-1>", reset)
|
|
|
|
prediction_label = tk.Label(root, text="", font=("Helvetica", 16))
|
|
prediction_label.grid(row=0, column=1, padx=10)
|
|
# show_image_button = tk.Button(root, text="Show Resized Image", command=show_resized_image) # this is to be remvoed later
|
|
# show_image_button.grid(row=1, column=0, pady=10)
|
|
clear_button = tk.Button(root, text="Clear", command=clear_canvas)
|
|
clear_button.grid(row=1, column=0)
|
|
resized_image_label = tk.Label(root)
|
|
resized_image_label.grid(row=2, column=0, columnspan=2, padx=10, pady=10)
|
|
last_x, last_y = None, None
|
|
|
|
# Run the Tkinter event loop
|
|
root.mainloop()
|