How to write a program differentiably
In this blog, I will show how to write a linear algebra function, the sparse matrix-vector multiplication function, differentiably, by converting an existing irreversible program in SparseArrays to a reversible one, step by step.
First of all, you need to install the package of reversible embedded domain-specific language (eDSL) NiLang (v0.9), by typing
]add NiLang
Step1: Find the function that you want to differentiate
As an example, I will re-write the sparse matrix-vector multiplication function defined in
The following is a commented version of this program. I put some remarks above the "evil" statements that make a program irreversible so that we can fix them later.
function mul!(C::StridedVecOrMat, A::AbstractSparseMatrixCSC, B::Union{StridedVector,AdjOrTransStridedOrTriangularMatrix}, α::Number, β::Number)
# 1. the following 3 lines are not allowed in reversible programming,
# because reversible program can not reverse through an error interruption.
size(A, 2) == size(B, 1) || throw(DimensionMismatch())
size(A, 1) == size(C, 1) || throw(DimensionMismatch())
size(B, 2) == size(C, 2) || throw(DimensionMismatch())
# The following assign `nzv = nonzeros(A)` is not reversible. One should use the ancilla statement like `nzv ← nonzeros(A)`. The difference is the latter appends an ancilla deallocation statement `nzv → nonzeros(A)` at the end of the current scope automatically.
nzv = nonzeros(A)
rv = rowvals(A)
# this `if` statement looks reversible since `β` is not changed while executing the branch. However, the compiler does not know. We can simply add a trivial `postcondition` (~) to let the compiler know that. i.e. `if (β != 1, ~) ... end`.
if β != 1
# `rmul!` (or `*=`) and `fill!` are not reversible, one has to allocate new memory for it. In the following, we forbid the case of `β != 1` to simplify the discussion.
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
end
for k = 1:size(C, 2)
for col = 1:size(A, 2)
# assignment is not allowed, use `+=` instead
αxj = B[col,k] * α
for j = getcolptr(A)[col]:(getcolptr(A)[col + 1] - 1)
C[rv[j], k] += nzv[j]*αxj
end
end
end
# There is no need to write return for a reversible function, because a reversible function returns the input arguments automatically.
C
end
Step2: Modify the irreversible statements
Finally, your reversible program will look like
using NiLang
using SparseArrays: SparseMatrixCSC, AbstractSparseMatrix, nonzeros, rowvals, getcolptr
function mul!(C::StridedVecOrMat, A::AbstractSparseMatrix, B::StridedVector{T}, α::Number, β::Number) where T
size(A, 2) == size(B, 1) || throw(DimensionMismatch())
size(A, 1) == size(C, 1) || throw(DimensionMismatch())
size(B, 2) == size(C, 2) || throw(DimensionMismatch())
begin
nzv ← nonzeros(A)
rv ← rowvals(A)
end
if (β != 1, ~)
error("only β = 1 is supported, got β = $(β).")
end
# Here, we close the reversibility check inside the loop to increase performance
for k = 1:size(C, 2)
for col = 1:size(A, 2)
begin
αxj ← zero(T)
αxj += B[col,k] * α
end
for j = SparseArrays.getcolptr(A)[col]:(SparseArrays.getcolptr(A)[col + 1] - 1)
C[rv[j], k] += nzv[j]*αxj
end
~
end
end
~
end
Step3: Test and benchmark your program
using SparseArrays: sprand
import SparseArrays
using BenchmarkTools
n = 1000
sp1 = sprand(ComplexF64, n, n,0.1)
v = randn(ComplexF64, n)
out = zero(v);
SparseArrays.mul!($(copy(out)), $sp1, $v, 0.5+0im, 1)
mul!($(copy(out)), $(sp1), $v, 0.5+0im, 1)
A reversible program written in this way is also differentiable. We can benchmark the automatic differentiation by typing
using NiLang.AD
~mul!)($(GVar(copy(out))), $(GVar(sp1)), $(GVar(v)), $(GVar(0.5)), 1) (
Here, we used the GVar
type to store the gradients, which is similar to the Dual
type in ForwardDiff.jl, but in reverse mode. The above benchmark shows the performance of back propagating the gradients. It is easy to see complex-valued autodiff is supported in NiLang. In the practical case, one may want to define a real-valued output as the loss functions. We recommend readers to read our paper for more detials.
Debugging
NiLang is still experimental, with a lot of unexpected "features". Feel free to join and discuss in Julia slack with me. I appear in the #autodiff and #reversible-computing channel quite often.