Linear Regression in 3 Forms

Linear regression can be implemented a few different ways that can lead to the trade-off of accuracy versus speed. Rereading this post by Steven G. Johnson on the Julia Discourse site will give you summary

Here is a screenshot if you don't want to link out.

Well and ill conditioned

For reference to what conditioning is and what well and ill condition mean see this wikipedia article. If you have access to the book Numerical Linear Algebra, Trefethen and Bau, you'll find section III useful (also in the screenshot Johnson refers to this book in terms of QR factorization as the basis for how Julia handles A \ b, which can be found in section II).

Some Code

The three forms of linear regression is a good thing to remember if you are implementing your own function. Julia makes this easy. Here are the the three forms.

using LinearAlgebra, BenchmarkTools
0.3s
#set up some data to regress on
N = 10000
xn = rand(N);
Xp = [ones(N) xn];
yp = (10 .+ xn .* 0.3);
(sum(Xp), sum(yp))
0.2s
(15008.1, 1.01502e5)

Fast but not accurate

# normal equations approach
#fastest but least accurate; doubles the number of digits you lose to rounding.
#only use when you know you have well-conditioned matrices
function linreg1(X, y)
    #β_hat
    return (X' * X) \ (X' * y)
end
0.3s
linreg1 (generic function with 1 method)

Not the fastest and not the least accurate

#QR factorization
#slower than linreg1 but uses pivoted QR factorization.
#more accurate for badly conditioned matrices than "normal equations" approach
#it does not square the condition number
function linreg2(X, y)
    #β_hat
    return X \ y
end
0.2s
linreg2 (generic function with 1 method)

Slow but accurate

#SVD approach
#uses SVD to apply psuedo-inverse. the slowest of the linreg* approaches
#most accurate of linreg*
function linreg3(X, y)
    #β_hat
    return pinv(X) * y
end
0.2s
linreg3 (generic function with 1 method)

Benchmarks

println("linreg1: ", linreg1(Xp, yp))
@benchmark linreg1(Xp, yp)
6.8s
BenchmarkTools.Trial: memory estimate: 544 bytes allocs estimate: 6 -------------- minimum time: 41.702 μs (0.00% GC) median time: 42.995 μs (0.00% GC) mean time: 47.114 μs (0.00% GC) maximum time: 3.822 ms (0.00% GC) -------------- samples: 10000 evals/sample: 1
#show that we are not modifying the inputs
(sum(Xp), sum(yp))
0.2s
(15008.1, 1.01502e5)
println("linreg2: ", linreg2(Xp, yp))
@benchmark linreg2(Xp, yp)
8.3s
BenchmarkTools.Trial: memory estimate: 303.69 KiB allocs estimate: 45 -------------- minimum time: 149.918 μs (0.00% GC) median time: 162.451 μs (0.00% GC) mean time: 187.496 μs (4.83% GC) maximum time: 8.394 ms (0.00% GC) -------------- samples: 10000 evals/sample: 1
#show that we are not modifying the inputs
(sum(Xp), sum(yp))
0.2s
(15008.1, 1.01502e5)
println("linreg3: ", linreg3(Xp, yp))
@benchmark linreg3(Xp, yp)
11.2s
BenchmarkTools.Trial: memory estimate: 636.22 KiB allocs estimate: 32 -------------- minimum time: 369.101 μs (0.00% GC) median time: 399.361 μs (0.00% GC) mean time: 461.574 μs (3.83% GC) maximum time: 10.714 ms (0.00% GC) -------------- samples: 10000 evals/sample: 1
#show that we are not modifying the inputs
(sum(Xp), sum(yp))
0.2s
(15008.1, 1.01502e5)

Runtimes (1)