Reverse computing is not checkpointing

A recent propose of implementing the programming language level automatic differentiation in a reversible programming language [arXiv: 2003.04617] suggests that using reverse computing to differentiate the whole language might be more practical than the traditional checkpointing based approach. Since reverse computing and checkpointing share many similarities, it is still not clear to most people that what makes them different.

What is checkpointing?

Checkpointing [A Griewank, 2008] is a technique to trace back prior computational states (*Note: This definition covers the default mode in reverse mode AD that "checkpoint" every step). In machine learning, it is famous for being a flexible method to trade time and space in reverse mode automatic differentiation.

To understand how it works, let's consider a checkpointing algorithm (illustrated below) that snapshots a program every 100 steps, to recover a state not cached, the program starts from the closest previously cached state 1, and re-computes the program 2-100.

Here, the snapshot does not copy the whole memory. Instead, it only records the changed parts (or the changed variables in one operation). The more often checkpointing takes a snapshot, the less computational overhead to trace back a state and the more space to store intermediate results.

Reverse computing

Reverse computing is a strategy that reversible computing uses to free up memories. Depending on whether a variable is used in future computation, we classify data as output state (not shown in the figure) and garbage (red cubics). In regular computing, the program can deallocate the garbage space immediately or by need, so that only the output variables are transferred to the next stage. However, in reversible computing, the program can not erase a variable unless its value has been cleared (usually by reverse computing).

When reversible computing does not introduce any garbage variable as shown in (b), there is no need to cache variables. One can trace back the states by running the program in the reversed direction. However, when the computation introduces garbage variables as shown in (a), the program sometimes needs to free the garbage. This is done by copying the output to another space and run the program in the reversed direction. So that only inputs and outputs are left on the disk, while all the rest are zero cleared. The more often the program uncomputes, the more time overhead and the less garbage in the memory, which is called the g-segment tradeoff in reverse computing [KS Perumalla, 2013]. It is also the reason why some people prefer to have a messy desk: When one returns stuff from the desktop to the correct position (uncomputing), the desktop becomes more spacious. However, returning stuff also costs time.

The similarities

Reverse computing and checkpointing share many similarities.

  • Both of them can make program reversible.

  • When there is no time overhead, both have a space overhead inline_formula not implemented (i.e. linear to time).

Are they essentially the same thing in different forms?

Reverse Computing is not checkpointing

When a polynomial overhead in time is allowed, reversible computing has a minimum space overhead of inline_formula not implemented [Robert Y. Levine, 1990]. While for checkpointing, there can be no space overhead. One can just recompute from beginning to obtain any intermediate state.

Are there cases where reverse computing can do better? Yes!

Handling allocations better

Let's first consider the Pyramid example in the book "Evaluate Derivatives", Sec. 3.5.

function pyramid!(v!, x::AbstractVector{T}) where T
    @assert size(v!,2) == size(v!,1) == length(x)
    @inbounds for j=1:length(x)
        v![1,j] += x[j]
    end
    @inbounds for i=1:size(v!,1)-1
        for j=1:size(v!,2)-i
            v![i+1,j] += cos(v![i,j+1]) * sin(v![i,j])
        end
    end
    v![end,1]
end
1.0s
pyramid! (generic function with 1 method)

Here, we define an in-place function pyramid! that repeatedly modifies a single element in v!. In both PyTorch and TensorFlow, array mutation operations are not allowed (note: both a PyTorch leaf tensor and a TensorFlow TensorArray allows modifying the tensor once, but they are not truly mutable). When changing a tensor element, the tensor copy is inevitable to protect previously cached data. Hence, both PyTorch and TensorFlow forbid mutable arrays. There is also a wide class of automatic differentiation frameworks that only cache scalar data like Tapenade and ADOL-C. In these frameworks, intermediate results inside the loop are pushed into a stack.

In the following, I will show the difference between the reverse-computing and stack-based AD by implementing them in the reversible eDSL NiLang. The reversible style to write the pyramid function is

using Pkg
Pkg.add("NiLang");
5.9s
using NiLang
0.1s
@i function i_pyramid!(y!, v!, x::AbstractVector{T}) where T
    @safe @assert size(v!,2) == size(v!,1) == length(x)
    @invcheckoff @inbounds for j=1:length(x)
        v![1,j] += x[j]
    end
    @invcheckoff @inbounds for i=1:size(v!,1)-1
        for j=1:size(v!,2)-i
            @routine begin
                @zeros T c s
                c += cos(v![i,j+1])
                s += sin(v![i,j])
            end
            v![i+1,j] += c * s
            ~@routine
        end
    end
    y! += v![end,1]
end
0.3s

One can print the reversible IR

ex_ir = NiLangCore.precom_ex(Main, :(begin
    @routine begin
        @zeros T c s
        c += cos(v![i,j+1])
        s += sin(v![i,j])
    end
    v![i+1,j] += c * s
    ~@routine
end)) |> NiLangCore.rmlines
0.2s
quote begin begin c ← zero(T) s ← zero(T) end c += cos(v![i, j + 1]) s += sin(v![i, j]) end v![i + 1, j] += c * s begin s -= sin(v![i, j]) c -= cos(v![i, j + 1]) begin s → zero(T) c → zero(T) end end end

One can see this IR is highly symmetric except the middle statement v![i + 1, j] += c * s. Hence its reversed code can be easily obtained by typing

NiLangCore.dual_ex(Main, ex_ir)
0.2s
quote begin begin c ← zero(T) s ← zero(T) end c += cos(v![i, j + 1]) s += sin(v![i, j]) end v![i + 1, j] -= c * s begin s -= sin(v![i, j]) c -= cos(v![i, j + 1]) begin s → zero(T) c → zero(T) end end end

Intermediate variables s and c inside the loop are recovered through uncomputing, which causes a ~2x overhead in times. By taking all forward computing, backward computing and gradient computing into consideration, the final gradient/loss ratio in time is ~6.2.

using BenchmarkTools
@benchmark pyramid!(zeros(20, 20), randn(20))
5.9s
BenchmarkTools.Trial: memory estimate: 3.48 KiB allocs estimate: 2 -------------- minimum time: 3.288 μs (0.00% GC) median time: 3.600 μs (0.00% GC) mean time: 4.292 μs (5.37% GC) maximum time: 721.192 μs (89.99% GC) -------------- samples: 10000 evals/sample: 8
@benchmark NiLang.AD.gradient(Val(1), i_pyramid!, (0.0, zeros(20, 20), randn(20)))
7.5s
BenchmarkTools.Trial: memory estimate: 13.94 KiB allocs estimate: 14 -------------- minimum time: 20.383 μs (0.00% GC) median time: 22.331 μs (0.00% GC) mean time: 24.313 μs (2.94% GC) maximum time: 2.997 ms (80.19% GC) -------------- samples: 10000 evals/sample: 1

The stack-based approach can achieve similar performance. However, the memory overhead is much larger.

@i function i_pyramid_stack!(y!, v!, x::AbstractVector{T}, stack) where T
    @safe @assert size(v!,2) == size(v!,1) == length(x)
    @invcheckoff @inbounds for j=1:length(x)
        v![1,j] += x[j]
    end
    @invcheckoff @inbounds for i=1:size(v!,1)-1
        for j=1:size(v!,2)-i
            @zeros T c s
            c += cos(v![i,j+1])
            s += sin(v![i,j])
            v![i+1,j] += c * s
            PUSH!(stack, c)
            PUSH!(stack, s)
        end
    end
    y! += v![end,1]
end
0.3s
@benchmark NiLang.AD.gradient(Val(1), i_pyramid_stack!, (0.0, zeros(20, 20), randn(20), Float64[]))
7.2s
BenchmarkTools.Trial: memory estimate: 28.44 KiB allocs estimate: 25 -------------- minimum time: 20.964 μs (0.00% GC) median time: 23.753 μs (0.00% GC) mean time: 27.445 μs (5.34% GC) maximum time: 3.198 ms (78.45% GC) -------------- samples: 10000 evals/sample: 1

Here, we also want to "warn" users that the stack allocation is not available in GPU kernels. Hence only the reverse-computing version is compatible with GPU.

Utilizing Reversibility

Reversibility can be used to save memory in reversible computing, the same argument is not true for checkpointing. For example, one can parameterize a unitary matrix as a consecutive application of two-level unitaries. The reversible implementation of a unitary matrix multiplication does not allocate new spaces.

@i function i_umm!(x!::AbstractArray, θ)
    k  0
    for l = 1:size(x!, 2)
        for j=1:size(x!, 1)
            for i=size(x!, 1)-1:-1:j
                INC(k)
                ROT(x![i,l], x![i+1,l], θ[k])
            end
        end
    end
    k  length(θ)
end
0.5s

Unitary matrices are widely used in quantum simulations, recurrent neural networks [arXiv: 1612.05231] et. al.

Again, since checkpointing does not assume a specific execution order, it can not utilize reversibility to reduce memory usage.

Helping users code better

Suppose we want to compute inline_formula not implemented without using the power function. When one does not have reversibility in mind. They might write something like

"""
blindly implemented power100
"""
function power100(x::T) where T
    y = one(T)
    for i=1:100
        y *= x
    end
    return y
end
0.4s
power100

One needs a vector of size 100 to cache intermediate results because *= is not reversible for regular arithmetics. A reversible style code would be

"""
reverse friendly code. Chunk-wise allocation.
"""
@i function i_power100(out, x::T) where T
    @routine begin
        ys  zeros(T, 100)
        ys[1] += x
        for i=2:100
            ys[i] += ys[i-1]*x
        end
    end
    out += ys[100]
    ~@routine
end
0.4s

One preallocates a chunk of memory to store intermediate results. Which is much more efficient than push!. There is an even better way to write this function with constant memory. One can enjoy the reversibility of *= in the logarithmic number system.

"""
no allocation at all utilizing log numbers.
"""
@i function i_power100_noalloc(x100, x::T) where T
	  if (x!=0, ~)
          @routine begin
          absx  zero(T)
          lx  one(ULogarithmic{T})
          lx100  one(ULogarithmic{T})
          absx += abs(x)
          lx *= convert(absx)  # convert `x` to the log number
          for i=1:100
            lx100 *= lx
          end
        end
        # convert the result to fixedpoint/floating point numbers
        x100 += convert(lx100)
        ~@routine
    end
end
0.3s

This trick can be used in computing an arithmetic function with Taylor expansion. No allocation or chunk-wise allocation is important when one wants to write GPU-compatible code with KernelAbstractions.jl. It is the user who determines whether the code is suitable for recovering previous states. Hence a user's reversible thinking matters!

Runtimes (1)