Solving ODEs in Julia
Source: https://github.com/JuliaDiffEq/DiffEqTutorials.jl
Ordinary Differential Equations (ODEs)
This notebook will get you started with DifferentialEquations.jl by introducing you to the functionality for solving ordinary differential equations (ODEs). The corresponding documentation page is the ODE tutorial. While some of the syntax may be different for other types of equations, the same general principles hold in each case. Our goal is to give a gentle and thorough introduction that highlights these principles in a way that will help you generalize what you have learned.
Background
If you are new to the study of differential equations, it can be helpful to do a quick background read on the definition of ordinary differential equations. We define an ordinary differential equation as an equation which describes the way that a variable inline_formula not implemented changes, that is
inline_formula not implemented
where inline_formula not implemented are the parameters of the model, inline_formula not implemented is the time variable, and inline_formula not implemented is the nonlinear model of how inline_formula not implemented changes. The initial value problem also includes the information about the starting value:
inline_formula not implemented
Together, if you know the starting value and you know how the value will change with time, then you know what the value will be at any time point in the future. This is the intuitive definition of a differential equation.
First Model: Exponential Growth
Our first model will be the canonical exponential growth model. This model says that the rate of change is proportional to the current value, and is this:
inline_formula not implemented
where we have a starting value inline_formula not implemented. Let's say we put 1 dollar into Bitcoin which is increasing at a rate of inline_formula not implemented per year. Then calling now inline_formula not implemented and measuring time in years, our model is:
inline_formula not implemented
and inline_formula not implemented. We encode this into Julia by noticing that, in this setup, we match the general form when
f(u,p,t) = 0.98u
with inline_formula not implemented. If we want to solve this model on a time span from t=0.0
to t=1.0
, then we define an ODEProblem
by specifying this function f
, this initial condition u0
, and this time span as follows:
using DifferentialEquations
f(u,p,t) = 0.98u
u0 = 1.0
tspan = (0.0,1.0)
prob = ODEProblem(f,u0,tspan)
To solve our ODEProblem
we use the command solve
.
sol = solve(prob)
and that's it: we have succesfully solved our first ODE!
Analyzing the Solution
Of course, the solution type is not interesting in and of itself. We want to understand the solution! The documentation page which explains in detail the functions for analyzing the solution
is the Solution Handling page. Here we will describe some of the basics. You can plot the solution using the plot recipe provided by Plots.jl:
using Plots; gr()
plot(sol)
From the picture we see that the solution is an exponential curve, which matches our intuition. As a plot recipe, we can annotate the result using any of the Plots.jl attributes. For example:
plot(sol, linewidth=5,
title="Solution to the linear ODE with a thick line",
xaxis="Time (t)", yaxis="u(t) (in μm)",
label="My Thick Line!") # legend=false
Using the mutating plot!
command we can add other pieces to our plot. For this ODE we know that the true solution is inline_formula not implemented, so let's add some of the true solution to our plot:
plot!(sol.t, t->1.0*exp(0.98t), lw=3, ls=:dash, label="True Solution!")
In the previous command I demonstrated sol.t
, which grabs the array of time points that the solution was saved at:
sol.t
We can get the array of solution values using sol.u
:
sol.u
sol.u[i]
is the value of the solution at time sol.t[i]
. We can compute arrays of functions of the solution values using standard comprehensions, like:
[t+u for (u,t) in tuples(sol)]
However, one interesting feature is that, by default, the solution is a continuous function. If we check the print out again:
sol
you see that it says that the solution has a order changing interpolation. The default algorithm automatically switches between methods in order to handle all types of problems. For non-stiff equations (like the one we are solving), it is a continuous function of 4th order accuracy. We can call the solution as a function of time sol(t)
. For example, to get the value at t=0.45
, we can use the command:
sol(0.45)
Controlling the Solver
DifferentialEquations.jl has a common set of solver controls among its algorithms which can be found at the Common Solver Options page. We will detail some of the most widely used options.
Tolerances
The most useful options are the tolerances abstol
and reltol
. These tell the internal adaptive time stepping engine how precise of a solution you want. Generally, reltol
is the relative accuracy while abstol
is the accuracy when u
is near zero. These tolerances are local tolerances and thus are not global guarantees. However, a good rule of thumb is that the total solution accuracy is 1-2 digits less than the relative tolerances. Thus for the defaults abstol=1e-6
and reltol=1e-3
, you can expect a global accuracy of about 1-2 digits. If we want to get around 6 digits of accuracy, we can use the commands:
sol = solve(prob,abstol=1e-8,reltol=1e-8)
Now we can see no visible difference against the true solution:
plot(sol)
plot!(sol.t, t->1.0*exp(0.98t),lw=3,ls=:dash,label="True Solution!")
Notice that by decreasing the tolerance, the number of steps the solver had to take was 9
instead of the previous 5
. There is a trade off between accuracy and speed, and it is up to you to determine what is the right balance for your problem.
Saveat
Another common option is to use saveat
to make the solver save at specific time points. For example, if we want the solution at an even grid of 0.1 time unit, we would use the command:
sol = solve(prob, saveat=0.1)
Notice that when saveat
is used the continuous output variables are no longer saved and thus sol(t)
, the interpolation, is only first order. We can save at an uneven grid of points by passing a collection of values to saveat
. For example:
sol = solve(prob,saveat=[0.2,0.7,0.9])
If we need to reduce the amount of saving, we can also turn off the continuous output directly via dense=false
:
sol = solve(prob,dense=false)
and to turn off all intermediate saving we can use save_everystep=false
:
sol = solve(prob,save_everystep=false)
If we want to solve and only save the final value, we can even set save_start=false
.
sol = solve(prob,save_everystep=false,save_start = false)
Note that similarly on the other side there is save_end=false
.
More advanced saving behaviors, such as saving functionals of the solution, are handled via the SavingCallback
in the Callback Library which will be addressed later in the tutorial.
Choosing Solver Algorithms
There is no best algorithm for numerically solving a differential equation. When you call solve(prob)
, DifferentialEquations.jl makes a guess at a good algorithm for your problem, given the properties that you ask for (the tolerances, the saving information, etc.). However, in many cases you may want more direct control. A later notebook will help introduce the various algorithms in DifferentialEquations.jl, but for now let's introduce the syntax.
The most crucial determining factor in choosing a numerical method is the stiffness of the model. Stiffness is roughly characterized by a Jacobian f
with large eigenvalues. That's quite mathematical, and we can think of it more intuitively: if you have big numbers in f
(like parameters of order 1e5
), then it's probably stiff. Or, as the creator of the MATLAB ODE Suite, Lawrence Shampine, likes to define it, if the standard algorithms are slow, then it's stiff. We will go into more depth about diagnosing stiffness in a later tutorial, but for now note that if you believe your model may be stiff, you can hint this to the algorithm chooser via alg_hints = [:stiff]
.
sol = solve(prob,alg_hints=[:stiff])
Stiff algorithms have to solve implicit equations and linear systems at each step so they should only be used when required.
If we want to choose an algorithm directly, you can pass the algorithm type after the problem as solve(prob,alg)
. For example, let's solve this problem using the Tsit5()
algorithm, and just for show let's change the relative tolerance to 1e-6
at the same time:
sol = solve(prob,Tsit5(),reltol=1e-6)
Systems of ODEs: The Lorenz Equation
Now let's move to a system of ODEs. The Lorenz equation is the famous "butterfly attractor" that spawned chaos theory. It is defined by the system of ODEs:
inline_formula not implemented
To define a system of differential equations in DifferentialEquations.jl, we define our f
as a vector function with a vector initial condition. Thus, for the vector u = [x,y,z]'
, we have the derivative function:
function lorenz!(du,u,p,t)
σ,ρ,β = p
du[1] = σ*(u[2]-u[1])
du[2] = u[1]*(ρ-u[3]) - u[2]
du[3] = u[1]*u[2] - β*u[3]
end
Notice here we used the in-place format which writes the output to the preallocated vector du
. For systems of equations the in-place format is faster. We use the initial condition inline_formula not implemented as follows:
u0 = [1.0,0.0,0.0]
Lastly, for this model we made use of the parameters p
. We need to set this value in the ODEProblem
as well. For our model we want to solve using the parameters inline_formula not implemented, inline_formula not implemented, and inline_formula not implemented, and thus we build the parameter collection:
p = (10,28,8/3) # we could also make this an array, or any other sequence type!
Now we generate the ODEProblem
type. In this case, since we have parameters, we add the parameter values to the end of the constructor call. Let's solve this on a time span of t=0
to t=100
:
tspan = (0.0,100.0)
prob = ODEProblem(lorenz!,u0,tspan,p)
Now, just as before, we solve the problem:
sol = solve(prob)
The same solution handling features apply to this case. Thus sol.t
stores the time points and sol.u
is an array storing the solution at the corresponding time points.
However, there are a few extra features which are good to know when dealing with systems of equations. First of all, sol
also acts like an array. sol[i]
returns the solution at the i
th time point.
println(sol.t[10], ",", sol[10])
Additionally, the solution acts like a matrix where sol[j,i]
is the value of the j
th variable at time i
:
sol[2,10]
We can get a real matrix by performing a conversion:
A = Array(sol)
This is the same as sol, i.e. sol[i,j] = A[i,j]
, but now it's a true matrix. Plotting will by default show the time series for each variable:
plot(sol)
If we instead want to plot values against each other, we can use the vars
command. Let's plot variable 1
against variable 2
against variable 3
:
plot(sol,vars=(1,2,3), lw=1)
This is the classic Lorenz attractor plot, where the x
axis is u[1]
, the y
axis is u[2]
, and the z
axis is u[3]
. Note that the plot recipe by default uses the interpolation, but we can turn this off:
plot(sol,vars=(1,2,3),denseplot=false)
Yikes! This shows how calculating the continuous solution has saved a lot of computational effort by computing only a sparse solution and filling in the values! Note that in vars, 0=time
, and thus we can plot the time series of a single component like:
plot(sol,vars=(0,2))
ModelingToolkit: A DSL for DE systems
Note: ParameterizedFunctions.jl is being replaced by ModelingToolkit.jl
In many cases you may be defining a lot of functions with parameters. There exists the domain-specific language (DSL) defined by the ModelingToolkit package for helping with this common problem. For example, we can define the Lorenz equation above as follows:
using ModelingToolkit
# Define some variables
# Same as t, σ, ρ, β = [Variable(s)(t; known=true) for s in [:t, :σ, :ρ, :β]]
t σ ρ β
# Same as x, y, z = [Variable(s)(t) for s in [:x, :y, :z]]
x(t) y(t) z(t)
# Same as D = Differential(t)
D'~t
# Define equations (system of ODEs)
eqs = [D(x) ~ σ*(y-x),
D(y) ~ x*(ρ-z)-y,
D(z) ~ x*y - β*z]
Each operation builds an Operation
type, and thus eqs
is an array of Operation
and Variable
s. This holds a tree of the full system that can be analyzed by other programs. We can turn this into an ODESystem
and a ODEFunction
via:
de = ODESystem(eqs)
lorenz = ODEFunction(de, [x,y,z], [σ,ρ,β])
We can then use the result just like an ODE function from before:
u0 = [1.0,0.0,0.0]
p = (10,28,8/3)
tspan = (0.0,100.0)
prob = ODEProblem(lorenz,u0,tspan,p)
sol = solve(prob)
plot(sol, vars=(1, 2, 3), lw=1)
We can see the generated low-level code via generate_function
to see how much behind-the-scene work is done by this toolkit.
generate_function(de, [x,y,z], [σ,ρ,β])
Not only is the DSL convenient syntax, but it does some magic behind the scenes. For example, further parts of the tutorial will describe how solvers for stiff differential equations have to make use of the Jacobian in calculations. Here, the DSL uses symbolic differentiation to automatically derive that function:
generate_jacobian(de)
The DSL can derive many other functions; this ability is used to speed up the solvers. An extension to DifferentialEquations.jl, Latexify.jl, allows you to extract these pieces as LaTeX expressions.
Internal Types
The last basic user-interface feature to explore is the choice of types. DifferentialEquations.jl respects your input types to determine the internal types that are used. Thus since in the previous cases, when we used Float64
values for the initial condition, this meant that the internal values would be solved using Float64
. We made sure that time was specified via Float64
values, meaning that time steps would utilize 64-bit floats as well. But, by simply changing these types we can change what is used internally.
As a quick example, let's say we want to solve an ODE defined by a matrix. To do this, we can simply use a matrix as input.
A = [1. 0 0 -5
4 -2 4 -3
-4 0 0 1
5 -2 2 3]
u0 = rand(4,2)
tspan = (0.0,1.0)
f(u,p,t) = A*u
prob = ODEProblem(f,u0,tspan)
sol = solve(prob)
There is no real difference from what we did before, but now in this case u0
is a 4x2
matrix. Because of that, the solution at each time point is matrix:
sol[3]
In DifferentialEquations.jl, you can use any type that defines +
, -
, *
, /
, and has an appropriate norm
. For example, if we want arbitrary precision floating point numbers, we can change the input to be a matrix of BigFloat
:
big_u0 = big.(u0)
and we can solve the ODEProblem
with arbitrary precision numbers by using that initial condition:
prob = ODEProblem(f,big_u0,tspan)
sol = solve(prob)
sol[1,3]
To really make use of this, we would want to change abstol
and reltol
to be small! Notice that the type for "time" is different than the type for the dependent variables, and this can be used to optimize the algorithm via keeping multiple precisions. We can convert time to be arbitrary precision as well by defining our time span with BigFloat
variables:
prob = ODEProblem(f,big_u0,big.(tspan))
sol = solve(prob, atol=big(1e-12), rtol=big(1e-12))
Let's end by showing a more complicated use of types. For small arrays, it's usually faster to do operations on static arrays via the package StaticArrays.jl. The syntax is similar to that of normal arrays, but for these special arrays we utilize the @SMatrix
macro to indicate we want to create a static array.
using StaticArrays
A = [ 1.0 0.0 0.0 -5.0
4.0 -2.0 4.0 -3.0
-4.0 0.0 0.0 1.0
5.0 -2.0 2.0 3.0]
u0 = rand(4,2)
tspan = (0.0,1.0)
f(u,p,t) = A*u
prob = ODEProblem(f,u0,tspan)
sol = solve(prob)
sol[3]
Conclusion
These are the basic controls in DifferentialEquations.jl. All equations are defined via a problem type, and the solve
command is used with an algorithm choice (or the default) to get a solution. Every solution acts the same, like an array sol[i]
with sol.t[i]
, and also like a continuous function sol(t)
with a nice plot command plot(sol)
. The Common Solver Options can be used to control the solver for any equation type. Lastly, the types used in the numerical solving are determined by the input types, and this can be used to solve with arbitrary precision and add additional optimizations (this can be used to solve via GPUs for example!). While this was shown on ODEs, these techniques generalize to other types of equations as well.
Choosing an ODE Algorithm
While the default algorithms, along with alg_hints = [:stiff]
, will suffice in most cases, there are times when you may need to exert more control. The purpose of this part of the tutorial is to introduce you to some of the most widely used algorithm choices and when they should be used. The corresponding page of the documentation is the ODE Solvers page which goes into more depth.
Diagnosing Stiffness
One of the key things to know for algorithm choices is whether your problem is stiff. Let's take for example the driven Van Der Pol equation
using DifferentialEquations, ModelingToolkit
t μ
x(t) y(t)
D'~t
# Define equations (system of ODEs)
eqs = [D(x) ~ 1*y,
D(y) ~ μ*((1-x^2)*y - x)]
# Convert it to a proper function
de = ODESystem(eqs)
van = ODEFunction(de, [x,y], [μ])
prob = ODEProblem(van,[0.0,2.0],(0.0,6.3),1e6)
One indicating factor that should alert you to the fact that this model may be stiff is the fact that the parameter is 1e6
: large parameters generally mean stiff models. If we try to solve this with the default method:
sol = solve(prob,Tsit5());
Here it shows that maximum iterations were reached. Another thing that can happen is that the solution can return that the solver was unstable (exploded to infinity) or that dt
became too small. If these happen, the first thing to do is to check that your model is correct. It could very well be that you made an error that causes the model to be unstable!
If the model is the problem, then stiffness could be the reason. We can thus hint to the solver to use an appropriate method:
sol = solve(prob, alg_hints=[:stiff])
The Recommended Methods
When picking a method, the general rules are as follows:
Higher order is more efficient at lower tolerances, lower order is more efficient at higher tolerances.
Adaptivity is essential in most real-world scenarios.
Runge-Kutta methods do well with non-stiff equations, Rosenbrock methods do well with small stiff equations, BDF methods do well with large stiff equations.
While there are always exceptions to the rule, those are good guiding principles. Based on those, a simple way to choose methods is:
The default is
Tsit5()
, a non-stiff Runge-Kutta method of Order 5If you use low tolerances (
1e-8
), tryVern7()
orVern9()
If you use high tolerances, try
BS3()
If the problem is stiff, try
Rosenbrock23()
,Rodas5()
, orCVODE_BDF()
If you don't know, use
AutoTsit5(Rosenbrock23())
orAutoVern9(Rodas5())
.
Comparison to other Software
If you are familiar with MATLAB, SciPy, or R's DESolve, here's a quick translation start to have transfer your knowledge over.
ode23
->BS3()
ode45
/dopri5
->DP5()
, though in most casesTsit5()
is more efficientode23s
->Rosenbrock23()
, though in most casesRodas4()
is more efficientode113
->VCABM()
, though in many casesVern7()
is more efficientdop853
->DP8()
, though in most casesVern7()
is more efficientode15s
/vode
->QNDF()
, though in many casesCVODE_BDF()
,Rodas4()
orradau()
are more efficientode23t
->Trapezoid()
for efficiency andGenericTrapezoid()
for robustnessode23tb
->TRBDF2
lsoda
->lsoda()
(requires]add LSODA; using LSODA
)ode15i
->IDA()
, though in many casesRodas4()
can handle the DAE and is significantly more efficient
Optimizing DiffEq Code
In this section we will walk through some of the main tools for optimizing your code in order to efficiently solve DifferentialEquations.jl. User-side optimizations are important because, for sufficiently difficult problems, most of the time will be spent inside of your f
function, the function you are trying to solve. "Efficient" integrators are those that reduce the required number of f
calls to hit the error tolerance. The main ideas for optimizing your DiffEq code, or any Julia function, are the following:
Make it non-allocating
Use StaticArrays for small arrays
Use broadcast fusion
Make it type-stable
Reduce redundant calculations
Make use of BLAS calls
Optimize algorithm choice
Optimizing Small Systems (<100 DEs)¶
Let's take the classic Lorenz system from before. Let's start by naively writing the system in its out-of-place form:
function lorenz2(u,p,t)
dx = 10.0*(u[2]-u[1])
dy = u[1]*(28.0-u[3]) - u[2]
dz = u[1]*u[2] - (8/3)*u[3]
[dx,dy,dz]
end
Here, lorenz
returns an object, [dx,dy,dz]
, which is created within the body of lorenz
.
This is a common code pattern from high-level languages like MATLAB, SciPy, or R's deSolve. However, the issue with this form is that it allocates a vector, [dx,dy,dz]
, at each step. Let's benchmark the solution process with this choice of function:
using DifferentialEquations, BenchmarkTools
u0 = [1.0;0.0;0.0]
tspan = (0.0,100.0)
prob = ODEProblem(lorenz2,u0,tspan)
solve(prob,Tsit5())
The BenchmarkTools package's @benchmark
runs the code multiple times to get an accurate measurement. The minimum time is the time it takes when your OS and other background processes aren't getting in the way. Notice that in this case it takes about 5ms to solve and allocates around 11.11 MiB. However, if we were to use this inside of a real user code we'd see a lot of time spent doing garbage collection (GC) to clean up all of the arrays we made. Even if we turn off saving we have these allocations.
solve(prob,Tsit5(),save_everystep=false)
The problem of course is that arrays are created every time our derivative function is called. This function is called multiple times per step and is thus the main source of memory usage. To fix this, we can use the in-place form to make our code non-allocating:
function lorenz2!(du,u,p,t)
du[1] = 10.0*(u[2]-u[1])
du[2] = u[1]*(28.0-u[3]) - u[2]
du[3] = u[1]*u[2] - (8/3)*u[3]
end
Here, instead of creating an array each time, we utilized the cache array du
. When the inplace form is used, DifferentialEquations.jl takes a different internal route that minimizes the internal allocations as well. When we benchmark this function, we will see quite a difference.
u0 = [1.0;0.0;0.0]
tspan = (0.0,100.0)
prob = ODEProblem(lorenz2!,u0,tspan)
solve(prob,Tsit5())
solve(prob,Tsit5(),save_everystep=false)
Notice there are still some allocations and this is due to the construction of the integration cache. But this doesn't scale with the problem size since that's all just setup allocations.
tspan = (0.0,500.0) # 5x longer than before
prob = ODEProblem(lorenz2!,u0,tspan)
solve(prob,Tsit5(),save_everystep=false)
But if the system is small we can optimize even more.
Allocations are only expensive if they are "heap allocations". For a more in-depth definition of heap allocations, there are a lot of sources online. But a good working definition is that heap allocations are variable-sized slabs of memory which have to be pointed to, and this pointer indirection costs time. Additionally, the heap has to be managed and the garbage controllers has to actively keep track of what's on the heap.
However, there's an alternative to heap allocations, known as stack allocations. The stack is statically-sized (known at compile time) and thus its accesses are quick. Additionally, the exact block of memory is known in advance by the compiler, and thus re-using the memory is cheap. This means that allocating on the stack has essentially no cost!
Arrays have to be heap allocated because their size (and thus the amount of memory they take up) is determined at runtime. But there are structures in Julia which are stack-allocated. struct
s for example are stack-allocated "value-type"s. Tuple
s are a stack-allocated collection. The most useful data structure for DiffEq though is the StaticArray
from the package StaticArrays.jl. These arrays have their length determined at compile-time. They are created using macros attached to normal array expressions, for example:
using StaticArrays
A = [2.0,3.0,5.0]
Notice that the 3
after SVector
gives the size of the SVector
. It cannot be changed. Additionally, SVector
s are immutable, so we have to create a new SVector
to change values. But remember, we don't have to worry about allocations because this data structure is stack-allocated. SArray
s have a lot of extra optimizations as well: they have fast matrix multiplication, fast QR factorizations, etc. which directly make use of the information about the size of the array. Thus, when possible they should be used.
Unfortunately static arrays can only be used for sufficiently small arrays. After a certain size, they are forced to heap allocate after some instructions and their compile time balloons. Thus static arrays shouldn't be used if your system has more than 100 variables. Additionally, only the native Julia algorithms can fully utilize static arrays.
Let's optimize lorenz
using static arrays. Note that in this case, we want to use the out-of-place allocating form, but this time we want to output a static array:
function lorenz_static(u,p,t)
dx = 10.0*(u[2]-u[1])
dy = u[1]*(28.0-u[3]) - u[2]
dz = u[1]*u[2] - (8/3)*u[3]
dx,dy,dz] [
end
To make the solver internally use static arrays, we simply give it a static array as the initial condition.
u0 = [1.0,0.0,0.0]
tspan = (0.0,100.0)
prob = ODEProblem(lorenz_static,u0,tspan)
solve(prob,Tsit5())
solve(prob,Tsit5(),save_everystep=false)
And that's pretty much all there is to it. With static arrays you don't have to worry about allocating, so use operations like *
and don't worry about fusing operations (discussed in the next section). Do "the vectorized code" of R/MATLAB/Python and your code in this case will be fast, or directly use the numbers/values.
Optimizing Large Systems
Interlude: Managing Allocations with Broadcast Fusion
When your system is sufficiently large, or you have to make use of a non-native Julia algorithm, you have to make use of Array
s. In order to use arrays in the most efficient manner, you need to be careful about temporary allocations. Vectorized calculations naturally have plenty of temporary array allocations. This is because a vectorized calculation outputs a vector. Thus:
A = rand(1000,1000); B = rand(1000,1000); C = rand(1000,1000)
test(A,B,C) = A + B + C
test(A,B,C)
That expression A + B + C
creates 2 arrays. It first creates one for the output of A + B
, then uses that result array to + C
to get the final result. 2 arrays! We don't want that! The first thing to do to fix this is to use broadcast fusion. Broadcast fusion puts expressions together. For example, instead of doing the +
operations separately, if we were to add them all at the same time, then we would only have a single array that's created. For example:
test2(A,B,C) = map((a,b,c)->a+b+c,A,B,C)
test2(A,B,C)
Puts the whole expression into a single function call, and thus only one array is required to store output. This is the same as writing the loop:
function test3(A,B,C)
D = similar(A)
for i in eachindex(A)
D[i] = A[i] + B[i] + C[i]
end
D
end
test3(A,B,C)
However, Julia's broadcast is syntactic sugar for this. If multiple expressions have a .
, then it will put those vectorized operations together. Thus:
test4(A,B,C) = A .+ B .+ C
test4(A,B,C)
is a version with only 1 array created (the output). Note that .
s can be used with function calls as well
sin.(A) .+ sin.(B)
Also, the @.
macro applys a dot to every operator:In [ ]:
test5(A,B,C) = @. A + B + C #only one array allocated
test5(A,B,C)
Using these tools we can get rid of our intermediate array allocations for many vectorized function calls. But we are still allocating the output array. To get rid of that allocation, we can instead use mutation. Mutating broadcast is done via .=
. For example, if we pre-allocate the output:
D = zeros(1000,1000);
Then we can keep re-using this cache for subsequent calculations. The mutating broadcasting form is:
test6!(D,A,B,C) = D .= A .+ B .+ C #only one array allocated
test6!(D,A,B,C)
If we use @.
before the =
, then it will turn it into .=
:In [ ]:
test7!(D,A,B,C) = @. D = A + B + C #only one array allocated
test7!(D,A,B,C)
Notice that in this case, there is no "output", and instead the values inside of D
are what are changed (like with the DiffEq inplace function). Many Julia functions have a mutating form which is denoted with a !
. For example, the mutating form of the map
is map!
:
test8!(D,A,B,C) = map!((a,b,c)->a+b+c,D,A,B,C)
test8!(D,A,B,C)
Some operations require using an alternate mutating form in order to be fast. For example, matrix multiplication via *
allocates a temporary
A*B
Instead, we can use the mutating form mul!
into a cache array to avoid allocating the output:
using LinearAlgebra
mul!(D,A,B) # same as D = A * B
For repeated calculations this reduced allocation can stop GC cycles and thus lead to more efficient code. Additionally, we can fuse together higher level linear algebra operations using BLAS. The package SugarBLAS.jl makes it easy to write higher level operations like alpha*B*A + beta*C
as mutating BLAS calls.
Example Optimization: Gierer-Meinhardt Reaction-Diffusion PDE Discretization
Let's optimize the solution of a Reaction-Diffusion PDE's discretization. In its discretized form, this is the ODE:
formula not implementedwhere u, v, and A are matrices. Here, we will use the simplified version where A is the tridiagonal stencil [1,−2,1], i.e. it's the 2D discretization of the Laplacian. The native code would be something along the lines of:
# Generate the constants
p = (1.0,1.0,1.0,10.0,0.001,100.0) # a,α,ubar,β,D1,D2
N = 100
Ax = Array(Tridiagonal([1.0 for i in 1:N-1],[-2.0 for i in 1:N],[1.0 for i in 1:N-1]))
Ay = copy(Ax)
Ax[2,1] = 2.0
Ax[end-1,end] = 2.0
Ay[1,2] = 2.0
Ay[end,end-1] = 2.0
function basic_version!(dr,r,p,t)
a,α,ubar,β,D1,D2 = p
u = r[:,:,1]
v = r[:,:,2]
Du = D1*(Ay*u + u*Ax)
Dv = D2*(Ay*v + v*Ax)
dr[:,:,1] = Du .+ a.*u.*u./v .+ ubar .- α*u
dr[:,:,2] = Dv .+ a.*u.*u .- β*v
end
a,α,ubar,β,D1,D2 = p
uss = (ubar+β)/α
vss = (a/β)*uss^2
r0 = zeros(100,100,2)
r0[:,:,1] .= uss.+0.1.*rand.()
r0[:,:,2] .= vss
prob = ODEProblem(basic_version!,r0,(0.0,0.1),p)
In this version we have encoded our initial condition to be a 3-dimensional array, with u[:,:,1]
being the A
part and u[:,:,2]
being the B
part.
solve(prob,Tsit5())
While this version isn't very efficient,
We recommend writing the "high-level" code first, and iteratively optimizing it!
The first thing that we can do is get rid of the slicing allocations. The operation r[:,:,1]
creates a temporary array instead of a "view", i.e. a pointer to the already existing memory. To make it a view, add @view
. Note that we have to be careful with views because they point to the same memory, and thus changing a view changes the original values:
A = rand(4)
print(A)
B = A[1:3]
B[2] = 2
print(A)
Notice that changing B
changed A
. This is something to be careful of, but at the same time we want to use this since we want to modify the output dr
. Additionally, the last statement is a purely element-wise operation, and thus we can make use of broadcast fusion there. Let's rewrite basic_version!
to avoid slicing allocations and to use broadcast fusion:
function gm2!(dr,r,p,t)
a,α,ubar,β,D1,D2 = p
u = r[:,:,1]
v = r[:,:,2]
du = dr[:,:,1]
dv = dr[:,:,2]
Du = D1*(Ay*u + u*Ax)
Dv = D2*(Ay*v + v*Ax)
@. du = Du + a.*u.*u./v + ubar - α*u
@. dv = Dv + a.*u.*u - β*v
end
prob = ODEProblem(gm2!,r0,(0.0,0.1),p)
solve(prob,Tsit5())
Now, most of the allocations are taking place in Du = D1*(Ay*u + u*Ax)
since those operations are vectorized and not mutating. We should instead replace the matrix multiplications with mul!
. When doing so, we will need to have cache variables to write into. This looks like:
Ayu = zeros(N,N)
uAx = zeros(N,N)
Du = zeros(N,N)
Ayv = zeros(N,N)
vAx = zeros(N,N)
Dv = zeros(N,N)
function gm3!(dr,r,p,t)
a,α,ubar,β,D1,D2 = p
u = r[:,:,1]
v = r[:,:,2]
du = dr[:,:,1]
dv = dr[:,:,2]
mul!(Ayu,Ay,u)
mul!(uAx,u,Ax)
mul!(Ayv,Ay,v)
mul!(vAx,v,Ax)
@. Du = D1*(Ayu + uAx)
@. Dv = D2*(Ayv + vAx)
@. du = Du + a*u*u./v + ubar - α*u
@. dv = Dv + a*u*u - β*v
end
prob = ODEProblem(gm3!,r0,(0.0,0.1),p)
solve(prob,Tsit5())
But our temporary variables are global variables. We need to either declare the caches as const
or localize them. We can localize them by adding them to the parameters, p
. It's easier for the compiler to reason about local variables than global variables. Localizing variables helps to ensure type stability.
p = (1.0,1.0,1.0,10.0,0.001,100.0,Ayu,uAx,Du,Ayv,vAx,Dv) # a,α,ubar,β,D1,D2
function gm4!(dr,r,p,t)
a,α,ubar,β,D1,D2,Ayu,uAx,Du,Ayv,vAx,Dv = p
u = r[:,:,1]
v = r[:,:,2]
du = dr[:,:,1]
dv = dr[:,:,2]
mul!(Ayu,Ay,u)
mul!(uAx,u,Ax)
mul!(Ayv,Ay,v)
mul!(vAx,v,Ax)
@. Du = D1*(Ayu + uAx)
@. Dv = D2*(Ayv + vAx)
@. du = Du + a*u*u./v + ubar - α*u
@. dv = Dv + a*u*u - β*v
end
prob = ODEProblem(gm4!,r0,(0.0,0.1),p)
solve(prob,Tsit5())
We could then use the BLAS gemmv
to optimize the matrix multiplications some more, but instead let's devectorize the stencil.
p = (1.0,1.0,1.0,10.0,0.001,100.0,N)
function fast_gm!(du,u,p,t)
a,α,ubar,β,D1,D2,N = p
for j in 2:N-1, i in 2:N-1
du[i,j,1] = D1*(u[i-1,j,1] + u[i+1,j,1] + u[i,j+1,1] + u[i,j-1,1] - 4u[i,j,1]) +
a*u[i,j,1]^2/u[i,j,2] + ubar - α*u[i,j,1]
end
for j in 2:N-1, i in 2:N-1
du[i,j,2] = D2*(u[i-1,j,2] + u[i+1,j,2] + u[i,j+1,2] + u[i,j-1,2] - 4u[i,j,2]) +
a*u[i,j,1]^2 - β*u[i,j,2]
end
for j in 2:N-1
i = 1
du[1,j,1] = D1*(2u[i+1,j,1] + u[i,j+1,1] + u[i,j-1,1] - 4u[i,j,1]) +
a*u[i,j,1]^2/u[i,j,2] + ubar - α*u[i,j,1]
end
for j in 2:N-1
i = 1
du[1,j,2] = D2*(2u[i+1,j,2] + u[i,j+1,2] + u[i,j-1,2] - 4u[i,j,2]) +
a*u[i,j,1]^2 - β*u[i,j,2]
end
for j in 2:N-1
i = N
du[end,j,1] = D1*(2u[i-1,j,1] + u[i,j+1,1] + u[i,j-1,1] - 4u[i,j,1]) +
a*u[i,j,1]^2/u[i,j,2] + ubar - α*u[i,j,1]
end
for j in 2:N-1
i = N
du[end,j,2] = D2*(2u[i-1,j,2] + u[i,j+1,2] + u[i,j-1,2] - 4u[i,j,2]) +
a*u[i,j,1]^2 - β*u[i,j,2]
end
for i in 2:N-1
j = 1
du[i,1,1] = D1*(u[i-1,j,1] + u[i+1,j,1] + 2u[i,j+1,1] - 4u[i,j,1]) +
a*u[i,j,1]^2/u[i,j,2] + ubar - α*u[i,j,1]
end
for i in 2:N-1
j = 1
du[i,1,2] = D2*(u[i-1,j,2] + u[i+1,j,2] + 2u[i,j+1,2] - 4u[i,j,2]) +
a*u[i,j,1]^2 - β*u[i,j,2]
end
for i in 2:N-1
j = N
du[i,end,1] = D1*(u[i-1,j,1] + u[i+1,j,1] + 2u[i,j-1,1] - 4u[i,j,1]) +
a*u[i,j,1]^2/u[i,j,2] + ubar - α*u[i,j,1]
end
for i in 2:N-1
j = N
du[i,end,2] = D2*(u[i-1,j,2] + u[i+1,j,2] + 2u[i,j-1,2] - 4u[i,j,2]) +
a*u[i,j,1]^2 - β*u[i,j,2]
end
begin
i = 1; j = 1
du[1,1,1] = D1*(2u[i+1,j,1] + 2u[i,j+1,1] - 4u[i,j,1]) +
a*u[i,j,1]^2/u[i,j,2] + ubar - α*u[i,j,1]
du[1,1,2] = D2*(2u[i+1,j,2] + 2u[i,j+1,2] - 4u[i,j,2]) +
a*u[i,j,1]^2 - β*u[i,j,2]
i = 1; j = N
du[1,N,1] = D1*(2u[i+1,j,1] + 2u[i,j-1,1] - 4u[i,j,1]) +
a*u[i,j,1]^2/u[i,j,2] + ubar - α*u[i,j,1]
du[1,N,2] = D2*(2u[i+1,j,2] + 2u[i,j-1,2] - 4u[i,j,2]) +
a*u[i,j,1]^2 - β*u[i,j,2]
i = N; j = 1
du[N,1,1] = D1*(2u[i-1,j,1] + 2u[i,j+1,1] - 4u[i,j,1]) +
a*u[i,j,1]^2/u[i,j,2] + ubar - α*u[i,j,1]
du[N,1,2] = D2*(2u[i-1,j,2] + 2u[i,j+1,2] - 4u[i,j,2]) +
a*u[i,j,1]^2 - β*u[i,j,2]
i = N; j = N
du[end,end,1] = D1*(2u[i-1,j,1] + 2u[i,j-1,1] - 4u[i,j,1]) +
a*u[i,j,1]^2/u[i,j,2] + ubar - α*u[i,j,1]
du[end,end,2] = D2*(2u[i-1,j,2] + 2u[i,j-1,2] - 4u[i,j,2]) +
a*u[i,j,1]^2 - β*u[i,j,2]
end
end
prob = ODEProblem(fast_gm!,r0,(0.0,0.1),p)
solve(prob,Tsit5())
Lastly, we can do other things like multithread the main loops, but these optimizations get the last 2x-3x out. The main optimizations which apply everywhere are the ones we just performed (though the last one only works if your matrix is a stencil. This is known as a matrix-free implementation of the PDE discretization).
This gets us to about 8x faster than our original MATLAB/SciPy/R vectorized style code!
The last thing to do is then optimize our algorithm choice. We have been using Tsit5()
as our test algorithm, but in reality this problem is a stiff PDE discretization and thus one recommendation is to use CVODE_BDF()
. However, instead of using the default dense Jacobian, we should make use of the sparse Jacobian afforded by the problem. The Jacobian is the matrix df/dr , where r is read by the linear index (i.e. down columns). But since the u variables depend on the v, the band size here is large, and thus this will not do well with a Banded Jacobian solver. Instead, we utilize sparse Jacobian algorithms. CVODE_BDF
allows us to use a sparse Newton-Krylov solver by setting linear_solver = :GMRES
(see the solver documentation, and thus we can solve this problem efficiently. Let's see how this scales as we increase the integration time.
prob = ODEProblem(fast_gm!,r0,(0.0,10.0),p)
solve(prob,Tsit5())
using Pkg; Pkg.add("Sundials")
using Sundials
solve(prob,CVODE_BDF(linear_solver=:GMRES))
prob = ODEProblem(fast_gm!,r0,(0.0,100.0),p)
# Will go out of memory if we don't turn off `save_everystep`!
solve(prob,Tsit5(),save_everystep=false)
solve(prob,CVODE_BDF(linear_solver=:GMRES))
Now let's check the allocation growth.In [ ]:
solve(prob,CVODE_BDF(linear_solver=:GMRES),save_everystep=false)
prob = ODEProblem(fast_gm!,r0,(0.0,500.0),p)
solve(prob,CVODE_BDF(linear_solver=:GMRES),save_everystep=false)
Notice that we've elimated almost all allocations, allowing the code to grow without hitting garbage collection and slowing down.
Why is CVODE_BDF
doing well? What's happening is that, because the problem is stiff, the number of steps required by the explicit Runge-Kutta method grows rapidly, whereas CVODE_BDF
is taking large steps. Additionally, the GMRES
linear solver form is quite an efficient way to solve the implicit system in this case. This is problem-dependent, and in many cases using a Krylov method effectively requires a preconditioner, so you need to play around with testing other algorithms and linear solvers to find out what works best with your problem.
Conclusion
Julia gives you the tools to optimize the solver "all the way", but you need to make use of it. The main thing to avoid is temporary allocations. For small systems, this is effectively done via static arrays. For large systems, this is done via in-place operations and cache arrays. Either way, the resulting solution can be immensely sped up over vectorized formulations by using these principles.
Callbacks and Events
In working with a differential equation, our system will evolve through many states. Particular states of the system may be of interest to us, and we say that an "event" is triggered when our system reaches these states. For example, events may include the moment when our system reaches a particular temperature or velocity. We handle these events with callbacks, which tell us what to do once an event has been triggered.
These callbacks allow for a lot more than event handling, however. For example, we can use callbacks to achieve high-level behavior like exactly preserve conservation laws and save the trace of a matrix at pre-defined time points. This extra functionality allows us to use the callback system as a modding system for the DiffEq ecosystem's solvers.
This tutorial is an introduction to the callback and event handling system in DifferentialEquations.jl, documented in the Event Handling and Callback Functions page of the documentation. We will also introduce you to some of the most widely used callbacks in the Callback Library, which is a library of pre-built mods.
Events and Continuous Callbacks
Event handling is done through continuous callbacks. Callbacks take a function, condition
, which triggers an affect!
when condition == 0
. These callbacks are called "continuous" because they will utilize rootfinding on the interpolation to find the "exact" time point at which the condition takes place and apply the affect!
at that time point.
Let's use a bouncing ball as a simple system to explain events and callbacks. Let's take Newton's model of a ball falling towards the Earth's surface via a gravitational constant g
. In this case, the velocity is changing via -g
, and position is changing via the velocity. Therefore we receive the system of ODEs:
using DifferentialEquations, ModelingToolkit
# Define some variables
t g
# Same as x, y, z = [Variable(s)(t) for s in [:x, :y, :z]]
y(t) v(t)
# Same as D = Differential(t)
D'~t
# Define equations (system of ODEs)
eqs = [D(y) ~ v,
D(v) ~ -g]
de = ODESystem(eqs)
ball! = ODEFunction(de, [y, v], [g])
We want the callback to trigger when y=0
since that's when the ball will hit the Earth's surface (our event). We do this with the condition
# Signature of condition function
function condition(u, t, integrator)
u[1]
end
Recall that the condition
will trigger when it evaluates to zero, and here it will evaluate to zero when u[1] == 0
, which occurs when v == 0
. Now we have to say what we want the callback to do. Callbacks make use of the Integrator Interface. Instead of giving a full description, a quick and usable rundown is:
Values are strored in
integrator.u
Times are stored in
integrator.t
The parameters are stored in
integrator.p
integrator(t)
performs an interpolation in the current interval betweenintegrator.tprev
andintegrator.t
(and allows extrapolation)User-defined options (tolerances, etc.) are stored in
integrator.opts
integrator.sol
is the current solution object. Note thatintegrator.sol.prob
is the current problem
While there's a lot more on the integrator interface page, that's a working knowledge of what's there.
What we want to do with our affect!
is to "make the ball bounce". Mathematically speaking, the ball bounces when the sign of the velocity flips. As an added behavior, let's also use a small friction constant to dampen the ball's velocity. This way only a percentage of the velocity will be retained when the event is triggered and the callback is used. We'll define this behavior in the affect!
function:
# Signature of affect! function
function affect!(integrator)
integrator.u[2] = -integrator.p[2] * integrator.u[2]
end
integrator.u[2]
is the second value of our model, which is v
or velocity, and integrator.p[2]
, is our friction coefficient.
Therefore affect!
can be read as follows: affect!
will take the current value of velocity, and multiply it -1
multiplied by our friction coefficient. Therefore the ball will change direction and its velocity will dampen when affect!
is called.
Now let's build the ContinuousCallback
bounce_cb = ContinuousCallback(condition,affect!)
Now let's make an ODEProblem
which has our callback
u0 = [50.0,0.0]
tspan = (0.0,15.0)
p = (9.8,0.9)
prob = ODEProblem(ball!,u0,tspan,p,callback=bounce_cb)
Notice that we chose a friction constant of 0.9
. Now we can solve the problem and plot the solution as we normally would:
sol = solve(prob,Tsit5())
using Plots; gr()
plot(sol)
and tada, the ball bounces! Notice that the ContinuousCallback
is using the interpolation to apply the effect "exactly" when v == 0
. This is crucial for model correctness, and thus when this property is needed a ContinuousCallback
should be used.
Discrete Callbacks
A discrete callback checks a condition
after every integration step and, if true, it will apply an affect!
. For example, let's say that at time t=2
we want to include that a kid kicked the ball, adding 20
to the current velocity. This kind of situation, where we want to add a specific behavior which does not require rootfinding, is a good candidate for a DiscreteCallback
. In this case, the condition
is a boolean for whether to apply the affect!
, so:
function condition_kick(u,t,integrator)
t == 2
end
We want the kick to occur at t=2
, so we check for that time point. When we are at this time point, we want to do
function affect_kick!(integrator)
integrator.u[2] += 50
end
Now we build the problem as before
kick_cb = DiscreteCallback(condition_kick,affect_kick!)
u0 = [50.0,0.0]
tspan = (0.0,10.0)
p = (9.8,0.9)
prob = ODEProblem(ball!,u0,tspan,p,callback=kick_cb)
Note that, since we are requiring our effect at exactly the time t=2
, we need to tell the integration scheme to step at exactly t=2
to apply this callback. This is done via the option tstops
, which is like saveat
but means "stop at these values".
sol = solve(prob,Tsit5(),tstops=[2.0])
plot(sol)
Note that this example could've been done with a ContinuousCallback
by checking the condition t-2
.
Merging Callbacks with Callback Sets
In some cases you may want to merge callbacks to build up more complex behavior. In our previous result, notice that the model is unphysical because the ball goes below zero! What we really need to do is add the bounce callback together with the kick. This can be achieved through the CallbackSet
cb = CallbackSet(bounce_cb,kick_cb)
A CallbackSet
merges their behavior together. The logic is as follows. In a given interval, if there are multiple continuous callbacks that would trigger, only the one that triggers at the earliest time is used. The time is pulled back to where that continuous callback is triggered, and then the DiscreteCallback
s in the callback set are called in order.
u0 = [50.0,0.0]
tspan = (0.0,15.0)
p = (9.8,0.9)
prob = ODEProblem(ball!,u0,tspan,p,callback=cb)
sol = solve(prob,Tsit5(),tstops=[2.0])
plot(sol)
Notice that we have now merged the behaviors. We can then nest this as deep as we like.
Integration Termination and Directional Handling
Let's look at another model now: the model of the Harmonic Oscillator. We can write this using ParameterizedFunctions.jl.
using Pkg
pkg"add ParameterizedFunctions"
using DifferentialEquations, ParameterizedFunctions
harmonic! = HarmonicOscillator begin
dv = -x
dx = v
end
u0 = [1.,0.]
tspan = (0.0,10.0)
prob = ODEProblem(harmonic!,u0,tspan)
sol = solve(prob)
plot(sol)
Let's instead stop the integration when a condition is met. From the Integrator Interface stepping controls we see that terminate!(integrator)
will cause the integration to end. So our new affect!
is simply:
function terminate_affect!(integrator)
terminate!(integrator)
end
Let's first stop the integration when the particle moves back to x=0
. This means we want to use the condition
function terminate_condition(u,t,integrator)
u[2]
end
terminate_cb = ContinuousCallback(terminate_condition,terminate_affect!)
Note that instead of adding callbacks to the problem, we can also add them to the solve
command. This will automatically form a CallbackSet
with any problem-related callbacks and naturally allows you to distinguish between model features and integration controls
sol = solve(prob,callback=terminate_cb)
plot(sol)
Notice that the harmonic oscilator's true solution here is sin
and cosine
, and thus we would expect this return to zero to happen at t=π
sol.t[end]
This is one way to approximate π! Lower tolerances and arbitrary precision numbers can make this more exact, but let's not look at that. Instead, what if we wanted to halt the integration after exactly one cycle? To do so we would need to ignore the first zero-crossing. Luckily in these types of scenarios there's usually a structure to the problem that can be exploited. Here, we only want to trigger the affect!
when crossing from positive to negative, and not when crossing from negative to positive. In other words, we want our affect!
to only occur on upcrossings.
If the ContinuousCallback
constructor is given a single affect!
, it will occur on both upcrossings and downcrossings. If there are two affect!
s given, then the first is for upcrossings and the second is for downcrossings. An affect!
can be ignored by using nothing
. Together, the "upcrossing-only" version of the effect means that the first affect!
is what we defined above and the second is nothing
. Therefore we want:
terminate_upcrossing_cb = ContinuousCallback(terminate_condition,terminate_affect!,nothing)
sol = solve(prob,callback=terminate_upcrossing_cb)
plot(sol)
Callback Library
As you can see, callbacks can be very useful and through CallbackSets
we can merge together various behaviors. Because of this utility, there is a library of pre-built callbacks known as the Callback Library. We will walk through a few examples where these callbacks can come in handy.
Manifold Projection
One callback is the manifold projection callback. Essentially, you can define any manifold g(sol)=0
which the solution must live on, and cause the integration to project to that manifold after every step. As an example, let's see what happens if we naively run the harmonic oscillator for a long time:
tspan = (0.0,10000.0)
prob = ODEProblem(harmonic!,u0,tspan)
sol = solve(prob)
pyplot(fmt=:png) # Make it a PNG instead of an SVG since there's a lot of points!
plot(sol,vars=(1,2))
plot(sol,vars=(0,1),denseplot=false)
Notice that what's going on is that the numerical solution is drifting from the true solution over this long time scale. This is because the integrator is not conserving energy.
plot(sol.t,[u[2]^2 + u[1]^2 for u in sol.u],legend=:topleft) # Energy ~ x^2 + v^2
Some integration techniques like symplectic integrators are designed to mitigate this issue, but instead let's tackle the problem by enforcing conservation of energy. To do so, we define our manifold as the one where energy equals 1 (since that holds in the initial condition), that is:
function g2(resid,u,p,t)
resid[1] = u[2]^2 + u[1]^2 - 1
resid[2] = 0
end
Here the residual measures how far from our desired energy we are, and the number of conditions matches the size of our system (we ignored the second one by making the residual 0). Thus we define a ManifoldProjection
callback and add that to the solver:
cb = ManifoldProjection(g2)
sol = solve(prob,callback=cb)
plot(sol,vars=(1,2))
plot(sol,vars=(0,1),denseplot=false)
Now we have "perfect" energy conservation, where if it's ever violated too much the solution will get projected back to energy=1
u1,u2 = sol[500]
u2^2 + u1^2
While choosing different integration schemes and using lower tolerances can achieve this effect as well, this can be a nice way to enforce physical constraints and is thus used in many disciplines like molecular dynamics. Another such domain constraining callback is the PositiveCallback()
which can be used to enforce positivity of the variables.
SavingCallback
The SavingCallback
can be used to allow for special saving behavior. Let's take a linear ODE define on a system of 1000x1000 matrices
prob = ODEProblem((du,u,p,t)->du.=u,rand(1000,1000),(0.0,1.0))
In fields like quantum mechanics you may only want to know specific properties of the solution such as the trace or the norm of the matrix. Saving all of the 1000x1000 matrices can be a costly way to get this information! Instead, we can use the SavingCallback
to save the trace
and norm
at specified times. To do so, we first define our SavedValues
cache. Our time is in terms of Float64
, and we want to save tuples of Float64
s (one for the trace
and one for the norm
), and thus we generate the cache as:
saved_values = SavedValues(Float64, Tuple{Float64,Float64})
Now we define the SavingCallback
by giving it a function of (u,p,t,integrator)
that returns the values to save, and the cache
using LinearAlgebra
cb = SavingCallback((u,t,integrator)->(tr(u),norm(u)), saved_values)
Here we take u
and save (tr(u),norm(u))
. When we solve with this callback
# Turn off normal saving
sol = solve(prob, Tsit5(), callback=cb, save_everystep=false, save_start=false, save_end = false)
Our values are stored in our saved_values
variable
saved_values.t
saved_values.saveval
By default this happened only at the solver's steps. But the SavingCallback
has similar controls as the integrator. For example, if we want to save at every 0.1
seconds, we do can so using saveat
saved_values = SavedValues(Float64, Tuple{Float64,Float64}) # New cache
cb = SavingCallback((u,t,integrator)->(tr(u),norm(u)), saved_values, saveat = 0.0:0.1:1.0)
# Turn off normal saving
sol = solve(prob, Tsit5(), callback=cb, save_everystep=false, save_start=false, save_end = false)
saved_values.t
saved_values.saveval
Formatting Plots
Since the plotting functionality is implemented as a recipe to Plots.jl, all of the options open to Plots.jl can be used in our plots. In addition, there are special features specifically for differential equation plots. This tutorial will teach some of the most commonly used options. Let's first get the solution to some ODE. Here I will use one of the Lorenz ordinary differential equation. As with all commands in DifferentialEquations.jl, I got a plot of the solution by calling solve
on the problem, and plot
on the solution
using DifferentialEquations, Plots, ParameterizedFunctions
lorenz3 = begin
dx = σ*(y-x)
dy = ρ*x-y-x*z
dz = x*y-β*z
end σ β ρ
p = [10.0,8/3,28]
u0 = [1., 5., 10.]
tspan = (0., 100.)
prob = ODEProblem(lorenz3, u0, tspan, p)
sol = solve(prob)
plot(sol)
Now let's change it to a phase plot. As discussed in the plot functions page, we can use the vars
command to choose the variables to plot. Let's plot variable x
vs variable y
vs variable z
plot(sol,vars=(1, 2, 3))
# Not available in MOdelingToolkit.jl yet
# plot(sol,vars=[:x])
Notice that we were able to use the variable names because we had defined the problem with the macro. But in general, we can use the indices. The previous plots would be
plot(sol,vars=(1,2,3))
plot(sol,vars=[1])
Common options are to add titles, axis, and labels. For example
plot(sol,linewidth=5,
title="Solution to the linear ODE with a thick line",
xaxis="Time (t)",yaxis="u(t) (in mm)",label=["X","Y","Z"])
Notice that series recipes apply to the solution type as well. For example, we can use a scatter plot on the timeseries
# scatter(sol,vars=[:x])
This shows that the recipe is using the interpolation to smooth the plot. It becomes abundantly clear when we turn it off using denseplot=false
plot(sol,vars=(1,2,3),denseplot=false)
When this is done, only the values the timestep hits are plotted. Using the interpolation usually results in a much nicer looking plot so it's recommended, and since the interpolations have similar orders to the numerical methods, their results are trustworthy on the full interval. We can control the number of points used in the interpolation's plot using the plotdensity
command
plot(sol,vars=(1,2,3),plotdensity=100)
That's plotting the entire solution using 100 points spaced evenly in time.
plot(sol,vars=(1,2,3),plotdensity=10000)
That's more like it! By default it uses 100*length(sol)
, where the length is the number of internal steps it had to take. This heuristic usually does well, but unusually difficult equations it can be relaxed (since it will take small steps), and for equations with events / discontinuities raising the plot density can help resolve the discontinuity.
Lastly notice that we can compose plots. Let's show where the 100 points are using a scatter plot:
plot(sol,vars=(1,2,3))
scatter!(sol,vars=(1,2,3),plotdensity=100)
We can instead work with an explicit plot object. This form can be better for building a complex plot in a loop.
p = plot(sol,vars=(1,2,3))
scatter!(p,sol,vars=(1,2,3),plotdensity=100)
title!("I added a title")