CounterfactualFairness.jl

CounterfactualFairness.jl is a Julia package that provides an interface for causal inference and counterfactual fairness. This project was completed under the mentorship of Zenna TavaresMoritz Schauer, Jiahao Chen and Sebastian Vollmer.

Link to the GitHub repository: https://github.com/zenna/CounterfactualFairness.jl

Link to the previous blog post: https://nextjournal.com/archanarw/counterfactual-fairness-blogpost-1

Brief Walkthrough of the CounterfactualFairness.jl

The package is designed with Pearl's Causal Ladder in mind, thus allows for association (constructing the causal model), interventions and counterfactuals.

Importing Required Packages

The required packages for the following demonstration -

  • CounterfactualFairness (arw branch)

  • Omega (lang branch)

  • CausalInference

  • Distributions (version - 0.25.11)

  • MLJ

  • Flux (version - 0.12.6)

  • For visualization- Plots, GraphPlot, Colors, PrettyPrinting

The packages required to precompile CounterfactualFairness.jl successfully (Since it depends on some unregistered packages, they must also be added) -

  • InferenceBase

  • SoftPredicates

  • ReplicaExchange

  • OmegaCore

  • OmegaMH

using Pkg
Pkg.add(PackageSpec(url="https://github.com/zenna/Omega.jl", rev="lang:InferenceBase"))
Pkg.add(PackageSpec(url="https://github.com/zenna/Omega.jl", rev="lang:SoftPredicates"))
Pkg.add(PackageSpec(url="https://github.com/zenna/Omega.jl", rev="lang:ReplicaExchange"))
Pkg.add(PackageSpec(url="https://github.com/zenna/Omega.jl",rev="102cc01d1f7dbb4a4caad822746ced6fa5c7164b:OmegaCore"))
Pkg.add(PackageSpec(url="https://github.com/zenna/Omega.jl", rev="102cc01d1f7dbb4a4caad822746ced6fa5c7164b:OmegaMH"))
48.9s
Pkg.add(PackageSpec(url="https://github.com/zenna/Omega.jl", rev="lang"))
Pkg.add(PackageSpec(url="https://github.com/zenna/CounterfactualFairness.jl", rev="arw"))
Pkg.add("CausalInference")
Pkg.add("Distributions")
Pkg.add("MLJ")
Pkg.add(name = "Flux", version = "0.12.0")
Pkg.add("GraphPlot")
Pkg.add("Plots")
Pkg.add("PrettyPrinting")
Pkg.add("Colors")
Pkg.add("DataFrames")
86.5s
using Omega, OmegaCore 
using DataFrames, CounterfactualFairness, CausalInference
using Distributions, MLJ, Flux
592.8s
using GraphPlot, Plots, Colors, PrettyPrinting
149.1s

Association

Using CounterfactualFairness.jl, you may construct a causal model in the following ways -

  • Automatically from data - prob_causal_graph(df) can be used to construct a causal model from the dataframe df (by Gaussian mechanism). The function uses pcalg from the CausalInference.jl is used to construct the causal graph.

  • By loading a causal model from CounterfactualFairness.jl

cm = @load_law_school;
0.7s
gplot(dag(cm), nodelabel = ([variable(cm, i).name for i in 1:nv(cm)]), nodefillc = colorant"seagreen2", edgestrokec = colorant"black", layout = shell_layout, NODESIZE = 0.4/sqrt(nv(cm)))
15.2s
  • By entering the distributions over exogenous variables and functions over endogenous variables)

Consider the causal graph below -

Example from Causal Inference In Statisctics - A Primer

This structure could describe the causal mechanism that connects a day's temperature (X), sales at an ice-cream shop (Y) and number of crimes (Z). We may represent this model using CounterfactualFairness.jl as given below -

g = CausalModel(); # Empty causal graph
0.2s

Adding exogenous variables (variables that have no explicit cause within the model)-

U₁ = add_exo_variable!(g, :U₁, 1 ~ Normal(24, 8));
U₂ = add_exo_variable!(g, :U₂, 1 ~ Normal(15, 3));
U₃ = add_exo_variable!(g, :U₃, 1 ~ Normal(2, 1));
0.1s

Adding endogenous variables (variables whose values depend on other variables in the model) -

Temp = add_endo_variable!(g, :Temp, identity, U₁);
IceCreamSales = add_endo_variable!(g, :IceCreamSales, *, Temp, U₂);
Crime = add_endo_variable!(g, :Crime, /, Temp, U₃);
0.2s

Visualizing the graph -

gplot(dag(g), nodelabel = ([variable(g, i).name for i in 1:nv(g)]), nodefillc = colorant"seagreen2", edgestrokec = colorant"black", layout = stressmajorize_layout, NODESIZE = 0.4/sqrt(nv(g)))
5.4s

To apply a context (values of exogenous variables) to the model, we may use apply_context or use g(ω) where ω is a random variable as defined in Omega.

apply_context(g, (U₁ = 26.2, U₂ = 14.8, U₃ = 2.)) |> pprint
2.7s
g(defω(), return_type = NamedTuple) |> pprint
0.9s

Interventions

An intervention is applied to a causal model by fixing the value of a particular variable in the model and modifying the entire model accordingly (by removing all incoming edges to the variable since it now has a fixed value). Interventions may be computed as follows -

(Continuing with the previous example)

i = CounterfactualFairness.Intervention(:Temp, 24.) # Fixing value of Temp to 24
i |> pprint
1.6s
m = apply_intervention(g, i)
intervened_model = randsample(ω -> m(ω))
intervened_model |> pprint
1.8s
gplot(dag(m), nodelabel = ([variable(m, i).name for i in 1:nv(m)]), nodefillc = colorant"seagreen2", edgestrokec = colorant"black", layout = stressmajorize_layout, NODESIZE = 0.4/sqrt(nv(m)))
0.5s

Counterfactuals

inline_formula not implemented

where V contain observed variables.

To obtain counterfactuals, we condition :Crime on the observed values and the intervention inline_formula not implemented.

count = ω -> counterfactual(:Crime, (IceCreamSales = 340.,), i, g, ω);
0.1s
@show randsample(count);
2.2s
Computing counterfactuals using MLJ

Wrapper to compute counterfactuals for each observation in a given dataset:

toy = @load_synthetic # Synthetic causal model
cfw = CounterfactualWrapper(test = gausscitest, p = 0.1, cf = :Y, interventions = CounterfactualFairness.Intervention(:A, 40.)) 
1.2s

Now we may use cfw in fit/transform workflow in MLJ.

Training a neural network in a way that the predictor is counterfactually fair

Check for Sufficient Condition for Counterfactual Fairness

Lemma 1: Let G be the causal graph of the given model (U, V, F). Then Ŷ will be counterfactually fair if it is a function of the non-descendants of A.

To check for sufficient condition given above, isNonDesc returns true if the condition is satisfied and false if it isn't.

@show isNonDesc(g, (:IceCreamSales, U₃), (:Temp,)); # false since Temp is a descendant of IceCreamSales
@show isNonDesc(g, (:U₁, :IceCreamSales), (:Crime,)); # true since Crime is not a descendant of neither U₁ nor IceCreamSales
1.2s
Training using MLJ Interface

Creating synthetic dataset -

toy = @load_synthetic
n = 500
X = (CausalVar(toy, :X1), CausalVar(toy, :X2), CausalVar(toy, :X3), CausalVar(toy, :X4))
U = (CausalVar(toy, :U₁), CausalVar(toy, :U₂), CausalVar(toy, :U₃), CausalVar(toy, :U₄), CausalVar(toy, :U₅))
A = CausalVar(toy, :A)
Y = CausalVar(toy, :Y)
df = DataFrame(
  X1 = randsample(ω -> X[1](ω), n),
  X2 = randsample(ω -> X[2](ω), n), 
  X3 = randsample(ω -> X[3](ω), n), 
  X4 = randsample(ω -> X[4](ω), n),
  A = randsample(ω -> A(ω), n),
  Y = randsample(ω -> Y(ω), n)
)
pprint(df)
4.7s

Wrapper for the adversarial learning for counterfactual fairness -

model = AdversarialWrapper(cm = toy, 
  grp = :A, 
  latent = [:U₁, :U₂, :U₃, :U₄, :U₅], 
  observed = [:X1, :X2, :X3, :X4], 
  predictor = Chain(Dense(4, 3), Dense(3, 2), Dense(2, 1)), 
  adversary = Chain(Dense(5, 3), Dense(3, 2, relu)), 
  loss = Flux.Losses.logitbinarycrossentropy, 
  iters = 5)
model |> pprint
1.2s

Now, model fits into the same framework as other wrappers in MLJ and can be used the same as others.

Using fit! we can train model and predict using predict.

Future Work

  • Path-specific interventions are not computed correctly in the package currently, which must be corrected.

  • Benchmark counterfactual explanations

  • Add recourse methods

References

Runtimes (1)