Spatial Transformer Networks
Jaderberg et al. (2015) introduced the learnable Spatial Transformer (ST) module that can be used to empower standard neural networks to actively spatially transform feature maps or input data. In essence, the ST can be understood as a black box that applies some spatial transformation (e.g., crop, scale, rotate) to a given input (or part of it) conditioned on the particular input during a single forward path. In general, STs can also be seen as a learnable attention mechanism (including spatial transformation on the region of interest). Notably, STs can be easily integrated in existing neural network architectures without any supervision or modification to the optimization, i.e., STs are differentiable plug-in modules. The authors could show that STs help the models to learn invariances to translation, scale, rotation and more generic warping which resulted in state-of-the-art performance on several benchmarks, see image below.
ST Example: Results (after training) of using a ST as the first layer of a fully-connected network (ST-FCN Affine
, left) or a convolutional neural network (ST-CNN Affine
, right) trained for cluttered MNIST digit recognition are shown. Clearly, the output of the ST exhibits much less translation variance and attends to the digit. Taken from Jaderberg et al. (2015) linked video.
Model Description
The aim of STs is to provide neural networks with spatial transformation and attention capabilities in a reasonable and efficient way. Note that standard neural network architectures (e.g., CNNs) are limited in this regardinline_formula not implemented. Therefore, the ST constitutes parametrized transformations inline_formula not implemented that transform the regular input grid to a new sampling grid, see image below. Then, some form of interpolation is used to compute the pixel values in the new sampling grid (i.e., interpolation between values of the old grid).
Two examples of applying the parametrised sampling grid to an image inline_formula not implemented producing the output inline_formula not implemented. The green dots represent the new sampling grid which is obtained by transforming the regular grid inline_formula not implemented (defined on inline_formula not implemented) using the transformation inline_formula not implemented. (a) The sampling grid is the regular grid inline_formula not implemented, where inline_formula not implemented is the identity transformation matrix. (b) The sampling grid is the result of warping the regular grid with an affine transformation inline_formula not implemented. Taken from Jaderberg et al. (2015).
To this end, the ST is divided into three consecutive parts:
Localisation Network: Its purpose is to retrieve the parameters inline_formula not implemented of the spatial transformation inline_formula not implemented taking the current feature map inline_formula not implemented as input, i.e., inline_formula not implemented. Thereby, the spatial transformation is conditioned on the input. Note that dimensionality of inline_formula not implemented depends on the transformation type which needs to be defined beforehand, see some examples below. Furthermore, the localisation network can take any differentiable form, e.g., a CNN or FCN.
Examples of Spatial Transformations
The following examples highlight how a regular grid
formula not implementeddefined on the output/target map inline_formula not implemented (i.e., inline_formula not implemented and inline_formula not implemented denote height and width of inline_formula not implemented) can be transformed into a new sampling grid
formula not implementeddefined on the input/source feature map inline_formula not implemented using a parametrized transformation inline_formula not implemented, i.e., inline_formula not implemented. Visualizations have bee created by me, interactive versions can be found here.
Grid Generator: Its purpose to create the new sampling grid inline_formula not implemented on the input feature map inline_formula not implemented by applying the predefined parametrized transformation using the parameters inline_formula not implemented obtained from the localisation network, see examples above.
Sampler: Its purpose is to compute the warped version of the input feature map inline_formula not implemented by computing the pixel values in the new sampling grid inline_formula not implemented obtained from the grid generator. Note that the new sampling grid does not necessarily align with the input feature map grid, therefore some kind of interpolation is needed. Jaderberg et al. (2015) formulate this interpolation as the application of a sampling kernel centered at a particular location in the input feature map, i.e.,
formula not implementedwhere inline_formula not implemented denotes the new pixel value of the inline_formula not implemented-th channel at the inline_formula not implemented-th position of the new sampling grid coordinatesinline_formula not implemented inline_formula not implementedand inline_formula not implemented are the parameters of a generic sampling kernel inline_formula not implemented which defines the image interpolation. As the sampling grid coordinates are not channel-dependent, each channel is transformed in the same way resulting in spatial consistency between channels. Note that although in theory we need to sum over all input locations, in practice we can ignore this sum by just looking at the kernel support region for each inline_formula not implemented (similar to CNNs).
The sampling kernel can be chosen freely as long as (sub-)gradients can be defined with respect to inline_formula not implemented and inline_formula not implemented. Some possible choices are shown below.
formula not implemented
The figure below summarizes the ST architecture and shows how the individual parts interact with each other.
Taken from Jaderberg et al. (2015)
Motivation: With the introduction of GPUs, convolutional layers enabled computationally efficient training of feature detectors on patches due to their weight sharing and local connectivity concepts. Since then, CNNs have proven to be the most powerful framework when it comes to computer vision tasks such as image classification or segmentation.
Despite their success, Jaderberg et al. (2015) note that CNNs are still lacking mechanisms to be spatially invariant to the input data in a computationally and parameter efficient manner. While convolutional layers are translation-equivariant to the input data and the use of max-pooling layers has helped to allow the network to be somewhat spatially invariant to the position of features, this invariance is limited to the (typically) small spatial support of max-pooling (e.g., inline_formula not implemented). As a result, CNNs are typically not invariant to larger transformations, thus need to learn complicated functions to approximate these invariances.
What if we could enable the network to learn transformations of the input data? This is the main idea of STs! Learning spatial invariances is much easier when you have spatial transformation capabilities. The second aim of STs is to be computationally and parameter efficient. This is done by using structured, parameterized transformations which can be seen as a weight sharing scheme.
Implementation
Jaderberg et al. (2015) performed several supervised learning tasks (distorted MNIST, Street View House Numbers, fine-grained bird classification) to test the performance of a standard architecture (FCN or CNN) against an architecture that includes one or several ST modules. They could emperically validate that including STs results in performance gains, i.e., higher accuracies across multiple tasks.
The following reimplementation aims to reproduce a subset of the distored MNIST experiment (RTS distorted MNIST) comparing a standard CNN with a ST-CNN architecture. A starting point for the implementation was this pytorch tutorial by Ghassen Hamrouni.
RTS Distorted MNIST
While Jaderberg et al. (2015) explored multiple distortions on the MNIST handwriting dataset, this reimplementation focuses on the rotation-translation-scale (RTS) distorted MNIST, see image below. As described in appendix A.4 of Jaderberg et al. (2015) this dataset can easily be generated by augmenting the standard MNIST dataset as follows:
randomly rotate by sampling the angle uniformly in inline_formula not implemented,
randomly scale by sampling the factor uniformly in inline_formula not implemented,
translate by picking a random location on a inline_formula not implemented image (MNIST digits are inline_formula not implemented).
Note that this transformation could also be used as a data augmentation technique, as the resulting images remain (mostly) valid digit representations (humans could still assign correct labels).
The code below can be used to create this dataset:
import torch
from torchvision import datasets, transforms
def load_data():
"""loads MNIST datasets with 'RTS' (rotation, translation, scale)
transformation
Returns:
train_dataset (torch dataset): training dataset
test_dataset (torch dataset): test dataset
"""
def place_digit_randomly(img):
new_img = torch.zeros([42, 42])
x_pos, y_pos = torch.randint(0, 42-28, (2,))
new_img[y_pos:y_pos+28, x_pos:x_pos+28] = img
return new_img
transform = transforms.Compose([
transforms.RandomAffine(degrees=(-45, 45),
scale=(0.7, 1.2)),
transforms.ToTensor(),
transforms.Lambda(lambda img: place_digit_randomly(img)),
transforms.Lambda(lambda img: img.unsqueeze(0))
])
train_dataset = datasets.MNIST('./data', transform=transform,
train=True, download=True)
test_dataset = datasets.MNIST('./data', transform=transform,
train=True, download=True)
return train_dataset, test_dataset
train_dataset, test_dataset = load_data()
import matplotlib.pyplot as plt
import numpy as np
train_dataset, test_dataset = load_data()
n_samples = 7
fig = plt.figure(figsize=(n_samples*2.5, 2.5))
i_samples = np.random.choice(range(len(train_dataset)),
n_samples, replace=False)
for counter, i_sample in enumerate(i_samples):
img, label = train_dataset[i_sample]
plt.subplot(1, n_samples, counter + 1)
plt.imshow(transforms.ToPILImage()(img), cmap='gray')
plt.axis('off')
plt.title(str(label), fontsize=16, y=-0.2)
Model Implementation
The model implementation can be divided into three tasks:
Network Architectures: The network architectures are based upon the description in appendix A.4 of Jaderberg et al. (2015). Note that there is only one ST at the beginning of the network such that the resulting transformation is only applied over one channel (input channel). For the sake of simplicity, we only implement an affine transformation matrix. Clearly, including an ST increases the networks capacity due to the number of added trainable parameters. To allow for a fair comparison, we therefore increase the capacity of the convolutional and linear layers in the standard CNN.
The code below creates both architectures and counts their trainable parameters.
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
def get_number_of_trainable_parameters(model):
"""taken from
discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325
"""
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
return params
class CNN(nn.Module):
def __init__(self, img_size=42, include_ST=False):
super(CNN, self).__init__()
self.ST = include_ST
self.name = 'ST-CNN Affine' if include_ST else 'CNN'
c_dim = 32 if include_ST else 36
self.convs = nn.Sequential(
nn.Conv2d(1, c_dim, kernel_size=9, stride=1, padding=0),
nn.MaxPool2d(kernel_size=(2,2), stride=2),
nn.ReLU(True),
nn.Conv2d(c_dim, c_dim, kernel_size=7, stride=1, padding=0),
nn.MaxPool2d(kernel_size=(2,2), stride=2),
nn.ReLU(True),
)
out_conv = int((int((img_size - 8)/2) - 6)/2)
self.classification = nn.Sequential(
nn.Linear(out_conv**2*c_dim, 50),
nn.ReLU(True),
nn.Linear(50, 10),
nn.LogSoftmax(dim=1),
)
if include_ST:
loc_conv_out_dim = int((int(img_size/2) - 4)/2) - 4
loc_regression_layer = nn.Linear(20, 6)
# initalize final regression layer to identity transform
loc_regression_layer.weight.data.fill_(0)
loc_regression_layer.bias = nn.Parameter(
torch.tensor([1., 0., 0., 0., 1., 0.]))
self.localisation_net = nn.Sequential(
nn.Conv2d(1, 20, kernel_size=5, stride=1, padding=0),
nn.MaxPool2d(kernel_size=(2,2), stride=2),
nn.ReLU(True),
nn.Conv2d(20, 20, kernel_size=5, stride=1, padding=0),
nn.ReLU(True),
nn.Flatten(),
nn.Linear(loc_conv_out_dim**2*20, 20),
nn.ReLU(True),
loc_regression_layer
)
return
def forward(self, img):
batch_size = img.shape[0]
if self.ST:
out_ST = self.ST_module(img)
img = out_ST
out_conv = self.convs(img)
out_classification = self.classification(out_conv.view(batch_size, -1))
return out_classification
def ST_module(self, inp):
# act on twice downsampled inp
down_inp = F.interpolate(inp, scale_factor=0.5, mode='bilinear',
recompute_scale_factor=False,
align_corners=False)
theta_vector = self.localisation_net(down_inp)
# affine transformation
theta_matrix = theta_vector.view(-1, 2, 3)
# grid generator
grid = F.affine_grid(theta_matrix, inp.size(), align_corners=False)
# sampler
out = F.grid_sample(inp, grid, align_corners=False)
return out
def get_attention_rectangle(self, inp):
assert inp.shape[0] == 1, 'batch size has to be one'
# act on twice downsampled inp
down_inp = F.interpolate(inp, scale_factor=0.5, mode='bilinear',
recompute_scale_factor=False,
align_corners=False)
theta_vector = self.localisation_net(down_inp)
# affine transformation matrix
theta_matrix = theta_vector.view(2, 3).detach()
# create normalized target rectangle input image
target_rectangle = torch.tensor([
[-1., -1., 1., 1., -1],
[-1., 1., 1., -1, -1.],
[1, 1, 1, 1, 1]]
).to(inp.device)
# get source rectangle by transformation
source_rectangle = torch.matmul(theta_matrix, target_rectangle)
return source_rectangle
# instantiate models
cnn = CNN(img_size=42, include_ST=False)
st_cnn = CNN(img_size=42, include_ST=True)
# print trainable parameters
for model in [cnn, st_cnn]:
num_trainable_params = get_number_of_trainable_parameters(model)
print(f'{model.name} has {num_trainable_params} trainable parameters')
Training Procedure: As described in appendix A.4 of Jaderberg et al. (2015), the networks are trained with standard SGD, batch size of inline_formula not implemented and base learning rate of inline_formula not implemented. To reduce computation time, the number of epochs is limited to inline_formula not implemented.
The loss function is the multinomial cross entropy loss, i.e.,
formula not implementedwhere inline_formula not implemented enumerates the number of classes, inline_formula not implemented enumerates the number of images, inline_formula not implemented denotes the true probability of image inline_formula not implemented and class inline_formula not implemented and inline_formula not implemented is the probability predicted by the network. Note that the true probability distribution is categorical (hard labels), i.e.,
formula not implementedwhere inline_formula not implemented is the label assigned to the inline_formula not implemented-th image inline_formula not implemented. Thus, we can rewrite the loss as follows
formula not implementedwhich is the definition of the negative log likelihood loss (NLLLoss) in Pytorch, when the logarithmized predictions inline_formula not implemented (matrix of size inline_formula not implemented) and class labels inline_formula not implemented (vector of size inline_formula not implemented) are given as input.
The code below summarizes the whole training procedure.
from livelossplot import PlotLosses
from torch.utils.data import DataLoader
def train(model, dataset):
# fix hyperparameters
epochs = 50
learning_rate = 0.01
batch_size = 256
step_size_scheduler = 50000
gamma_scheduler = 0.1
# set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Device: {device}')
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True,
num_workers=4)
model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, gamma=gamma_scheduler,
step_size=step_size_scheduler)
losses_plot = PlotLosses()
print(f'Start training with {model.name}')
for epoch in range(1, epochs+1):
avg_loss = 0
for data, label in data_loader:
model.zero_grad()
log_prop_pred = model(data.to(device))
# multinomial cross entropy loss
loss = F.nll_loss(log_prop_pred, label.to(device))
loss.backward()
optimizer.step()
scheduler.step()
avg_loss += loss.item() / len(data_loader)
losses_plot.update({'log loss': np.log(avg_loss)})
losses_plot.send()
trained_model = model
return trained_model
Test Procedure: A very simple test procedure to evaluate both models is shown below. It is basically the same as in the pytorch tutorial.
def test(trained_model, test_dataset):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=True,
num_workers=4)
with torch.no_grad():
trained_model.eval()
test_loss = 0
correct = 0
for data, label in test_loader:
data, label = data.to(device), label.to(device)
log_prop_pred = trained_model(data)
class_pred = log_prop_pred.max(1, keepdim=True)[1]
test_loss += F.nll_loss(log_prop_pred, label).item()/len(test_loader)
correct += class_pred.eq(label.view_as(class_pred)).sum().item()
print(f'{trained_model.name}: avg loss: {np.round(test_loss, 2)}, ' +
f'avg acc {np.round(100*correct/len(test_dataset), 2)}%')
return
Results
Lastly, the results can also divided into three sections:
Training Results: Firstly, we train our models on the training dataset and compare the logarithmized losses:
trained_cnn = train(cnn, train_dataset)
trained_st_cnn = train(st_cnn, train_dataset)
The logarithmized losses already indicate that the ST-CNN performs better than the standard CNN (at least, it decreases the loss faster). However, it can also be noted that training the ST-CNN seems less stable.
Test Performance: While the performance on the training dataset may be a good indicator, test set performance is much more meaningful. Let's compare the losses and accuracies between both trained models:
for trained_model in [trained_cnn, trained_st_cnn]:
test(trained_model, test_dataset)
Clearly, the ST-CNN performs much better than the standard CNN. Note that training for more epochs would probably result in even better accuracies in both models.
Visualization of Learned Transformations: Lastly, it might be interesting to see what the ST module actually does after training.
import matplotlib.pyplot as plt
from matplotlib.patches import ConnectionPatch
def visualize_learned_transformations(trained_st_cnn, test_dataset, digit_class):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
trained_st_cnn.to(device)
n_samples = 5
data_loader = DataLoader(test_dataset, batch_size=256, shuffle=True)
batch_img, batch_label = next(iter(data_loader))
i_samples = np.where(batch_label.numpy() == digit_class)[0][0:n_samples]
fig = plt.figure(figsize=(n_samples*2.5, 2.5*4))
for counter, i_sample in enumerate(i_samples):
img = batch_img[i_sample]
label = batch_label[i_sample]
# input image
ax1 = plt.subplot(4, n_samples, 1 + counter)
plt.imshow(transforms.ToPILImage()(img), cmap='gray')
plt.axis('off')
if counter == 0:
ax1.annotate('Input', xy=(-0.3, 0.5), xycoords='axes fraction',
fontsize=14, va='center', ha='right')
# image including border of affine transformation
img_inp = img.unsqueeze(0).to(device)
source_normalized = trained_st_cnn.get_attention_rectangle(img_inp)
# remap into absolute values
source_absolute = 0 + 20.5*(source_normalized.cpu() + 1)
ax2 = plt.subplot(4, n_samples, 1 + counter + n_samples)
x = np.arange(42)
y = np.arange(42)
X, Y = np.meshgrid(x, y)
plt.pcolor(X, Y, img.squeeze(0), cmap='gray')
plt.plot(source_absolute[0], source_absolute[1], color='red')
plt.axis('off')
ax2.axes.set_aspect('equal')
ax2.set_ylim(41, 0)
ax2.set_xlim(0, 41)
if counter == 0:
ax2.annotate('ST', xy=(-0.3, 0.5), xycoords='axes fraction',
fontsize=14, va='center', ha='right')
# add arrow between
con = ConnectionPatch(xyA=(21, 41), xyB=(21, 0), coordsA='data',
coordsB='data', axesA=ax1, axesB=ax2,
arrowstyle="-|>", shrinkB=5)
ax2.add_artist(con)
# ST module output
st_img = trained_st_cnn.ST_module(img.unsqueeze(0).to(device))
ax3 = plt.subplot(4, n_samples, 1 + counter + 2*n_samples)
plt.imshow(transforms.ToPILImage()(st_img.squeeze(0).cpu()), cmap='gray')
plt.axis('off')
if counter == 0:
ax3.annotate('ST Output', xy=(-0.3, 0.5), xycoords='axes fraction',
fontsize=14, va='center', ha='right')
# add arrow between
con = ConnectionPatch(xyA=(21, 41), xyB=(21, 0), coordsA='data',
coordsB='data', axesA=ax2, axesB=ax3,
arrowstyle="-|>", shrinkB=5)
ax3.add_artist(con)
# predicted label
log_pred = trained_st_cnn(img.unsqueeze(0).to(device))
pred_label = log_pred.max(1)[1].item()
ax4 = plt.subplot(4, n_samples, 1 + counter + 3*n_samples)
plt.text(0.45, 0.43, str(pred_label), fontsize=22)
plt.axis('off')
#plt.title(f'Ground Truth {label.item()}', y=-0.1, fontsize=14)
if counter == 0:
ax4.annotate('Prediction', xy=(-0.3, 0.5), xycoords='axes fraction',
fontsize=14, va='center', ha='right')
# add arrow between
con = ConnectionPatch(xyA=(21, 41), xyB=(0.5, 0.65), coordsA='data',
coordsB='data', axesA=ax3, axesB=ax4,
arrowstyle="-|>", shrinkB=5)
ax4.add_artist(con)
return
visualize_learned_transformations(st_cnn, test_dataset, 2)
Clearly, the ST module attends to the digits such that the ST output has much less variation in terms of rotation, translation and scale making the classification task for the follow up CNN easier.
Pretty cool, hugh?
-------------------------------------------------------------------------------------------
inline_formula not implemented: Clearly, convolutional layers are not rotation or scale invariant. Even the translation-equivariance property does not necessarily make CNNs translation-invariant as typically some fully connected layers are added at the end. Max-pooling layers can introduce some translation invariance, however are limited by their size such that often large translation are not captured.
inline_formula not implemented: Jaderberg et al. (2015) define the transformation with normalized coordinates, i.e., inline_formula not implemented. However, in the sampling kernel equations it seems more likely that they assume unnormalized/absolute coordinates, e.g., in equation 4 of the paper normalized coordinates would be nonsensical.