Hugh Murrell / Aug 15 2019
Chapter 6, Stochastic Gradient Descent
Stochastic Gradient Descent for optimising a loss function.
Generate some data
# generate 100 data points scattered around a straight line and plot using Plots, LinearAlgebra regX = rand(100) regY = 50 .+ 100 * regX + 2 * randn(100); scatter(regX, regY, fmt = :png, legend=:bottomright, label="data")
Compute an exact solution to the regression problem
# exact solution ϕ = hcat(regX.^0, regX.^1) δ = 0.01 θ = inv(ϕ'ϕ + δ^2*I)*ϕ'regY
2-element Array{Float64,1}:
50.75614231733681
98.02203291532464
Define a loss function
function loss(X,y,θ) cost = (y .- X*θ)'*(y .- X*θ) grad = - 2 * X'y + 2*X'X*θ return (cost, grad) end
loss (generic function with 1 method)
Visualise the loss basin
using Plots xr=range(40,stop=60,length=10) yr=range(90,stop=110,length=10) X = hcat(ones(length(regX)), regX) f(x,y) = loss(X,regY,[x y]')[1][1] # scatter(regX, regY, regZ, label="data", colour="blue", fmt = :png) plot(xr,yr,f,st=:wireframe, fmt = :png, label="loss", colour="blue")
f(50.11183065165801,99.49152488683228)
433.99925702524683
Code to perform SGD
0.4s
Julia
using StatsBase function sgd(loss, X, y, θ_start, η, n, num_iters) #= Arguments: loss -- the function to optimize, it takes a data set X, y and a parameter set, θ and yields two outputs, a cost and the gradient with respect to θ theta0 -- the initial guestimate for θ η -- the learning rate n -- batch size num_iters -- total iterations to run SGD for Returns: theta -- the parameter value after SGD finishes path -- estimates for optimal θ along the way =# data = [(X[i,:],y[i]) for i in 1:length(y)] θ = θ_start path = θ for iter in 1:num_iters s = sample(data,n) Xs = hcat(first.(s)...)' ys = last.(s) cost, grad = loss(Xs,ys,θ) θ = θ .- ( ( η * num_iters/(num_iters+iter) * grad) / n ) path = hcat(path,θ) end return (θ,path) end
sgd (generic function with 1 method)
run the code
X = hcat(ones(length(regX)), regX) y = regY (t,p) = sgd(loss, X, y, [40 90]', 0.1, 1, 500)
([50.9235; 98.2238], [40.0 42.5895 … 50.7868 50.9235; 90.0 91.0247 … 98.1147 98.2238])
visualise the descent in the loss basin
scatter!(p[1,:],p[2,:],f.(p[1,:],p[2,:]),(label="sgd_path"))