User-Drawn Sketches
This Sketch Prediciton Neural Network takes in a user-inputted sketch drawn in a sketchpad gui, and tries to recognize it. We used a Convolutional Neural Network to extract features from the images, and PyTorch to train the model. PyGame was the platform that created the pipeline for the drawn images.
Sample Images
The dataset utilized for this project is called the Sketch Dataset.It contains 20,000 images spanning across 250 classes. Each class has 80 images assigned to it, making this dataset perfectly balanced. Each image is 1111x1111 pixels, all in grayscale (black and white).
Model Architecture
The final architecture contained three Convolutional Layers, where each layer had a 5x5 kernel window, 2x2 stride, and a padding of 2. Between each Convolutional layer, batch normalization was performed to act as a regularizer. Each of these layers, including the output were also pooled to reduce dimesionality. The output of these layers were flattened and fed into the fully connected layer, which contained 250 output neurons, (The same size as our number of classes). A dropout of 50% was used here to reduce overfitting and provide generalization.
The training process began by resizing the images to 120x120, to make them easier to work with. Along with this transformation, we also converted the images to tensors so they could be valid inputs into the neural network. Our training/validation split was 85% and 15% respectively, while our test images were coming from the user input. From here, the training images are passed through the network, appending the loss to a txt file each time for debugging purposes. The losses are also appended to a list for easy access to evaluation metrics. The same is done for the validation images.
The results above boast a 55% accuracy and a 55% precision on the validation set. The user inputter image is then given to the network at the end of the evaluation and the top K classes are predicted for the drawn image. Our team is proud of the results, considering the sparsity of the dataset, being only 80 images per class. As you can see in the results above, there was overfitting in our model which we couldn't break past. We attempted data augmentation and several combinations of hyperparameters.