Solving CAPTCHA

Reading time: 15 minutes

This post introduces sequence to sequence prediction. In previous posts, our models have only made single predictions in the form of a class (for classification) or a quantity (for regression). By adopting a Long Short-Term Memory (LSTM) architecture for the Neural Network and Connectionist Temporal Classification (CTC) loss for the loss function, we can read in a sequence and predict a sequence.

More specifically, we'll take a CAPTCHA image as input and predict the numbers in that image as output. CAPTCHAs (short for Completely Automated Public Turing test to tell Computers and Humans Apart) are used by websites to determine whether a user is human or not, most often to prevent spam bots from submitting fake forms online. Google has famously used it's own version of CAPTCHA (called reCATPCHA) to digitise the archives of The New York Times and books from its own Google Books collection.

In this post, you'll learn how to:

  • generate CAPTCHAs;
  • build a Long Short Term Memory (LSTM) neural network;
  • train a network using Connectionist Temporal Classification (CTC) loss.

By the end, you'll learn to decode the numbers in:

Out[128]:
<matplotlib.image.AxesImage at 0x7fbaec7c9898>

Import Packages

In [5]:
from itertools import groupby

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.utils.data.dataset import random_split

import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

Generate CAPTCHAs

Given a string, the Python package captcha can generate a CAPTCHA image with that string. Here, we generate CAPTCHA images for all 4 digit numbers, saving the number in the file name:

In [3]:
from captcha.image import ImageCaptcha

image = ImageCaptcha()

for chars in range(0, 10000):
    image.write(f'{chars:>04}', f'{chars:>04}.png')

Wrapping the CAPTCHA images in a PyTorch dataset:

In [6]:
class CaptchaDataset(Dataset):
    """CAPTCHA dataset."""

    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.image_paths = list(Path(root_dir).glob('*'))
        self.transform = transform

    def __getitem__(self, index):
        image = Image.open(self.image_paths[index])

        if self.transform:
            image = self.transform(image)
        
        label_sequence = [int(c) for c in self.image_paths[index].stem]
        return (image, torch.tensor(label_sequence))
    
    def __len__(self):
        return len(self.image_paths)

Let's load the dataset and preview the first item:

To be able to normalise the dataset, we need to calculate the mean and variance across the entire dataset:

In [7]:
transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.ToTensor(),
])

dataset = CaptchaDataset(root_dir='data/captcha', transform=transform)

dataloader = DataLoader(dataset, batch_size=10000)

for batch_index, (inputs, labels) in enumerate(dataloader):
    print(f'Mean: {inputs.mean()}, Variance: {inputs.std()}')
Mean: 0.8909896016120911, Variance: 0.14787691831588745

Now, let's use the mean, variance and data from CaptchaDataset to create train and test dataloaders:

In [8]:
transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.ToTensor(),
    transforms.Normalize((0.89099,), (0.14788,)),
])

dataset = CaptchaDataset(root_dir='data/captcha', transform=transform)

train_dataset, test_dataset = random_split(dataset, [128*64, 28*64])  # total images: 9984

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False)

Long Short Term Memory (LSTM) Neural Networks

First, let's set the device:

In [7]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

We'll be making use of a neural network architecture called Stacked Long Short Term Memory (LSTM). But before that, let's start by looking at Recurrent Neural Networks (RNN), which is a general form of the LSTM network.

When humans process sequences of text or images, like reading a sentence or watching a video, they don't examine each word or image individually. Remembering what was processed earlier in a sequence helps us understand things that occur later in the sequence. RNNs are neural networks that contain loops so information from the past affects the current output. One way to visualise a RNN is as multiple copies of the same neural network with fully connected links between them.

Out[5]:

The Long Short Term Memory (LSTM) neural networks are a specific form of RNNs. While RNNs can learn from past information if the information was quite recent, it often fails to incorporate information that appears much earlier in the sequence. LSTM networks resolve this issue and allows the network to learn from long-term dependencies. It does this by replacing the single neural network layer found in an unstacked RNN with four neural network layers that interact with each other.

In [100]:
class StackedLSTM(nn.Module):
    def __init__(self, input_size=60, output_size=11, hidden_size=512, num_layers=2):
        super(StackedLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.dropout = nn.Dropout()
        self.fc = nn.Linear(hidden_size, output_size)
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers)
        
    def forward(self, inputs, hidden):
        batch_size, seq_len, input_size = inputs.shape
        outputs, hidden = self.lstm(inputs, hidden)
        outputs = self.dropout(outputs)
        outputs = torch.stack([self.fc(outputs[i]) for i in range(width)])
        outputs = F.log_softmax(outputs, dim=2)
        return outputs, hidden
    
    def init_hidden(self, batch_size):
        weight = next(self.parameters()).data 
        return (weight.new(self.num_layers, batch_size, self.hidden_size).zero_(),
                weight.new(self.num_layers, batch_size, self.hidden_size).zero_())
    
net = StackedLSTM().to(device)

Connectionist Temporal Classification (CTC) Loss

Setting the criterion and optimizer:

In [107]:
criterion = nn.CTCLoss(blank=10)
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

For the criterion, we have replaced our usual nn.CrossEntropyLoss() with nn.CTCLoss(). While Cross Entropy Loss compares a single input with a single output, Connectionist Temporal Classification (CTC) loss lets you compare an input sequence with an output sequence without the need to align them together. Previously, in order to classify sequences like CAPTCHA images, people have used rules to segment the characters with a bounding box first before recognising each digit individually, like in our previous MNIST classifier. Now, using CTC loss, we are able to train the classifer end-to-end, without any immediate steps.

The idea behind CTC loss is to make a classification at each step in the sequence. In our case, each column of pixels in the CAPTCHA image is a step in the sequence and is read from left-to-right. The classifier can choose from the available labels (in our case, 0-9), plus a special BLANK label (denoted by "-") meaning that the classifier couldn't make a prediction.

An example of what the classifier may predict from sequentially classifying a 32 pixel wide CAPTCHA image:

--00000-11---1122222--

The trick to extracting a meaningful prediction from the raw predictions is to collapse repeated characters:

-0-1-12-

And then removing the BLANK labels:

0112

Loss functions require a quantity to optimise. To compute that quantity, we make a list of all valid sequences that match the target. For example, if the target in the above example were indeed 0112, then other valid sequences to reach the same target would include:

--0000--1----1122222-- --0000--1----11---22-- --0000--1-11111---22--

Since the classifier would have a probabilities for each possible label at each step in the sequence, these can be multiplied together to get an overall probability for each of the valid sequences. The quantity for the loss function to optimise then, is the sum of the overall probabilities for all valid sequences. Although the list of valid sequences can be very long, the calculation can be improved using a dynamic programming algorithm.

Out[6]:

Unfortunately, PyTorch requires the BLANK label to be an integer, so instead of "-", we assign it to be 10:

In [108]:
BLANK_LABEL = 10

Training and Testing the Model

Training Phase

In [ ]:
net.train()  # set network to training phase
    
epochs = 30
batch_size = 64

# for each pass of the training dataset
for epoch in range(epochs):
    train_loss, train_correct, train_total = 0, 0, 0
    
    h = net.init_hidden(batch_size)
    
    # for each batch of training examples
    for batch_index, (inputs, targets) in enumerate(train_dataloader):
        inputs, labels = inputs.to(device), labels.to(device)
        h = tuple([each.data for each in h])
        
        batch_size, channels, height, width = inputs.shape
        
        # reshape inputs: NxCxHxW -> WxNx(HxC)
        inputs = (inputs
                  .permute(3, 0, 2, 1)
                  .contiguous()
                  .view((width, batch_size, -1)))
        
        optimizer.zero_grad()  # zero the parameter gradients
        outputs, h = net(inputs, h)  # forward pass

        # compare output with ground truth
        input_lengths = torch.IntTensor(batch_size).fill_(width)
        target_lengths = torch.IntTensor([len(t) for t in targets])
        loss = criterion(outputs, targets, input_lengths, target_lengths)

        loss.backward()  # backpropagation
        nn.utils.clip_grad_norm_(net.parameters(), 10)  # clip gradients
        optimizer.step()  # update network weights
        
        # record statistics
        prob, max_index = torch.max(outputs, dim=2)
        train_loss += loss.item()
        train_total += len(targets)

        for i in range(batch_size):
            raw_pred = list(max_index[:, i].cpu().numpy())
            pred = [c for c, _ in groupby(raw_pred) if c != BLANK_LABEL]
            target = list(targets[i].cpu().numpy())
            if pred == target:
                train_correct += 1

        # print statistics every 10 batches
        if (batch_index + 1) % 10 == 0:
            print(f'Epoch {epoch + 1}/{epochs}, ' +
                  f'Batch {batch_index + 1}/{len(train_dataloader)}, ' +
                  f'Train Loss: {(train_loss/1):.5f}, ' +
                  f'Train Accuracy: {(train_correct/train_total):.5f}')
            
            train_loss, train_correct, train_total = 0, 0, 0

Testing Phase

In [ ]:
h = net.init_hidden(batch_size)  # init hidden state

net.eval()

test_loss = 0
test_correct = 0
test_total = len(test_dataloader.dataset)

with torch.no_grad():  # detach gradients so network runs faster
    
    # for each batch of testing examples
    for batch_index, (inputs, targets) in enumerate(test_dataloader):
        inputs, labels = inputs.to(device), labels.to(device)
        h = tuple([each.data for each in h])
        batch_size, channels, height, width = inputs.shape
        
        # reshape inputs: NxCxHxW -> WxNx(HxC)
        inputs = (inputs
                  .permute(3, 0, 2, 1)
                  .contiguous()
                  .view((width, batch_size, -1)))

        outputs, h = net(inputs, h)  # forward pass
        
        # record loss
        input_lengths = torch.IntTensor(batch_size).fill_(width)
        target_lengths = torch.IntTensor([len(t) for t in targets])
        loss = criterion(outputs, targets, input_lengths, target_lengths)
        test_loss += loss.item()
        
        # compare prediction with ground truth
        prob, max_index = torch.max(outputs, dim=2)
        
        for i in range(batch_size):
            raw_pred = list(max_index[:, i].cpu().numpy())
            pred = [c for c, _ in groupby(raw_pred) if c != BLANK_LABEL]
            target = list(targets[i].cpu().numpy())
            if pred == target:
                test_correct += 1
                
print(f'Test Loss: {(test_loss/len(test_dataloader)):.5f}, ' +
      f'Test Accuracy: {(test_correct/test_total):.5f} ' +
      f'({test_correct}/{test_total})')

Now, let's load the CAPTCHA image mentioned at the start and the ground truth target:

In [11]:
data_iterator = iter(test_dataloader)
inputs, targets = data_iterator.next()

i = 1

image = inputs[i,0,:,:]

print(f"Target: {''.join(map(str, targets[i].numpy()))}")
plt.imshow(image)
Target: 3291
Out[11]:
<matplotlib.image.AxesImage at 0x7fae5d569cf8>

And let's make a prediction using our trained LSTM neural network:

In [163]:
inputs = inputs.to(device)

batch_size, channels, height, width = inputs.shape
h = net.init_hidden(batch_size)

inputs = (inputs
          .permute(3, 0, 2, 1)
          .contiguous()
          .view((width, batch_size, -1)))

# get prediction
outputs, h = net(inputs, h)  # forward pass
prob, max_index = torch.max(outputs, dim=2)
raw_pred = list(max_index[:, i].cpu().numpy())

# print raw prediction with BLANK_LABEL replaced with "-"
print('Raw Prediction: ' + ''.join([str(c) if c != BLANK_LABEL else '-' for c in raw_pred]))

pred = [str(c) for c, _ in groupby(raw_pred) if c != BLANK_LABEL]
print(f"Prediction: {''.join(pred)}")

    
Raw Prediction: -------------------33----------------------------------------------222-------------999--------------------------------------------------77----------------------
Prediction: 3297

Out of interest, let's annotate the CAPTCHA image with lines where the digits are first recognised:

In [168]:
line_mask = [(a == BLANK_LABEL) & (b != BLANK_LABEL) for a, b in zip(raw_pred, raw_pred[1:])]
indices = [i for i, x in enumerate(line_mask) if x]

annotated_image = image.clone()
for index in indices:
    annotated_image[:, index] = 0
    annotated_image[:, index+1] = -2

plt.imshow(annotated_image)
Out[168]:
<matplotlib.image.AxesImage at 0x7fbadc7d3a90>

Summary and Next Steps

In this post, you've learnt how to:

  • generate CAPTCHAs;
  • build a Long Short Term Memory (LSTM) neural network;
  • train a network using Connectionist Temporal Classification (CTC) loss.

LSTM neural networks have been applied to image captioning and text summarisation with excellent results.

In the next post, we'll be looking at image style transfer, i.e. how to take an image and transform it into a specific artistic style.