Handwritten Equation Recognition with Transformers

Learn how to turn handwritten math into LaTeX using Transformers and the CROHME 2023 dataset.

Abstract artwork

In this tutorial, we’ll be building a transformer-based model for Handwritten Mathematical Expression Recognition (HMER). While the output of the model is LaTeX text, the input is a set of images rendered from stroke data making this an offline task. The online variation of this task would see us directly using the stroke data, which in addition to coordinates, can include useful attributes like pen pressure and timestamp data.

Although this task dates back to the 1960s, the first Competition on Recognition of Online Handwritten Mathematical Expressions (CROHME) was held at the International Conference on Document Analysis and Recognition (ICDAR) in 2011. Since then it has been run many times, most recently in 2023. We’ll be using the CROHME 2023 dataset1 to train this model.

Tokenization

Tokenization is an important first step in making the LaTeX text interpretable by a Machine Learning model. Instead of using Hugging Face Tokenizers2, a powerful library that includes tokenization algorithms such as Byte-Pair Encoding (BPE) and WordPiece, we’ll build our own basic tokeniser, LaTeXTokenizer, from scratch so that we can peer into how tokenization, vocabulary building, and encoding/decoding work.

import collections
import re

from typing import Dict, List, Tuple, Union


class LaTeXTokenizer:
    def __init__(self):
        self.special_tokens = ["[PAD]", "[BOS]", "[EOS]", "[UNK]"]
        self.vocab = {}
        self.token_to_id = {}
        self.id_to_token = {}

    def tokenize(self, text: str) -> List[str]:
        # Tokenize LaTeX using regex to capture commands, numbers and other characters
        return re.findall(r"\\[a-zA-Z]+|\\.|[a-zA-Z0-9]|\S", text)

    def build_vocab(self, texts: List[str]):
        # Add special tokens to vocabulary
        for token in self.special_tokens:
            self.vocab[token] = len(self.vocab)

        # Create a counter to hold token frequencies
        counter = collections.Counter()

        # Tokenize each text and update the counter
        for text in texts:
            tokens = self.tokenize(text)
            counter.update(tokens)

        # Add tokens to vocab based on their frequency
        for token, _ in counter.most_common():
            if token not in self.vocab:
                self.vocab[token] = len(self.vocab)

        # Build dictionaries for token to ID and ID to token conversion
        self.token_to_id = self.vocab
        self.id_to_token = {v: k for k, v in self.vocab.items()}

    def encode(self, text: str) -> List[int]:
        # Tokenize the input text and add start and end tokens
        tokens = ["[BOS]"] + self.tokenize(text) + ["[EOS]"]

        # Map tokens to their IDs, using [UNK] for unknown tokens
        unk_id = self.token_to_id["[UNK]"]
        return [self.token_to_id.get(token, unk_id) for token in tokens]

    def decode(self, token_ids: List[int]) -> List[str]:
        # Map token IDs back to tokens
        tokens = [self.id_to_token.get(id, "[UNK]") for id in token_ids]

        # Remove tokens beyond the [EOS] token
        if "[EOS]" in tokens:
            tokens = tokens[: tokens.index("[EOS]")]

        # Replace [UNK] with ?
        tokens = ["?" if token == "[UNK]" else token for token in tokens]

        # Reconstruct the original text, ignoring special tokens
        return "".join([token for token in tokens if token not in self.special_tokens])

I’d like to highlight several interesting aspects of the LaTeXTokenizer code:

  • we designate 4 special tokens that aid us in various downstream tasks: a padding token [PAD], a beginning of sentence token [BOS], an end of sentence token [EOS] and an unknown token [UNK]
  • we use the regular expression \\[a-zA-Z]+|\\.|[a-zA-Z0-9]|\S to capture LaTeX commands, numbers, and other characters
  • decoded LaTeX strings are truncated upon encountering an [EOS] token
  • encoded tokens are prefixed by the [BOS] token and suffixed with the [EOS] token
  • encoded tokens that are not in the vocabulary are designed [UNK], meaning that they weren’t seen during training and are only seen during val/test/predict time

Let’s have a look at the LaTeXTokenizer class in action.

Vocabulary Building

First, let’s instantiate the LaTeXTokenizer class and build the vocabulary with two example mathematical expressions written in LaTeX: $a^2 + b^2 = c^2$ and $e^{i\pi} + 1 = 0$.

tokenizer = LaTeXTokenizer()
tokenizer.build_vocab(["a^2 + b^2 = c^2", "e^{i\\pi} + 1 = 0"])

After running this code snippet, you can inspect the vocabulary to see the tokens and their corresponding IDs.

print(tokenizer.vocab)
{'[PAD]': 0, '[BOS]': 1, '[EOS]': 2, '[UNK]': 3, '^': 4, '2': 5, '+': 6, '=': 7, 'a': 8, 'b': 9, 'c': 10, 'e': 11, '{': 12, 'i': 13, '\\pi': 14, '}': 15, '1': 16, '0': 17}

Notice that the special tokens [PAD], [BOS], [EOS] and [UNK] are also present in the vocabulary.

Encoding

Now let’s encode a new mathematical expression, $i^2 = -1$, into its corresponding token IDs. The encode method will tokenize the input string and convert each token into its respective ID from the vocabulary. Unknown tokens, if any, will be mapped to [UNK].

ids = tokenizer.encode('i^2 = -1')
print(ids)
[1, 13, 4, 5, 7, 3, 16, 2]

Decoding

Finally, to check that the encoding makes sense, we decode the token IDs back into a LaTeX string. The decode method will convert the token IDs back to their original LaTeX tokens, joining them into a LaTeX expression.

latex = tokenizer.decode(ids)
print(latex)
i^2=?1

Notice that the minus symbol (-) was not present in the vocabulary and so was encoded into [UNK], represented in the decoded string as a ?.

CROHME Dataset

The CROHME dataset is a collection of handwritten mathematical expressions with their corresponding LaTeX annotations. In this section, we’ll be parsing the dataset’s InkML files and preparing them for model training.

InkML is a data format for representing digital ink entered with an electronic pen. It supports many attributes including writer information (like age, gender, and handedness), pen pressure, pen tilt, stroke data, and ground truth data. For our purpose, we’ll only be using the stroke data and ground truth LaTeX string, which we parse using the parse_inkml function. Instead of rendering the stroke data and extracting the LaTeX string on-the-fly, as data gets loaded into the model, we cache them to the filesystem to speed things up.

import io
import matplotlib.pyplot as plt
import xml.etree.ElementTree as ET

from pathlib import Path
from PIL import Image


def parse_inkml(inkml_file_path, ns={"inkml": "http://www.w3.org/2003/InkML"}):
    tree = ET.parse(inkml_file_path)
    root = tree.getroot()

    strokes = []
    for trace in root.findall(".//inkml:trace", ns):
        coords = trace.text.strip().split(",")
        coords = [
            (float(x), -float(y))  # Invert y-axis to match InkML's coordinate system
            for x, y, *z in [coord.split() for coord in coords]
        ]
        strokes.append(coords)

    latex = root.find('.//inkml:annotation[@type="truth"]', ns).text.strip(" $")

    return strokes, latex


def cache_data():
    fig, ax = plt.subplots()

    for inkml_file in data_dir.glob("*/*.inkml"):
        img_file = inkml_file.with_suffix(".png")
        txt_file = inkml_file.with_suffix(".txt")

        strokes, latex = parse_inkml(inkml_file)

        # Write LaTeX to file
        with open(txt_file, "w") as f:
            f.write(latex)

        # Render strokes to file
        ax.set_axis_off()
        ax.set_aspect("equal")
        for coords in strokes:
            x, y = zip(*coords)
            ax.plot(x, y, color="black", linewidth=2)
        buf = io.BytesIO()
        plt.savefig(buf, bbox_inches="tight", pad_inches=0)
        plt.cla()
        buf.seek(0)
        img = Image.open(buf).convert("RGB")
        img.save(img_file)


cache_data()

Next, we wrap CROHMEDataset around this cached image and LaTeX data, storing the corresponding train/val splits in a CROHMEDataModule. The collate_fn method takes care of centering and padding images within a fixed size canvas as well as padding token sequences to the maximum length of the batch. The images are randomly distorted using RandomPerspective to make the model more robust.

import lightning.pytorch as pl
import torch
import torch.nn.functional as F

from pathlib import Path
from PIL import Image
from torch import nn, optim, Tensor
from torch.nn.utils.rnn import pad_sequence
from torch.utils import data
from torchvision import transforms


class CROHMEDataset(data.Dataset):
    def __init__(self, latex_files, tokenizer, transform):
        super().__init__()
        self.latex_files = list(latex_files)
        self.tokenizer = tokenizer
        self.transform = transform

    def __len__(self):
        return len(self.latex_files)

    def __getitem__(self, idx):
        latex_file = self.latex_files[idx]
        image_file = latex_file.with_suffix(".png")

        with open(latex_file) as f:
            latex = f.read()

        x = self.transform(Image.open(image_file))
        y = Tensor(self.tokenizer.encode(latex))
        return x, y


class CROHMEDataModule(pl.LightningDataModule):
    def __init__(
        self,
        data_dir: str,
        batch_size: int = 1,
        num_workers: int = 0,
        pin_memory: bool = False,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.data_dir = Path(data_dir)
        self.transform = transforms.Compose(
            [
                transforms.RandomPerspective(distortion_scale=0.1, p=0.5, fill=255),
                transforms.ToTensor(),
            ]
        )

    def setup(self, stage):
        latexes = []
        for latex_file in self.data_dir.glob("train/*.txt"):
            with open(latex_file) as f:
                latexes.append(f.read())

        self.tokenizer = LaTeXTokenizer()
        self.tokenizer.build_vocab(latexes)
        self.vocab_size = len(self.tokenizer.vocab)

        if stage == "fit" or stage is None:
            self.train_dataset = CROHMEDataset(
                self.data_dir.glob("train/*.txt"), self.tokenizer, self.transform
            )
            self.val_dataset = CROHMEDataset(
                self.data_dir.glob("val/*.txt"), self.tokenizer, self.transform
            )

    def collate_fn(self, batch, max_width: int = 512, max_height: int = 384):
        images, labels = zip(*batch)

        # Create a white background for each image in the batch
        src = torch.ones((len(images), 3, max_height, max_width))

        # Center and pad individual images to fit into the white background
        for i, img in enumerate(images):
            height_start = (max_height - img.size(1)) // 2
            height_end = height_start + img.size(1)
            width_start = (max_width - img.size(2)) // 2
            width_end = width_start + img.size(2)
            src[i, :, height_start:height_end, width_start:width_end] = img

        # Pad sequences for labels and create attention mask
        tgt = pad_sequence(labels, batch_first=True).long()
        seq_len = tgt.size(1)
        tgt_mask = torch.triu(torch.ones(seq_len, seq_len) * float("-inf"), diagonal=1)

        return src, tgt, tgt_mask

    def train_dataloader(self):
        return data.DataLoader(
            self.train_dataset,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            collate_fn=self.collate_fn,
            pin_memory=self.hparams.pin_memory,
        )

    def val_dataloader(self):
        return data.DataLoader(
            self.val_dataset,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            collate_fn=self.collate_fn,
            pin_memory=self.hparams.pin_memory,
        )

Visualising a Batch of Data

Let’s inspect a batch of data to see what’s being fed to the model.

datamodule = CROHMEDataModule("../data/CROHME/", batch_size=16)
datamodule.setup(stage="fit")

tokenizer = datamodule.tokenizer

train_dataloader = datamodule.train_dataloader()
batch = next(iter(train_dataloader))
src, tgt, tgt_mask = batch

Here,

  • src contains the images of the handwritten mathematical expressions;
  • tgt contains the token IDs representing the LaTeX expressions; and
  • tgt_mask is the attention mask, which we’ll explain later on.

Handwritten Image

from torchvision.utils import make_grid

plt.imshow(make_grid(src, nrow=4).permute(1, 2, 0))
plt.axis("off")
(-0.5, 2057.5, 1545.5, -0.5)

png

LaTeX String

for t in tgt:
    print(tokenizer.decode(t.tolist()))
{\mu}_{o}
\frac{\sqrt{99{x^{7}}}}{11{x^{3}}}
x^2+3x
5\pm(137-194+49)\times36
uu_{x}+u_{y}+u_{t}=y
\sum_{n=1}^{\infty}{(\frac{\sum_{i=1}^{n}a_{i}}{n})^{p}}\lt{(\frac{p}{p-1})^{p}}\sum_{n=1}^{\infty}{a_{n}^{p}}
y=aS(t)=a\int_0^{t}\sin(\frac{1}{2}\pis^2)ds
\left(1.8\right)
a+\frac{\sqrt{b+c}}{2}
\frac{\int\sqrt{1+{y^{'}(t)^{2}}}dt}{\int\sqrt{{x^{'}(t)^{2}}+{y^{'}(t)^{2}}}dt}
{{T+\sin{a}^{M}}\leq4.45}
\frac{5}{x+1}+\frac{5}{{x^{2}}+x}
x_x^x+y_y^y+z_z^z-x-y-z
14\times87\neq-196
\mbox{d}
\sum_{i=1}^{n}i=\frac{1}{2}n(n+1)

Positional Encoding for Attention Models

In models like Transformers, positional information is not naturally captured by the self-attention mechanism. For this reason, we use a technique called positional encoding to give the model information about the relative positions of the tokens. Ensuring that the model can take into account the order of the tokens allows it to significantly improve its performance on sequence-based tasks.

We’ll be using two types of positional encodings: 1D and 2D positional encodings, which are essential for 1D sequences like text and 2D structures like images, respectively.

  1. In the 1D case, each position in the sequence is encoded into a high-dimensional vector. We use sine and cosine functions of different frequencies to encode the positional information.

  2. When dealing with images or any 2D data, a 2D positional encoding is more appropriate. Here, we use a similar technique but extend it to 2D. We calculate 1D positional encoding first and then use the outer product to construct the 2D positional encoding. This way, each position in a 2D structure (like an image) gets a unique encoding based on its row and column.

import math
import matplotlib.pyplot as plt


class PositionalEncoding1D(nn.Module):
    def __init__(
        self,
        d_model: int,
        dropout: float = 0.1,
        max_len: int = 1000,
        temperature: float = 10000.0,
    ):
        super().__init__()

        # Generate position and dimension tensors for encoding
        position = torch.arange(max_len).unsqueeze(1)
        dim_t = torch.arange(0, d_model, 2)
        div_term = torch.exp(dim_t * (-math.log(temperature) / d_model))

        # Initialize and fill the positional encoding matrix with sine/cosine values
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer("pe", pe)

    def forward(self, x):
        batch, sequence_length, d_model = x.shape
        return self.dropout(x + self.pe[None, :sequence_length, :])


class PositionalEncoding2D(nn.Module):
    def __init__(
        self,
        d_model: int,
        dropout: float = 0.1,
        max_len: int = 30,
        temperature: float = 10000.0,
    ):
        super().__init__()

        # Generate position and dimension tensors for 1D encoding
        position = torch.arange(max_len).unsqueeze(1)
        dim_t = torch.arange(0, d_model, 2)
        div_term = torch.exp(dim_t * (-math.log(temperature) / d_model))

        # Initialize and fill the 1D positional encoding matrix with sine/cosine values
        pe_1D = torch.zeros(max_len, d_model)
        pe_1D[:, 0::2] = torch.sin(position * div_term)
        pe_1D[:, 1::2] = torch.cos(position * div_term)

        # Compute the 2D positional encoding matrix using outer product
        pe_2D = torch.zeros(max_len, max_len, d_model)
        for i in range(d_model):
            pe_2D[:, :, i] = pe_1D[:, i].unsqueeze(-1) + pe_1D[:, i].unsqueeze(0)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer("pe", pe_2D)

    def forward(self, x):
        batch, height, width, d_model = x.shape
        return self.dropout(x + self.pe[None, :height, :width, :])

Visualising Positional Encodings with Python

Understanding positional encodings can become much clearer when you visualise them.

1D Positional Encoding

First, we will create a 1D positional encoding with a d_model of 256 and a maximum length of 100. We will then use matshow from Matplotlib to display it as a heatmap that reflects the values of the positional encoding. The x-axis represents the dimension (d_model) and the y-axis represents the sequence length (up to max_len).

2D Positional Encoding

For the 2D positional encoding, we use a more complex visualisation. We create a 2D positional encoding with d_model set to 32 and max_len to 64. Each subplot below show different layers (or dimensions) of the 2D positional encoding, allowing you to see how each dimension encodes 2D positional information differently.