CounterfactualFairness.jl
CounterfactualFairness.jl is a Julia package that provides an interface for causal inference and counterfactual fairness. This project was completed under the mentorship of Zenna Tavares, Moritz Schauer, Jiahao Chen and Sebastian Vollmer.
Link to the GitHub repository: https://github.com/zenna/CounterfactualFairness.jl
Link to the previous blog post: https://nextjournal.com/archanarw/counterfactual-fairness-blogpost-1
Brief Walkthrough of the CounterfactualFairness.jl
The package is designed with Pearl's Causal Ladder in mind, thus allows for association (constructing the causal model), interventions and counterfactuals.
Importing Required Packages
The required packages for the following demonstration -
CounterfactualFairness (arw branch)
Omega (lang branch)
CausalInference
Distributions (version - 0.25.11)
MLJ
Flux (version - 0.12.6)
For visualization- Plots, GraphPlot, Colors, PrettyPrinting
The packages required to precompile CounterfactualFairness.jl successfully (Since it depends on some unregistered packages, they must also be added) -
InferenceBase
SoftPredicates
ReplicaExchange
OmegaCore
OmegaMH
using PkgPkg.add(PackageSpec(url="https://github.com/zenna/Omega.jl", rev="lang:InferenceBase"))Pkg.add(PackageSpec(url="https://github.com/zenna/Omega.jl", rev="lang:SoftPredicates"))Pkg.add(PackageSpec(url="https://github.com/zenna/Omega.jl", rev="lang:ReplicaExchange"))Pkg.add(PackageSpec(url="https://github.com/zenna/Omega.jl",rev="102cc01d1f7dbb4a4caad822746ced6fa5c7164b:OmegaCore"))Pkg.add(PackageSpec(url="https://github.com/zenna/Omega.jl", rev="102cc01d1f7dbb4a4caad822746ced6fa5c7164b:OmegaMH"))Pkg.add(PackageSpec(url="https://github.com/zenna/Omega.jl", rev="lang"))Pkg.add(PackageSpec(url="https://github.com/zenna/CounterfactualFairness.jl", rev="arw"))Pkg.add("CausalInference")Pkg.add("Distributions")Pkg.add("MLJ")Pkg.add(name = "Flux", version = "0.12.0")Pkg.add("GraphPlot")Pkg.add("Plots")Pkg.add("PrettyPrinting")Pkg.add("Colors")Pkg.add("DataFrames")using Omega, OmegaCore using DataFrames, CounterfactualFairness, CausalInferenceusing Distributions, MLJ, Fluxusing GraphPlot, Plots, Colors, PrettyPrintingAssociation
Using CounterfactualFairness.jl, you may construct a causal model in the following ways -
Automatically from data -
prob_causal_graph(df)can be used to construct a causal model from the dataframedf(by Gaussian mechanism). The function usespcalgfrom the CausalInference.jl is used to construct the causal graph.By loading a causal model from CounterfactualFairness.jl
cm = ;gplot(dag(cm), nodelabel = ([variable(cm, i).name for i in 1:nv(cm)]), nodefillc = colorant"seagreen2", edgestrokec = colorant"black", layout = shell_layout, NODESIZE = 0.4/sqrt(nv(cm)))By entering the distributions over exogenous variables and functions over endogenous variables)
Consider the causal graph below -

Example from Causal Inference In Statisctics - A Primer
This structure could describe the causal mechanism that connects a day's temperature (X), sales at an ice-cream shop (Y) and number of crimes (Z). We may represent this model using CounterfactualFairness.jl as given below -
g = CausalModel(); # Empty causal graphAdding exogenous variables (variables that have no explicit cause within the model)-
U₁ = add_exo_variable!(g, :U₁, 1 ~ Normal(24, 8));U₂ = add_exo_variable!(g, :U₂, 1 ~ Normal(15, 3));U₃ = add_exo_variable!(g, :U₃, 1 ~ Normal(2, 1));Adding endogenous variables (variables whose values depend on other variables in the model) -
Temp = add_endo_variable!(g, :Temp, identity, U₁);IceCreamSales = add_endo_variable!(g, :IceCreamSales, *, Temp, U₂);Crime = add_endo_variable!(g, :Crime, /, Temp, U₃);Visualizing the graph -
gplot(dag(g), nodelabel = ([variable(g, i).name for i in 1:nv(g)]), nodefillc = colorant"seagreen2", edgestrokec = colorant"black", layout = stressmajorize_layout, NODESIZE = 0.4/sqrt(nv(g)))To apply a context (values of exogenous variables) to the model, we may use apply_context or use g(ω) where ω is a random variable as defined in Omega.
apply_context(g, (U₁ = 26.2, U₂ = 14.8, U₃ = 2.)) |> pprintg(defω(), return_type = NamedTuple) |> pprintInterventions
An intervention is applied to a causal model by fixing the value of a particular variable in the model and modifying the entire model accordingly (by removing all incoming edges to the variable since it now has a fixed value). Interventions may be computed as follows -
(Continuing with the previous example)
i = CounterfactualFairness.Intervention(:Temp, 24.) # Fixing value of Temp to 24i |> pprintm = apply_intervention(g, i)intervened_model = randsample(ω -> m(ω))intervened_model |> pprintgplot(dag(m), nodelabel = ([variable(m, i).name for i in 1:nv(m)]), nodefillc = colorant"seagreen2", edgestrokec = colorant"black", layout = stressmajorize_layout, NODESIZE = 0.4/sqrt(nv(m)))Counterfactuals
inline_formula not implemented
where V contain observed variables.
To obtain counterfactuals, we condition :Crime on the observed values and the intervention inline_formula not implemented.
count = ω -> counterfactual(:Crime, (IceCreamSales = 340., ), i, g, ω); randsample(count);Computing counterfactuals using MLJ
Wrapper to compute counterfactuals for each observation in a given dataset:
toy = # Synthetic causal modelcfw = CounterfactualWrapper(test = gausscitest, p = 0.1, cf = :Y, interventions = CounterfactualFairness.Intervention(:A, 40.)) Now we may use cfw in fit/transform workflow in MLJ.
Training a neural network in a way that the predictor is counterfactually fair
Check for Sufficient Condition for Counterfactual Fairness
Lemma 1: Let G be the causal graph of the given model (U, V, F). Then Ŷ will be counterfactually fair if it is a function of the non-descendants of A.
To check for sufficient condition given above, isNonDesc returns true if the condition is satisfied and false if it isn't.
isNonDesc(g, (:IceCreamSales, U₃), (:Temp,)); # false since Temp is a descendant of IceCreamSales isNonDesc(g, (:U₁, :IceCreamSales), (:Crime,)); # true since Crime is not a descendant of neither U₁ nor IceCreamSalesTraining using MLJ Interface
Creating synthetic dataset -
toy = n = 500X = (CausalVar(toy, :X1), CausalVar(toy, :X2), CausalVar(toy, :X3), CausalVar(toy, :X4))U = (CausalVar(toy, :U₁), CausalVar(toy, :U₂), CausalVar(toy, :U₃), CausalVar(toy, :U₄), CausalVar(toy, :U₅))A = CausalVar(toy, :A)Y = CausalVar(toy, :Y)df = DataFrame( X1 = randsample(ω -> X[1](ω), n), X2 = randsample(ω -> X[2](ω), n), X3 = randsample(ω -> X[3](ω), n), X4 = randsample(ω -> X[4](ω), n), A = randsample(ω -> A(ω), n), Y = randsample(ω -> Y(ω), n))pprint(df)Wrapper for the adversarial learning for counterfactual fairness -
model = AdversarialWrapper(cm = toy, grp = :A, latent = [:U₁, :U₂, :U₃, :U₄, :U₅], observed = [:X1, :X2, :X3, :X4], predictor = Chain(Dense(4, 3), Dense(3, 2), Dense(2, 1)), adversary = Chain(Dense(5, 3), Dense(3, 2, relu)), loss = Flux.Losses.logitbinarycrossentropy, iters = 5)model |> pprintNow, model fits into the same framework as other wrappers in MLJ and can be used the same as others.
Using fit! we can train model and predict using predict.
Future Work
Path-specific interventions are not computed correctly in the package currently, which must be corrected.
Benchmark counterfactual explanations
Add recourse methods