Neural Style Transfer in Pytorch

In this article we're implementing a neural algorithm of artistic style based on the original paper by Gatys et al. This algorithm will allow us to separate the content and style of images and create new images taking the style of one given image and the content of another. For this purpose we'll be using a pre-trained convolutional neural network and perform gradient descent on an input image to minimize a distance measure of context and one of style, simultaneously. The model of choice is the so-called VGG-19, where the 19 stands for the number of layers within the network. It is the same model the authors used in the original paper and it is readily available with pre-trained weights in torchvision.

Note that the gradient descent process here is different from how neural networks are usually trained. Where one would usually adjust the weights of the neural network layers, here we will keep them fixed and instead the input image's pixel values are treated as parameters. Our gradients w.r.t. the distance measure will then be backpropagated to the inputs, thus transforming the inputs (and therefore the image itself).

The image transformation process is therefore actually quite simple - we don't have to care about overfitting, dataset splits, cross-validation etc. Actually what we want here is to overfit to one example. But where we do have pay attention to the details is the construction of the "correct" loss function and weights of its individual terms.

1.
Imports

In addition to Pytorch, we'll make heavy use of the Torchvision package, which offers handy image transformation methods as well as pre-trained models. Let's import all the packages we need for our neural style transfer algorithm.

from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.optim as optim
from torchvision import transforms, models

2.
Utility Functions

Before we start setting up the model, we'll first define some utility functions, which will help us transfer images to and from Pytorch tensors, compute features of a given layer and also compute a gram matrix which we'll use in our style loss.

Since we want to work with our own uploaded image, we have to define a function to load such an image and turn it into a Pytorch tensor, so we can use it with our model. The function will take as arguments an image path, as well as a maximum size and an optional shape argument. Large images will slow down the processing later on and since we're impatient, we'll cap the image size to 400 in x and y dimension. With the size set, we'll build up a list of image transformations using transforms.Compose. The first transform in this list resizes the image. Afterwards we transform the image into a Pytorch tensor because our model expects tensor inputs. The last transformation is also specific to the model we will be using: a normalization of the kind which was used when training the model for its original purpose - image classification on the ImageNet dataset. The exact numbers used there are a result of statistics of the ImageNet dataset. We'll just take them as a given for now.

def load_image(img_path, max_size=400, shape=None):
  image = Image.open(img_path).convert('RGB')  
  
  if max(image.size) > max_size:
    size = max_size
  else:
    size = max(image.size)
	
  if shape is not None:
    size = shape
  
  in_transform = transforms.Compose([
    transforms.Resize(size),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
  
  image = in_transform(image)[:3, :, :].unsqueeze(0)
  
  return image

Next we also want a function to do the opposite: convert an image tensor back to a numpy array which we can display. For that we have to rearrange the dimensions in the right way and un-do the normalization.

def im_convert(tensor):
  image = tensor.to("cpu").clone().detach()
  image = image.numpy().squeeze()
  image = image.transpose(1, 2, 0)
  image = image * np.array((0.229, 0.224, 0.225)) + np.array(
    (0.485, 0.456, 0.406))
  image = image.clip(0, 1)
  
  return image

Now for the functions which compute our model's features. Why do we even need them? The feature maps of certain layers within a deep convolutional neural network (CNN) have been shown to capture both style and content of the images fed into the model. In the their seminal paper, Gatys et al. write:

"When Convolutional Neural Networks are trained on object recognition, they develop a representation of the image that makes object information increasingly explicit along the processing hierarchy. Therefore, along the processing hierarchy of the network, the input image is transformed into representations that increasingly care about the actual content of the image compared to its detailed pixel values."

Therefore, we can refer to features of the later layers in the model as content features. Figure 1 in the above-mentioned paper provides a nice visualization of this.

Calculating the features of given layers in our model requires the image for which we want to compute the features and the model itself. We'll keep the function general so we can use it to obtain features of the style layers as well as the content layers. The function performs a forward pass through the model, one layer at a time, and stores the feature map responses if the name of the layer matches one of the keys in the predefined layer dict. This dict serves as a mapping from the Pytorch VGG19 implementation's layer indices to the layer names defined in the paper. If no layers are specified, we'll use a complete set of both the content layer and the style layers as a default.

def get_features(image, model, layers=None):
  if layers is None:
    layers = {'0': 'conv1_1','5': 'conv2_1',
              '10': 'conv3_1',
              '19': 'conv4_1',
              '21': 'conv4_2',  ## content layer
              '28': 'conv5_1'}
  features = {}
  x = image
  for name, layer in enumerate(model.features):
    x = layer(x)
    if str(name) in layers:
      features[layers[str(name)]] = x
  
  return features

With that in place we can turn our focus on the representation of style in the layers of our model. As it turns out, style representations can be obtained by measuring the correlation between different feature map responses of a given layer, which boils down to computing the Gram matrix of the vectorised feature map. For this we take as input a feature map tensor, reshape the spatial extend (height and width) of the tensor to be one vector, and then just compute the inner product of the reshaped tensor.

def gram_matrix(tensor):
  _, n_filters, h, w = tensor.size()
  tensor = tensor.view(n_filters, h * w)
  gram = torch.mm(tensor, tensor.t())
  
  return gram

3.
Model Setup

Now we have all the ingredients in place for the actual style process. For that we now load the VGG model with its pre-trained weights and set requires_grad to False for all parameters (weights). For that we first download the weights.

torch.utils.model_zoo.load_url('https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', model_dir='/results/')
vgg19-dcbb9e9d.pth

As mentioned above, we don't want to compute gradients with respect to our model's weights, but only with respect to our input image tensor.

vgg = models.vgg19()
vgg.load_state_dict(torch.load(vgg19-dcbb9e9d.pth))

for param in vgg.parameters():
  param.requires_grad_(False)

3.1.
A Little Trick

The authors propose to replace all max-pooling layers in the network with average pooling for better-looking results. So let's go ahead and do exactly that:

for i, layer in enumerate(vgg.features):
  if isinstance(layer, torch.nn.MaxPool2d):
    vgg.features[i] = torch.nn.AvgPool2d(kernel_size=2, stride=2, padding=0)

Now we select the device on which we'll run the image transformation process. If available, we of course want to utilize the power of a GPU.

torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vgg.to(device).eval()

Let's not forget to store all the relevant tensors on the same device.

4.
Loading the Image

Now we're ready to load the content and style images. Here we'll use this beautiful image taken in Shanghai by Photographer Cagdas Eli as the content image and transfer the style from the infamous Starry Night by Vincent van Gogh. We'll also resize the style image to match the content, so we don't have to bother with dimensions later on.

Cagdas Eli, Shanghai Orientpearl Tower, 2017, photograph

For those images we can then also start to compute the feature map responses of the layers we specified in the defaults of the function.

Vincent Van Gogh, Starry Night

content = load_image(Image).to(device)
style = load_image(Image).to(device)

content_features = get_features(content, vgg)
style_features = get_features(style, vgg)

As we stated above, we also want to compute the gram matrices for all the style layers. Let's go ahead and build a dictionary with all the style gram matrices.

style_grams = {
  layer: gram_matrix(style_features[layer]) for layer in style_features}

Now we can create a third image, which will serve as our starting point for the image transformation process. One could also start just from random noise, but we'll choose a copy of the content image here. This way, we can visualize a nice transformation from the original content to the stylized version. Note that we set the requires_grad_ to True for this tensor, so we can perform gradient descend updates on the image.

target = torch.randn_like(content).requires_grad_(True).to(device)
target = content.clone().requires_grad_(True).to(device)

5.
A Loss Function of Artistic Style

We already established that our loss function will rely on feature map responses of various content and style layers. Now let's start to actually implement the full loss function with its content and style term.

In the style term of our loss function, we'll have multiple style layers contributing. It's helpful to have different layers contribute to the style term to different extents. We can achieve this by simple style weights for each layer. This enables us to tune the style artifacts to our liking. As a tendency, larger weights for earlier layers yield larger artifacts.

style_weights = {'conv1_1': 0.75,
                 'conv2_1': 0.5,
                 'conv3_1': 0.2,
                 'conv4_1': 0.2,
                 'conv5_1': 0.2}

Of course we also want to have weights for the overall strength of both individual loss terms (content and style). While the original paper reports a ratio of content to style weights of 1⋅10−3 and 1⋅10−4, we'll go for a different fraction here.

content_weight = 1e3
style_weight = 1e2

Now we have all the weights in place, but what does the loss function actually look like? As it turns out, it's rather simple: In the case of the content loss it's a mean squared-error loss between the two feature map responses.

The style loss will look pretty similar, just replacing the feature map responses by the Gram matrices and also dividing the mean squared-error loss by the total number of elements in the respective feature map.

6.
The Style Transfer Loop

But before we construct the total loss, let's set a few hyperparameters for the style transfer process. First we need an optimizer. The original paper reports using an L-BFGS optimizer, but we'll just stick with the standard Adam optimizer. If we were to stick to the L-BFGS optimizer, we'd just have to replace optim.Adam with optim.LBFGS below.

optimizer = optim.Adam([target], lr=0.01)

Now let's define for how many iterations we wish to run the style transfer loop. To track the progress in a visual way, we'll not only print out our total loss value every now and then, but also show the current state of the image from time to time.

num_iterations = 300
show_every = 100

Now we can define the actual loop:

For the defined number of iterations we'll compute the content and style losses (remember that multiple layers contribute to the style loss), multiply them by their respective weights and add them up for the total loss. With the total loss we can then perform the backpropagation step and iteratively update our image until finished.

214.8s
Language:Python
for i in range(1, num_iterations + 1):
  optimizer.zero_grad()
  target_features = get_features(target, vgg)
  
  content_loss = torch.mean((target_features['conv4_2'] -
                             content_features['conv4_2']) ** 2)
  
  style_loss = 0
  for layer in style_weights:
    target_feature = target_features[layer]
    target_gram = gram_matrix(target_feature)
    _, d, h, w = target_feature.shape
    style_gram = style_grams[layer]
    layer_style_loss = style_weights[layer] * torch.mean(
      (target_gram - style_gram) ** 2)
    style_loss += layer_style_loss / (d * h * w)
    
    total_loss = content_weight * content_loss + style_weight * style_loss
    total_loss.backward(retain_graph=True)
    optimizer.step()
    
    if i % 50 == 0:
      print('content: {}'.format(round(content_weight*content_loss.item(),2)))
      print('style: {}'.format(round(style_weight*style_loss.item(),2)))
      content_fraction = round(
        content_weight*content_loss.item()/total_loss.item(), 2)
      style_fraction = round(
        style_weight*style_loss.item()/total_loss.item(), 2)
      print('Iteration {}, Total loss: {} - (content: {}, style {})'.format(
        i,total_loss.item(), content_fraction, style_fraction))
      
final_img = im_convert(target)
fig = plt.figure()
plt.imshow(final_img)
fig
© 2018 Nextjournal GmbH