Hugh Murrell / Aug 15 2019

Chapter 11, Autoencoders

Autoencoders

based on a code by Mike J Innes, found at

https://github.com/FluxML/model-zoo/blob/master/vision/mnist/vae.jl

Load the packages

using Flux, Flux.Data.MNIST, Statistics 
using Flux: throttle, params
using Flux: @epochs, mse
using Plots
using Colors
using Images

Load the data

5000 random training samples and 5000 random test samples

# Load test and training data of size N, 
# and partition training data into mini-batches of M.

X = float.(hcat(vec.(MNIST.images())...)) 
# .> 0.5 # uncomment this for binary data
Y = MNIST.labels()

N, M = 5000, 100

p_train = rand(1:size(X,2),N)
p_test = rand(1:size(X,2),N)

trainData = [(X[:,p_train[i]],Y[p_train[i]]) for i in Iterators.partition(1:N,M)];
testData = [(X[:,p_test[i]],Y[p_test[i]]) for i in Iterators.partition(1:N,M)];
(N,M)
(5000, 100)

Inspect the first five MNIST samples

img(x) = Gray.(reshape(x, 28, 28))
[img(X[:,i]) for i in 1:5]

Display and helper functions

0.1s
Julia 1.1+Flux+Images (Julia)
# Show where the innermost layer maps samples on a 2D plot, 
# coloured by each of the classes in the dataset
# includes a function for computing the centroids for each class

my_norm(v) = sqrt(sum(v.^2))

function accuracy(lm, range, centroids)
    sum([ argmin(mapslices(my_norm,centroids.-lm(X[:,i]).data,dims=1))[2]-1==Y[i] 
        for i in range ]) / length(range)
end

function computeClasses(Y)
    sort(unique(Y))
end

function computeCentroids(lm, data, classes) 
    centroids = zeros((2,length(classes)))
    for k in 1:length(data)
        x = first(data[k])
        y = last(data[k])
        for i in 1:length(classes)
            points = lm(x[:,y.==classes[i]])
            centroids[:,i] += Flux.data.(mapslices(mean,points,dims=2))
        end
    end;
    centroids ./= length(data)
    return centroids
end

function plotLatentSpace(lm,data,classes,width=0)
    centroids = computeCentroids(lm, data, classes)
    colors = distinguishable_colors(12, [RGB(1,1,1)])[3:12]
    if ( width > 0 )
        p = scatter(xlim=(-width,width),ylim=(-width,width))
    else
        p = scatter()
    end
    for lab = 1:length(classes)
        scatter!([centroids[1,lab]],[centroids[2,lab]],
                label=classes[lab],markersize=[9],markercolor=[colors[lab]])
    end
    for k in 1:length(data)
        x = first(data[k])
        y = last(data[k])
        for i in 1:length(classes)
            points = Flux.data(lm(x[:,y.==classes[i]]))
            scatter!(points[1,:],points[2,:],label="",
               markercolor=[colors[i]])
        end
    end
    for lab = 1:length(classes)
        scatter!([centroids[1,lab]],[centroids[2,lab]],
                label="",markersize=[9],markercolor=[colors[lab]])
    end
    scatter!(centroids[1,:],centroids[2,:],label="centroids")
    return p
end
plotLatentSpace (generic function with 2 methods)

Simple Auto-Encoder

# Latent dimensionality, # hidden units.
Dz, Dh = 2, 500

# encoder
g = Chain(Dense(28^2, Dh, tanh), Dense(Dh, Dz))
# decoder
f = Chain(Dense(Dz, Dh, tanh), Dense(Dh, 28^2, σ))
# model
sae = Chain(g,f)
# loss
loss(X,Y) = mse(sae(X), X)
# callback
evalcb = throttle(() -> ( p=rand(1:N, M); @show(loss(X[:, p],Y[p]))), 20) 
# optimization
opt = ADAM()
# parameters
ps = params(sae);
p=rand(1:N, M)
loss(X[:,p],Y[p])
0.23123043770487028 (tracked)

train the SAE

@epochs 100 Flux.train!(loss, params(sae), trainData, opt, cb = evalcb)
classesMNIST = computeClasses(Y)
10-element Array{Int64,1}: 0 1 2 3 4 5 6 7 8 9
centroidsSAE = computeCentroids(g,trainData,classesMNIST)
2×10 Array{Float64,2}: 5.83981 -40.3121 -20.9355 -16.0359 … -6.61001 -19.5566 -4.39321 55.4499 -9.28185 33.9844 35.0697 0.977353 21.3116 5.72353
plotLatentSpace(g,trainData,classesMNIST)
plotLatentSpace(g,testData,classesMNIST)
accuracy(g, p_train, centroidsSAE)
0.5666
accuracy(g, p_test, centroidsSAE)
0.5534
savefig(plotLatentSpace(g,trainData,classesMNIST),"tmp/sae_latent_space_train_data.png")
savefig(plotLatentSpace(g,testData,classesMNIST),"tmp/sae_latent_space_test_data.png")
function my_sampleSAE(centroids) 
    [ f(centroids[:,i]) for i in 1:length(classesMNIST) ]
end

sample = hcat(img.(my_sampleSAE(centroidsSAE))...)
save("tmp/sae_sample.png", sample)

Variational Auto-Encoder

0.1s
Julia 1.1+Flux+Images (Julia)
################################# Define Model #################################

# Latent dimensionality, # hidden units.
Dz, Dh = 2, 500

# Components of the encoder, code now has mean and variance part.
A, μ, logσ = Dense(28^2, Dh, tanh), Dense(Dh, Dz), Dense(Dh, Dz)
g_vae(X) = (h = A(X); (μ(h), logσ(h)))

# Latent space Map of input data
lm_vae(X) = g_vae(X)[1]

# sample from latent space.
z_vae(μ, logσ) = μ + exp(logσ) * randn(Float32)

# now the decoder
f_vae = Chain(Dense(Dz, Dh, tanh), Dense(Dh, 28^2, σ))

# A numerically stable logpdf for Bernoulli distribution close to 1 or 0.
logpdfBernoulli(p, y) = y * log(p + eps(Float32)) + (1f0 - y) * log(1 - p + eps(Float32))

# KL-divergence between approximation posterior and N(c, 1) prior.
kl_q_p(c, μ, logσ) = - 0.5f0 * sum( 1f0 .+ logσ .- (exp.(logσ)).^2 .- (μ-c).^2 )

# conditional probability of data given latents.
logp_x_z(x, z) = sum((f_vae(z) - x).^2) #* ( 1 // M)

# try to encourage all classes to group near the origin
place_vae = function (Y)  
    ret = Array{Float64,2}(undef, 2, length(Y))
    for i in 1:length(Y)
        ret[:,i] = zeros((2,1))
    end
    return ret
end

# loss function to be optimised
loss_vae(X,Y) = ((μ̂, logσ̂) = g_vae(X);  = place_vae(Y);  # @show(ĉ);
    pxz = logp_x_z(X, z_vae.(μ̂, logσ̂)); # @show(pxz);
    klqp = kl_q_p(, μ̂, logσ̂); # @show(klqp);
    ls = ( pxz + klqp) * 1 // M; # @show(ls);
    ls )

# callback, optimiser and paramaters
evalcb_vae = throttle(() -> ( p=rand(1:N, M); @show(loss_vae(X[:, p],Y[p])) ), 20)
opt_vae = ADAM()
ps_vae = params(A, μ, logσ, f_vae);
p=rand(1:N, M)
loss_vae(X[:,p],Y[p])
183.54863578556768 (tracked)

train the VAE

@epochs 100 Flux.train!(loss_vae, ps_vae, trainData, opt_vae, cb = evalcb_vae)
classesMNIST = computeClasses(Y)
plotLatentSpace(lm_vae,trainData,classesMNIST)
plotLatentSpace(lm_vae,testData,classesMNIST)
centroidsVAE = computeCentroids(lm_vae,trainData,classesMNIST)
2×10 Array{Float64,2}: 2.50798 -0.235618 0.662205 2.04968 … -2.66616 1.07081 -1.54038 2.40422 -2.6246 0.141147 -0.690347 -0.166533 -0.400606 0.474194
accuracy(lm_vae, p_train, centroidsVAE)
0.566
accuracy(lm_vae, p_test, centroidsVAE)
0.5356
savefig(plotLatentSpace(lm_vae,trainData,classesMNIST),"tmp/vae_latent_space_train_data.png")
savefig(plotLatentSpace(lm_vae,testData,classesMNIST),"tmp/vae_latent_space_test_data.png")

Semi-Supervised VAE

0.1s
Julia 1.1+Flux+Images (Julia)
################################# Define Model #################################

# Latent dimensionality, # hidden units.
Dz, Dh = 2, 500

# Components of the encoder, code now has mean and variance part.
A, μ, logσ = Dense(28^2, Dh, tanh), Dense(Dh, Dz), Dense(Dh, Dz)
g_ss(X) = (h = A(X); (μ(h), logσ(h)))

# Latent space Map of input data
lm_ss(X) = g_ss(X)[1]

# sample from latent space.
z_ss(μ, logσ) = μ + exp(logσ) * randn(Float32)

# now the decoder
f_ss = Chain(Dense(Dz, Dh, tanh), Dense(Dh, 28^2, σ))

# A numerically stable logpdf for Bernoulli distribution close to 1 or 0.
logpdfBernoulli(p, y) = y * log(p + eps(Float32)) + (1f0 - y) * log(1 - p + eps(Float32))

# KL-divergence between approximation posterior and N(c, 1) prior.
kl_q_p(c, μ, logσ) = - 0.5f0 * sum( 1f0 .+ logσ .- (exp.(logσ)).^2 .- (μ-c).^2 )

# conditional probability of data given latents.
logp_x_z(x, z) = sum((f_ss(z) - x).^2) #* ( 1 // M)

# try to encourage all classes to group near current class centroid
place_ss = function(Y,c)  
    ret = Array{Float64,2}(undef, 2, length(Y))
    for i in 1:length(Y)
        ret[:,i] = c[:,Y[i]+1]
    end
    return ret
end

updateCentroids = function(lm, data, cents)
    new_cents = zeros(size(cents))
    for k in 1:length(data)
        x = first(data[k])
        y = last(data[k])
        for i in 0:9
            μ̂ = lm(x[:,y.==i]);
            new_cents[:,i+1] += Flux.data.(mapslices(mean,μ̂,dims=2))
        end
    end;
    return ( new_cents ./= length(data) )
end

minSeparation = function(cents)
    seps = sort(unique(pairwise(Euclidean(),cents,dims=2)))
    if ( length(seps) > 1 )
        return seps[2]
    else
        return seps[1]
    end
end

centerCentroids = function(cents)
    return (cents .- mapslices(mean,cents,dims=2))
end 

nansum = function(x)
    return sum(x[.!(isnan.(x))])
end


eforce = function(cents,i)
    d = cents .- cents[:,i]
    dd = mapslices(sum,d.^2,dims=1).^(3/2)
    return( mapslices(nansum,  d ./ ( dd) ,dims=2))
end

espread = function(cents,k=1)
    rs = mapslices(my_norm,cents,dims=1)
    r = maximum(rs)
    dc = hcat([eforce(cents,i) for i in 1:10]...)
    return(cents - ((1 .- rs ./ r) .* k) .* dc)
end

# initial centroids for 10 classes
global centroids_ss = zeros((2,10))

# initial "gravity" for strength of attraction to current centroids
global gravity_ss = 1

# loss function to be optimised
loss_ss(X,Y) = ((μ̂, logσ̂) = g_ss(X);  = place_ss(Y,centroids_ss);  # @show(ĉ);
    pxz = logp_x_z(X, z_ss.(μ̂, logσ̂)); # @show(pxz);
    klqp = kl_q_p(, μ̂, logσ̂); # @show(klqp);
    ls = ( pxz + gravity_ss * klqp) * 1 // M; # @show(ls);
    ls )

# callback, optimiser and paramaters
evalcb_ss = throttle(() -> ( p=rand(1:N, M); @show(loss_ss(X[:, p],Y[p])) ), 20)
opt_ss = ADAM()
ps_ss = params(A, μ, logσ, f_ss);
p=rand(1:N, M)
loss_ss(X[:,p],Y[p])
182.02264523243412 (tracked)
minSeparation(centroids_ss)
0.0
# Sample from the learned model.
# modelsample() = rand.(Bernoulli.(f(z.(zeros(Dz), zeros(Dz)))))
modelsample( cents, σ ) = [ f_ss(z_ss.(place_ss(i, cents), σ * ones(Dz))) for i in 0:9 ]
modelsample (generic function with 1 method)
run(`mkdir tmp/animation`)
0.1s
Julia 1.1+Flux+Images (Julia)
function learn_ss(expand = 20, contract = 80, width = 8, minsep=1)
    global gravity_ss
    global centroids_ss
    classesMNIST = computeClasses(Y)
    centroids_ss = updateCentroids(lm_ss,trainData,centroids_ss)
    fname = "tmp/animation/vae_ss_"*string(0,pad=3)*".png"
    savefig(plotLatentSpace(lm_ss,trainData,classesMNIST,width),fname)
    # first expand
    gravity_ss = 1
    for i = 1:expand
        ms = minSeparation(centroids_ss)
        @info "epoch = $i , gravity_ss = $gravity_ss , separation = $ms of $minsep "
        centroids_ss = centerCentroids(centroids_ss)
        Flux.train!(loss_ss, ps_ss, trainData, opt_ss, cb=evalcb_ss)
        centroids_ss = updateCentroids(lm_ss,trainData,centroids_ss)
        @info " accuracy " * string(accuracy(lm_ss, p_train, centroids_ss))
        fname = "tmp/animation/vae_ss_"*string(i,pad=3)*".png"
        savefig(plotLatentSpace(lm_ss,trainData,classesMNIST,width),fname)
    end
    # now contract
     for i = expand+1:expand+contract
        gravity_ss = 1 + i - expand
        ms = minSeparation(centroids_ss)
        @info "epoch = $i , gravity_ss = $gravity_ss , separation = $ms of $minsep "
        # centroids_ss = espread(centroids_ss,10)
        centroids_ss = centerCentroids(centroids_ss)
        centroids_ss = espread(centroids_ss)
        Flux.train!(loss_ss, ps_ss, trainData, opt_ss, cb=evalcb_ss)
        centroids_ss = updateCentroids(lm_ss,trainData,centroids_ss)
        @info " accuracy " * string(accuracy(lm_ss, p_train, centroids_ss))
        fname = "tmp/animation/vae_ss_"*string(i,pad=3)*".png"
        savefig(plotLatentSpace(lm_ss,trainData,classesMNIST,width),fname)
    end
end
learn_ss (generic function with 5 methods)

train the SS-VAE

learn_ss()
accuracy(lm_ss,p_train,centroids_ss)
0.9998
accuracy(lm_ss,p_test,centroids_ss)
0.9058
################################# Sample Output ##############################
sample = hcat(img.(modelsample(centroids_ss, 0.001))...)
save("tmp/vae_ss_sample.png", sample)
classesMNIST = computeClasses(Y)
plotLatentSpace(lm_ss,trainData,classesMNIST)
plotLatentSpace(lm_ss,testData,classesMNIST)
savefig(plotLatentSpace(lm_ss,trainData,classesMNIST),"tmp/vae_ss_encoder_train_data.png")
savefig(plotLatentSpace(lm_ss,testData,classesMNIST),"tmp/vae_ss_encoder_test_data.png")
#Show what points on the innermost 2D plane map to in terms of image representations 
#"Granularity" controls how finely these points are sampled
function plot_decoder(granularity::Int)    
    hcat([vcat(img.([Tracker.data(f_ss(z_ss.([i,1-j], zeros(Dz)))) 
                        for j in -10:20/granularity:10])...) 
            for i in -10:20/granularity:10]...)
end;
save("results/vae_ss_decoder.png",plot_decoder(20))
run(`convert tmp/animation/*.png tmp/ani.gif`)
Process(`convert 'tmp/animation/*.png' tmp/ani.gif`, ProcessExited(0))
run(`ls -al tmp`)
Process(`ls -al tmp`, ProcessExited(0))

display an animation of latent space

run(`cp tmp/ani.gif results/`)
Process(`cp tmp/ani.gif results/`, ProcessExited(0))