CounterfactualFairness.jl
CounterfactualFairness.jl is a Julia package that provides an interface for causal inference and counterfactual fairness. This project was completed under the mentorship of Zenna Tavares, Moritz Schauer, Jiahao Chen and Sebastian Vollmer.
Link to the GitHub repository: https://github.com/zenna/CounterfactualFairness.jl
Link to the previous blog post: https://nextjournal.com/archanarw/counterfactual-fairness-blogpost-1
Brief Walkthrough of the CounterfactualFairness.jl
The package is designed with Pearl's Causal Ladder in mind, thus allows for association (constructing the causal model), interventions and counterfactuals.
Importing Required Packages
The required packages for the following demonstration -
CounterfactualFairness (arw branch)
Omega (lang branch)
CausalInference
Distributions (version - 0.25.11)
MLJ
Flux (version - 0.12.6)
For visualization- Plots, GraphPlot, Colors, PrettyPrinting
The packages required to precompile CounterfactualFairness.jl successfully (Since it depends on some unregistered packages, they must also be added) -
InferenceBase
SoftPredicates
ReplicaExchange
OmegaCore
OmegaMH
using Pkg
Pkg.add(PackageSpec(url="https://github.com/zenna/Omega.jl", rev="lang:InferenceBase"))
Pkg.add(PackageSpec(url="https://github.com/zenna/Omega.jl", rev="lang:SoftPredicates"))
Pkg.add(PackageSpec(url="https://github.com/zenna/Omega.jl", rev="lang:ReplicaExchange"))
Pkg.add(PackageSpec(url="https://github.com/zenna/Omega.jl",rev="102cc01d1f7dbb4a4caad822746ced6fa5c7164b:OmegaCore"))
Pkg.add(PackageSpec(url="https://github.com/zenna/Omega.jl", rev="102cc01d1f7dbb4a4caad822746ced6fa5c7164b:OmegaMH"))
Pkg.add(PackageSpec(url="https://github.com/zenna/Omega.jl", rev="lang"))
Pkg.add(PackageSpec(url="https://github.com/zenna/CounterfactualFairness.jl", rev="arw"))
Pkg.add("CausalInference")
Pkg.add("Distributions")
Pkg.add("MLJ")
Pkg.add(name = "Flux", version = "0.12.0")
Pkg.add("GraphPlot")
Pkg.add("Plots")
Pkg.add("PrettyPrinting")
Pkg.add("Colors")
Pkg.add("DataFrames")
using Omega, OmegaCore
using DataFrames, CounterfactualFairness, CausalInference
using Distributions, MLJ, Flux
using GraphPlot, Plots, Colors, PrettyPrinting
Association
Using CounterfactualFairness.jl, you may construct a causal model in the following ways -
Automatically from data -
prob_causal_graph(df)
can be used to construct a causal model from the dataframedf
(by Gaussian mechanism). The function usespcalg
from the CausalInference.jl is used to construct the causal graph.By loading a causal model from CounterfactualFairness.jl
cm = ;
gplot(dag(cm), nodelabel = ([variable(cm, i).name for i in 1:nv(cm)]), nodefillc = colorant"seagreen2", edgestrokec = colorant"black", layout = shell_layout, NODESIZE = 0.4/sqrt(nv(cm)))
By entering the distributions over exogenous variables and functions over endogenous variables)
Consider the causal graph below -
This structure could describe the causal mechanism that connects a day's temperature (X), sales at an ice-cream shop (Y) and number of crimes (Z). We may represent this model using CounterfactualFairness.jl as given below -
g = CausalModel(); # Empty causal graph
Adding exogenous variables (variables that have no explicit cause within the model)-
U₁ = add_exo_variable!(g, :U₁, 1 ~ Normal(24, 8));
U₂ = add_exo_variable!(g, :U₂, 1 ~ Normal(15, 3));
U₃ = add_exo_variable!(g, :U₃, 1 ~ Normal(2, 1));
Adding endogenous variables (variables whose values depend on other variables in the model) -
Temp = add_endo_variable!(g, :Temp, identity, U₁);
IceCreamSales = add_endo_variable!(g, :IceCreamSales, *, Temp, U₂);
Crime = add_endo_variable!(g, :Crime, /, Temp, U₃);
Visualizing the graph -
gplot(dag(g), nodelabel = ([variable(g, i).name for i in 1:nv(g)]), nodefillc = colorant"seagreen2", edgestrokec = colorant"black", layout = stressmajorize_layout, NODESIZE = 0.4/sqrt(nv(g)))
To apply a context (values of exogenous variables) to the model, we may use apply_context
or use g(ω)
where ω
is a random variable as defined in Omega
.
apply_context(g, (U₁ = 26.2, U₂ = 14.8, U₃ = 2.)) |> pprint
g(defω(), return_type = NamedTuple) |> pprint
Interventions
An intervention is applied to a causal model by fixing the value of a particular variable in the model and modifying the entire model accordingly (by removing all incoming edges to the variable since it now has a fixed value). Interventions may be computed as follows -
(Continuing with the previous example)
i = CounterfactualFairness.Intervention(:Temp, 24.) # Fixing value of Temp to 24
i |> pprint
m = apply_intervention(g, i)
intervened_model = randsample(ω -> m(ω))
intervened_model |> pprint
gplot(dag(m), nodelabel = ([variable(m, i).name for i in 1:nv(m)]), nodefillc = colorant"seagreen2", edgestrokec = colorant"black", layout = stressmajorize_layout, NODESIZE = 0.4/sqrt(nv(m)))
Counterfactuals
inline_formula not implemented
where V contain observed variables.
To obtain counterfactuals, we condition :Crime
on the observed values and the intervention inline_formula not implemented.
count = ω -> counterfactual(:Crime, (IceCreamSales = 340., ), i, g, ω);
randsample(count);
Computing counterfactuals using MLJ
Wrapper to compute counterfactuals for each observation in a given dataset:
toy = # Synthetic causal model
cfw = CounterfactualWrapper(test = gausscitest, p = 0.1, cf = :Y, interventions = CounterfactualFairness.Intervention(:A, 40.))
Now we may use cfw
in fit
/transform
workflow in MLJ.
Training a neural network in a way that the predictor is counterfactually fair
Check for Sufficient Condition for Counterfactual Fairness
Lemma 1: Let G be the causal graph of the given model (U, V, F). Then Ŷ will be counterfactually fair if it is a function of the non-descendants of A.
To check for sufficient condition given above, isNonDesc
returns true if the condition is satisfied and false if it isn't.
isNonDesc(g, (:IceCreamSales, U₃), (:Temp,)); # false since Temp is a descendant of IceCreamSales
isNonDesc(g, (:U₁, :IceCreamSales), (:Crime,)); # true since Crime is not a descendant of neither U₁ nor IceCreamSales
Training using MLJ Interface
Creating synthetic dataset -
toy =
n = 500
X = (CausalVar(toy, :X1), CausalVar(toy, :X2), CausalVar(toy, :X3), CausalVar(toy, :X4))
U = (CausalVar(toy, :U₁), CausalVar(toy, :U₂), CausalVar(toy, :U₃), CausalVar(toy, :U₄), CausalVar(toy, :U₅))
A = CausalVar(toy, :A)
Y = CausalVar(toy, :Y)
df = DataFrame(
X1 = randsample(ω -> X[1](ω), n),
X2 = randsample(ω -> X[2](ω), n),
X3 = randsample(ω -> X[3](ω), n),
X4 = randsample(ω -> X[4](ω), n),
A = randsample(ω -> A(ω), n),
Y = randsample(ω -> Y(ω), n)
)
pprint(df)
Wrapper for the adversarial learning for counterfactual fairness -
model = AdversarialWrapper(cm = toy,
grp = :A,
latent = [:U₁, :U₂, :U₃, :U₄, :U₅],
observed = [:X1, :X2, :X3, :X4],
predictor = Chain(Dense(4, 3), Dense(3, 2), Dense(2, 1)),
adversary = Chain(Dense(5, 3), Dense(3, 2, relu)),
loss = Flux.Losses.logitbinarycrossentropy,
iters = 5)
model |> pprint
Now, model
fits into the same framework as other wrappers in MLJ and can be used the same as others.
Using fit!
we can train model
and predict using predict
.
Future Work
Path-specific interventions are not computed correctly in the package currently, which must be corrected.
Benchmark counterfactual explanations
Add recourse methods