Reference Counting in Clojure

An investigation of temporary resource leaks during the evaluation of mathematical expressions.

Clojure is based on the Java Virtual Machine (JVM). The JVM have a "fire and forget" resource model, where objects are allocated on the JVM heap and automatically deallocated at some (albeit non-deterministic) point later in time by the Garbage Collector (GC). When dealing with large objects and native libraries, memory is stored outside of the JVM, which is known as "off heap" allocation. The JVM does not know about this memory and will make no attempt on managing it for you. Instead, you have to do it manually (giving you the C vibes doesn't it?). Java provides a syntactic construct try-with-resources that can automatically close resources when they leave a scope. The Clojure equivalent is the with-open macro. While this works perfectly fine for simple expressions, it fails for more complicated mathematical expressions commonly found in numeric computing, where intermediate values needs immediate cleanup in order to not run Out Of Memory (OOM). This notebook aims to investigate the problem in the context of Python, where this quickly becomes a problem. We will be using the Python interop library libpython-clj for our experiments.

Setup

Start with specifying some clojure dependencies

{:deps
 {org.clojure/clojure {:mvn/version "1.10.3"}
  clj-python/libpython-clj {:mvn/version "2.024"}}}
Extensible Data Notation

and installing the python dependencies

pip install numpy
6.5s

Import numpy into clojure

(ns cdeln.next.lexref
  (:require
   [libpython-clj2.python :as py]
   [libpython-clj2.require :refer [require-python]]))
(require-python '[builtins :as pyb])
(require-python '[numpy :as np])
21.8s

Define a helper function to inspect numpy arrays.

(defn array-info [array]
  {:dtype (py/py.- array dtype)
   :shape (py/py.- array shape)})
0.0s

Create a 2x3x4 array and ask about some facts.

(def tmp (np/zeros [2 3 4]))
(println :type (type tmp))
(println (array-info tmp))
(println tmp)
0.4s

Awesome! Let's write some arithmetic expressions

(let [ones (np/ones [2 3 4])
      fours (np/multiply 2 (np/add ones ones))]
  (println (array-info fours))
  (println fours))
0.3s

It works! However, as described in the intro, there are no guarantees that either of the arrays ones or fours, or the intermediate array produced by (np/add ones ones), are deallocated at the end of the let expression. libpython-clj provides a macro with-gil-stack-rc-context that grabs the python Global Interpret Lock (GIL) and ensures that all python objects allocated in scope are deallocated at the end. Since all python objects are released at the end of that macro, the final value needs to be converted back into a value on the JVM, and then copied back to python again if we want to keep the Python result after scope exit. Let's check out these converter functions in libpython-clj

(let [np-array-1 (np/ones [3])
      jvm-array (py/->jvm np-array-1)
      np-array-2 (py/->python jvm-array)]
  (println (pyb/type np-array-1))
  (println (type jvm-array))
  (println (pyb/type np-array-2)))
0.3s

That's not the expected result. The array object in Python is copied into a built in clojure vector and then copied back into a python built in tuple. To enable consistent and better memory representations across the JVM/Python boundary, we need to import the np-array sub-namespace from libpython-clj

(require 'libpython-clj2.python.np-array)
(let [np-array-1 (np/ones [3])
      jvm-array (py/->jvm np-array-1)
      np-array-2 (py/->python jvm-array)]
  (println (pyb/type np-array-1))
  (println (type jvm-array))
  (println (pyb/type np-array-2)))
2.1s

Looks better! Now the JVM representation is a DirectTensor instead of a PersistentVector which is much more memory efficient (contiguous memory layout), and we can copy forth and back between clojure and python using native memory only (probably just boils down to copying of some header data and a memcpy in the end). Let's write a macro that automates this memory jiggling for us

(defmacro with-python [& body]
  `(py/->python
     (py/with-gil-stack-rc-context
       (py/->jvm (do ~@body)))))
0.0s

and define a dummy function representing a computation of a nested mathematical expression

(defn compute [x]
  (with-python
    (np/add x (np/add x (np/add x (np/add x (np/add x x)))))))
0.0s

and try it out

(let [x (np/ones [2 3 4])
      y (compute x)]
  (println (array-info y))
  (println y))
0.3s

So far all good! While this macro ensures that everything is cleaned up at the end, it does not ensure that intermediate expressions in the body are promptly released, so large expressions can still require much more memory than they need. For small arrays this that is not an issue, but for larger arrays we can easily run OOM. We need to come up with a smarter strategy to properly handle these memory issues.

To confirm all just above, we need some way to check the memory usage before, during and after the evaluation of an expression. While you can use clj-memory-meter for measuring the memory usage on the JVM side, and you can just use sys.getobjsize on the Python size, these do not give accurate numbers after copying forth and back between the JVM/Python boundary (I will not showcase this in this notebook, but confirmed in REPL). A fool proof approach is to just use the operating system directly (running on GNU/Linux)

free -m
0.5s

Let's wrap that with Clojure!

(require '[clojure.java.shell :as shell])
(require '[clojure.string :as str])
(defn zip [& colls]
  (apply map vector colls))
(defn unzip [zipped]
  (apply zip zipped))
(defn string->keyword [s]
  "Intern a string to a keyword.
  Replace / with - to not accidentally create namespaced keywords."
  (keyword (str/replace s "/" "-")))
(defn string->integer [s]
  (Long/parseLong s))
(def header-regex (re-pattern (str/join "" (repeat 6 "\\s+([\\w/]+)"))))
(def memory-regex (re-pattern (str "Mem:" (str/join "" (repeat 6 "\\s+(\\d+)")))) )
(defn get-memory-info []
  (let [row-strings (-> (shell/sh "free" "-m") :out (str/split #"\n"))
        header-items (rest (re-matches header-regex (first row-strings)))
        memory-items (rest (re-matches memory-regex (second row-strings)))]
    (into {} (zip (map string->keyword header-items)
                  (map string->integer memory-items)))))
(defn get-used-memory []
  (:used (get-memory-info)))
(defn get-available-memory []
  (:available (get-memory-info)))
(println "Used     :" (get-used-memory))
(println "Available:" (get-available-memory))
0.3s

The problem

Now when we have our environment and helper functions setup, let's take a closer look at the problem and check the memory usage of different expressions. First we need to write some code to monitor the memory consumption during the evaluation of an expression.

(def thread-sleep-interval-millis 10)
(defn monitor-loop [monitor-fn flag state]
  "Append pairs of timestamps and monitor-fn values to state while flag is true"
  (let [start-time (System/nanoTime)]
    (while @flag
      (let [t1 (- (System/nanoTime) start-time)
            v (monitor-fn)
            t2 (- (System/nanoTime) start-time)
            t (* 1e-9 0.5 (+ t1 t2))]
        (swap! state conj [t v])
        (Thread/sleep thread-sleep-interval-millis)))))
(defn create-monitor [monitor-fn]
  (let [flag (atom true)
        state (atom [])
        thread (Thread. #(monitor-loop monitor-fn flag state))]
    {:state state
     :flag flag
     :thread thread}))
(defn run-with-monitor [main-fn monitor-fn]
  "Run main-fn with monitor-fn running as a monitor in the background.
  Return the result of main-fn and a tuple of timestamps and monitor-fn values.
  The monitor-fn is run for some time before and after main-fn to ensure accurate memory profile."
  (let [{:keys [state flag thread]} (create-monitor monitor-fn)]
    (.start thread)
    (Thread/sleep (* 10 thread-sleep-interval-millis))
    (let [result (main-fn)]
      (Thread/sleep (* 10 thread-sleep-interval-millis))
      (reset! flag false)
      (.join thread)
      [result (unzip @state)])))
0.0s

And we need some plotting

(defn plot [timestamps memory-usage]
  "Based on: https://nextjournal.com/btowers/using-plotly-with-clojure"
  ^{:nextjournal/viewer :plotly}
   {:data [{:x timestamps :y memory-usage}]
    :layout {:title "Memory usage"
             :xaxis1 {:title "time"}
             :yaxis1 {:title "Mb"}
             :margin {:t 50 :l 50 :b 50}}})
(defn eval-and-plot-mem-usage [f]
  "Evaluates a function with zero arguments and plots the memory profile."
  (let [[_ [timestamps memory-usage]] (run-with-monitor f #(get-used-memory))]
    (plot timestamps memory-usage)))
(defmacro with-plot-mem-usage [& body]
  "Evaluates body and plots the memory profile."
  `(eval-and-plot-mem-usage
     (fn []
       ~@body
       :ok)))
0.0s

Let's try it out on our compute expression! Define a 100Mb array input array globally so that allocation does not affect our measurements later

(with-plot-mem-usage
  (def input-array (np/ones [1000 1000 100] :dtype :int8)))
0.5s

As expected, around 100Mb is allocated by that expression. Let's see how the memory profile looks like when evaluating our compute expression a couple of times. We will store the result in a variable such that the expected memory consumption is 100Mb at the end of scope.

(with-plot-mem-usage
  (def x (compute input-array)))
1.3s

This clearly shows the massive temporal resource leak! This graph deserves some comments.

  1. Memory usage before evaluation is 1943 Mb

  2. Memory usage after evaluation is 2044 Mb (about 100 Mb difference, as expected)

  3. The memory usage peaks at 2427 Mb (about 500 Mb difference) which is due to the intermediate expressions not being released immediately as they are not used (the result of each np/add allocate 100Mb each).

  4. The sharp drop from 2427 Mb to 1956 Mb is due to the release of the intermediate results on the Python side at scope exit of py/with-gil-stack-rc-context

  5. The final raise from 1956 Mb to 2044 is most likely due to the re-allocation on the JVM side from the py/->jvm or on the Python side from py/->python

  6. It is unclear why the memory drops to the same level as before evaluation.

Let us evaluate the expression again and see what happens

(with-plot-mem-usage
  (def y (compute input-array)))
1.0s

The memory profile looks very similar as before. Let us run another time and confirm that the memory profile looks the same on consecutive evaluations

(with-plot-mem-usage
  (def z (compute input-array)))
1.1s

And indeed it does! These graphs illustrate the issues of "temporary resource leaks" that occurs on the boundary between the JVM and Python. I call them temporary because in the end all resources are reclaimed (as long as all expressions are wrapped in with-python), but evaluation can consume more memory than necessary.

The solution

Equipped with the knowledge gained from the excursion in the previous section we are ready to attack the temporary resource leak problem. We will solve it by implementing out own reference counting system in Clojure, that mimics the system used by CPython (note the C. In general, a Python implementation is not required to implement the GC such that objects are immediately released on scope exit a la C++, and are allowed to behave more like the JVM GC). To illustrate the behaviour in CPython, check this out

class Test: 
  
    def __init__(self, name):
        self.name = name
        print('init', name)
        
    def __del__(self):
        print('del')
        
    def __add__(self, other):
         return Test(self.name + '+' + other.name)
0.0s
z = Test('x') + Test('y')
0.3s

As you can see, there are 3 inits and 2 dels, indicating that 3 Test objects were created and 2 deleted in the evaluation of the expression. This is what you would get using C++ ctors and dtors as well. We can not do anything similar to this in Java. Java finalizers is the closest you get, but even with that (it's deprecated, so don't use it) you would most likely not get a nice pairing up of inits and dels as above due to the GC non-determinism. The seemingly official way to hook into an object lifecycle is using reference queues and weak references. In my opinion, it is very complicated and does not help us in solving our problem (the python interop library does these kind of stuff with the GC macros used before). Instead, we can add a custom evaluation strategy a la SICP on top of clojure and implement the resource ownership semantics we want!

Let us think a bit about the expected semantics and define the terminology. Whenever we evaluate an expression and just let it go, it should be released immediately after print. For example, the evaluating the expression

(+ 1 (+ 2 3))
0.0s

binds the literals 1, 2 and 3 to the arguments of the + function. Supplying a value to a function binds it to a name, increasing the reference count by 1. The sub-expression (+ 2 3) produce a temporary value 5 which is bound to the second argument of +. The value of the expression whole expression is a temporary 6. All literals 1, 2 and 3 are owned by the caller, and their reference count should be the same before and after evaluation. The final value 6 can have different reference counts because of potential aliasing (not applicable for fundamental types like integers which have copy semantics, but becomes important for arrays).

What about this expression?

(conj [1] 2)
0.0s

The clojure vector itself does not need to be reference counted. In general, any collection does not need reference counting. In practice, literals like the numbers 1 and 2 does not need it either, just think of them as placeholders for larger array objects later. The expectation is that the resulting expression will be a vector of temporaries.

Let's boil this down to code. A good start is to define a protocol for releasable objects with related functions. We will define the interface with a multi-method instead of a protocol for reasons that will become apparent later

(defmulti release! type)
(defn releasable? [x]
  (not (nil? (get-method release! (type x)))))
0.0s

Then define what we mean by a tree (i.e. collections that we intend to not reference count but whose elements we intend to do book keeping for) and methods of traversing them (traverse with a function preserving structure)

(defprotocol ITree
  (tree-vals [this])
  (tree-map [f this]))
(defn tree?
  "Check if an object is a tree.
  If an object is not a tree it is a leaf."
  [x]
  (satisfies? ITree x))
(defn leaf-map
  "Map a function over the leafs of a tree.
  The tree structure is preserved."
  [f x]
  (if (tree? x)
    (tree-map x (partial leaf-map f))
    (f x)))
0.0s

Then we implement the interface for all collections of interest

(extend-type clojure.lang.PersistentVector
  ITree
  (tree-vals [this] (seq this))
  (tree-map [this f]
    (mapv f this)))
(extend-type clojure.lang.PersistentList
  ITree
  (tree-vals [this] (seq this))
  (tree-map [this f]
    (into '() (reverse (map f this)))))
;; Function varargs become this type
(extend-type clojure.lang.ArraySeq
  ITree
  (tree-vals [this] (seq this))
  (tree-map [this f]
    (mapv f this)))
(extend-type clojure.lang.PersistentArrayMap
  ITree
  (tree-vals [this] (vals this))
  (tree-map [this f]
    (into {} (map (fn [[k v]] [k (f v)]) this))))
(extend-type clojure.lang.PersistentHashMap
  ITree
  (tree-vals [this] (vals this))
  (tree-map [this f]
    (into {} (map (fn [[k v]] [k (f v)]) this))))
(extend-type clojure.lang.PersistentHashSet
  ITree
  (tree-vals [this] (seq this))
  (tree-map [this f]
    (into #{} (map f this))))
0.0s

Trying it out on deeply nested collections we can see that it works as expected.

(require '[clojure.pprint :refer [pprint]])
(pprint (leaf-map vector '(1 [2 {:a :b :c #{3 4}}])))
0.4s

Indeed it does. All values are transformed to singleton vectors, key maps are preserved, and the overall structure of the data is preserved.

Define a function to enumerate all leaves as a flat sequence

(defn leaf-seq [x]
  (if (tree? x)
    (mapcat leaf-seq (tree-vals x))
    (list x)))
(pprint (leaf-seq '(1 [2 {:a :b :c #{3 4}}])))
0.3s

Define a function to filter leaves satisfying a predicate

(defn leaf-filter [pred tree]
  (filter pred (leaf-seq tree)))
(pprint (leaf-filter number? '(1 [2 {:a :b :c #{3 4}}])))
0.3s

and a function to search for a specific leaf element

(defn leaf-search [pred tree]
  (first (leaf-filter pred tree)))
(leaf-search keyword? ' (1 [2 {:a :b :c #{3 4}}]))
0.0s

Define a record type for lexical reference counting with associated functions. We define a lexical reference as a record of a releasable value, a reference count and a boolean for debugging purposes. The reference count and the boolean state needs to be in sync, hence they are wrapped in refs (plain old Clojure refs, that is).

(defrecord LexRef [value count released?])
(defn lex-ref? [x]
  (instance? LexRef x))
(defn lex-ref-create
  ([value]
   (lex-ref-create value 0))
  ([value count]
   (assert (releasable? value))
   (assert (not (lex-ref? value)))
   (->LexRef value (ref count) (ref false))))
(defn lex-ref-value [x]
  (if (lex-ref? x)
    (:value x)
    x))
(defn lex-ref-inc! [x]
  (assert (lex-ref? x))
  (dosync
   (alter (:count x) inc)))
(defn lex-ref-dec! [x]
  (assert (lex-ref? x))
  (dosync
    (alter (:count x) dec)))
;; A lexical reference is released by releasing the underlying value.
;; The reference count must be zero, and the reference must not have been released already.
(defmethod release! LexRef [this]
  (dosync
   (assert (zero? @(:count this)))
   (assert (not @(:released? this)))
   (release! (:value this))
   (alter (:released? this) (fn [_] true))))
0.1s

Note that we do not add the logic of releasing in lex-ref-dec! since we will manually handle jiggling around temporaries with zero ref count later.

Make numbers releasable for debugging purposes

(defmethod release! java.lang.Number [this]
  (println "release number " this))
0.0s

Function application

The majority of the brain work will be in redefining how functions are evaluated. The semantics were described in the previous section, and now we will implement that. As mentioned, we need to be able to detect aliasing, hence we start by defining a binary multi-method for equality test between two values that dispatch on type

(defmulti equals?
  (fn [a b] [(type a) (type b)]))
;; Default checks for equality on the JVM
(defmethod equals? :default [a b] (identical? a b))
0.0s

Withing a function body, any intermediate value needs to be resolved for aliasing against the function arguments. This is done by the function resolve-lex-refs , which takes a set of lexical references (i.e. function arguments), and a set of values (potentially not lexically referenced yet), and transform those values into properly initialized lexical references.

(defn- lex-ref-value-equals?
  ([y-val]
   (partial lex-ref-value-equals? y-val))
  ([y-val x-ref]
   (equals? y-val (lex-ref-value x-ref))))
(defn- resolve-lex-ref [x-refs y-val]
  (if-let [x-ref (leaf-search (lex-ref-value-equals? y-val) x-refs)]
    x-ref
    (if (lex-ref? y-val)
      y-val
      (if (releasable? y-val)
        (lex-ref-create y-val)
        y-val))))
(defn- resolve-lex-refs [x-refs y-vals]
  (leaf-map (partial resolve-lex-ref x-refs) y-vals))
0.0s

Now we are ready to implement function application. In order to support higher order functions we need to take some extra care. The implementation is really just the semantics described before encoded in code. When a function is applied to some arguments, all reference counts are incremented by one. The arguments might be a combination of lexical references and plain values such as externally owned data or literals. Then, the actual function is applied to transformed arguments, lexical references are lowered to plain values and functions are lifted to functions operating on lexical references. Then, the result is resolved against the arguments to account for aliasing and to initialize new lexical references. Then function returns by first decrementing the reference counts of the arguments (they are being unbound from their name), bump up the reference count of the freshly produced references (they are bound to the name in the let expression), release all dangling arguments (those that were passed as temporaries into this function. Finally we can decrement the reference count of the return values. Those that are not aliased will now have a reference count of 0 (temporary values), and those that alias any argument will have a non-zero count.

(declare lex-ref-apply)
(defn- dangling? [x]
  (and (lex-ref? x)
       (zero? @(:count x))))
(defn- lex-ref-fn [f]
  (fn [& xs]
    (lex-ref-apply f xs)))
(defn- fn-arg [x]
  (if (fn? x)
    (lex-ref-fn x)
    (lex-ref-value x)))
(defn lex-ref-apply [f args]
  (run! lex-ref-inc! (leaf-filter lex-ref? args))
  (let [y-vals (apply f (leaf-map fn-arg args))
        y-refs (resolve-lex-refs args y-vals)]
    (run! lex-ref-dec! (leaf-filter lex-ref? args))
    (run! lex-ref-inc! (leaf-filter lex-ref? y-refs))
    (run! release! (leaf-filter dangling? args))
    (run! lex-ref-dec! (leaf-filter lex-ref? y-refs))
    y-refs))
0.0s

Hope you are still reading and convinced that all those steps are necessary to properly do the reference counting.

Define a helper function for prettier printing of lexical references

(defn- lex-ref->map [x]
  (cond (lex-ref? x) {:value (:value x)
                      :count @(:count x)
                      :released? @(:released? x)}
        (tree? x) (leaf-map lex-ref->map x)
        :else x))
0.0s

Let's put this into action! Define a temporary and a bound variable (reference count 0 and 1 respectively) and see what happens when we apply a function that returns it's first argument to them

(def temporary (lex-ref-create 13))
(def bound (lex-ref-create 37 1))
(println "temporary before" (lex-ref->map temporary))
(println "bound before    " (lex-ref->map bound))
(def result (lex-ref-apply (fn [a b] a) [temporary bound]))
(println "result          " (lex-ref->map result))
(println "temporary after " (lex-ref->map temporary))
(println "bound after     " (lex-ref->map bound))
0.3s

Behaves just as expected! Prior to the function application, the temporary have a count of 0 and the bound have a count of 1. The function propagates the temporary without releasing it. Note that we have 2 names referring to the same temporary. That might look like a bug, but the def construct is not part of the reference counting machinery, it is just a way to illustrate intermediate values calculated by it for development purposes. We will come back to this later.

Let's see what happens if we return the second argument instead

(def temporary (lex-ref-create 13))
(def bound (lex-ref-create 37 1))
(println "temporary before" (lex-ref->map temporary))
(println "bound before    " (lex-ref->map bound))
(def result (lex-ref-apply (fn [a b] [b b]) [temporary bound]))
(println "result          " (lex-ref->map result))
(println "temporary after " (lex-ref->map temporary))
(println "bound after     " (lex-ref->map bound))
0.3s

Also behaves as expected! The temporary gets released before the value is assigned to the result variable, just like we want it too. Again, note that if we assign the result to a name in a reference counted context, the reference count should be 2 and not 1.

Expression transform

Now when function application is implemented, it's time to write the machinery to transform any expression into a reference counted equivalent. Personally, I like to stay away from macros as long as possible, and stick to normal functions for doing symbolic computations. Then I'll wrap it with a macro for end usage. Let's implement a function lex-ref-expr that takes a symbolic expression and transforms it into a reference counted expression.

(defn cons? [x]
  (= (type x) clojure.lang.Cons))
;; Yeah, lists are not cons in clojure...
(defn- list-like? [x]
  (or (list? x) (cons? x)))
(declare list-expr)
(defn lex-ref-expr [expr]
  (cond (list-like? expr) (list-expr expr)
        (tree? expr) (tree-map expr lex-ref-expr)
        :else expr))
0.0s

Alright, so an expression is either something listy, something treeish or something else. Most of the time it will be lists representing function or macro applications which is exactly what we want to transform. If it's a tree then we apply the structure preserving map. Otherwise we leave it as it is.

Transforming a list expression can be done using multi-methods again. This is a clean approach where we can dispatch on the list head and literally support any kind of list expression. Any list head not registered by a multi method and it is a macro name, it is expanded only if it is explicitly allowed. Otherwise it is assumed to be a function and transformed to an expression with lex-ref-apply defined previously.

(defmulti on-list-expr identity)
(def ^:private macro-whitelist (atom #{}))
(defn allow-macro [name]
  (swap! macro-whitelist (fn [s] (conj s name))))
(defn- macro? [sym]
  (when (symbol? sym)
    (:macro (meta (resolve sym)))))
(def ^:dynamic *allow-macros* false)
(defn- apply-expr [f xs]
  `(lex-ref-apply ~f ~(mapv lex-ref-expr xs)))
(defn- list-expr [expr]
  (let [[what & args] expr]
    (if-let [dispatch (get-method on-list-expr what)]
      (apply dispatch args)
      (if (macro? what)
        (if (or *allow-macros* (contains? @macro-whitelist what))
          (lex-ref-expr (macroexpand-1 expr))
          (throw (ex-info (str "Unsupported lexref macro " what) {})))
        (apply-expr what args)))))
0.0s

Implement multi methods for builtin forms

(defmethod on-list-expr 'if [cond-expr & body-exprs]
  `(if ~cond-expr
     ~@(map lex-ref-expr body-exprs)))
(defmethod on-list-expr 'do [& args]
  `(do ~@(map lex-ref-expr args)))
(defmethod on-list-expr 'fn [args & body] `(fn ~args ~@body))
(defmethod on-list-expr 'fn* [args & body] `(fn* ~args ~@body))
0.0s

Let's check how a transformed expression looks like

(def test-expr `(conj [] (update {:a 1} :a + (+ 2 3))))
(pprint (lex-ref-expr test-expr))
0.3s

Transformed expressions can be quite tedious to read due to the added namespaces, let's remove that for readability

(defn remove-namespace [expr]
  (cond
    (symbol? expr) (symbol (name expr))
    (keyword? expr) (keyword (name expr))
    (seq? expr) (map remove-namespace expr)
    (vector? expr) (mapv remove-namespace expr)
    :else expr))
(pprint (remove-namespace (lex-ref-expr test-expr)))
0.3s

That is easier to read. And it is correct. Let's finalize the framework by writing a macro that transforms an expression and handle some reference counting logic at the context boundaries. Start with defining a function that to initialize variables for reference counting

(defn lex-ref-init [x]
  (cond (releasable? x) (lex-ref-create x)
        :else x))
0.0s

Then define the top level expression transformer

(defn- bind-external-name-expr [var-name]
  [var-name `(leaf-map lex-ref-init ~var-name)])
(defn with-lexref-sym
  ([expr]
   (with-lexref-sym [] expr))
  ([vars expr]
   `(let [~@(mapcat bind-external-name-expr vars)]
      (leaf-map lex-ref-value
        (lex-ref-apply (fn ~vars ~(lex-ref-expr expr)) ~vars)))))
(defmacro with-lexref
  ([expr]
   (with-lexref-sym [] expr))
  ([vars expr]
   (with-lexref-sym vars expr)))
0.0s

The core logic is done by with-lexref-sym . It takes an expression, optionally preceded by a vector of external variables that are moved into the context. I borrow the terminology from C++ world here, where move means transfer of ownership between two scopes. This is useful when we do a computation yielding some result, which we then want to pass in to further computation while discarding it. Unless the transformed expression alias any of the passed variables, they are released at the end of scope and the evaluated expression result is returned.

The implementation simply transforms the passed expression using lex-ref-expr , wraps it with a function and apply it to the externally passed variables. The logic defined previously will take care of all reference counting for us, and we can simply extract the final result with lex-ref-value .

Let us play around by evaluating some expression

(def x 1)
;; Does not release x since it is aliased
(pprint (with-lexref [x] x))
0.3s
;; Does not release x since it is aliased
(pprint (with-lexref [x] [x x x]))
0.3s
;; Does not release the 2 or 3 since they are literals
(pprint (with-lexref [] (assoc {:a 1} :a (+ 2 3))))
0.3s
;; Properly releases the intermediate value 5 produced by (+ 2 3)
(pprint (with-lexref [] (conj [] (update {:a 1} :a + (+ 2 3)))))
0.3s
;; Higher order functions are handled
(with-lexref [] (reduce + [1 2 3 4 5]))
0.4s
;; And fairly complex expressions
(pprint (lex-ref->map (with-lexref [] (reduce + (concat [(+ 1 1)] [(+ 1 2)] [(+ 1 3)])))))
0.3s

What about a let expression?

; Evaluating this will throw a compile time error "Unsupported lexref macro let"
; (with-lexref (let [x 1] (+ x x)))
0.0s

By default, the framework does not expand macros unless explicitly registered with a multi-method or whitelisted. Let expressions is an example of a macro that will not behave correctly if expanded in place. Let us define it's expansion using with-lexref itself! This really shows the benefit of using another level of indirection, since most transforms can be delegated to the framework itself.

(defmethod on-list-expr 'let [bindings & body]
  (let [names (take-nth 2 bindings)
        exprs (take-nth 2 (drop 1 bindings))
        exprs' (map with-lexref-sym exprs)]
    `(let [~@(interleave names exprs')]
       (with-lexref [~@names]
         ~@body))))
(pprint (with-lexref (let [x 1] (+ x x))))
0.3s

Our new evaluation strategy is complete! All temporaries are released as they should. All that's left is to implement the resource interface for Numpy arrays.

Numpy

Integrating Numpy is easy now. Just implement release! and equals? . We tap into the Foreign Function Interface (FFI) of the python library to access the underlying python reference counting functions

(require '[libpython-clj2.python.ffi :as py-ffi])
(defmethod release! :pyobject [x]
  (println "Release array")
  (py-ffi/Py_DecRef x))
(defmethod equals? :pyobject [a b]
  (if (and (pyb/isinstance a np/ndarray)
           (pyb/isinstance b np/ndarray))
    (let [a-base (py/py.- a base)
          b-base (py/py.- b base)]
      (if (and (nil? a-base)
               (nil? b-base))
        false
        (or (= a-base b)
            (= b-base a))))
    (identical? a b)))
0.0s

This is where the motivation of using multi-method for release! is justified. Python objects created by the python interop library are made using reify, so their types are actually dynamically generated on the fly. The python library sets the type field of the object meta to the :pyobject keyword, and this is all the information we have about the type of the object. Then, we need to manually handle the GIL and disable the GC implemented by the python interop library.

(require '[libpython-clj2.python.gc :as py-gc])
(defn- with-python-gc-gil [f]
  (py-gc/with-disabled-gc
    (py/with-gil
      (f))))
(def ^:private in-python-context? (atom false))
(defn with-python [f]
  (if (compare-and-set! in-python-context? false true)
    (let [result (with-python-gc-gil f)]
      (reset! in-python-context? false)
      result)
    (f)))
(defmacro with-python-lexref
  ([vars body]
   `(with-python
      (fn []
        (with-lexref ~vars
          ~body))))
  ([body]
   `(with-python-lexref [] ~body)))
0.0s

Now we can use the with-python-lexref macro to evaluate python expressions!

Let's try it out with the expressions used in the beginning of this article. First, redefine the compute function using our new macro

(defn compute [x]
  (with-python-lexref
    (np/add x (np/add x (np/add x (np/add x (np/add x x)))))))
0.0s

and test it out

(with-plot-mem-usage
  (def input-array (np/ones [1000 1000 100] :dtype :int8)))
0.5s
(with-plot-mem-usage
  (def y2 (compute input-array)))
1.0s
(with-plot-mem-usage
  (def z2 (compute input-array)))
0.9s

And the problem is solved! The 4 intermediate temporary arrays are released. Maximal temporary memory usage is now the expected 200Mb (100Mb for the the temporary argument, another 100Mb for the output of each call to np/add). We can now evaluate mathematical expressions of any size with minimal memory consumption! As a side effect, the computation seems to run faster too (most likely due to the removed copy operations)!

Conclusion

Evaluating complex array expressions can consume a lot of memory. The JVM does not provide strict means of hooking into the object lifecycle. The flexibility of Clojure allows the construction of a reference counting mechanism based on syntax instead of runtime computation. Expressions can be transformed into reference counted equivalents, as demonstrated by this notebook.

The code in this notebook is pasted directly from a complete implementation of lexical references is available as a library at https://github.com/cdeln/lexref-clj .

Runtimes (2)