Hugh Murrell / Aug 15 2019

Chapter 5, Basis Functions

Linear regression with non-linear basis functions

generate some data

# generate some non-linear test data in 1D
using Plots, LinearAlgebra
regX = rand(100)
regY = 5 .+ 1 * sin.(regX .*  pi) + 0.02 * randn(100)
scatter(regX, regY, fmt = :png, legend=:bottomright, label="data")
# although this data is sinusoidal
# we will try to fit a parabola
# using basis functions: ϕ(x) = [1, x, x^2]
ϕ = hcat(regX.^0, regX.^1, regX.^2)
δ = 0.001
θ = inv(ϕ'ϕ + δ^2*I)*ϕ'regY
3-element Array{Float64,1}: 4.971034591029891 4.034932009738952 -4.046018918075333
# check the fit
plot!(x->θ[1]+θ[2]*x+θ[3]*x^2,0,1, label="fit")

In two dimensions

# lets try this in two dimensions
# first generate some data scattered about a quadratic
using Plots
regX = 2 .* (rand(1000) .- 0.5)
regY = 2 .* (rand(1000) .- 0.5)
regZ = 10 .+ ( ( regX.^2 + regY.^2) .+ 0.1 * (randn(1000) .- 0.5) )
scatter(regX, regY, regZ, fmt = :png, legend=:bottomright, label="data")
## we will try to fit a parabola
## using basis functions: ϕ(x) = [1, x, y, x^2, y^2, xy]
ϕ = hcat(regX.^0, regX.^1, regY.^1, regX.^2, regY.^2, regX .* regY )
δ = 0.001
θ = inv(ϕ'ϕ + δ^2*I)*ϕ'regZ
6-element Array{Float64,1}: 9.952766596111731 0.0017794404747676351 -0.012882384829761007 0.9953439561468225 1.0083998710280424 -0.019767544018164718
using Plots
x=range(-1,stop=1,length=10)
y=range(-1,stop=1,length=10)
f(x,y) = θ[1] + θ[2]*x +θ[3]*y + θ[4]*x^2 + θ[5]*y^2 +θ[6]*x*y
scatter(regX, regY, regZ, label="data", colour="blue", fmt = :png)
my_cg = cgrad([:red,:blue])
plot!(x,y,f,st=:wireframe,c=my_cg,alpha=0.1,camera=(-30,30))

Radial Basis functions

# generate some non-linear test data in 1D
using Plots
regX = rand(50)
regY = 5 .+ regX .+ sin.(regX *  3*pi) + 0.2 * randn(50)
scatter(regX, regY, fmt = :png, legend=:bottomright, label="data")
rbf(x,μ,λ) = exp.(-(x.-μ).^2 ./ λ)
ϕ = regX.^0
λ = 0.1
for μ in 0:0.1:1
    ϕ = hcat(ϕ,rbf(regX,μ,λ))
end
δ = 0.01
θ = inv(ϕ'ϕ + δ^2*I)*ϕ'regY
12-element Array{Float64,1}: 2.4162079835223267 2.3945809165597893 -2.5404320322995773 1.5689415002052556 4.796198306394217 -0.3950839098833967 -4.321381562767783 0.011785626062192023 3.7691676030517556 1.5557413129427005 -0.10555114402086474 1.3499522144848015
f(x) = θ[1] + sum([θ[i+2] * rbf(x,0.1*i,λ) for i=0:1:10])
f (generic function with 2 methods)
f(0)
5.110309575015851
# check the fit
plot!(x -> f(x),0,1, label="fit")