Reverse computing is not checkpointing
A recent propose of implementing the programming language level automatic differentiation in a reversible programming language [arXiv: 2003.04617] suggests that using reverse computing to differentiate the whole language might be more practical than the traditional checkpointing based approach. Since reverse computing and checkpointing share many similarities, it is still not clear to most people that what makes them different.
What is checkpointing?
Checkpointing [A Griewank, 2008] is a technique to trace back prior computational states (*Note: This definition covers the default mode in reverse mode AD that "checkpoint" every step). In machine learning, it is famous for being a flexible method to trade time and space in reverse mode automatic differentiation.
To understand how it works, let's consider a checkpointing algorithm (illustrated below) that snapshots a program every 100 steps, to recover a state not cached, the program starts from the closest previously cached state 1, and re-computes the program 2-100.
Here, the snapshot does not copy the whole memory. Instead, it only records the changed parts (or the changed variables in one operation). The more often checkpointing takes a snapshot, the less computational overhead to trace back a state and the more space to store intermediate results.
Reverse computing
Reverse computing is a strategy that reversible computing uses to free up memories. Depending on whether a variable is used in future computation, we classify data as output state (not shown in the figure) and garbage (red cubics). In regular computing, the program can deallocate the garbage space immediately or by need, so that only the output variables are transferred to the next stage. However, in reversible computing, the program can not erase a variable unless its value has been cleared (usually by reverse computing).
When reversible computing does not introduce any garbage variable as shown in (b), there is no need to cache variables. One can trace back the states by running the program in the reversed direction. However, when the computation introduces garbage variables as shown in (a), the program sometimes needs to free the garbage. This is done by copying the output to another space and run the program in the reversed direction. So that only inputs and outputs are left on the disk, while all the rest are zero cleared. The more often the program uncomputes, the more time overhead and the less garbage in the memory, which is called the g-segment tradeoff in reverse computing [KS Perumalla, 2013]. It is also the reason why some people prefer to have a messy desk: When one returns stuff from the desktop to the correct position (uncomputing), the desktop becomes more spacious. However, returning stuff also costs time.
The similarities
Reverse computing and checkpointing share many similarities.
Both of them can make program reversible.
When there is no time overhead, both have a space overhead inline_formula not implemented (i.e. linear to time).
Are they essentially the same thing in different forms?
Reverse Computing is not checkpointing
When a polynomial overhead in time is allowed, reversible computing has a minimum space overhead of inline_formula not implemented [Robert Y. Levine, 1990]. While for checkpointing, there can be no space overhead. One can just recompute from beginning to obtain any intermediate state.
Are there cases where reverse computing can do better? Yes!
Handling allocations better
Let's first consider the Pyramid example in the book "Evaluate Derivatives", Sec. 3.5.
function pyramid!(v!, x::AbstractVector{T}) where T
size(v!,2) == size(v!,1) == length(x)
for j=1:length(x)
v![1,j] += x[j]
end
for i=1:size(v!,1)-1
for j=1:size(v!,2)-i
v![i+1,j] += cos(v![i,j+1]) * sin(v![i,j])
end
end
v![end,1]
end
Here, we define an in-place function pyramid! that repeatedly modifies a single element in v!. In both PyTorch and TensorFlow, array mutation operations are not allowed (note: both a PyTorch leaf tensor and a TensorFlow TensorArray allows modifying the tensor once, but they are not truly mutable). When changing a tensor element, the tensor copy is inevitable to protect previously cached data. Hence, both PyTorch and TensorFlow forbid mutable arrays. There is also a wide class of automatic differentiation frameworks that only cache scalar data like Tapenade and ADOL-C. In these frameworks, intermediate results inside the loop are pushed into a stack.
In the following, I will show the difference between the reverse-computing and stack-based AD by implementing them in the reversible eDSL NiLang. The reversible style to write the pyramid function is
using Pkg
Pkg.add("NiLang");
using NiLang
function i_pyramid!(y!, v!, x::AbstractVector{T}) where T
size(v!,2) == size(v!,1) == length(x)
for j=1:length(x)
v![1,j] += x[j]
end
for i=1:size(v!,1)-1
for j=1:size(v!,2)-i
begin
T c s
c += cos(v![i,j+1])
s += sin(v![i,j])
end
v![i+1,j] += c * s
~
end
end
y! += v![end,1]
end
One can print the reversible IR
ex_ir = NiLangCore.precom_ex(Main, :(begin
begin
T c s
c += cos(v![i,j+1])
s += sin(v![i,j])
end
v![i+1,j] += c * s
~
end)) |> NiLangCore.rmlines
One can see this IR is highly symmetric except the middle statement v![i + 1, j] += c * s
. Hence its reversed code can be easily obtained by typing
NiLangCore.dual_ex(Main, ex_ir)
Intermediate variables s
and c
inside the loop are recovered through uncomputing, which causes a ~2x overhead in times. By taking all forward computing, backward computing and gradient computing into consideration, the final gradient/loss ratio in time is ~6.2.
using BenchmarkTools
pyramid!(zeros(20, 20), randn(20))
NiLang.AD.gradient(Val(1), i_pyramid!, (0.0, zeros(20, 20), randn(20)))
The stack-based approach can achieve similar performance. However, the memory overhead is much larger.
function i_pyramid_stack!(y!, v!, x::AbstractVector{T}, stack) where T
size(v!,2) == size(v!,1) == length(x)
for j=1:length(x)
v![1,j] += x[j]
end
for i=1:size(v!,1)-1
for j=1:size(v!,2)-i
T c s
c += cos(v![i,j+1])
s += sin(v![i,j])
v![i+1,j] += c * s
PUSH!(stack, c)
PUSH!(stack, s)
end
end
y! += v![end,1]
end
NiLang.AD.gradient(Val(1), i_pyramid_stack!, (0.0, zeros(20, 20), randn(20), Float64[]))
Here, we also want to "warn" users that the stack allocation is not available in GPU kernels. Hence only the reverse-computing version is compatible with GPU.
Utilizing Reversibility
Reversibility can be used to save memory in reversible computing, the same argument is not true for checkpointing. For example, one can parameterize a unitary matrix as a consecutive application of two-level unitaries. The reversible implementation of a unitary matrix multiplication does not allocate new spaces.
function i_umm!(x!::AbstractArray, θ)
k ← 0
for l = 1:size(x!, 2)
for j=1:size(x!, 1)
for i=size(x!, 1)-1:-1:j
INC(k)
ROT(x![i,l], x![i+1,l], θ[k])
end
end
end
k → length(θ)
end
Unitary matrices are widely used in quantum simulations, recurrent neural networks [arXiv: 1612.05231] et. al.
Again, since checkpointing does not assume a specific execution order, it can not utilize reversibility to reduce memory usage.
Helping users code better
Suppose we want to compute inline_formula not implemented without using the power function. When one does not have reversibility in mind. They might write something like
"""
blindly implemented power100
"""
function power100(x::T) where T
y = one(T)
for i=1:100
y *= x
end
return y
end
One needs a vector of size 100 to cache intermediate results because *=
is not reversible for regular arithmetics. A reversible style code would be
"""
reverse friendly code. Chunk-wise allocation.
"""
function i_power100(out, x::T) where T
begin
ys ← zeros(T, 100)
ys[1] += x
for i=2:100
ys[i] += ys[i-1]*x
end
end
out += ys[100]
~
end
One preallocates a chunk of memory to store intermediate results. Which is much more efficient than push!
. There is an even better way to write this function with constant memory. One can enjoy the reversibility of *=
in the logarithmic number system.
"""
no allocation at all utilizing log numbers.
"""
function i_power100_noalloc(x100, x::T) where T
if (x!=0, ~)
begin
absx ← zero(T)
lx ← one(ULogarithmic{T})
lx100 ← one(ULogarithmic{T})
absx += abs(x)
lx *= convert(absx) # convert `x` to the log number
for i=1:100
lx100 *= lx
end
end
# convert the result to fixedpoint/floating point numbers
x100 += convert(lx100)
~
end
end
This trick can be used in computing an arithmetic function with Taylor expansion. No allocation or chunk-wise allocation is important when one wants to write GPU-compatible code with KernelAbstractions.jl. It is the user who determines whether the code is suitable for recovering previous states. Hence a user's reversible thinking matters!