How to write a program differentiably

In this blog, I will show how to write a linear algebra function, the sparse matrix-vector multiplication function, differentiably, by converting an existing irreversible program in SparseArrays to a reversible one, step by step.

First of all, you need to install the package of reversible embedded domain-specific language (eDSL) NiLang (v0.9), by typing

]add NiLang
49.5s

Step1: Find the function that you want to differentiate

As an example, I will re-write the sparse matrix-vector multiplication function defined in

https://github.com/JuliaLang/julia/blob/b773bebcdb1eccaf3efee0bfe564ad552c0bcea7/stdlib/SparseArrays/src/linalg.jl#L10

The following is a commented version of this program. I put some remarks above the "evil" statements that make a program irreversible so that we can fix them later.

function mul!(C::StridedVecOrMat, A::AbstractSparseMatrixCSC, B::Union{StridedVector,AdjOrTransStridedOrTriangularMatrix}, α::Number, β::Number)
    # 1. the following 3 lines are not allowed in reversible programming,
  	# because reversible program can not reverse through an error interruption.
    size(A, 2) == size(B, 1) || throw(DimensionMismatch())
    size(A, 1) == size(C, 1) || throw(DimensionMismatch())
    size(B, 2) == size(C, 2) || throw(DimensionMismatch())
    # The following assign `nzv = nonzeros(A)` is not reversible. One should use the ancilla statement like `nzv ← nonzeros(A)`. The difference is the latter appends an ancilla deallocation statement `nzv → nonzeros(A)` at the end of the current scope automatically.
    nzv = nonzeros(A)
    rv = rowvals(A)
  
  	# this `if` statement looks reversible since `β` is not changed while executing the branch. However, the compiler does not know. We can simply add a trivial `postcondition` (~) to let the compiler know that. i.e. `if (β != 1, ~) ... end`.
    if β != 1
        # `rmul!` (or `*=`) and `fill!` are not reversible, one has to allocate new memory for it. In the following, we forbid the case of `β != 1` to simplify the discussion.
        β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
    end
    for k = 1:size(C, 2)
        @inbounds for col = 1:size(A, 2)
            # assignment is not allowed, use `+=` instead
            αxj = B[col,k] * α
            for j = getcolptr(A)[col]:(getcolptr(A)[col + 1] - 1)
                C[rv[j], k] += nzv[j]*αxj
            end
        end
    end
    # There is no need to write return for a reversible function, because a reversible function returns the input arguments automatically.
    C
end

Step2: Modify the irreversible statements

Finally, your reversible program will look like

using NiLang
using SparseArrays: SparseMatrixCSC, AbstractSparseMatrix, nonzeros, rowvals, getcolptr
@i function mul!(C::StridedVecOrMat, A::AbstractSparseMatrix, B::StridedVector{T}, α::Number, β::Number) where T
    @safe size(A, 2) == size(B, 1) || throw(DimensionMismatch())
    @safe size(A, 1) == size(C, 1) || throw(DimensionMismatch())
    @safe size(B, 2) == size(C, 2) || throw(DimensionMismatch())
    @routine begin
        nzv  nonzeros(A)
        rv  rowvals(A)
    end
    if (β != 1, ~)
        @safe error("only β = 1 is supported, got β = $(β).")
    end
    # Here, we close the reversibility check inside the loop to increase performance
    @invcheckoff for k = 1:size(C, 2)
        @inbounds for col = 1:size(A, 2)
            @routine begin
                αxj  zero(T)
                αxj += B[col,k] * α
            end
            for j = SparseArrays.getcolptr(A)[col]:(SparseArrays.getcolptr(A)[col + 1] - 1)
                C[rv[j], k] += nzv[j]*αxj
            end
            ~@routine
        end
    end
    ~@routine
end
90.9s

Step3: Test and benchmark your program

using SparseArrays: sprand
import SparseArrays
using BenchmarkTools
n = 1000
sp1 = sprand(ComplexF64, n, n,0.1)
v = randn(ComplexF64, n)
out = zero(v);
0.7s
@benchmark SparseArrays.mul!($(copy(out)), $sp1, $v, 0.5+0im, 1)
11.6s
@benchmark mul!($(copy(out)), $(sp1), $v, 0.5+0im, 1)
10.5s

A reversible program written in this way is also differentiable. We can benchmark the automatic differentiation by typing

using NiLang.AD
@benchmark (~mul!)($(GVar(copy(out))), $(GVar(sp1)), $(GVar(v)), $(GVar(0.5)), 1)
12.1s

Here, we used the GVar type to store the gradients, which is similar to the Dual type in ForwardDiff.jl, but in reverse mode. The above benchmark shows the performance of back propagating the gradients. It is easy to see complex-valued autodiff is supported in NiLang. In the practical case, one may want to define a real-valued output as the loss functions. We recommend readers to read our paper for more detials.

Debugging

NiLang is still experimental, with a lot of unexpected "features". Feel free to join and discuss in Julia slack with me. I appear in the #autodiff and #reversible-computing channel quite often.

Runtimes (1)