Auto-Encoding Variational Bayes
Kingma and Welling (2013) introduced the Variational Auto-Encoder (VAE) to showcase how their Auto-Encoding Variational Bayes (AEVB) algorithm can be used in practice. Assuming i.i.d. datasets and continuous latent variables, the AEVB algorithm learns an approximate probabilistic encoder inline_formula not implemented jointly with the probabilisitc decoder inline_formula not implemented (where inline_formula not implemented parametrize the corresponding distributions) by learning the optimal model parameters inline_formula not implemented through optimizing an objective function with standard gradient ascent methods. In summary, a VAE is probabilistic autoencoder which uses variational inference to regularize the coding space. Furthermore, a VAE is a deep generative model as sampling from the coding space is possible, i.e., new observations can be generated.
Model Description
The AEVB algorithm basically assumes a generative process, introduces a variational approximation (see figure below) and optimizes the model parameters by maximizing an objective function. The objective function consists of the (reparametrized) variational lower bound of each datapoint. Reparametrization is necessary to allow the explicit formulation of gradients with respect to the model parameters.
Objective Function Derivation: Let inline_formula not implemented denote the dataset consisting of inline_formula not implemented i.i.d. samples and let inline_formula not implemented denote the unobserved continuous random variable (i.e., hidden or code variable). Kingma and Welling (2013) assume that each observed sample inline_formula not implemented comes from a generative process in which: Firstly, a hidden variable inline_formula not implemented is generated from a prior distribution inline_formula not implemented. Secondly, inline_formula not implemented is generated from the conditional distribution inline_formula not implemented. Note that we do not know inline_formula not implemented nor do we have information about inline_formula not implemented. In order to recover this generative process, they introduce inline_formula not implemented as an approximation to the intractable true posteriorinline_formula not implemented inline_formula not implemented. The marginal log likelihood of each individual datapoint inline_formula not implemented can then be stated as follows (see Eric Jang's amazing blog post for detailed derivation)
formula not implementedwhere inline_formula not implemented denotes the KL divergence of the approximate from the true posterior (this quantity remains unknown since the true posterior inline_formula not implemented is intractable). inline_formula not implemented is called the variational lower bound or evidence lower bound (ELBO). The goal is to optimize inline_formula not implemented such that variational lower bound is maximized, thereby we indirectly maximize the marginal log likelihood. The variational lower bound can rewritten such that the objective function is obtained (also derived in Eric Jang's blog post)
formula not implementedThe two terms have an associated interpretation in autoencoder language:
Reconstruction Accuracy (opposite of Reconstruction Error): The expectation can be interpreted using Monte Carlo integration,i.e.,
which results in an unbiased estimate. Sampling inline_formula not implemented can be understood as encoding the observed input inline_formula not implemented into a code inline_formula not implemented using the probabilistic encoder inline_formula not implemented. Clearly, the expectation is maximized when the decoder inline_formula not implemented maps the encoded input inline_formula not implemented back the original input inline_formula not implemented, i.e., assigns high probability to inline_formula not implemented.
Regularization Term: The KL divergence is non-negative and only zero if both distributions are identical. Thus, maximizing this term forces the encoder distribution inline_formula not implemented to be close to the prior inline_formula not implemented. In VAEs, the prior is typically set to be an isotropic normal distribution resulting in a regularized code space, i.e., encouraging a code space that is close to a normal distribution.
Reparametrization Trick: While the KL-divergence inline_formula not implemented (i.e., the regularization term) can often be integrated analytically, the second term inline_formula not implemented(i.e., the reconstruction accuracy) requires sampling from inline_formula not implemented. There are two downsides associated wih sampling from inline_formula not implemented approaches:
Backpropagation does not work with a sampling operation, i.e., the implementation of VAEs would be more difficult.
The usual Monte Carlo gradient estimator (which relies on sampling from inline_formula not implemented) w.r.t. inline_formula not implemented exhibits very high variance.
To overcome these problems, Kingma and Welling (2013) use the reparametrization trick:
Substitute sampling inline_formula not implemented by using a deterministic mapping inline_formula not implemented with the differential transformation inline_formula not implemented of an auxiliary noise variable inline_formula not implemented with inline_formula not implemented.
As a result, the reparametrized objective function can be written as follows
formula not implementedin which the second term can be approximated with Monte Carlo integration yielding
formula not implementedwith inline_formula not implemented. Note that Kingma and Welling denote this estimator as the second version of the Stochastic Gradient Variational Bayes (SGVB) estimator. Assuming that the KL-divergence can be integrated analytically, the derivatives inline_formula not implemented can be taken (see figure below), i.e., this estimator can be optimized using standard stochastic gradient methods.
To increase stability and performance, Kingma and Welling introduce a minibatch estimator of the lower bound:
formula not implementedwhere inline_formula not implemented denotes a minibatch of inline_formula not implemented datapoints from the full dataset inline_formula not implemented of inline_formula not implemented datapoints.
Learning the Model
Learning the probabilistic encoder inline_formula not implemented and decoder inline_formula not implemented comes down to learning the optimal model parameters inline_formula not implemented using the AEVB algorithm which can be summarized in 5 steps:
Initialize model parameters inline_formula not implemented randomly.
Sample random minibatch inline_formula not implemented.
Compute gradients inline_formula not implemented.
Update model parameters inline_formula not implemented by taking a gradient ascent step.
Repeat steps 2-4 until model parameters converged
VAE Implementation
A VAE simply uses deep neural networks (DNNs) as function approximators to parametrize the probabilistic encoder inline_formula not implemented and decoder inline_formula not implemented. The optimal parameters inline_formula not implementedare learned jointly by training the VAE using the AEVB algorithm.
Regularization Term: Typically, the prior over the latent variables is set to be the centered isotropic Gaussian, i.e., inline_formula not implemented. Note that this prior is needed to compute the regularization term in the objective function. Furthermore, it is commonly assumed that the true posterior inline_formula not implemented may be approximated by inline_formula not implemented(subscripts denote that these parameters come from the encoder network). As a result, the regularization term can be integrated analytically leading to a term that only depends on inline_formula not implemented (see Appendix B of Kingma and Welling)
formula not implementedwhere inline_formula not implemented denotes the latent space dimension.
Encoder/Decoder Network: Kingma and Welling (2013) use simple neural networks with only one hidden layer to approximate the parameters of the probabilistic encoder and decoder. As stated above, the encoder network is fixed to compute the parameters inline_formula not implemented of the Gaussian distribution inline_formula not implemented. In fact, the encoder network takes a sample inline_formula not implemented and outputs the mean inline_formula not implemented and logarithmized variance, i.e.,
formula not implementedNote that using the logarithmized version of the variance increases stability and simplifies the traininginline_formula not implemented.
In principle, the encoder and decoder network are very similar only that the dimension of the input and output are reversed. While the encoder network is fixed to approximate a multivariate Gaussian with diagonal covariance structure, the decoder network can approximate a multivariate Gaussian (real-valued data) or Bernoulli (binary data) distribution.
Below is a simple Python class that can be used to instantiate the encoder or decoder network as described in appendix C of Kingma and Welling (2013).
import torch.nn as nn
from collections import OrderedDict
class CoderNetwork(nn.Module):
"""Encoder/Decoder for use in VAE based on Kingma and Welling
Args:
input_dim: input dimension (int)
output_dim: output dimension (int)
hidden_dim: hidden layer dimension (int)
coder_type: encoder/decoder type can be
'Gaussian' - Gaussian with diagonal covariance structure
'I-Gaussian' - Gaussian with identity as covariance matrix
'Bernoulli' - Bernoulli distribution
"""
def __init__(self, input_dim, hidden_dim, output_dim, coder_type):
super().__init__()
assert coder_type in ['Gaussian', 'I-Gaussian' ,'Bernoulli'], \
'unknown coder_type'
self.input_dim = input_dim
self.output_dim = output_dim
self.coder_type = coder_type
self.coder = nn.Sequential(OrderedDict([
('h', nn.Linear(input_dim, hidden_dim)),
('ReLU', nn.ReLU()) # ReLU instead of Tanh proposed by K. and W.
]))
self.fc_mu = nn.Linear(hidden_dim, output_dim)
if coder_type == 'Gaussian':
self.fc_log_var = nn.Linear(hidden_dim, output_dim)
elif coder_type == 'Bernoulli':
self.sigmoid_mu = nn.Sigmoid()
return
def forward(self, inp):
out = self.coder(inp)
mu = self.fc_mu(out)
if self.coder_type == 'Gaussian':
log_var = self.fc_log_var(out)
return [mu, log_var]
elif self.coder_type == 'I-Gaussian':
return mu
elif self.coder_type == 'Bernoulli':
return self.sigmoid_mu(mu)
return
Reconstruction Accuracy: Sampling from the encoder distribution is avoided by using the reparameterization trick, i.e., the latent variable inline_formula not implemented is expressed as a deterministic variable
formula not implementedand inline_formula not implemented denotes element-wise multiplication.
Note that we do not need to sample from the decoder distribution, since during training the reconstruction accuracy in the objective function only sums the log-likelihood of each sample inline_formula not implemented and during test time we are mostly interested in the reconstructed inline_formula not implemented with highest probability, i.e., the mean.
The reconstruction accuracy in the reparametrized form is given by
formula not implementedwhere inline_formula not implemented denotes the number of samples used during the reparameterization trick. Depending on the chosen decoder distribution, the log-likelihood can be stated in terms of the estimated distribution parameters:
Gaussian distribution (diagonal covariance) inline_formula not implemented
with the original observation inline_formula not implemented. In this form, the objective function is ill-posed since there are no limitations on the form of the normal distribution. As a result the objective function is unbounded, i.e., the VAE could learn the true mean inline_formula not implemented with arbitrary variance inline_formula not implemented or huge variances with arbitrary means to maximize the log-likelihood (see this post). Note that in the encoder network, the prior inline_formula not implemented is used to constrain the encoder distribution (i.e., the mean and variance).
Gaussian distribution (identity as covariance variance) inline_formula not implemented
with the original observation inline_formula not implemented. In this case the reconstruction accuracy is proportional to the negative mean squarred error which is typically used as the loss function in standard autoencoders.
Bernoulli distribution
with the original observation inline_formula not implemented. In this case the reconstruction accuracy equals the negative binary cross entropy loss. Note that there are plenty of VAE implementations that use the binary cross entropy loss on non- binary observations, see discussions in this thread.
To put this into practice, below is a simple VAE Python class which will be used to compare the different decoder distributions.
import torch
from torch.distributions.multivariate_normal import MultivariateNormal
class VAE(nn.Module):
"""A simple VAE class based on Kingma and Welling
Args:
encoder_network: instance of CoderNetwork class
decoder_network: instance of CoderNetwork class
L: number of samples used during reparameterization trick
"""
def __init__(self, encoder_network, decoder_network, L=1):
super().__init__()
self.encoder = encoder_network
self.decoder = decoder_network
self.L = L
latent_dim = encoder_network.output_dim
self.normal_dist = MultivariateNormal(torch.zeros(latent_dim),
torch.eye(latent_dim))
return
def forward(self, x):
L = self.L
z, mu_E, log_var_E = self.encode(x, L)
# regularization term per batch, i.e., size: (batch_size)
regularization_term = (1/2) * (1 + log_var_E - mu_E**2
- torch.exp(log_var_E)).sum(axis=1)
# upsample x and reshape
batch_size = x.shape[0]
x_ups = x.repeat(L, 1).view(batch_size, L, -1)
if self.decoder.coder_type == 'Gaussian':
# mu_D, log_var_D have shape (batch_size, L, output_dim)
mu_D, log_var_D = self.decode(z)
# reconstruction accuracy per batch, i.e., size: (batch_size)
recons_acc = (1/L) * (-(0.5)*(log_var_D.sum(axis=2)).sum(axis=1)
-(0.5)*((1/torch.exp(log_var_D))
*((x_ups - mu_D)**2)
).sum(axis=2).sum(axis=1))
elif self.decoder.coder_type == 'I-Gaussian':
# mu_D has shape (batch_size, L, output_dim)
mu_D = self.decode(z)
# reconstruction accuracy per batch, i.e., size: (batch_size)
recons_acc = (1/L) * (-(0.5) * ((x_ups - mu_D)**2
).sum(axis=2).sum(axis=1))
elif self.decoder.coder_type == 'Bernoulli':
# mu_D has shape (batch_size, L, output_dim)
mu_D = self.decode(z)
# reconstruction accuracy per batch, i.e., size: (batch_size)
# corresponds to the negative binary cross entropy loss (BCELoss)
recons_acc = (1/L) * (x_ups * torch.log(mu_D) +
(1 - x_ups) * torch.log(1 - mu_D)
).sum(axis=2).sum(axis=1)
loss = - regularization_term.sum() - recons_acc.sum()
return loss
def encode(self, x, L=1):
# get encoder distribution parameters
mu_E, log_var_E = self.encoder(x)
# sample noise variable L times for each batch
batch_size = x.shape[0]
epsilon = self.normal_dist.sample(sample_shape=(batch_size, L, ))
# upsample mu_E, log_var_E and reshape
mu_E_ups = mu_E.repeat(L, 1).view(batch_size, L, -1)
log_var_E_ups = log_var_E.repeat(L, 1).view(batch_size, L, -1)
# get latent variable by reparametrization trick
z = mu_E_ups + torch.sqrt(torch.exp(log_var_E_ups)) * epsilon
return z, mu_E, log_var_E
def decode(self, z):
# get decoder distribution parameters
if self.decoder.coder_type == 'Gaussian':
mu_D, log_var_D = self.decoder(z)
return mu_D, log_var_D
elif self.decoder.coder_type == 'I-Gaussian':
mu_D = self.decoder(z)
return mu_D
elif self.decoder.coder_type == 'Bernoulli':
mu_D = self.decoder(z)
return mu_D
return
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
def train(decoder_type, dataset, x_dim, hid_dim, z_dim, batch_size, L, epochs):
encoder_network = CoderNetwork(input_dim=x_dim,
hidden_dim=hid_dim,
output_dim=z_dim,
coder_type='Gaussian')
decoder_network = CoderNetwork(input_dim=z_dim,
hidden_dim=hid_dim,
output_dim=x_dim,
coder_type=decoder_type)
model = VAE(encoder_network, decoder_network, L=L)
data_loader = DataLoader(dataset, batch_size, shuffle=True)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
print('Start training with {} decoder distribution\n'.format(decoder_type))
for epoch in range(1, epochs + 1):
print('Epoch {}/{}'.format(epoch, epochs))
avg_loss = 0
for counter, (mini_batch_data, label) in enumerate(data_loader):
model.zero_grad()
loss = model(mini_batch_data.view(-1, x_dim))
loss.backward()
optimizer.step()
avg_loss += loss.item() / len(dataset)
if counter % 20 == 0 or (counter + 1)==len(data_loader):
batch_loss = loss.item() / len(mini_batch_data)
print('\r[{}/{}] batch loss: {:.2f}'.format(counter + 1,
len(data_loader),
batch_loss),
end='', flush=True)
print('\nAverage loss: {:.3f}'.format(avg_loss))
print('Done!\n')
trained_VAE = model
return trained_VAE
dataset = datasets.MNIST('data/', transform=transforms.ToTensor(),
download=True)
x_dim, hid_dim, z_dim = 28*28, 400, 20
batch_size, L, epochs = 128, 1, 3
Bernoulli_VAE = train('Bernoulli', dataset, x_dim, hid_dim, z_dim,
batch_size, L, epochs)
Gaussian_VAE = train('Gaussian', dataset, x_dim, hid_dim, z_dim,
batch_size, L, epochs)
I_Gaussian_VAE = train('I-Gaussian', dataset, x_dim, hid_dim, z_dim,
batch_size, L, epochs)
Let's look at the differences in the reconstructions:
import matplotlib.pyplot as plt
def plot_results(trained_model, dataset, n_samples):
decoder_type = trained_model.decoder.coder_type
fig = plt.figure(figsize=(14, 3))
fig.suptitle(decoder_type + ' Distribution: Observations (top row) and ' +
'their reconstructions (bottom row)')
for i_sample in range(n_samples):
x_sample = dataset[i_sample][0].view(-1, 28*28)
z, mu_E, log_var_E = trained_model.encode(x_sample, L=1)
if decoder_type in ['Bernoulli', 'I-Gaussian']:
x_prime = trained_model.decode(z)
else:
x_prime = trained_model.decode(z)[0]
plt.subplot(2, n_samples, i_sample + 1)
plt.imshow(x_sample.view(28, 28).data.numpy())
plt.axis('off')
plt.subplot(2, n_samples, i_sample + 1 + n_samples)
plt.imshow(x_prime.view(28, 28).data.numpy())
plt.axis('off')
return fig
n_samples = 10
plot_results(Bernoulli_VAE, dataset, n_samples)
plot_results(Gaussian_VAE, dataset, n_samples)
plot_results(I_Gaussian_VAE, dataset, n_samples)
Acknowledgement
Daniel Daza's blog was really helpful and the presented code is highly inspired by his summary on VAEs.