mleprovost / Sep 09 2019

# Tutorial 1: Lorenz attractor

using Pkg
Pkg.add(["EnKF", "Distributions", "ProgressMeter", "OrdinaryDiffEq", "LinearAlgebra", "DocStringExtensions"])
using EnKF
using Distributions
using DocStringExtensions
using LinearAlgebra
using ProgressMeter
using OrdinaryDiffEq
using Plots

We are interested in simulating the Lorenz attractor

Define parameters of the Lorenz attractor

function lorenz(du,u,p,t)
du[1] = 10.0*(u[2]-u[1])
du[2] = u[1]*(28.0-u[3]) - u[2]
du[3] = u[1]*u[2] - (8/3)*u[3]
end

u0 = [10.0; -5.0; 2.0]
tspan = (0.0,40.0)

Δt = 1e-2
T = tspan[1]:Δt:tspan[end]

prob = ODEProblem(lorenz,u0,tspan)
sol = solve(prob, RK4(), adaptive = false, dt = Δt)

integrator = init(prob, RK4(), adaptive =false, dt = Δt, save_everystep=false)
t: 0.0 u: 3-element Array{Float64,1}: 10.0 -5.0 2.0
plot(sol)

Define propagation function fprop

function (::PropagationFunction)(t::Float64, ENS::EnsembleState{N, TS}) where {N, TS}
for (i,s) in enumerate(ENS.S)

set_t!(integrator, deepcopy(t))
set_u!(integrator, deepcopy(s))
for j=1:50
step!(integrator)
end
ENS.S[i] = deepcopy(integrator.u)

end

return ENS
end
fprop = PropagationFunction()
PropagationFunction()

Define measurement function m

function (::MeasurementFunction)(t::Float64, s::TS) where TS
return [s[1]+s[2]+s[3]]
end
function (::MeasurementFunction)(t::Float64)
return reshape([1.0, 1.0 , 1.0],(1,3))
end
m = MeasurementFunction()
MeasurementFunction()

Define real measurement function z, always measure the true state but is corrupted by noise ϵ

function (::RealMeasurementFunction)(t::Float64, ENS::EnsembleState{N, TZ}) where {N, TZ}
let s = sol(t)
fill!(ENS, [deepcopy(s[1]+s[2]+s[3])])
end
return ENS
end
z = RealMeasurementFunction()
RealMeasurementFunction()

Define covariance inflation

A = MultiAdditiveInflation(3, 1.05, MvNormal(zeros(3), 1.0*I))
# A = IdentityInflation()
MultiAdditiveInflation{3}(1.05, IsoNormal( dim: 3 μ: [0.0, 0.0, 0.0] Σ: [1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 1.0] ) )

Define noise covariance

ϵ = AdditiveInflation(MvNormal(zeros(1), 3.0*I))
AdditiveInflation{1}(IsoNormal( dim: 1 μ: [0.0] Σ: [9.0] ) )
N = 10
NZ = 1
isinflated = true
isfiltered = false
isaugmented = false
false
u0
3-element Array{Float64,1}: 10.0 -5.0 2.0
ens = initialize(N, MvNormal([20.0, -10.0, 10.0], 2.0*I))
estimation_state = [deepcopy(ens.S)]

tmp = deepcopy(u0)
true_state = [deepcopy(u0)]
1-element Array{Array{Float64,1},1}: [10.0, -5.0, 2.0]
g = FilteringFunction()
FilteringFunction()
enkf = ENKF(N, NZ, fprop, A, g, m, z, ϵ, isinflated, isfiltered, isaugmented)
ENKF{10,1}(PropagationFunction(), MultiAdditiveInflation{3}(1.05, IsoNormal( dim: 3 μ: [0.0, 0.0, 0.0] Σ: [1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 1.0] ) ), FilteringFunction(), MeasurementFunction(), RealMeasurementFunction(), AdditiveInflation{1}(IsoNormal( dim: 1 μ: [0.0] Σ: [9.0] ) ), true, false, false)

## Ensemble Kalman filter estimation

Δt = 1e-2
Tsub = 0.0:50*Δt:40.0-50*Δt

@showprogress for (n,t) in enumerate(Tsub)

global ens
#     enkf.f(t, ens)
t, ens,_ = enkf(t, 50*Δt, ens)
push!(estimation_state, deepcopy(ens.S))

end
s =  hcat(sol(T).u...)
ŝ =  hcat(mean.(estimation_state)...)

plt = plot(layout = (3, 1), legend = true)
plot!(plt[1], T, s[1,1:end], linewidth = 2, label = "truth")
scatter!(plt[1], Tsub, ŝ[1,1:end-1], linewidth = 2, markersize = 3, label = "EnKF mean", xlabel = "t", ylabel = "x", linestyle =:dash)

plot!(plt[2], T, s[2,1:end], linewidth = 2, label = "truth")
scatter!(plt[2], Tsub, ŝ[2,1:end-1], linewidth = 2, markersize = 3, label = "EnKF mean", xlabel = "t", ylabel = "y", linestyle =:dash)

plot!(plt[3], T, s[3,1:end], linewidth = 2, label = "truth")
scatter!(plt[3], Tsub, ŝ[3,1:end-1], linewidth = 2, markersize = 3, label = "EnKF mean", xlabel = "t", ylabel = "z", linestyle =:dash)
plot(T, s[1,:], linewidth = 3, label = "truth")
# plot!(Tsub, ŝ[1,1:end-1], linewidth = 3, label = "EnKF mean", xlabel = "t", ylabel = "x", linestyle =:dash)
scatter!(Tsub, ŝ[1,1:end-1], linewidth = 3, label = "EnKF mean", xlabel = "t", ylabel = "x", linestyle =:dash)
plot(s[1,:], s[2,:], s[3,:], linewidth = 2, label = "truth", legend = true)
plot!(ŝ[1,1:end-1], ŝ[2,1:end-1], ŝ[3,1:end-1], linewidth = 2, label = "EnKF mean", xlabel = "x",
ylabel = "y", zlabel ="z", linestyle = :solid)
scatter!(ŝ[1,:], ŝ[2,:], ŝ[3,:], linewidth = 2, label = "EnKF mean", xlabel = "x",
ylabel = "y", zlabel ="z", linestyle = :solid)