mleprovost / Sep 09 2019

Tutorial 1: Lorenz attractor

using Pkg
Pkg.add(["EnKF", "Distributions", "ProgressMeter", "OrdinaryDiffEq", "LinearAlgebra", "DocStringExtensions"])
using EnKF
using Distributions
using DocStringExtensions
using LinearAlgebra
using ProgressMeter
using OrdinaryDiffEq
using Plots

We are interested in simulating the Lorenz attractor

Define parameters of the Lorenz attractor

function lorenz(du,u,p,t)
 du[1] = 10.0*(u[2]-u[1])
 du[2] = u[1]*(28.0-u[3]) - u[2]
 du[3] = u[1]*u[2] - (8/3)*u[3]
end

u0 = [10.0; -5.0; 2.0]
tspan = (0.0,40.0)

Δt = 1e-2
T = tspan[1]:Δt:tspan[end]

prob = ODEProblem(lorenz,u0,tspan)
sol = solve(prob, RK4(), adaptive = false, dt = Δt)

integrator = init(prob, RK4(), adaptive =false, dt = Δt, save_everystep=false)
t: 0.0 u: 3-element Array{Float64,1}: 10.0 -5.0 2.0
plot(sol)

Define propagation function fprop

function (::PropagationFunction)(t::Float64, ENS::EnsembleState{N, TS}) where {N, TS}
    for (i,s) in enumerate(ENS.S)
        
        set_t!(integrator, deepcopy(t))
        set_u!(integrator, deepcopy(s))
        for j=1:50
        step!(integrator)
        end
        ENS.S[i] = deepcopy(integrator.u)

    end
    
    return ENS
end
fprop = PropagationFunction()
PropagationFunction()

Define measurement function m

function (::MeasurementFunction)(t::Float64, s::TS) where TS
    return [s[1]+s[2]+s[3]]
end
function (::MeasurementFunction)(t::Float64) 
    return reshape([1.0, 1.0 , 1.0],(1,3))
end
m = MeasurementFunction()
MeasurementFunction()

Define real measurement function z, always measure the true state but is corrupted by noise ϵ

function (::RealMeasurementFunction)(t::Float64, ENS::EnsembleState{N, TZ}) where {N, TZ}
    let s = sol(t)
    fill!(ENS, [deepcopy(s[1]+s[2]+s[3])])
    end
    return ENS
end
z = RealMeasurementFunction()
RealMeasurementFunction()

Define covariance inflation

A = MultiAdditiveInflation(3, 1.05, MvNormal(zeros(3), 1.0*I))
# A = IdentityInflation()
MultiAdditiveInflation{3}(1.05, IsoNormal( dim: 3 μ: [0.0, 0.0, 0.0] Σ: [1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 1.0] ) )

Define noise covariance

ϵ = AdditiveInflation(MvNormal(zeros(1), 3.0*I))
AdditiveInflation{1}(IsoNormal( dim: 1 μ: [0.0] Σ: [9.0] ) )
N = 10
NZ = 1
isinflated = true
isfiltered = false
isaugmented = false
false
u0
3-element Array{Float64,1}: 10.0 -5.0 2.0
ens = initialize(N, MvNormal([20.0, -10.0, 10.0], 2.0*I))
estimation_state = [deepcopy(ens.S)]

tmp = deepcopy(u0)
true_state = [deepcopy(u0)]
1-element Array{Array{Float64,1},1}: [10.0, -5.0, 2.0]
g = FilteringFunction()
FilteringFunction()
enkf = ENKF(N, NZ, fprop, A, g, m, z, ϵ, isinflated, isfiltered, isaugmented)
ENKF{10,1}(PropagationFunction(), MultiAdditiveInflation{3}(1.05, IsoNormal( dim: 3 μ: [0.0, 0.0, 0.0] Σ: [1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 1.0] ) ), FilteringFunction(), MeasurementFunction(), RealMeasurementFunction(), AdditiveInflation{1}(IsoNormal( dim: 1 μ: [0.0] Σ: [9.0] ) ), true, false, false)

Ensemble Kalman filter estimation

Δt = 1e-2
Tsub = 0.0:50*Δt:40.0-50*Δt

@showprogress for (n,t) in enumerate(Tsub)

    global ens
#     enkf.f(t, ens)
    t, ens,_ = enkf(t, 50*Δt, ens)
    push!(estimation_state, deepcopy(ens.S))
    

end
s =  hcat(sol(T).u...)
 =  hcat(mean.(estimation_state)...)

plt = plot(layout = (3, 1), legend = true)
plot!(plt[1], T, s[1,1:end], linewidth = 2, label = "truth")
scatter!(plt[1], Tsub, [1,1:end-1], linewidth = 2, markersize = 3, label = "EnKF mean", xlabel = "t", ylabel = "x", linestyle =:dash)

plot!(plt[2], T, s[2,1:end], linewidth = 2, label = "truth")
scatter!(plt[2], Tsub, [2,1:end-1], linewidth = 2, markersize = 3, label = "EnKF mean", xlabel = "t", ylabel = "y", linestyle =:dash)

plot!(plt[3], T, s[3,1:end], linewidth = 2, label = "truth")
scatter!(plt[3], Tsub, [3,1:end-1], linewidth = 2, markersize = 3, label = "EnKF mean", xlabel = "t", ylabel = "z", linestyle =:dash)
plot(T, s[1,:], linewidth = 3, label = "truth")
# plot!(Tsub, ŝ[1,1:end-1], linewidth = 3, label = "EnKF mean", xlabel = "t", ylabel = "x", linestyle =:dash)
scatter!(Tsub, [1,1:end-1], linewidth = 3, label = "EnKF mean", xlabel = "t", ylabel = "x", linestyle =:dash)
plot(s[1,:], s[2,:], s[3,:], linewidth = 2, label = "truth", legend = true)
plot!([1,1:end-1], [2,1:end-1], [3,1:end-1], linewidth = 2, label = "EnKF mean", xlabel = "x", 
    ylabel = "y", zlabel ="z", linestyle = :solid)
scatter!([1,:], [2,:], [3,:], linewidth = 2, label = "EnKF mean", xlabel = "x", 
    ylabel = "y", zlabel ="z", linestyle = :solid)