The Rectangle Game - MCMC example in Clojure
Here is the implementation of The Rectangle Game from Probabilistic Models of Cognition book done purely in Clojure from the scratch without using probabilistic programming.
We will use the Metropolis-Hastings algorithm which implementation is based on this article by Joseph Moukarzel.
{:deps {org.clojure/clojure {:mvn/version "1.10.1"} cljplot {:mvn/version "0.0.2-SNAPSHOT"} generateme/fastmath {:mvn/version "1.4.0-SNAPSHOT"} clojure2d {:mvn/version "1.2.0-SNAPSHOT"}}}
(require [fastmath.core :as m] [fastmath.random :as r] [clojure2d.core :refer :all] [clojure2d.extra.utils :as utils] [clojure2d.color :as c] [cljplot.build :as b] [cljplot.core :as plot])
Goal
Citing the book:
The data are a set of points in the plane, that we assume to be randomly sampled from within some unknown rectangle. Given the examples, what is the rectangle they came from? We can model this learning as a conditional of the rectangle given the points.
Model
First let's define data points and observe alias to show similarity to probability programming model.
(def observe r/log-likelihood) ;; 4 data points (def data [{:x 0.40 :y 0.70} {:x 0.50 :y 0.40} {:x 0.46 :y 0.63} {:x 0.43 :y 0.51}])
Our parameters are just corners of rectangle, we want all corners to be in the range of [0,1]
. So we can define prior probabilities as uniform.
(def priors (repeat 4 (r/distribution :uniform-real {:lower 0 :upper 1}))) (let [p1 (first priors)] {:sample (r/sample p1) :pdf (r/pdf p1 0.5) :log-pdf (r/lpdf p1 0.5) :mean (r/mean p1) :likelihood (r/likelihood p1 [0 0.5 1]) :log-likelihood (observe p1 [0 0.5 1])})
Let's define our model in terms of log likelihood. Log likelihood is just logarithm of PDF of distribution and is used to calculate score of current set of distribution parameters against data and prior probability. See article.
Why logarithm? Because instead of multiplying probabilities we can just add their logarithms.
r/log-likelihood
alias observe
accepts set of data and returns sum of log likelihoods for each sample.
fastmath uniform distribution doesn't allow reversed ranges. So we have to check it and return negative infinity to mark unlikely situation.
(defn model [[x1 x2 y1 y2]] (if (and (< x1 x2) (< y1 y2)) ;; ranges validation (let [distr-x (r/distribution :uniform-real {:lower x1 :upper x2}) distr-y (r/distribution :uniform-real {:lower y1 :upper y2})] (+ (observe distr-x (mapv :x data)) (observe distr-y (mapv :y data)))) ;; mapData in the book -Inf)) ;; log likelihood of some rectangles {:valid-rectangle (model [0.0 1.0 0.1 0.7]) :invalid-rectangle (model [-1.0 1.0 0.3 1.3]) :missing-points-rectangle (model [0.1 0.2 0.1 0.3])}
That's it for the model itself. We had to define just two elements: prior distributions of our parameters and model in terms of target distributions and sum of log likelihoods against data.
Algorithm
One of the element of MCMC algorithm and Bayesian Inference is to have prior and their log likelihood against data. This part helps algorithm to keep parameters within the range or distribution shape.
(defn priors-log-likelihood [current] (reduce + (map (r/lpdf %1 %2) priors current))) {:valid-rectangle (priors-log-likelihood [0.5 0.5 0.5 0.5]) :invalid-rectangle (priors-log-likelihood [0.5 -0.5 0.5 0.5])}
Total log likelihood (logarithm of numerator from Bayes rule) is just sum of two above.
(defn log-likelihood [current] (+ (model current) (priors-log-likelihood current))) {:valid-rectangle (log-likelihood [0.0 1.0 0.1 0.7]) :invalid-rectangle (log-likelihood [-1.0 1.0 0.3 1.3]) :missing-points-rectangle (log-likelihood [0.1 0.2 0.1 0.3])}
Now to the Metropolis-Hastings algorithm. First let's define random walking part. We move corners of the rectangle using random value from Gaussian distribution with small standard deviation. This random value is our step size.
We set it to value of 0.01
. Which creates rather small step and leads to one of the problems with MCMC algorithms. About this later.
(defn param-step [current] (mapv (r/grand % 0.01) current)) (param-step [0 1 0 1])
Next function just adds score for new sample.
(defn sample [current] (let [new (param-step current)] {:state new :score (log-likelihood new)})) (sample [0.1 0.9 0.1 0.9])
Finally Metropolis-Hastings MCMC algorithm. This is very simple implementation based on iterate
function. Algorithm returns lazy sequence.
(defn acceptance-rule "Should we accept new points or not?" [old-l new-l] (or (> new-l old-l) ;; always accept when new score is greater than old one (< (r/drand) (m/exp (- new-l old-l))) ;; if not, accept with the probability defined by the ratio of likelihoods )) (defn metropolis-hastings [init] (let [first-step {:state init :score (log-likelihood init) :accepted [init] :rejected [] :path [init]}] (iterate (fn [{:keys [state score accepted rejected path] :as all}] (let [{new-state :state new-score :score} (sample state) new-step (if (acceptance-rule score new-score) {:state new-state :score new-score :accepted (conj accepted new-state) :rejected rejected :path path} (update all :rejected conj new-state))] (update new-step :path conj new-state))) first-step))) (first (drop 100 (metropolis-hastings [0 1 0 1])))
Visualization
Let's see some of the rectangles and the best one. We start with biggest possible rectangle.
(def mh-instance (metropolis-hastings [0 1 0 1])) (def samples (->> (nth mh-instance 50000) ;; run MCMC inference (:accepted) ;; select accepted list (drop 100) ;; burn (take-nth 100) ;; lag (take 150))) ;; samples (count samples)
Best rectangle (minimal one) can be found maximizing log likelihood. Warning: this is not the last one from the accepted samples.
(def best (apply max-key log-likelihood samples)) best
Now we are going to draw rectangles using Clojure2d library.
(def img (with-canvas [c (canvas 500 500)] (set-background c :white) (scale c 500) (set-color c 0x99ccff 10) (doseq [[x1 x2 y1 y2] samples] (rect c x1 y1 (- x2 x1) (- y2 y1))) (set-color c (c/darken 0x99ccff)) (let [[x1 x2 y1 y2] best] (set-stroke c (/ 500.0)) (rect c x1 y1 (- x2 x1) (- y2 y1) true)) (set-color c :black) (doseq [{:keys [x y]} data] (ellipse c x y 0.01 0.01)) c)) (save img "/results/rectangles.jpg")
Charts and analysis
Now, we are going to see some data visualizations and discover one of the problems of MCMC methods.
First let's see distributions of all four rectangle corners.
(def last-accepted (->> (nth mh-instance 50000) (:accepted) (drop 10000))) (plot/save (plot/xy-chart {:width 700 :height 320} (b/lattice :histogram (zipmap [:x1 :x2 :y1 :y2] (apply map vector last-accepted)) {:bins 30 :density? true :type :lollipops} {:label name :shape [1 4]}) (b/update-scales :x :ticks 4) (b/add-axes :bottom) (b/add-axes :left) (b/add-label :left "density") (b/add-label :bottom "rectangle corner") (b/add-label :top "Histograms of parameters")) "/results/densities.jpg")
Distributions are highly skewed due to the fact that we didn't allow rectangles without points inside. Which is also visible below.
Red points represent rejected samples, blue accepted. You can also observe random walk path towards corners.
(def accepted-rejected (select-keys (nth mh-instance 10000) [:accepted :rejected])) (defn xy1 [[x1 _ y1 _]] [x1 y1]) (defn xy2 [[_ x2 _ y2]] [x2 y2]) (plot/save (plot/xy-chart {:width 700 :height 700} (b/series [:grid] [:scatter (map xy1 (:rejected accepted-rejected)) {:color :red}] [:scatter (map xy2 (:rejected accepted-rejected)) {:color :red}] [:scatter (map xy1 (:accepted accepted-rejected))] [:scatter (map xy2 (:accepted accepted-rejected))]) (b/add-axes :bottom) (b/add-axes :left) (b/add-label :left "y") (b/add-label :bottom "x") (b/add-label :top "Accepted (blue) and rejected (red) samples")) "/results/accepted-rejected.jpg")
(def random-walking (:path (nth mh-instance 1000))) (plot/save (plot/xy-chart {:width 700 :height 700} (b/series [:grid] [:line (map xy1 random-walking)] [:line (map xy2 random-walking)]) (b/add-axes :bottom) (b/add-axes :left) (b/add-label :left "y") (b/add-label :bottom "x") (b/add-label :top "First 1000 steps")) "/results/paths.jpg")
Next chart shows how score changed every accepted sample. You can see that after 2000 accepted samples process is kind of stabilized.
However shape of the trace (log likelihood and x1 value below) shows certain problem. The chain is not random enough. There are two reasons:
- Very small step in random walk (standard deviation of random gaussian was set to 0.01).
- The number of samples is too small (actually this was partly mitigated by taking every hundredth sample for our visualization).
(def ll-vals (->> (nth mh-instance 10000) (:accepted) (map log-likelihood))) (defn trace-plot [data y-axis-name] (plot/xy-chart {:width 700 :height 300} (b/series [:grid] [:line data]) (b/update-scales :x :fmt int) (b/add-axes :bottom) (b/add-axes :left) (b/add-label :left y-axis-name) (b/add-label :bottom "step"))) (plot/save (trace-plot (map-indexed vector ll-vals) "log(likelihood)") "/results/loglikelihood.jpg")
(def x1-vals (->> (nth mh-instance 10000) (:accepted) (map first))) (plot/save (trace-plot (map-indexed vector x1-vals) "value") "/results/x1.jpg")
The last one, ACF and PACF plots, proves our observation that samples are highly correlated (AR process) and we need to adjust algorithm parameters.
See more about this here.
(def near-end (map first (take-last 1000 last-accepted))) (plot/save (plot/xy-chart {:width 700 :height 350} (b/series [:grid nil {:position [0 1]}] [:acf near-end {:lags 100 :position [0 1] :label "ACF"}] [:grid nil {:position [0 0]}] [:pacf near-end {:lags 100 :position [0 0] :label "PACF"}]) (b/update-scales :x :fmt int) (b/add-axes :bottom) (b/add-axes :left) (b/add-label :left "autocorrelation") (b/add-label :bottom "lag") (b/add-label :top "ACF/PACF")) "/results/acf-pacf.jpg")
Summary
Choosing above example I wanted to show three things:
- Metropolis-Hastings algorithm is very simple but powerful method for probabilistic inference. The Algorithm section contains complete and small set of functions which can be reused for wide variety of problems without (ok, almost) any changes.
- Model definition can be done directly and very elegantly in Clojure. There is no actual difference between Clojure version here and WebPPL version presented in the book. Of course it's not always the case.
- Fortunately (yes!) we've encountered a problem. Trace and ACF plots are great tools to do investigation.
I hope you enjoyed and thanks for reading! Questions and issues go to this topic.