Hugh Murrell / Aug 15 2019

Chapter 6, Stochastic Gradient Descent

Stochastic Gradient Descent for optimising a loss function.

Generate some data

# generate 100 data points scattered around a straight line and plot
using Plots, LinearAlgebra
regX = rand(100)
regY = 50 .+ 100 * regX + 2 * randn(100);
scatter(regX, regY, fmt = :png, legend=:bottomright, label="data")

Compute an exact solution to the regression problem

# exact solution
ϕ = hcat(regX.^0, regX.^1)
δ = 0.01
θ = inv(ϕ'ϕ + δ^2*I)*ϕ'regY
2-element Array{Float64,1}: 50.75614231733681 98.02203291532464

Define a loss function

function loss(X,y,θ)
    cost = (y .- X*θ)'*(y .- X*θ)
    grad = - 2 * X'y + 2*X'X*θ
    return (cost, grad)
loss (generic function with 1 method)

Visualise the loss basin

using Plots
X = hcat(ones(length(regX)), regX)
f(x,y) = loss(X,regY,[x y]')[1][1] 
# scatter(regX, regY, regZ, label="data", colour="blue", fmt = :png)
plot(xr,yr,f,st=:wireframe, fmt = :png, label="loss", colour="blue") 

Code to perform SGD

using StatsBase
function sgd(loss, X, y, θ_start, η, n, num_iters)
       loss -- the function to optimize, 
               it takes a data set X, y and
               a parameter set, θ
               and yields two outputs, 
               a cost and the gradient
               with respect to θ
       theta0 -- the initial guestimate for θ
       η -- the learning rate
       n -- batch size
       num_iters -- total iterations to run SGD for
       theta -- the parameter value after SGD finishes
       path -- estimates for optimal θ along the way
    data = [(X[i,:],y[i]) for i in 1:length(y)]
    θ = θ_start
    path = θ
    for iter in 1:num_iters
        s = sample(data,n)
        Xs = hcat(first.(s)...)'
        ys = last.(s)
        cost, grad = loss(Xs,ys,θ)
        θ = θ .- ( ( η * num_iters/(num_iters+iter) * grad) / n )
        path = hcat(path,θ)
    return (θ,path)
sgd (generic function with 1 method)

run the code

X = hcat(ones(length(regX)), regX)
y = regY
(t,p) = sgd(loss, X, y, [40 90]', 0.1, 1, 500)
([50.9235; 98.2238], [40.0 42.5895 … 50.7868 50.9235; 90.0 91.0247 … 98.1147 98.2238])

visualise the descent in the loss basin
