div.ProseMirror

Spatial Broadcast Decoder

Watters et al. (2019) introduce the Spatial Broadcast Decoder (SBD) as an architecture for the decoder in Variational Auto-Encoders (VAEs) to improve disentanglement in the latent spaceinline_formula not implemented, reconstruction accuracy and generalization in limited datasets (i.e., held-out regions in data space). Motivated by the limitations of deconvolutional layers in traditional decoders, these upsampling layers are replaced by a tiling operation in the Spatial Broadcast decoder. Furthermore, explicit spatial information (inductive bias) is appended in the form of coordinate channels leading to a simplified optimization problem and improved positional generalization. As a proof of concept, they tested the model on the colored sprites dataset (known factors of variation such as position, size, shape), Chairs and 3D Object-in-Room datasets (no positional variation), a dataset with small objects and a dataset with dependent factors. They could show that the Spatial Broadcast decoder can be used complementary or as an improvement to state-of-the-art disentangling techniques.

Model Description

As stated in the title, the model architecture of the Spatial Broadcast decoder is very simple: Take a standard VAE decoder and replace all upsampling deconvolutional layers by tiling the latent code inline_formula not implemented across the original image space, appending fixed coordinate channels and applying an convolutional network with inline_formula not implemented stride, see the figure below.

Schematic of Spatial Broadcast VAE

(left) Schematic of the Spatial Broadcast VAE. In the decoder, we broadcast (tile) the latent code inline_formula not implemented of size inline_formula not implemented to the image width inline_formula not implemented and height inline_formula not implemented, and concatenate two "coordinate" channels. This is then fed to an unstrided convolutional decoder. (right) Pseudo-code of the spatial operation. Taken from Watters et al. (2019)

Motivation: The presented architecture is mainly motivated by two reasons:

  1. Deconvolution layers cause optimization difficulties: Watters et al. (2019) argue that upsampling deconvolutional layers should be avoided, since these are prone to produce checkerboard artifacts, i.e., a checkerboard pattern can be identified on the resulting images (when looking closer), see figure below. These artifacts constrain the reconstruction accuracy and Watters et al. (2019) hypothesize that the resulting effects may raise problems for learning a disentangled representation in the latent space.

Checkerboard Artifacts

A checkerboard pattern can often be identified in artifically generated images that use deconvolutional layers. Taken from Odena et al. (2016) (very worth reading).

2. Appended coordinate channels improve positional generalization and optimization: Previous work by Liu et al. (2018) showed that standard convolution/deconvolution networks (CNNs) perform badly when trying to learn trivial coordinate transformations (e.g., learning a mapping from Cartesian space into one-hot pixel space or vice versa). This behavior may seem counterintuitive (easy task, small dataset), however the feature of translational equivariance (i.e., shifting an object in the input equally shifts its representation in the output) in CNNsinline_formula not implemented hinders learning this task: The filters have by design no information about their position. Thus, coordinate transformations result in complicated functions which makes optimization difficult. E.g., changing the input coordinate slighlty might push the resulting function in a completelty different direction.

CoordConv Solution: To overcome this problem, Liu et al. (2018) propose to append coordinate channels before convolution and term the resulting layer CoordConv, see figure below. In principle, this layer can learn to use or discard translational equivariance and keeps the other advantages of convolutional layers (fast computations, few parameters). Under this modification learning coordinate transformation problems works out of the box with perfect generalization in less time (150 times faster) and less memory (10-100 times fewer parameters). As coordinate transformations are implicitely needed in a variaty of tasks (such as producing bounding boxes in object detection) using CoordConv instead of standard convolutions might increase the performance of several other models.

CoordConv Layer

Comparison of 2D convolutional and CoordConv layers. Taken from Liu et al. (2018).

Positional Generalization: Appending fixed coordinate channels is mainly beneficial in datasets in which same objects may appear at distinct positions (i.e., there is positional variation). The main idea is that rendering an object at a specific position without spatial information (i.e., standard convolution/deconvolution) results in a very complicated function. In contrast,the Spatial Broadcast decoder architecture can leverage the spatial information to reveal objects easily: E.g., by convolving the positions in the latent space with the fixed coordinate channels and applying a threshold operation. Thus, Watters et al. (2019) argue that the Spatial Broadcast decoder architecture puts a prior on dissociating positional from non-positional features in the latent distribution.Datasets without positional variation in turn seem unlikely to benefit from this architecture. However, Watters et al. (2019) showed that the Spatial Broadcast decoder could still help in these datasets and attribute this to the replacement of deconvolutional layers.

Implementation

Watters et al. (2019) conducted experiments with several datasets and could show that incorporating the Spatial Broadcast decoder into state-of-the-art VAE architectures consistently increased their perfomance. While this is impressive, it is always frustrating to not being able to reproduce results due to missing implementation details, less computing resources or simply not having enough time to work on a reimplementation.

The following reimplementation intends to eliminate that frustration by reproducing some of their experiments on much smaller datasets with similar characteristics such that training will take less time (less than 30 minutes with a NVIDIA Tesla K80 GPU).

Data Generation

A dataset that is in spirit similar to the colored sprites dataset will be generated, i.e., procedurally generated objects from known factors of variation. Watters et al. (2019) use a binary dsprites dataset consisting of 737.280 images and transform these during training into colored images by uniformly sampling from a predefined HSV space (see Appendix A.3). As a result, the dataset has 8 factors of variation (inline_formula not implemented-position, inline_formula not implemented-position, size, shape, angle, 3D-color) with infinite samples (due to sampling of color). They used inline_formula not implemented training steps.

To reduce training time, we are going to generate a much simpler dataset consisting of 3675 images with a circle (fixed size) inside generated from a predefined set of possible colors and positions such that there are only 3 factors of variation (inline_formula not implemented-position, inline_formula not implemented-position, discretized color). In this case inline_formula not implemented training steps suffice for approximate convergence.

Visualization of self-written dataset.

The code below creates the dataset. Note that it is kept more generic than necessary to allow the creation of several variations of this dataset, i.e., more dedicated experiments can be conducted.

from PIL import Image, ImageDraw
import torchvision.transforms as transforms
import numpy as np
import torch
from torch.utils.data import TensorDataset
def generate_img(x_position, y_position, shape, color, img_size, size=20):
    """Generate an RGB image from the provided latent factors
    
    Args:
        x_position (float): normalized x position
        y_position (float): normalized y position
        shape (string): can only be 'circle' or 'square'
        color (string): color name or rgb string
        img_size (int): describing the image size (img_size, img_size)
        size (int): size of shape
        
    Returns:
        torch tensor [3, img_size, img_size] (dtype=torch.float32)
    """
    # creation of image
    img = Image.new('RGB', (img_size, img_size), color='black')
    # map (x, y) position to pixel coordinates
    x_position = (img_size - 2 - size) * x_position 
    y_position = (img_size - 2 - size) * y_position
    # define coordinates
    x_0, y_0 = x_position, y_position
    x_1, y_1 = x_position + size, y_position + size
    # draw shapes
    img1 = ImageDraw.Draw(img)
    if shape == 'square':
        img1.rectangle([(x_0, y_0), (x_1, y_1)], fill=color)
    elif shape == 'circle':       
        img1.ellipse([(x_0, y_0), (x_1, y_1)], fill=color)
    return transforms.ToTensor()(img).type(torch.float32)
def generate_dataset(img_size, shape_sizes, num_pos, shapes, colors):
    """procedurally generated from 4 ground truth independent latent factors, 
       these factors are/can be 
           Position X: num_pos values in [0, 1]
           Poistion Y: num_pos values in [0, 1]
           Shape: square, circle
           Color: standard HTML color name or 'rgb(x, y, z)'
    
    Args:
           img_size (int): describing the image size (img_size, img_size)  
           shape_sizes (list): sizes of shapes
           num_pos (int): discretized positions
           shapes (list): shapes (can only be 'circle', 'square')
           colors (list): colors
    
    Returns:
           data: torch tensor [n_samples, 3, img_size, img_size]
           latents: each entry describes the latents of corresp. data entry
    """
    num_shapes, num_colors, sizes = len(shapes), len(colors), len(shape_sizes)
    
    n_samples = num_pos*num_pos*num_shapes*num_colors*sizes
    data = torch.empty([n_samples, 3, img_size, img_size])
    latents = np.empty([n_samples], dtype=object)
    
    index = 0
    for x_pos in np.linspace(0, 1, num_pos):
        for y_pos in np.linspace(0, 1, num_pos):
            for shape in shapes:
                for size in shape_sizes:
                    for color in colors:
                        img = generate_img(x_pos, y_pos, shape, color, 
                                           img_size, size)
                        data[index] = img
                        latents[index] = [x_pos, y_pos, shape, color]
                    
                        index += 1
    return data, latents
circles_data, latents = generate_dataset(img_size=64, shape_sizes=[16],
                                         num_pos=35,
                                         shapes=['circle'],
                                         colors=['red', 'green', 'blue'])
sprites_dataset = TensorDataset(circles_data)
4.0s
Python Jupyter (Python)

Model Implementation

Although in principle implementing a VAE is fairly simple (see my post for details), in practice one must choose lots of hyperparmeters. These can be divided into three broader categories:

  • Encoder/Decoder and Prior Distribution: As suggested by Watters et al. (2019) in Appendix A, we use a Gaussian decoder distribution with fixed diagonal covariance structure inline_formula not implemented, hence the reconstruction accuracy can be calculated as followsinline_formula not implemented

    formula not implemented

    For the encoder distribution a Gaussian with diagonal covariance inline_formula not implemented and as prior a centered multivariate Gaussian inline_formula not implemented are chosen (both typical choices).

  • Network Architecture for Encoder/Decoder: The network architectures for the standard encoder and decoder consist of convolutional and deconvolutional layers (since these perform typically much better on image data). The Spatial Broadcast decoder defines a different kind of architecture, see Model Description. The exact architectures are taken from Appendix A.1 of Watters et al., see code belowinline_formula not implemented:

from torch import nn
class Encoder(nn.Module):
    """"Encoder class for use in convolutional VAE
      
    Args:
        latent_dim: dimensionality of latent distribution
    Attributes:
        encoder_conv: convolution layers of encoder
        fc_mu: fully connected layer for mean in latent space
        fc_log_var: fully connceted layers for log variance in latent space
    """
    def __init__(self, latent_dim=6):
        super().__init__()
        self.latent_dim = latent_dim
        self.encoder_conv = nn.Sequential(
            # shape: [batch_size, 3, 64, 64]
            nn.Conv2d(3,  64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            # shape: [batch_size, 64, 32, 32]
            nn.Conv2d(64, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            # shape: [batch_size, 64, 4, 4],
            nn.Flatten(),
            # shape: [batch_size, 1024]
            nn.Linear(1024, 256),
            nn.ReLU(),
            # shape: [batch_size, 256]
        )
        self.fc_mu = nn.Sequential(
            nn.Linear(in_features=256, out_features=self.latent_dim),
        )
        self.fc_log_var = nn.Sequential(
            nn.Linear(in_features=256, out_features=self.latent_dim),
        )
        return
    def forward(self, inp):
        out = self.encoder_conv(inp)
        mu = self.fc_mu(out)
        log_var = self.fc_log_var(out)
        return [mu, log_var]
class Decoder(nn.Module):
    """(standard) Decoder class for use in convolutional VAE,
    a Gaussian distribution with fixed variance (identity times fixed variance
    as covariance matrix) used as the decoder distribution
      
    Args:
        latent_dim: dimensionality of latent distribution
        fixed_variance: variance of distribution
    Attributes:
        decoder_upsampling: linear upsampling layer(s)
        decoder_deconv: deconvolution layers of decoder (also upsampling)
    """
    def __init__(self, latent_dim, fixed_variance):
        super().__init__()
        self.latent_dim = latent_dim
        self.coder_type = 'Gaussian with fixed variance'
        self.fixed_variance = fixed_variance
        self.decoder_upsampling = nn.Sequential(
            nn.Linear(self.latent_dim, 256),
            nn.ReLU(),
            # reshaped into [batch_size, 64, 2, 2]
        )
        self.decoder_deconv = nn.Sequential(
            # shape: [batch_size, 64, 2, 2]
            nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            # shape: [batch_size, 64, 4, 4]
            nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64,  3, kernel_size=4, stride=2, padding=1),
            # shape: [batch_size, 3, 64, 64]
        )
        return
    def forward(self, inp):
        ups_inp = self.decoder_upsampling(inp)
        ups_inp = ups_inp.view(-1, 64, 2, 2)
        mu = self.decoder_deconv(ups_inp)
        return mu
          
          
class SpatialBroadcastDecoder(nn.Module):
    """SBD class for use in convolutional VAE,
      a Gaussian distribution with fixed variance (identity times fixed 
      variance as covariance matrix) used as the decoder distribution
    Args:
        latent_dim: dimensionality of latent distribution
        fixed_variance: variance of distribution
    Attributes:
        img_size: image size (necessary for tiling)
        decoder_convs: convolution layers of decoder (also upsampling)
    """
    def __init__(self, latent_dim, fixed_variance):
        super().__init__()
        self.img_size = 64
        self.coder_type = 'Gaussian with fixed variance'
        self.latent_dim = latent_dim
        self.fixed_variance = fixed_variance
        x = torch.linspace(-1, 1, self.img_size)
        y = torch.linspace(-1, 1, self.img_size)
        x_grid, y_grid = torch.meshgrid(x, y)
        # reshape into [1, 1, img_size, img_size] and save in state_dict
        self.register_buffer('x_grid', x_grid.view((1, 1) + x_grid.shape))
        self.register_buffer('y_grid', y_grid.view((1, 1) + y_grid.shape))
        self.decoder_convs = nn.Sequential(
            # shape [batch_size, latent_dim + 2, 64, 64]
            nn.Conv2d(in_channels=self.latent_dim+2, out_channels=64,
                      stride=(1, 1), kernel_size=(3,3), padding=1),           
            nn.ReLU(),
            # shape [batch_size, 64, 64, 64]
            nn.Conv2d(in_channels=64, out_channels=64, stride=(1,1), 
                      kernel_size=(3, 3), padding=1),
            nn.ReLU(),
            # shape [batch_size, 64, 64, 64]
            nn.Conv2d(in_channels=64, out_channels=3, stride=(1,1), 
                      kernel_size=(3, 3), padding=1),
            # shape [batch_size, 3, 64, 64]         
        )
        return
    def forward(self, z):
        batch_size = z.shape[0]
        # reshape z into [batch_size, latent_dim, 1, 1]
        z = z.view(z.shape + (1, 1))
        # tile across image [batch_size, latent_im, img_size, img_size]
        z_b = z.expand(-1, -1, self.img_size, self.img_size)
        # upsample x_grid and y_grid to [batch_size, 1, img_size, img_size]
        x_b = self.x_grid.expand(batch_size, -1, -1, -1)
        y_b = self.y_grid.expand(batch_size, -1, -1, -1)
        # concatenate vectors [batch_size, latent_dim+2, img_size, img_size]
        z_sb = torch.cat((z_b, x_b, y_b), dim=1)
        # apply convolutional layers
        mu_D = self.decoder_convs(z_sb)
        return mu_D
0.2s
Python Jupyter (Python)

The VAE implementation below combines the encoder and decoder architectures (slightly modified version of my last VAE implementation).

from torch.distributions.multivariate_normal import MultivariateNormal
class VAE(nn.Module):
    """A simple VAE class
    Args:
        vae_tpe: type of VAE either 'Standard' or 'SBD'
        latent_dim: dimensionality of latent distribution
        fixed_var: fixed variance of decoder distribution
    """
    def __init__(self, vae_type, latent_dim, fixed_var):
        super().__init__()
        self.vae_type = vae_type
        if self.vae_type == 'Standard':
            self.decoder = Decoder(latent_dim=latent_dim, 
                                  fixed_variance=fixed_var)
        else:
            self.decoder = SpatialBroadcastDecoder(latent_dim=latent_dim,
                                                   fixed_variance=fixed_var)
        self.encoder = Encoder(latent_dim=latent_dim)
        self.normal_dist = MultivariateNormal(torch.zeros(latent_dim), 
                                              torch.eye(latent_dim))
        return
    def forward(self, x):      
        z, mu_E, log_var_E = self.encode(x)
        # regularization term per batch, i.e., size: (batch_size)
        regularization_term = 0.5 * (1 + log_var_E - mu_E**2
                                      - torch.exp(log_var_E)).sum(axis=1)
        batch_size = x.shape[0]
        if self.decoder.coder_type == 'Gaussian with fixed variance':
            # x_rec has shape (batch_size, 3, 64, 64)
            x_rec = self.decode(z)
            # reconstruction accuracy per batch, i.e., size: (batch_size)
            factor = 0.5 * (1/self.decoder.fixed_variance)
            recons_acc = - factor * ((x.view(batch_size, -1) - 
                                    x_rec.view(batch_size, -1))**2
                                  ).sum(axis=1)
        return -regularization_term.mean(), -recons_acc.mean()
    def reconstruct(self, x):
        mu_E, log_var_E = self.encoder(x)
        x_rec = self.decoder(mu_E)
        return x_rec
    def encode(self, x):
        # get encoder distribution parameters
        mu_E, log_var_E = self.encoder(x)
        # sample noise variable for each batch
        batch_size = x.shape[0]
        epsilon = self.normal_dist.sample(sample_shape=(batch_size, )
                                          ).to(x.device)
        # get latent variable by reparametrization trick
        z = mu_E + torch.exp(0.5*log_var_E) * epsilon
        return z, mu_E, log_var_E
    def decode(self, z):
        # get decoder distribution parameters
        mu_D = self.decoder(z)
        return mu_D
0.1s
Python Jupyter (Python)
  • Training Parameters: Lastly, training neural networks itself consists of several hyperparmeters. Again, we are using the same setup as defined in Appendix A.1 of Watters et al. (2019), see code below.

from livelossplot import PlotLosses
from torch.utils.data import DataLoader
  
def train(dataset, epochs, VAE):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
        
    print('Device: {}'.format(device))
    data_loader = DataLoader(dataset, batch_size=16, shuffle=True)
    VAE.to(device)
    optimizer = torch.optim.Adam(VAE.parameters(), lr=3e-4)
    losses_plot = PlotLosses(groups={'avg log loss': 
                                    ['kl loss', 'reconstruction loss']})
    print('Start training with {} decoder\n'.format(VAE.vae_type))
    for epoch in range(1, epochs +1):
        avg_kl = 0 
        avg_recons_err = 0
        for counter, mini_batch_data in enumerate(data_loader):
            VAE.zero_grad()
            kl_div, recons_err = VAE(mini_batch_data[0].to(device))
            loss = kl_div + recons_err
            loss.backward()
            optimizer.step()
            avg_kl += kl_div.item() / len(dataset)
            avg_recons_err += recons_err.item() / len(dataset)
        losses_plot.update({'kl loss': np.log(avg_kl), 
                            'reconstruction loss': np.log(avg_recons_err)})
        losses_plot.send()
    trained_VAE = VAE
    return trained_VAE
2.1s
Python Jupyter (Python)

Visualization Functions

Evaluating the representation quality of trained models is a difficult task, since we are not only interested in the reconstruction accuracy but also in the latent space and its properties. Ideally the latent space offers a disentangled representation such that each latent variable represents a factor of variation with perfect reconstruction accuracy (i.e., for evaluation it is very helpful to know in advance how many and what factors of variation exist). Although there are some metrics to quantify disentanglement, many of them have serious shortcomings and there is yet no consensus in the literature which to use (Watters et al., 2019). Instead of focusing on some metric, we are going to visualize the results by using two approaches:

  • Reconstructions and Latent Traversals: A very popular and helpful plot is to show some (arbitrarly chosen) reconstructions compared to the original input together with a series of latent space traversals. I.e., taking some encoded input and looking at the reconstructions when sweeping each coordinate in the latent space in a predefined interval (here from -2 to +2) while keeping all other coordinates constant. Ideally, each sweep can be associated with a factor of variation. The code below will be used to generate these plots. Note that the reconstructions are clamped into inline_formula not implemented as this is the allowed image range.

import matplotlib.pyplot as plt
%matplotlib inline
  
def reconstructions_and_latent_traversals(STD_VAE, SBD_VAE, dataset, SEED=1):
    np.random.seed(SEED)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    latent_dims = STD_VAE.encoder.latent_dim
    n_samples = 7
    i_samples = np.random.choice(range(len(dataset)), n_samples, replace=False)
    # preperation for latent traversal
    i_latent = i_samples[n_samples//2]
    lat_image = dataset[i_latent][0]
    sweep = np.linspace(-2, 2, n_samples)
    fig = plt.figure(constrained_layout=False, figsize=(2*n_samples, 2+latent_dims))
    grid = plt.GridSpec(latent_dims + 5, n_samples*2 + 3, 
                        hspace=0.2, wspace=0.02, figure=fig)
    # standard VAE
    for counter, i_sample in enumerate(i_samples):
        orig_image = dataset[i_sample][0]
        # original
        main_ax = fig.add_subplot(grid[1, counter + 1])
        main_ax.imshow(transforms.ToPILImage()(orig_image))
        main_ax.axis('off')
        main_ax.set_aspect('equal')
        # reconstruction
        x_rec = STD_VAE.reconstruct(orig_image.unsqueeze(0).to(device))
        # clamp output into [0, 1] and prepare for plotting
        recons_image =  torch.clamp(x_rec, 0, 1).squeeze(0).cpu()
        main_ax = fig.add_subplot(grid[2, counter + 1])
        main_ax.imshow(transforms.ToPILImage()(recons_image))
        main_ax.axis('off')
        main_ax.set_aspect('equal')
    # latent dimension traversal
    z, mu_E, log_var_E = STD_VAE.encode(lat_image.unsqueeze(0).to(device))
    for latent_dim in range(latent_dims):
        for counter, z_replaced in enumerate(sweep):
            z_new = z.detach().clone()
            z_new[0][latent_dim] = z_replaced
            # clamp output into [0, 1] and prepare for plotting
            img_rec = torch.clamp(STD_VAE.decode(z_new), 0, 1).squeeze(0).cpu()
            main_ax = fig.add_subplot(grid[4 + latent_dim, counter + 1])
            main_ax.imshow(transforms.ToPILImage()(img_rec))
            main_ax.axis('off')
    # SBD VAE
    for counter, i_sample in enumerate(i_samples):
        orig_image = dataset[i_sample][0]
        # original
        main_ax = fig.add_subplot(grid[1, counter + n_samples + 2])
        main_ax.imshow(transforms.ToPILImage()(orig_image))
        main_ax.axis('off')
        main_ax.set_aspect('equal')
        # reconstruction
        x_rec = SBD_VAE.reconstruct(orig_image.unsqueeze(0).to(device))
        # clamp output into [0, 1] and prepare for plotting
        recons_image = torch.clamp(x_rec, 0, 1).squeeze(0).cpu()
        main_ax = fig.add_subplot(grid[2, counter + n_samples + 2])
        main_ax.imshow(transforms.ToPILImage()(recons_image))
        main_ax.axis('off')
        main_ax.set_aspect('equal')
    # latent dimension traversal
    z, mu_E, log_var_E = SBD_VAE.encode(lat_image.unsqueeze(0).to(device))
    for latent_dim in range(latent_dims):
        for counter, z_replaced in enumerate(sweep):
            z_new = z.detach().clone()
            z_new[0][latent_dim] = z_replaced
            # clamp output into [0, 1] and prepare for plotting
            img_rec = torch.clamp(SBD_VAE.decode(z_new), 0, 1).squeeze(0).cpu()
            main_ax = fig.add_subplot(grid[4+latent_dim, counter+n_samples+2])
            main_ax.imshow(transforms.ToPILImage()(img_rec))
            main_ax.axis('off')
    # prettify by adding annotation texts
    fig = prettify_with_annotation_texts(fig, grid, n_samples, latent_dims)
    return fig
def prettify_with_annotation_texts(fig, grid, n_samples, latent_dims):
    # figure titles
    titles = ['Deconv Reconstructions', 'Spatial Broadcast Reconstructions',
              'Deconv Traversals', 'Spatial Broadcast Traversals']
    idx_title_pos = [[0, 1, n_samples+1], [0, n_samples+2, n_samples*2+2],
                    [3, 1, n_samples+1], [3, n_samples+2, n_samples*2+2]]
    for title, idx_pos in zip(titles, idx_title_pos):
        fig_ax = fig.add_subplot(grid[idx_pos[0], idx_pos[1]:idx_pos[2]])
        fig_ax.annotate(title, xy=(0.5, 0), xycoords='axes fraction', 
                        fontsize=14, va='bottom', ha='center')
        fig_ax.axis('off')
    # left annotations
    fig_ax = fig.add_subplot(grid[1, 0])
    fig_ax.annotate('input', xy=(1, 0.5), xycoords='axes fraction', 
                    fontsize=12,  va='center', ha='right')
    fig_ax.axis('off')
    fig_ax = fig.add_subplot(grid[2, 0])
    fig_ax.annotate('recons', xy=(1, 0.5), xycoords='axes fraction', 
                    fontsize=12, va='center', ha='right')
    fig_ax.axis('off')
    fig_ax = fig.add_subplot(grid[4:latent_dims + 4, 0])
    fig_ax.annotate('latent coordinate traversed', xy=(0.9, 0.5), 
                    xycoords='axes fraction', fontsize=12,
                    va='center', ha='center', rotation=90)
    fig_ax.axis('off')
    # pertubation magnitude
    for i_y_grid in [[1, n_samples+1], [n_samples+2, n_samples*2+2]]:
        fig_ax = fig.add_subplot(grid[latent_dims + 4, i_y_grid[0]:i_y_grid[1]])
        fig_ax.annotate('pertubation magnitude', xy=(0.5, 0), 
                        xycoords='axes fraction', fontsize=12,
                        va='bottom', ha='center')
        fig_ax.set_frame_on(False)
        fig_ax.axes.set_xlim([-2.5, 2.5])
        fig_ax.xaxis.set_ticks([-2, 0, 2])
        fig_ax.xaxis.set_ticks_position('top')
        fig_ax.xaxis.set_tick_params(direction='inout', pad=-16)
        fig_ax.get_yaxis().set_ticks([])
    # latent dim
    for latent_dim in range(latent_dims):
        fig_ax = fig.add_subplot(grid[4 + latent_dim, n_samples*2 + 2])
        fig_ax.annotate('lat dim ' + str(latent_dim + 1), xy=(0, 0.5), 
                        xycoords='axes fraction', 
                        fontsize=12, va='center', ha='left')
        fig_ax.axis('off')
    return 
0.3s
Python Jupyter (Python)
  • Latent Space Geometry: While latent traversals may be helpful, Watters et al. (2019) note that this techniques suffers from two shortcommings:

    1. Latent space entanglement might be difficult to perceive by eye.

2. Traversals are only taken at some point in space. It could be that traversals at some points are more disentangled than at other positions. Thus, judging disentanglement by the aforementioned method might be ultimately dependent to randomness.

To overcome these limitations, they propose a new method which they term latent space geometry. The main idea is to visualize a transformation from a 2-dimensional generative factor space (subspace of all generative factors) into the 2-dimensional latent subspace (choosing the two latent components that correspond to the factors of variation). Latent space geometry that preserves the chosen geometry of the generative factor space (while scaling and rotation might be allowed depending on the chosen generative factor space) indicates disentanglement.

To put this into practice, the code below creates circle images by varying inline_formula not implemented and inline_formula not implemented positions uniformly and keeping the other generative factors (here only color) constant. Accordingly, the geometry of the generative factor space is a uniform grid (which will be plotted). These images will be encoded into mean and variance of the latent distribution. In order to find the latent components that correspond to the inline_formula not implemented and inline_formula not implemented position, we choose the components with smallest mean variance across all reconstructions, i.e., the most informative componentsinline_formula not implemented. Then, we can plot the latent space geometry by using the latent components of the mean (encoder distribution), see code below.

def latent_space_geometry(STD_VAE, SBD_VAE):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    plt.figure(figsize=(18, 6))
    # x,y position grid in [0.2, 0.8] (generative factors)
    equi = np.linspace(0.2, 0.8, 31)
    equi_without_vert = np.setdiff1d(equi, np.linspace(0.2, 0.8, 6))
    x_pos = np.append(np.repeat(np.linspace(0.2, 0.8, 6), len(equi)),
                      np.tile(equi_without_vert, 6))
    y_pos = np.append(np.tile(equi, 6),
                      np.repeat(np.linspace(0.8, 0.2, 6), len(equi_without_vert)))
    labels = np.append(np.repeat(np.arange(6), 31),
                      np.repeat(np.arange(6)+10, 25))
    # plot generative factor geometry
    plt.subplot(1, 3, 1)
    plt.scatter(x_pos, y_pos, c=labels, cmap=plt.cm.get_cmap('rainbow', 10))
    plt.gca().set_title('Ground Truth Factors', fontsize=16)
    plt.xlabel('X-Position')
    plt.ylabel('Y-Position')
    # generate images
    img_size = 64
    shape_size = 16
    images = torch.empty([len(x_pos), 3, img_size, img_size]).to(device)
    for counter, (x, y) in enumerate(zip(x_pos, y_pos)):
        images[counter] = generate_img(x, y, 'circle', 'red', 
                                      img_size, shape_size)
    # STD VAE
    [all_mu, all_log_var] = STD_VAE.encoder(images)
    # most informative latent variable
    lat_1, lat_2 = all_log_var.mean(axis=0).sort()[1][:2]
    # latent coordinates
    x_lat = all_mu[:, lat_1].detach().cpu().numpy()
    y_lat = all_mu[:, lat_2].detach().cpu().numpy()
    # plot latent space geometry
    plt.subplot(1, 3, 2)
    plt.scatter(x_lat, y_lat, c=labels, cmap=plt.cm.get_cmap('rainbow', 10))
    plt.gca().set_title('DeConv', fontsize=16)
    plt.xlabel('latent 1 value')
    plt.ylabel('latent 2 value')
    # SBD VAE
    [all_mu, all_log_var] = SBD_VAE.encoder(images)
    # most informative latent variable
    lat_1, lat_2 = all_log_var.mean(axis=0).sort()[1][:2]
    # latent coordinates
    x_lat = all_mu[:, lat_1].detach().cpu().numpy()
    y_lat = all_mu[:, lat_2].detach().cpu().numpy()
    # plot latent space geometry
    plt.subplot(1, 3, 3)
    plt.scatter(x_lat, y_lat, c=labels, cmap=plt.cm.get_cmap('rainbow', 10))
    plt.gca().set_title('Spatial Broadcast', fontsize=16)
    plt.xlabel('latent 1 value')
    plt.ylabel('latent 2 value')
    return
0.1s
Python Jupyter (Python)

Results

Lastly, let's train our models and look at the results:

epochs = 150
latent_dims = 5 # x position, y position, color, extra slots
fixed_variance = 0.3
standard_VAE = VAE(vae_type='Standard', latent_dim=latent_dims, 
                   fixed_var=fixed_variance)
SBD_VAE = VAE(vae_type='SBD', latent_dim=latent_dims, 
              fixed_var=fixed_variance)
0.1s
Python Jupyter (Python)
trained_standard_VAE  = train(sprites_dataset, epochs, standard_VAE)
645.2s
Python Jupyter (Python)
trained_SBD_VAE = train(sprites_dataset, epochs, SBD_VAE)
953.3s
Python Jupyter (Python)

At the log-losses plots, we can already see that using the Spatial Broadcast decoder results in an improved reconstruction accuracy and regularization term. Now let's compare both models visually by their

  • Reconstructions and Latent Traversals:

reconstructions_and_latent_traversals(trained_standard_VAE, 
                                      trained_SBD_VAE, sprites_dataset, 2)
6.2s
Python Jupyter (Python)

While the reconstructions within both models look pretty good, the latent space traversal shows an entangled representation in the standard (DeConv) VAE whereas the Spatial Broadcast model seems quite disentangled.

  • Latent Space Geometry:

latent_space_geometry(trained_standard_VAE, trained_SBD_VAE)
1.3s
Python Jupyter (Python)

The latent space geometry verifies our pervious findings: The DeConv decoder has a highly entangled latent space (transformation is highly non linear) whereas in the Spatial Broadcast decoder the latent space geometry highly resembles the generating factors geometry (affine transformation). The transformation of the Spatial Broadcast decoder indicates very similar behavior in the inline_formula not implemented position subspace (of generative factors) as in the corresponding latent subspace.

Drawbacks of Paper

  • although there are fewer parameters in the Spatial Broadcast decoder, it does require more memory (in the implementation about 50% more)

  • longer training times compared to standard DeConv VAE

Acknowledgement

Daniel Daza's blog was really helpful and the presented code is highly inspired by his VAE-SBD implementation.

-------------------------------------------------------------------------------------------

[1]: As outlined by Watters et al. (2019), there is "yet no consensus on the definition of a disentangled representation". However, in their paper they focus on feature compositionality (i.e., composing a scene in terms of independent features such as color and object) and refer to it as disentangled representation.

[2]: In typical image classification problems, translational equivariance is highly valued since it ensures that if a filter detects an object (e.g., edges), it will detect it irrespective of its position.

[3]: For simplicity, we are setting the number of (noise variable) samples inline_formula not implemented per datapoint to 1 (see equation for inline_formula not implemented in Reparametrization Trick paragraph). Note that Kingma and Welling (2013) stated that in their experiments setting inline_formula not implemented sufficed as long as the minibatch size was large enough.

[4]: The Spatial Broadcast decoder architecture is slightly modified: Kernel size of 3 instead of 4 to get the desired output shapes..

[5]: An intuitve way to understand why latent compontents with smaller variance within the encoder distribution are more informative than others is to think about the sampled noise and the loss function: If the variance is high, the latent code inline_formula not implemented will vary a lot which in turn makes the task for the decoder more difficult. However, the regularization term (KL-divergence) pushes the variances towards 1. Thus, the network will only reduce the variance of its components if it helps to increase the reconstruction accuracy.

Runtimes (1)