Hugh Murrell / Aug 15 2019
Chapter 11, Autoencoders
Autoencoders
based on a code by Mike J Innes, found at
https://github.com/FluxML/model-zoo/blob/master/vision/mnist/vae.jl
Load the packages
using Flux, Flux.Data.MNIST, Statistics using Flux: throttle, params using Flux: , mse using Plots using Colors using Images
Load the data
5000 random training samples and 5000 random test samples
# Load test and training data of size N, # and partition training data into mini-batches of M. X = float.(hcat(vec.(MNIST.images())...)) # .> 0.5 # uncomment this for binary data Y = MNIST.labels() N, M = 5000, 100 p_train = rand(1:size(X,2),N) p_test = rand(1:size(X,2),N) trainData = [(X[:,p_train[i]],Y[p_train[i]]) for i in Iterators.partition(1:N,M)]; testData = [(X[:,p_test[i]],Y[p_test[i]]) for i in Iterators.partition(1:N,M)]; (N,M)
(5000, 100)
Inspect the first five MNIST samples
img(x) = Gray.(reshape(x, 28, 28)) [img(X[:,i]) for i in 1:5]
Display and helper functions
0.1s
Julia 1.1+Flux+Images (Julia)
# Show where the innermost layer maps samples on a 2D plot, # coloured by each of the classes in the dataset # includes a function for computing the centroids for each class my_norm(v) = sqrt(sum(v.^2)) function accuracy(lm, range, centroids) sum([ argmin(mapslices(my_norm,centroids.-lm(X[:,i]).data,dims=1))[2]-1==Y[i] for i in range ]) / length(range) end function computeClasses(Y) sort(unique(Y)) end function computeCentroids(lm, data, classes) centroids = zeros((2,length(classes))) for k in 1:length(data) x = first(data[k]) y = last(data[k]) for i in 1:length(classes) points = lm(x[:,y.==classes[i]]) centroids[:,i] += Flux.data.(mapslices(mean,points,dims=2)) end end; centroids ./= length(data) return centroids end function plotLatentSpace(lm,data,classes,width=0) centroids = computeCentroids(lm, data, classes) colors = distinguishable_colors(12, [RGB(1,1,1)])[3:12] if ( width > 0 ) p = scatter(xlim=(-width,width),ylim=(-width,width)) else p = scatter() end for lab = 1:length(classes) scatter!([centroids[1,lab]],[centroids[2,lab]], label=classes[lab],markersize=[9],markercolor=[colors[lab]]) end for k in 1:length(data) x = first(data[k]) y = last(data[k]) for i in 1:length(classes) points = Flux.data(lm(x[:,y.==classes[i]])) scatter!(points[1,:],points[2,:],label="", markercolor=[colors[i]]) end end for lab = 1:length(classes) scatter!([centroids[1,lab]],[centroids[2,lab]], label="",markersize=[9],markercolor=[colors[lab]]) end scatter!(centroids[1,:],centroids[2,:],label="centroids") return p end
plotLatentSpace (generic function with 2 methods)
Simple Auto-Encoder
# Latent dimensionality, # hidden units. Dz, Dh = 2, 500 # encoder g = Chain(Dense(28^2, Dh, tanh), Dense(Dh, Dz)) # decoder f = Chain(Dense(Dz, Dh, tanh), Dense(Dh, 28^2, σ)) # model sae = Chain(g,f) # loss loss(X,Y) = mse(sae(X), X) # callback evalcb = throttle(() -> ( p=rand(1:N, M); (loss(X[:, p],Y[p]))), 20) # optimization opt = ADAM() # parameters ps = params(sae);
p=rand(1:N, M) loss(X[:,p],Y[p])
0.23123043770487028 (tracked)
train the SAE
100 Flux.train!(loss, params(sae), trainData, opt, cb = evalcb)
classesMNIST = computeClasses(Y)
10-element Array{Int64,1}:
0
1
2
3
4
5
6
7
8
9
centroidsSAE = computeCentroids(g,trainData,classesMNIST)
2×10 Array{Float64,2}:
5.83981 -40.3121 -20.9355 -16.0359 … -6.61001 -19.5566 -4.39321
55.4499 -9.28185 33.9844 35.0697 0.977353 21.3116 5.72353
plotLatentSpace(g,trainData,classesMNIST)
plotLatentSpace(g,testData,classesMNIST)
accuracy(g, p_train, centroidsSAE)
0.5666
accuracy(g, p_test, centroidsSAE)
0.5534
savefig(plotLatentSpace(g,trainData,classesMNIST),"tmp/sae_latent_space_train_data.png") savefig(plotLatentSpace(g,testData,classesMNIST),"tmp/sae_latent_space_test_data.png")
function my_sampleSAE(centroids) [ f(centroids[:,i]) for i in 1:length(classesMNIST) ] end sample = hcat(img.(my_sampleSAE(centroidsSAE))...)
save("tmp/sae_sample.png", sample)
Variational Auto-Encoder
0.1s
Julia 1.1+Flux+Images (Julia)
################################# Define Model ################################# # Latent dimensionality, # hidden units. Dz, Dh = 2, 500 # Components of the encoder, code now has mean and variance part. A, μ, logσ = Dense(28^2, Dh, tanh), Dense(Dh, Dz), Dense(Dh, Dz) g_vae(X) = (h = A(X); (μ(h), logσ(h))) # Latent space Map of input data lm_vae(X) = g_vae(X)[1] # sample from latent space. z_vae(μ, logσ) = μ + exp(logσ) * randn(Float32) # now the decoder f_vae = Chain(Dense(Dz, Dh, tanh), Dense(Dh, 28^2, σ)) # A numerically stable logpdf for Bernoulli distribution close to 1 or 0. logpdfBernoulli(p, y) = y * log(p + eps(Float32)) + (1f0 - y) * log(1 - p + eps(Float32)) # KL-divergence between approximation posterior and N(c, 1) prior. kl_q_p(c, μ, logσ) = - 0.5f0 * sum( 1f0 .+ logσ .- (exp.(logσ)).^2 .- (μ-c).^2 ) # conditional probability of data given latents. logp_x_z(x, z) = sum((f_vae(z) - x).^2) #* ( 1 // M) # try to encourage all classes to group near the origin place_vae = function (Y) ret = Array{Float64,2}(undef, 2, length(Y)) for i in 1:length(Y) ret[:,i] = zeros((2,1)) end return ret end # loss function to be optimised loss_vae(X,Y) = ((μ̂, logσ̂) = g_vae(X); ĉ = place_vae(Y); # @show(ĉ); pxz = logp_x_z(X, z_vae.(μ̂, logσ̂)); # @show(pxz); klqp = kl_q_p(ĉ, μ̂, logσ̂); # @show(klqp); ls = ( pxz + klqp) * 1 // M; # @show(ls); ls ) # callback, optimiser and paramaters evalcb_vae = throttle(() -> ( p=rand(1:N, M); (loss_vae(X[:, p],Y[p])) ), 20) opt_vae = ADAM() ps_vae = params(A, μ, logσ, f_vae);
p=rand(1:N, M) loss_vae(X[:,p],Y[p])
183.54863578556768 (tracked)
train the VAE
100 Flux.train!(loss_vae, ps_vae, trainData, opt_vae, cb = evalcb_vae)
classesMNIST = computeClasses(Y) plotLatentSpace(lm_vae,trainData,classesMNIST)
plotLatentSpace(lm_vae,testData,classesMNIST)
centroidsVAE = computeCentroids(lm_vae,trainData,classesMNIST)
2×10 Array{Float64,2}:
2.50798 -0.235618 0.662205 2.04968 … -2.66616 1.07081 -1.54038
2.40422 -2.6246 0.141147 -0.690347 -0.166533 -0.400606 0.474194
accuracy(lm_vae, p_train, centroidsVAE)
0.566
accuracy(lm_vae, p_test, centroidsVAE)
0.5356
savefig(plotLatentSpace(lm_vae,trainData,classesMNIST),"tmp/vae_latent_space_train_data.png") savefig(plotLatentSpace(lm_vae,testData,classesMNIST),"tmp/vae_latent_space_test_data.png")
Semi-Supervised VAE
0.1s
Julia 1.1+Flux+Images (Julia)
################################# Define Model ################################# # Latent dimensionality, # hidden units. Dz, Dh = 2, 500 # Components of the encoder, code now has mean and variance part. A, μ, logσ = Dense(28^2, Dh, tanh), Dense(Dh, Dz), Dense(Dh, Dz) g_ss(X) = (h = A(X); (μ(h), logσ(h))) # Latent space Map of input data lm_ss(X) = g_ss(X)[1] # sample from latent space. z_ss(μ, logσ) = μ + exp(logσ) * randn(Float32) # now the decoder f_ss = Chain(Dense(Dz, Dh, tanh), Dense(Dh, 28^2, σ)) # A numerically stable logpdf for Bernoulli distribution close to 1 or 0. logpdfBernoulli(p, y) = y * log(p + eps(Float32)) + (1f0 - y) * log(1 - p + eps(Float32)) # KL-divergence between approximation posterior and N(c, 1) prior. kl_q_p(c, μ, logσ) = - 0.5f0 * sum( 1f0 .+ logσ .- (exp.(logσ)).^2 .- (μ-c).^2 ) # conditional probability of data given latents. logp_x_z(x, z) = sum((f_ss(z) - x).^2) #* ( 1 // M) # try to encourage all classes to group near current class centroid place_ss = function(Y,c) ret = Array{Float64,2}(undef, 2, length(Y)) for i in 1:length(Y) ret[:,i] = c[:,Y[i]+1] end return ret end updateCentroids = function(lm, data, cents) new_cents = zeros(size(cents)) for k in 1:length(data) x = first(data[k]) y = last(data[k]) for i in 0:9 μ̂ = lm(x[:,y.==i]); new_cents[:,i+1] += Flux.data.(mapslices(mean,μ̂,dims=2)) end end; return ( new_cents ./= length(data) ) end minSeparation = function(cents) seps = sort(unique(pairwise(Euclidean(),cents,dims=2))) if ( length(seps) > 1 ) return seps[2] else return seps[1] end end centerCentroids = function(cents) return (cents .- mapslices(mean,cents,dims=2)) end nansum = function(x) return sum(x[.!(isnan.(x))]) end eforce = function(cents,i) d = cents .- cents[:,i] dd = mapslices(sum,d.^2,dims=1).^(3/2) return( mapslices(nansum, d ./ ( dd) ,dims=2)) end espread = function(cents,k=1) rs = mapslices(my_norm,cents,dims=1) r = maximum(rs) dc = hcat([eforce(cents,i) for i in 1:10]...) return(cents - ((1 .- rs ./ r) .* k) .* dc) end # initial centroids for 10 classes global centroids_ss = zeros((2,10)) # initial "gravity" for strength of attraction to current centroids global gravity_ss = 1 # loss function to be optimised loss_ss(X,Y) = ((μ̂, logσ̂) = g_ss(X); ĉ = place_ss(Y,centroids_ss); # @show(ĉ); pxz = logp_x_z(X, z_ss.(μ̂, logσ̂)); # @show(pxz); klqp = kl_q_p(ĉ, μ̂, logσ̂); # @show(klqp); ls = ( pxz + gravity_ss * klqp) * 1 // M; # @show(ls); ls ) # callback, optimiser and paramaters evalcb_ss = throttle(() -> ( p=rand(1:N, M); (loss_ss(X[:, p],Y[p])) ), 20) opt_ss = ADAM() ps_ss = params(A, μ, logσ, f_ss);
p=rand(1:N, M) loss_ss(X[:,p],Y[p])
182.02264523243412 (tracked)
minSeparation(centroids_ss)
0.0
# Sample from the learned model. # modelsample() = rand.(Bernoulli.(f(z.(zeros(Dz), zeros(Dz))))) modelsample( cents, σ ) = [ f_ss(z_ss.(place_ss(i, cents), σ * ones(Dz))) for i in 0:9 ]
modelsample (generic function with 1 method)
run(`mkdir tmp/animation`)
0.1s
Julia 1.1+Flux+Images (Julia)
function learn_ss(expand = 20, contract = 80, width = 8, minsep=1) global gravity_ss global centroids_ss classesMNIST = computeClasses(Y) centroids_ss = updateCentroids(lm_ss,trainData,centroids_ss) fname = "tmp/animation/vae_ss_"*string(0,pad=3)*".png" savefig(plotLatentSpace(lm_ss,trainData,classesMNIST,width),fname) # first expand gravity_ss = 1 for i = 1:expand ms = minSeparation(centroids_ss) "epoch = $i , gravity_ss = $gravity_ss , separation = $ms of $minsep " centroids_ss = centerCentroids(centroids_ss) Flux.train!(loss_ss, ps_ss, trainData, opt_ss, cb=evalcb_ss) centroids_ss = updateCentroids(lm_ss,trainData,centroids_ss) " accuracy " * string(accuracy(lm_ss, p_train, centroids_ss)) fname = "tmp/animation/vae_ss_"*string(i,pad=3)*".png" savefig(plotLatentSpace(lm_ss,trainData,classesMNIST,width),fname) end # now contract for i = expand+1:expand+contract gravity_ss = 1 + i - expand ms = minSeparation(centroids_ss) "epoch = $i , gravity_ss = $gravity_ss , separation = $ms of $minsep " # centroids_ss = espread(centroids_ss,10) centroids_ss = centerCentroids(centroids_ss) centroids_ss = espread(centroids_ss) Flux.train!(loss_ss, ps_ss, trainData, opt_ss, cb=evalcb_ss) centroids_ss = updateCentroids(lm_ss,trainData,centroids_ss) " accuracy " * string(accuracy(lm_ss, p_train, centroids_ss)) fname = "tmp/animation/vae_ss_"*string(i,pad=3)*".png" savefig(plotLatentSpace(lm_ss,trainData,classesMNIST,width),fname) end end
learn_ss (generic function with 5 methods)
train the SS-VAE
learn_ss()
accuracy(lm_ss,p_train,centroids_ss)
0.9998
accuracy(lm_ss,p_test,centroids_ss)
0.9058
################################# Sample Output ############################## sample = hcat(img.(modelsample(centroids_ss, 0.001))...)
save("tmp/vae_ss_sample.png", sample)
classesMNIST = computeClasses(Y) plotLatentSpace(lm_ss,trainData,classesMNIST)
plotLatentSpace(lm_ss,testData,classesMNIST)
savefig(plotLatentSpace(lm_ss,trainData,classesMNIST),"tmp/vae_ss_encoder_train_data.png") savefig(plotLatentSpace(lm_ss,testData,classesMNIST),"tmp/vae_ss_encoder_test_data.png")
#Show what points on the innermost 2D plane map to in terms of image representations #"Granularity" controls how finely these points are sampled function plot_decoder(granularity::Int) hcat([vcat(img.([Tracker.data(f_ss(z_ss.([i,1-j], zeros(Dz)))) for j in -10:20/granularity:10])...) for i in -10:20/granularity:10]...) end;
save("results/vae_ss_decoder.png",plot_decoder(20))
run(`convert tmp/animation/*.png tmp/ani.gif`)
Process(`convert 'tmp/animation/*.png' tmp/ani.gif`, ProcessExited(0))
run(`ls -al tmp`)
Process(`ls -al tmp`, ProcessExited(0))
display an animation of latent space
run(`cp tmp/ani.gif results/`)
Process(`cp tmp/ani.gif results/`, ProcessExited(0))