Auto-Encoding Variational Bayes

Kingma and Welling (2013) introduced the Variational Auto-Encoder (VAE) to showcase how their Auto-Encoding Variational Bayes (AEVB) algorithm can be used in practice. Assuming i.i.d. datasets and continuous latent variables, the AEVB algorithm learns an approximate probabilistic encoder inline_formula not implemented jointly with the probabilisitc decoder inline_formula not implemented (where inline_formula not implemented parametrize the corresponding distributions) by learning the optimal model parameters inline_formula not implemented through optimizing an objective function with standard gradient ascent methods. In summary, a VAE is probabilistic autoencoder which uses variational inference to regularize the coding space. Furthermore, a VAE is a deep generative model as sampling from the coding space is possible, i.e., new observations can be generated.

Model Description

The AEVB algorithm basically assumes a generative process, introduces a variational approximation (see figure below) and optimizes the model parameters by maximizing an objective function. The objective function consists of the (reparametrized) variational lower bound of each datapoint. Reparametrization is necessary to allow the explicit formulation of gradients with respect to the model parameters.

Objective Function Derivation: Let inline_formula not implemented denote the dataset consisting of inline_formula not implemented i.i.d. samples and let inline_formula not implemented denote the unobserved continuous random variable (i.e., hidden or code variable). Kingma and Welling (2013) assume that each observed sample inline_formula not implemented comes from a generative process in which: Firstly, a hidden variable inline_formula not implemented is generated from a prior distribution inline_formula not implemented. Secondly, inline_formula not implemented is generated from the conditional distribution inline_formula not implemented. Note that we do not know inline_formula not implemented nor do we have information about inline_formula not implemented. In order to recover this generative process, they introduce inline_formula not implemented as an approximation to the intractable true posteriorinline_formula not implemented inline_formula not implemented. The marginal log likelihood of each individual datapoint inline_formula not implemented can then be stated as follows (see Eric Jang's amazing blog post for detailed derivation)

formula not implemented

where inline_formula not implemented denotes the KL divergence of the approximate from the true posterior (this quantity remains unknown since the true posterior inline_formula not implemented is intractable). inline_formula not implemented is called the variational lower bound or evidence lower bound (ELBO). The goal is to optimize inline_formula not implemented such that variational lower bound is maximized, thereby we indirectly maximize the marginal log likelihood. The variational lower bound can rewritten such that the objective function is obtained (also derived in Eric Jang's blog post)

formula not implemented

The two terms have an associated interpretation in autoencoder language:

  • Reconstruction Accuracy (opposite of Reconstruction Error): The expectation can be interpreted using Monte Carlo integration,i.e.,

formula not implemented

which results in an unbiased estimate. Sampling inline_formula not implemented can be understood as encoding the observed input inline_formula not implemented into a code inline_formula not implemented using the probabilistic encoder inline_formula not implemented. Clearly, the expectation is maximized when the decoder inline_formula not implemented maps the encoded input inline_formula not implemented back the original input inline_formula not implemented, i.e., assigns high probability to inline_formula not implemented.

  • Regularization Term: The KL divergence is non-negative and only zero if both distributions are identical. Thus, maximizing this term forces the encoder distribution inline_formula not implemented to be close to the prior inline_formula not implemented. In VAEs, the prior is typically set to be an isotropic normal distribution resulting in a regularized code space, i.e., encouraging a code space that is close to a normal distribution.

Reparametrization Trick: While the KL-divergence inline_formula not implemented (i.e., the regularization term) can often be integrated analytically, the second term inline_formula not implemented(i.e., the reconstruction accuracy) requires sampling from inline_formula not implemented. There are two downsides associated wih sampling from inline_formula not implemented approaches:

  1. Backpropagation does not work with a sampling operation, i.e., the implementation of VAEs would be more difficult.

  2. The usual Monte Carlo gradient estimator (which relies on sampling from inline_formula not implemented) w.r.t. inline_formula not implemented exhibits very high variance.

To overcome these problems, Kingma and Welling (2013) use the reparametrization trick:

Substitute sampling inline_formula not implemented by using a deterministic mapping inline_formula not implemented with the differential transformation inline_formula not implemented of an auxiliary noise variable inline_formula not implemented with inline_formula not implemented.

As a result, the reparametrized objective function can be written as follows

formula not implemented

in which the second term can be approximated with Monte Carlo integration yielding

formula not implemented

with inline_formula not implemented. Note that Kingma and Welling denote this estimator as the second version of the Stochastic Gradient Variational Bayes (SGVB) estimator. Assuming that the KL-divergence can be integrated analytically, the derivatives inline_formula not implemented can be taken (see figure below), i.e., this estimator can be optimized using standard stochastic gradient methods.

To increase stability and performance, Kingma and Welling introduce a minibatch estimator of the lower bound:

formula not implemented

where inline_formula not implemented denotes a minibatch of inline_formula not implemented datapoints from the full dataset inline_formula not implemented of inline_formula not implemented datapoints.

Learning the Model

Learning the probabilistic encoder inline_formula not implemented and decoder inline_formula not implemented comes down to learning the optimal model parameters inline_formula not implemented using the AEVB algorithm which can be summarized in 5 steps:

  1. Initialize model parameters inline_formula not implemented randomly.

  2. Sample random minibatch inline_formula not implemented.

  3. Compute gradients inline_formula not implemented.

  4. Update model parameters inline_formula not implemented by taking a gradient ascent step.

  5. Repeat steps 2-4 until model parameters converged

VAE Implementation

A VAE simply uses deep neural networks (DNNs) as function approximators to parametrize the probabilistic encoder inline_formula not implemented and decoder inline_formula not implemented. The optimal parameters inline_formula not implementedare learned jointly by training the VAE using the AEVB algorithm.

Regularization Term: Typically, the prior over the latent variables is set to be the centered isotropic Gaussian, i.e., inline_formula not implemented. Note that this prior is needed to compute the regularization term in the objective function. Furthermore, it is commonly assumed that the true posterior inline_formula not implemented may be approximated by inline_formula not implemented(subscripts denote that these parameters come from the encoder network). As a result, the regularization term can be integrated analytically leading to a term that only depends on inline_formula not implemented (see Appendix B of Kingma and Welling)

formula not implemented

where inline_formula not implemented denotes the latent space dimension.

Encoder/Decoder Network: Kingma and Welling (2013) use simple neural networks with only one hidden layer to approximate the parameters of the probabilistic encoder and decoder. As stated above, the encoder network is fixed to compute the parameters inline_formula not implemented of the Gaussian distribution inline_formula not implemented. In fact, the encoder network takes a sample inline_formula not implemented and outputs the mean inline_formula not implemented and logarithmized variance, i.e.,

formula not implemented

Note that using the logarithmized version of the variance increases stability and simplifies the traininginline_formula not implemented.

In principle, the encoder and decoder network are very similar only that the dimension of the input and output are reversed. While the encoder network is fixed to approximate a multivariate Gaussian with diagonal covariance structure, the decoder network can approximate a multivariate Gaussian (real-valued data) or Bernoulli (binary data) distribution.

Below is a simple Python class that can be used to instantiate the encoder or decoder network as described in appendix C of Kingma and Welling (2013).

import torch.nn as nn
from collections import OrderedDict
class CoderNetwork(nn.Module):
    """Encoder/Decoder for use in VAE based on Kingma and Welling
    
    Args:
        input_dim: input dimension (int)
        output_dim: output dimension (int)
        hidden_dim: hidden layer dimension (int)
        coder_type: encoder/decoder type can be 
                   'Gaussian'   - Gaussian with diagonal covariance structure
                   'I-Gaussian' - Gaussian with identity as covariance matrix 
                   'Bernoulli'  - Bernoulli distribution       
    """
    
    def __init__(self, input_dim, hidden_dim, output_dim, coder_type):
        super().__init__()
        
        assert coder_type in  ['Gaussian', 'I-Gaussian' ,'Bernoulli'], \
            'unknown coder_type'
        
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.coder_type = coder_type
        
        self.coder = nn.Sequential(OrderedDict([
            ('h', nn.Linear(input_dim, hidden_dim)),
            ('ReLU', nn.ReLU()) # ReLU instead of Tanh proposed by K. and W.       
        ]))
        self.fc_mu = nn.Linear(hidden_dim, output_dim)
        
        if coder_type == 'Gaussian':
            self.fc_log_var = nn.Linear(hidden_dim, output_dim)
        elif coder_type == 'Bernoulli':
            self.sigmoid_mu = nn.Sigmoid()
        return
    
    def forward(self, inp):
        out = self.coder(inp)
        mu = self.fc_mu(out)
        
        if self.coder_type == 'Gaussian':
            log_var = self.fc_log_var(out)
            return [mu, log_var]
        elif self.coder_type == 'I-Gaussian':
            return mu
        elif self.coder_type == 'Bernoulli':
            return self.sigmoid_mu(mu)
        return
2.4s
PyTorch (Python)

Reconstruction Accuracy: Sampling from the encoder distribution is avoided by using the reparameterization trick, i.e., the latent variable inline_formula not implemented is expressed as a deterministic variable

formula not implemented

and inline_formula not implemented denotes element-wise multiplication.

Note that we do not need to sample from the decoder distribution, since during training the reconstruction accuracy in the objective function only sums the log-likelihood of each sample inline_formula not implemented and during test time we are mostly interested in the reconstructed inline_formula not implemented with highest probability, i.e., the mean.

The reconstruction accuracy in the reparametrized form is given by

formula not implemented

where inline_formula not implemented denotes the number of samples used during the reparameterization trick. Depending on the chosen decoder distribution, the log-likelihood can be stated in terms of the estimated distribution parameters:

  • Gaussian distribution (diagonal covariance) inline_formula not implemented

formula not implemented

with the original observation inline_formula not implemented. In this form, the objective function is ill-posed since there are no limitations on the form of the normal distribution. As a result the objective function is unbounded, i.e., the VAE could learn the true mean inline_formula not implemented with arbitrary variance inline_formula not implemented or huge variances with arbitrary means to maximize the log-likelihood (see this post). Note that in the encoder network, the prior inline_formula not implemented is used to constrain the encoder distribution (i.e., the mean and variance).

  • Gaussian distribution (identity as covariance variance) inline_formula not implemented

formula not implemented

with the original observation inline_formula not implemented. In this case the reconstruction accuracy is proportional to the negative mean squarred error which is typically used as the loss function in standard autoencoders.

  • Bernoulli distribution

formula not implementedformula not implemented

with the original observation inline_formula not implemented. In this case the reconstruction accuracy equals the negative binary cross entropy loss. Note that there are plenty of VAE implementations that use the binary cross entropy loss on non- binary observations, see discussions in this thread.

To put this into practice, below is a simple VAE Python class which will be used to compare the different decoder distributions.

import torch
from torch.distributions.multivariate_normal import MultivariateNormal
class VAE(nn.Module):
    """A simple VAE class based on Kingma and Welling
        
    Args:
        encoder_network:  instance of CoderNetwork class
        decoder_network:  instance of CoderNetwork class
        L:                number of samples used during reparameterization trick
    """
    
    def __init__(self, encoder_network, decoder_network, L=1):
        super().__init__()
        self.encoder = encoder_network
        self.decoder = decoder_network
        self.L = L
        
        latent_dim = encoder_network.output_dim
                
        self.normal_dist = MultivariateNormal(torch.zeros(latent_dim), 
                                              torch.eye(latent_dim))
        return
    
    def forward(self, x):
        L = self.L
        
        z, mu_E, log_var_E = self.encode(x, L)
        # regularization term per batch, i.e., size: (batch_size)
        regularization_term = (1/2) * (1 + log_var_E - mu_E**2
                                       - torch.exp(log_var_E)).sum(axis=1)
        
        # upsample x and reshape
        batch_size = x.shape[0]
        x_ups = x.repeat(L, 1).view(batch_size, L, -1)    
        if self.decoder.coder_type == 'Gaussian':
            # mu_D, log_var_D have shape (batch_size, L, output_dim)
            mu_D, log_var_D = self.decode(z)
            # reconstruction accuracy per batch, i.e., size: (batch_size)       
            recons_acc = (1/L) * (-(0.5)*(log_var_D.sum(axis=2)).sum(axis=1)
                                  -(0.5)*((1/torch.exp(log_var_D))
                                          *((x_ups - mu_D)**2)
                                         ).sum(axis=2).sum(axis=1))
        elif self.decoder.coder_type == 'I-Gaussian':
            # mu_D has shape (batch_size, L, output_dim)
            mu_D = self.decode(z)
            # reconstruction accuracy per batch, i.e., size: (batch_size)
            recons_acc = (1/L) * (-(0.5) * ((x_ups - mu_D)**2
                                            ).sum(axis=2).sum(axis=1))
        elif self.decoder.coder_type == 'Bernoulli':
            # mu_D has shape (batch_size, L, output_dim)
            mu_D = self.decode(z)     
            # reconstruction accuracy per batch, i.e., size: (batch_size)
            # corresponds to the negative binary cross entropy loss (BCELoss)
            recons_acc = (1/L) * (x_ups * torch.log(mu_D) + 
                                  (1 - x_ups) * torch.log(1 - mu_D)
                                  ).sum(axis=2).sum(axis=1)
        loss = - regularization_term.sum() - recons_acc.sum()
        return loss
    
    def encode(self, x, L=1):
        # get encoder distribution parameters
        mu_E, log_var_E = self.encoder(x)
        # sample noise variable L times for each batch
        batch_size = x.shape[0]
        epsilon = self.normal_dist.sample(sample_shape=(batch_size, L, ))
        # upsample mu_E, log_var_E and reshape
        mu_E_ups = mu_E.repeat(L, 1).view(batch_size, L, -1) 
        log_var_E_ups = log_var_E.repeat(L, 1).view(batch_size, L, -1)
        # get latent variable by reparametrization trick
        z = mu_E_ups + torch.sqrt(torch.exp(log_var_E_ups)) * epsilon
        return z, mu_E, log_var_E
    
    def decode(self, z):
        # get decoder distribution parameters
        if self.decoder.coder_type == 'Gaussian':
            mu_D, log_var_D = self.decoder(z)
            return mu_D, log_var_D
        elif self.decoder.coder_type == 'I-Gaussian':
            mu_D = self.decoder(z)
            return mu_D
        elif self.decoder.coder_type == 'Bernoulli':
            mu_D = self.decoder(z)
            return mu_D
        return
0.2s
PyTorch (Python)
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
def train(decoder_type, dataset, x_dim, hid_dim, z_dim, batch_size, L, epochs):
    encoder_network = CoderNetwork(input_dim=x_dim, 
                                   hidden_dim=hid_dim, 
                                   output_dim=z_dim,
                                   coder_type='Gaussian')
    decoder_network = CoderNetwork(input_dim=z_dim, 
                                   hidden_dim=hid_dim, 
                                   output_dim=x_dim,
                                   coder_type=decoder_type)
    
    model = VAE(encoder_network, decoder_network, L=L)
    data_loader = DataLoader(dataset, batch_size, shuffle=True)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    
    print('Start training with {} decoder distribution\n'.format(decoder_type))
    for epoch in range(1, epochs + 1):
        print('Epoch {}/{}'.format(epoch, epochs))
        avg_loss = 0
        for counter, (mini_batch_data, label) in enumerate(data_loader):
            model.zero_grad()
            
            loss = model(mini_batch_data.view(-1, x_dim))
            loss.backward()
            optimizer.step()
            
            avg_loss += loss.item() / len(dataset)
            
            if counter % 20 == 0 or (counter + 1)==len(data_loader):
                batch_loss = loss.item() / len(mini_batch_data)
                print('\r[{}/{}] batch loss: {:.2f}'.format(counter + 1,
                                                            len(data_loader),
                                                            batch_loss),
                      end='', flush=True)
        print('\nAverage loss: {:.3f}'.format(avg_loss)) 
    print('Done!\n')
    trained_VAE = model
    return trained_VAE
dataset = datasets.MNIST('data/', transform=transforms.ToTensor(), 
                         download=True)
x_dim, hid_dim, z_dim = 28*28, 400, 20
batch_size, L, epochs = 128, 1, 3
Bernoulli_VAE = train('Bernoulli', dataset, x_dim, hid_dim, z_dim, 
                      batch_size, L, epochs)
Gaussian_VAE = train('Gaussian', dataset, x_dim, hid_dim, z_dim, 
                     batch_size, L, epochs)
I_Gaussian_VAE = train('I-Gaussian', dataset, x_dim, hid_dim, z_dim, 
                       batch_size, L, epochs)
157.4s
PyTorch (Python)

Let's look at the differences in the reconstructions:

import matplotlib.pyplot as plt
def plot_results(trained_model, dataset, n_samples):
    decoder_type = trained_model.decoder.coder_type
    
    fig = plt.figure(figsize=(14, 3))
    fig.suptitle(decoder_type + ' Distribution: Observations (top row) and ' +
                 'their reconstructions (bottom row)')
    for i_sample in range(n_samples):
        x_sample = dataset[i_sample][0].view(-1, 28*28)
        
        z, mu_E, log_var_E = trained_model.encode(x_sample, L=1)
        if decoder_type in ['Bernoulli', 'I-Gaussian']:
            x_prime = trained_model.decode(z)
        else:
            x_prime = trained_model.decode(z)[0]
    
        plt.subplot(2, n_samples, i_sample + 1)
        plt.imshow(x_sample.view(28, 28).data.numpy())
        plt.axis('off')
        plt.subplot(2, n_samples, i_sample + 1 + n_samples)
        plt.imshow(x_prime.view(28, 28).data.numpy())
        plt.axis('off')
    return fig
n_samples = 10
plot_results(Bernoulli_VAE, dataset, n_samples)
1.3s
PyTorch (Python)
plot_results(Gaussian_VAE, dataset, n_samples)
1.5s
PyTorch (Python)
plot_results(I_Gaussian_VAE, dataset, n_samples)
1.2s
PyTorch (Python)

Acknowledgement

Daniel Daza's blog was really helpful and the presented code is highly inspired by his summary on VAEs.

Runtimes (1)