How to write a program differentiably

In this blog, I will show how to write a linear algebra function, the sparse matrix-vector multiplication function, differentiably, by converting an existing irreversible program in SparseArrays to a reversible one, step by step.
First of all, you need to install the package of reversible embedded domain-specific language (eDSL) NiLang (v0.9), by typing
]add NiLangStep1: Find the function that you want to differentiate
As an example, I will re-write the sparse matrix-vector multiplication function defined in
The following is a commented version of this program. I put some remarks above the "evil" statements that make a program irreversible so that we can fix them later.
function mul!(C::StridedVecOrMat, A::AbstractSparseMatrixCSC, B::Union{StridedVector,AdjOrTransStridedOrTriangularMatrix}, α::Number, β::Number) # 1. the following 3 lines are not allowed in reversible programming, # because reversible program can not reverse through an error interruption. size(A, 2) == size(B, 1) || throw(DimensionMismatch()) size(A, 1) == size(C, 1) || throw(DimensionMismatch()) size(B, 2) == size(C, 2) || throw(DimensionMismatch()) # The following assign `nzv = nonzeros(A)` is not reversible. One should use the ancilla statement like `nzv ← nonzeros(A)`. The difference is the latter appends an ancilla deallocation statement `nzv → nonzeros(A)` at the end of the current scope automatically. nzv = nonzeros(A) rv = rowvals(A) # this `if` statement looks reversible since `β` is not changed while executing the branch. However, the compiler does not know. We can simply add a trivial `postcondition` (~) to let the compiler know that. i.e. `if (β != 1, ~) ... end`. if β != 1 # `rmul!` (or `*=`) and `fill!` are not reversible, one has to allocate new memory for it. In the following, we forbid the case of `β != 1` to simplify the discussion. β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C))) end for k = 1:size(C, 2) for col = 1:size(A, 2) # assignment is not allowed, use `+=` instead αxj = B[col,k] * α for j = getcolptr(A)[col]:(getcolptr(A)[col + 1] - 1) C[rv[j], k] += nzv[j]*αxj end end end # There is no need to write return for a reversible function, because a reversible function returns the input arguments automatically. CendStep2: Modify the irreversible statements
Finally, your reversible program will look like
using NiLangusing SparseArrays: SparseMatrixCSC, AbstractSparseMatrix, nonzeros, rowvals, getcolptr function mul!(C::StridedVecOrMat, A::AbstractSparseMatrix, B::StridedVector{T}, α::Number, β::Number) where T size(A, 2) == size(B, 1) || throw(DimensionMismatch()) size(A, 1) == size(C, 1) || throw(DimensionMismatch()) size(B, 2) == size(C, 2) || throw(DimensionMismatch()) begin nzv ← nonzeros(A) rv ← rowvals(A) end if (β != 1, ~) error("only β = 1 is supported, got β = $(β).") end # Here, we close the reversibility check inside the loop to increase performance for k = 1:size(C, 2) for col = 1:size(A, 2) begin αxj ← zero(T) αxj += B[col,k] * α end for j = SparseArrays.getcolptr(A)[col]:(SparseArrays.getcolptr(A)[col + 1] - 1) C[rv[j], k] += nzv[j]*αxj end ~ end end ~endStep3: Test and benchmark your program
using SparseArrays: sprandimport SparseArraysusing BenchmarkToolsn = 1000sp1 = sprand(ComplexF64, n, n,0.1)v = randn(ComplexF64, n)out = zero(v); SparseArrays.mul!($(copy(out)), $sp1, $v, 0.5+0im, 1) mul!($(copy(out)), $(sp1), $v, 0.5+0im, 1)A reversible program written in this way is also differentiable. We can benchmark the automatic differentiation by typing
using NiLang.AD (~mul!)($(GVar(copy(out))), $(GVar(sp1)), $(GVar(v)), $(GVar(0.5)), 1)Here, we used the GVar type to store the gradients, which is similar to the Dual type in ForwardDiff.jl, but in reverse mode. The above benchmark shows the performance of back propagating the gradients. It is easy to see complex-valued autodiff is supported in NiLang. In the practical case, one may want to define a real-valued output as the loss functions. We recommend readers to read our paper for more detials.
Debugging
NiLang is still experimental, with a lot of unexpected "features". Feel free to join and discuss in Julia slack with me. I appear in the #autodiff and #reversible-computing channel quite often.