Handwritten Equation Recognition with Transformers
Learn how to turn handwritten math into LaTeX using Transformers and the CROHME 2023 dataset.
In this tutorial, we’ll be building a transformer-based model for Handwritten Mathematical Expression Recognition (HMER). While the output of the model is LaTeX text, the input is a set of images rendered from stroke data making this an offline task. The online variation of this task would see us directly using the stroke data, which in addition to coordinates, can include useful attributes like pen pressure and timestamp data.
Although this task dates back to the 1960s, the first Competition on Recognition of Online Handwritten Mathematical Expressions (CROHME) was held at the International Conference on Document Analysis and Recognition (ICDAR) in 2011. Since then it has been run many times, most recently in 2023. We’ll be using the CROHME 2023 dataset1 to train this model.
Tokenization
Tokenization is an important first step in making the LaTeX text interpretable by a Machine Learning model. Instead of using Hugging Face Tokenizers2, a powerful library that includes tokenization algorithms such as Byte-Pair Encoding (BPE) and WordPiece, we’ll build our own basic tokeniser, LaTeXTokenizer
, from scratch so that we can peer into how tokenization, vocabulary building, and encoding/decoding work.
import collections
import re
from typing import Dict, List, Tuple, Union
class LaTeXTokenizer:
def __init__(self):
self.special_tokens = ["[PAD]", "[BOS]", "[EOS]", "[UNK]"]
self.vocab = {}
self.token_to_id = {}
self.id_to_token = {}
def tokenize(self, text: str) -> List[str]:
# Tokenize LaTeX using regex to capture commands, numbers and other characters
return re.findall(r"\\[a-zA-Z]+|\\.|[a-zA-Z0-9]|\S", text)
def build_vocab(self, texts: List[str]):
# Add special tokens to vocabulary
for token in self.special_tokens:
self.vocab[token] = len(self.vocab)
# Create a counter to hold token frequencies
counter = collections.Counter()
# Tokenize each text and update the counter
for text in texts:
tokens = self.tokenize(text)
counter.update(tokens)
# Add tokens to vocab based on their frequency
for token, _ in counter.most_common():
if token not in self.vocab:
self.vocab[token] = len(self.vocab)
# Build dictionaries for token to ID and ID to token conversion
self.token_to_id = self.vocab
self.id_to_token = {v: k for k, v in self.vocab.items()}
def encode(self, text: str) -> List[int]:
# Tokenize the input text and add start and end tokens
tokens = ["[BOS]"] + self.tokenize(text) + ["[EOS]"]
# Map tokens to their IDs, using [UNK] for unknown tokens
unk_id = self.token_to_id["[UNK]"]
return [self.token_to_id.get(token, unk_id) for token in tokens]
def decode(self, token_ids: List[int]) -> List[str]:
# Map token IDs back to tokens
tokens = [self.id_to_token.get(id, "[UNK]") for id in token_ids]
# Remove tokens beyond the [EOS] token
if "[EOS]" in tokens:
tokens = tokens[: tokens.index("[EOS]")]
# Replace [UNK] with ?
tokens = ["?" if token == "[UNK]" else token for token in tokens]
# Reconstruct the original text, ignoring special tokens
return "".join([token for token in tokens if token not in self.special_tokens])
I’d like to highlight several interesting aspects of the LaTeXTokenizer
code:
- we designate 4 special tokens that aid us in various downstream tasks: a padding token
[PAD]
, a beginning of sentence token[BOS]
, an end of sentence token[EOS]
and an unknown token[UNK]
- we use the regular expression
\\[a-zA-Z]+|\\.|[a-zA-Z0-9]|\S
to capture LaTeX commands, numbers, and other characters - decoded LaTeX strings are truncated upon encountering an
[EOS]
token - encoded tokens are prefixed by the
[BOS]
token and suffixed with the[EOS]
token - encoded tokens that are not in the vocabulary are designed
[UNK]
, meaning that they weren’t seen during training and are only seen during val/test/predict time
Let’s have a look at the LaTeXTokenizer
class in action.
Vocabulary Building
First, let’s instantiate the LaTeXTokenizer
class and build the vocabulary with two example mathematical expressions written in LaTeX: $a^2 + b^2 = c^2$ and $e^{i\pi} + 1 = 0$.
tokenizer = LaTeXTokenizer()
tokenizer.build_vocab(["a^2 + b^2 = c^2", "e^{i\\pi} + 1 = 0"])
After running this code snippet, you can inspect the vocabulary to see the tokens and their corresponding IDs.
print(tokenizer.vocab)
{'[PAD]': 0, '[BOS]': 1, '[EOS]': 2, '[UNK]': 3, '^': 4, '2': 5, '+': 6, '=': 7, 'a': 8, 'b': 9, 'c': 10, 'e': 11, '{': 12, 'i': 13, '\\pi': 14, '}': 15, '1': 16, '0': 17}
Notice that the special tokens [PAD]
, [BOS]
, [EOS]
and [UNK]
are also present in the vocabulary.
Encoding
Now let’s encode a new mathematical expression, $i^2 = -1$, into its corresponding token IDs. The encode
method will tokenize the input string and convert each token into its respective ID from the vocabulary. Unknown tokens, if any, will be mapped to [UNK]
.
ids = tokenizer.encode('i^2 = -1')
print(ids)
[1, 13, 4, 5, 7, 3, 16, 2]
Decoding
Finally, to check that the encoding makes sense, we decode the token IDs back into a LaTeX string. The decode
method will convert the token IDs back to their original LaTeX tokens, joining them into a LaTeX expression.
latex = tokenizer.decode(ids)
print(latex)
i^2=?1
Notice that the minus symbol (-) was not present in the vocabulary and so was encoded into [UNK]
, represented in the decoded string as a ?
.
CROHME Dataset
The CROHME dataset is a collection of handwritten mathematical expressions with their corresponding LaTeX annotations. In this section, we’ll be parsing the dataset’s InkML files and preparing them for model training.
InkML is a data format for representing digital ink entered with an electronic pen. It supports many attributes including writer information (like age, gender, and handedness), pen pressure, pen tilt, stroke data, and ground truth data. For our purpose, we’ll only be using the stroke data and ground truth LaTeX string, which we parse using the parse_inkml
function. Instead of rendering the stroke data and extracting the LaTeX string on-the-fly, as data gets loaded into the model, we cache them to the filesystem to speed things up.
import io
import matplotlib.pyplot as plt
import xml.etree.ElementTree as ET
from pathlib import Path
from PIL import Image
def parse_inkml(inkml_file_path, ns={"inkml": "http://www.w3.org/2003/InkML"}):
tree = ET.parse(inkml_file_path)
root = tree.getroot()
strokes = []
for trace in root.findall(".//inkml:trace", ns):
coords = trace.text.strip().split(",")
coords = [
(float(x), -float(y)) # Invert y-axis to match InkML's coordinate system
for x, y, *z in [coord.split() for coord in coords]
]
strokes.append(coords)
latex = root.find('.//inkml:annotation[@type="truth"]', ns).text.strip(" $")
return strokes, latex
def cache_data():
fig, ax = plt.subplots()
for inkml_file in data_dir.glob("*/*.inkml"):
img_file = inkml_file.with_suffix(".png")
txt_file = inkml_file.with_suffix(".txt")
strokes, latex = parse_inkml(inkml_file)
# Write LaTeX to file
with open(txt_file, "w") as f:
f.write(latex)
# Render strokes to file
ax.set_axis_off()
ax.set_aspect("equal")
for coords in strokes:
x, y = zip(*coords)
ax.plot(x, y, color="black", linewidth=2)
buf = io.BytesIO()
plt.savefig(buf, bbox_inches="tight", pad_inches=0)
plt.cla()
buf.seek(0)
img = Image.open(buf).convert("RGB")
img.save(img_file)
cache_data()
Next, we wrap CROHMEDataset
around this cached image and LaTeX data, storing the corresponding train/val splits in a CROHMEDataModule
. The collate_fn
method takes care of centering and padding images within a fixed size canvas as well as padding token sequences to the maximum length of the batch. The images are randomly distorted using RandomPerspective
to make the model more robust.
import lightning.pytorch as pl
import torch
import torch.nn.functional as F
from pathlib import Path
from PIL import Image
from torch import nn, optim, Tensor
from torch.nn.utils.rnn import pad_sequence
from torch.utils import data
from torchvision import transforms
class CROHMEDataset(data.Dataset):
def __init__(self, latex_files, tokenizer, transform):
super().__init__()
self.latex_files = list(latex_files)
self.tokenizer = tokenizer
self.transform = transform
def __len__(self):
return len(self.latex_files)
def __getitem__(self, idx):
latex_file = self.latex_files[idx]
image_file = latex_file.with_suffix(".png")
with open(latex_file) as f:
latex = f.read()
x = self.transform(Image.open(image_file))
y = Tensor(self.tokenizer.encode(latex))
return x, y
class CROHMEDataModule(pl.LightningDataModule):
def __init__(
self,
data_dir: str,
batch_size: int = 1,
num_workers: int = 0,
pin_memory: bool = False,
):
super().__init__()
self.save_hyperparameters()
self.data_dir = Path(data_dir)
self.transform = transforms.Compose(
[
transforms.RandomPerspective(distortion_scale=0.1, p=0.5, fill=255),
transforms.ToTensor(),
]
)
def setup(self, stage):
latexes = []
for latex_file in self.data_dir.glob("train/*.txt"):
with open(latex_file) as f:
latexes.append(f.read())
self.tokenizer = LaTeXTokenizer()
self.tokenizer.build_vocab(latexes)
self.vocab_size = len(self.tokenizer.vocab)
if stage == "fit" or stage is None:
self.train_dataset = CROHMEDataset(
self.data_dir.glob("train/*.txt"), self.tokenizer, self.transform
)
self.val_dataset = CROHMEDataset(
self.data_dir.glob("val/*.txt"), self.tokenizer, self.transform
)
def collate_fn(self, batch, max_width: int = 512, max_height: int = 384):
images, labels = zip(*batch)
# Create a white background for each image in the batch
src = torch.ones((len(images), 3, max_height, max_width))
# Center and pad individual images to fit into the white background
for i, img in enumerate(images):
height_start = (max_height - img.size(1)) // 2
height_end = height_start + img.size(1)
width_start = (max_width - img.size(2)) // 2
width_end = width_start + img.size(2)
src[i, :, height_start:height_end, width_start:width_end] = img
# Pad sequences for labels and create attention mask
tgt = pad_sequence(labels, batch_first=True).long()
seq_len = tgt.size(1)
tgt_mask = torch.triu(torch.ones(seq_len, seq_len) * float("-inf"), diagonal=1)
return src, tgt, tgt_mask
def train_dataloader(self):
return data.DataLoader(
self.train_dataset,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
collate_fn=self.collate_fn,
pin_memory=self.hparams.pin_memory,
)
def val_dataloader(self):
return data.DataLoader(
self.val_dataset,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
collate_fn=self.collate_fn,
pin_memory=self.hparams.pin_memory,
)
Visualising a Batch of Data
Let’s inspect a batch of data to see what’s being fed to the model.
datamodule = CROHMEDataModule("../data/CROHME/", batch_size=16)
datamodule.setup(stage="fit")
tokenizer = datamodule.tokenizer
train_dataloader = datamodule.train_dataloader()
batch = next(iter(train_dataloader))
src, tgt, tgt_mask = batch
Here,
src
contains the images of the handwritten mathematical expressions;tgt
contains the token IDs representing the LaTeX expressions; andtgt_mask
is the attention mask, which we’ll explain later on.
Handwritten Image
from torchvision.utils import make_grid
plt.imshow(make_grid(src, nrow=4).permute(1, 2, 0))
plt.axis("off")
(-0.5, 2057.5, 1545.5, -0.5)
LaTeX String
for t in tgt:
print(tokenizer.decode(t.tolist()))
{\mu}_{o}
\frac{\sqrt{99{x^{7}}}}{11{x^{3}}}
x^2+3x
5\pm(137-194+49)\times36
uu_{x}+u_{y}+u_{t}=y
\sum_{n=1}^{\infty}{(\frac{\sum_{i=1}^{n}a_{i}}{n})^{p}}\lt{(\frac{p}{p-1})^{p}}\sum_{n=1}^{\infty}{a_{n}^{p}}
y=aS(t)=a\int_0^{t}\sin(\frac{1}{2}\pis^2)ds
\left(1.8\right)
a+\frac{\sqrt{b+c}}{2}
\frac{\int\sqrt{1+{y^{'}(t)^{2}}}dt}{\int\sqrt{{x^{'}(t)^{2}}+{y^{'}(t)^{2}}}dt}
{{T+\sin{a}^{M}}\leq4.45}
\frac{5}{x+1}+\frac{5}{{x^{2}}+x}
x_x^x+y_y^y+z_z^z-x-y-z
14\times87\neq-196
\mbox{d}
\sum_{i=1}^{n}i=\frac{1}{2}n(n+1)
Positional Encoding for Attention Models
In models like Transformers, positional information is not naturally captured by the self-attention mechanism. For this reason, we use a technique called positional encoding to give the model information about the relative positions of the tokens. Ensuring that the model can take into account the order of the tokens allows it to significantly improve its performance on sequence-based tasks.
We’ll be using two types of positional encodings: 1D and 2D positional encodings, which are essential for 1D sequences like text and 2D structures like images, respectively.
-
In the 1D case, each position in the sequence is encoded into a high-dimensional vector. We use sine and cosine functions of different frequencies to encode the positional information.
-
When dealing with images or any 2D data, a 2D positional encoding is more appropriate. Here, we use a similar technique but extend it to 2D. We calculate 1D positional encoding first and then use the outer product to construct the 2D positional encoding. This way, each position in a 2D structure (like an image) gets a unique encoding based on its row and column.
import math
import matplotlib.pyplot as plt
class PositionalEncoding1D(nn.Module):
def __init__(
self,
d_model: int,
dropout: float = 0.1,
max_len: int = 1000,
temperature: float = 10000.0,
):
super().__init__()
# Generate position and dimension tensors for encoding
position = torch.arange(max_len).unsqueeze(1)
dim_t = torch.arange(0, d_model, 2)
div_term = torch.exp(dim_t * (-math.log(temperature) / d_model))
# Initialize and fill the positional encoding matrix with sine/cosine values
pe = torch.zeros(max_len, d_model)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.dropout = nn.Dropout(dropout)
self.register_buffer("pe", pe)
def forward(self, x):
batch, sequence_length, d_model = x.shape
return self.dropout(x + self.pe[None, :sequence_length, :])
class PositionalEncoding2D(nn.Module):
def __init__(
self,
d_model: int,
dropout: float = 0.1,
max_len: int = 30,
temperature: float = 10000.0,
):
super().__init__()
# Generate position and dimension tensors for 1D encoding
position = torch.arange(max_len).unsqueeze(1)
dim_t = torch.arange(0, d_model, 2)
div_term = torch.exp(dim_t * (-math.log(temperature) / d_model))
# Initialize and fill the 1D positional encoding matrix with sine/cosine values
pe_1D = torch.zeros(max_len, d_model)
pe_1D[:, 0::2] = torch.sin(position * div_term)
pe_1D[:, 1::2] = torch.cos(position * div_term)
# Compute the 2D positional encoding matrix using outer product
pe_2D = torch.zeros(max_len, max_len, d_model)
for i in range(d_model):
pe_2D[:, :, i] = pe_1D[:, i].unsqueeze(-1) + pe_1D[:, i].unsqueeze(0)
self.dropout = nn.Dropout(dropout)
self.register_buffer("pe", pe_2D)
def forward(self, x):
batch, height, width, d_model = x.shape
return self.dropout(x + self.pe[None, :height, :width, :])
Visualising Positional Encodings with Python
Understanding positional encodings can become much clearer when you visualise them.
1D Positional Encoding
First, we will create a 1D positional encoding with a d_model
of 256 and a maximum length of 100. We will then use matshow
from Matplotlib to display it as a heatmap that reflects the values of the positional encoding. The x-axis represents the dimension (d_model
) and the y-axis represents the sequence length (up to max_len
).
2D Positional Encoding
For the 2D positional encoding, we use a more complex visualisation. We create a 2D positional encoding with d_model
set to 32 and max_len
to 64. Each subplot below show different layers (or dimensions) of the 2D positional encoding, allowing you to see how each dimension encodes 2D positional information differently.