Training the Symbol Recognizer
Time to train a model! In this post, we will do a bunch of work preparing data, and then we’ll train a deep neural network to classify symbols.
As a recap, right now we have a backend talking to a frontend Overleaf extension:
I want to explain one of my design decisions right out of the gate. Note that the digital symbol drawings can be represented either as lists of points in mouse strokes, or as images.
The dataset I’ll be using uses the stroke information, but I will be transforming it into image representation. Although this loses information about the order in which strokes were drawn, I expect this may be more robust, as different people might draw symbols in lots of different ways.
Also, ONNX, which can be used to deploy models in the browser, didn’t properly support LSTM models, which I would have used for the stroke representation, when I started this project. That’s the real reason I’m using images.
The dataset I’m using is available here. It’s the dataset from detexify, the project I’m emulating / improving upon. The author decided to use a database to store the data, however I’m more comfortable with numpy, so my first order of action was to export this database to numpy.
Note that I already have Postgres installed so that I can work with the SQL database easily; if you don’t, you would have to setup Postgres or another SQL databse first.
After downloading the database, I ran the following command in my linux terminal to save the databse:
What this does is run the detexify.sql file as the postgres user.
Next, I wrote a script read_from_database.py, which I’ll now go through:
Tqdm is for progress pars, and psycopg2 is for reading databases.
Next, I connect to the database, and get the number of datapoints:
Then, I prepare to read 5,000 datapoints at a time out of the database:
After this, I actually go through all of the data and store it in a large list. Note that I have a lof of assertions here, just to make sure I understand the format of the data. For each point in each stroke, I store the x coordinate, y coordinate, and then a 1 or a 0 depending on whether the user ended the stroke at the given point.
Lastly, I save and close everything:
At this point, I had a dataX.npy and a dataY.npy to load data from, rather than the database.
As I’m viewing this as an image classification problem, the next step was to load these points and convert them into images. In retrospect, I could have combined this step with the previous step. However, initially I wasn’t sure I wanted to convert from points to images, so these are separate steps.
I created a dataset class to hold the strokes, as initially I wasn’t going to use images. Here’s the constructor, which just normalizes each drawn symbol so that it’s minimum and maximum coordinates are 0 and 1 respectively:
Note that I actually use some padding when normalizing strokes; this is because otherwise, points on the edge of the images wouldn’t show.
The next two functions for the dataset class are the __len__ and __getitem__ functions. The former tells whateve uses the dataset how long it is, and the latter returns the “ith” piece of data. Here they are:
The other important function for rendering the images was create_image. When I loop through the dataset, I call this function to create the images before I save them to a directory. Here it is:
Here I’m looping through the strokes, and drawing them to the image. I also have to manually draw the last point in the image, otherwise it won’t always appear.
Another important bit here is that I had to split up the images into three sets: the train set to learn the model, the validation set to tune the model, and the test set (which I never really used). I won’t show this script, as it’s not super involved.
However, there was one sticking point on account of how image datasets in PyTorch works. Note that the PyTorch image dataset class first orders the classes by name, and then assigns a number to each class. Here’s how that might play out if we have a training dataset, validation dataset, and test dataset:
Note that since the test dataset doesn’t have the “Equals” class, it assigns Nabla the number 3, instead of 4. This will cause all sorts of pain later on. For this reason, I had to ensure that each symbol was in all three sets. I found this out the hard way, and ended up having to read through the source code of PyTorch’s image dataset.
However, after I’d fixed this issue, one folder of images might look like so:
At this point, I was ready to train.
I used Google Colab to train the model in less than 20 minutes. I pay for Colab, but it’s also available free. There should be a public Colab recitation available at this course website if you would like to learn how to use it.
However, the model I’m using is not super big - by making it a little smaller, I suspect it will still train in a matter of hours on a normal computer.
Now I’ll go over some of the model training code. Note that I don’t mean for this to be a replacement for a deep-learning course, or the excellent pytorch tutorials. Thus, I will skip some things; eventually the whole file should be available publically here.
First, some imports:
Next, I set up some hyperparameters:
I run through the training data 20 times (epochs), and if I’m on a GPU to run the code faster, I put 512 images through the model at a time, versus two.
For the loss function that determines what needs to happen to improve the model, I uses cross-entropy loss, which is common for classification:
Then, I create the datasets, and then the dataloaders, which handle packaging data from the datasets so that the model can use it:
Then, I create the model, and an optimizer to update the model. For the optimizer I use weight decay so that the model doesn’t memorize the training data. Lastly, I run the training routine:
Next, I show the actual model architecture. Basic convolutional models often have a convolutional layer, a ReLU layer, and a batch normalization layer. They then repeat these layers. Next, they often use an averaging layer so that the intermediate width and height are set to 1, and lastly a linear layer with as many outputs as classification classes. The following model follows that structure; I didn’t tune it much, it just ends up working well enough.
For the function train_model, I train for an epoch, and then after each epoch, I save the model if the top-5 accuracy has improved:
Top-5 accuracy is similar to accuracy, but if the correct answer is any of the top 5 predictions made by the model, the model output is deemed as correct. I use this metric since the extension will be showing the user the top 5 predictions (or possibly more).
Next is the code that actually trains the model:
Looping through the the training dataloader, first I clear the model gradients (which tell us how to best update the model) with optimizer.zero_grad(). Then I put the images (here, “x”) through the model, and use the criterion to calculate how badly the model did. The loss’s backward() method then repopulates the gradients of the model, so we know how to update it. The optimizer’s step() method then updates the model. The rest of the function updates the progress bar so we know how the training is going.
This models ends up getting almost 99% top-5 accuracy on the validation set, which seems good enough for me. In other words, 99% of the time, if the user draws their symbols similar to those in our dataset, we expect to correctly guess the user’s latex symbol.
That’s pretty much it! Again, if you’re interested in training image models, the pytorch tutorials are great.
In this post, we prepared the data as images, split it up, and then trained a model. For me, the most interesting bit was preparing the data; the actual training is mostly copy-pastable image-classification tutorials.
Next time, I’ll modify the server to use our model, and then talk through hosting the server on Heroku. Until then!
Get occasional project updates!
Get occasional project updates!