Hugh Murrell / Aug 15 2019
Chapter 5, Basis Functions
Linear regression with non-linear basis functions
generate some data
# generate some non-linear test data in 1D using Plots, LinearAlgebra regX = rand(100) regY = 5 .+ 1 * sin.(regX .* pi) + 0.02 * randn(100) scatter(regX, regY, fmt = :png, legend=:bottomright, label="data")
# although this data is sinusoidal # we will try to fit a parabola # using basis functions: ϕ(x) = [1, x, x^2] ϕ = hcat(regX.^0, regX.^1, regX.^2) δ = 0.001 θ = inv(ϕ'ϕ + δ^2*I)*ϕ'regY
3-element Array{Float64,1}:
4.971034591029891
4.034932009738952
-4.046018918075333
# check the fit plot!(x->θ[1]+θ[2]*x+θ[3]*x^2,0,1, label="fit")
In two dimensions
# lets try this in two dimensions # first generate some data scattered about a quadratic using Plots regX = 2 .* (rand(1000) .- 0.5) regY = 2 .* (rand(1000) .- 0.5) regZ = 10 .+ ( ( regX.^2 + regY.^2) .+ 0.1 * (randn(1000) .- 0.5) ) scatter(regX, regY, regZ, fmt = :png, legend=:bottomright, label="data")
## we will try to fit a parabola ## using basis functions: ϕ(x) = [1, x, y, x^2, y^2, xy] ϕ = hcat(regX.^0, regX.^1, regY.^1, regX.^2, regY.^2, regX .* regY ) δ = 0.001 θ = inv(ϕ'ϕ + δ^2*I)*ϕ'regZ
6-element Array{Float64,1}:
9.952766596111731
0.0017794404747676351
-0.012882384829761007
0.9953439561468225
1.0083998710280424
-0.019767544018164718
using Plots x=range(-1,stop=1,length=10) y=range(-1,stop=1,length=10) f(x,y) = θ[1] + θ[2]*x +θ[3]*y + θ[4]*x^2 + θ[5]*y^2 +θ[6]*x*y scatter(regX, regY, regZ, label="data", colour="blue", fmt = :png) my_cg = cgrad([:red,:blue]) plot!(x,y,f,st=:wireframe,c=my_cg,alpha=0.1,camera=(-30,30))
Radial Basis functions
# generate some non-linear test data in 1D using Plots regX = rand(50) regY = 5 .+ regX .+ sin.(regX * 3*pi) + 0.2 * randn(50) scatter(regX, regY, fmt = :png, legend=:bottomright, label="data")
rbf(x,μ,λ) = exp.(-(x.-μ).^2 ./ λ) ϕ = regX.^0 λ = 0.1 for μ in 0:0.1:1 ϕ = hcat(ϕ,rbf(regX,μ,λ)) end δ = 0.01 θ = inv(ϕ'ϕ + δ^2*I)*ϕ'regY
12-element Array{Float64,1}:
2.4162079835223267
2.3945809165597893
-2.5404320322995773
1.5689415002052556
4.796198306394217
-0.3950839098833967
-4.321381562767783
0.011785626062192023
3.7691676030517556
1.5557413129427005
-0.10555114402086474
1.3499522144848015
f(x) = θ[1] + sum([θ[i+2] * rbf(x,0.1*i,λ) for i=0:1:10])
f (generic function with 2 methods)
f(0)
5.110309575015851
# check the fit plot!(x -> f(x),0,1, label="fit")