Code for the ICML 2023 paper:
High Fidelity Image Counterfactuals with Probabilistic Causal Models
Fabio De Sousa Ribeiro1, Tian Xia1, Miguel Monteiro1, Nick Pawlowski2, Ben Glocker1
1Imperial College London, 2Microsoft Research Cambridge, UK
BibTeX:
@InProceedings{pmlr-v202-de-sousa-ribeiro23a,
title={High Fidelity Image Counterfactuals with Probabilistic Causal Models},
author={De Sousa Ribeiro, Fabio and Xia, Tian and Monteiro, Miguel and Pawlowski, Nick and Glocker, Ben},
booktitle={Proceedings of the 40th International Conference on Machine Learning},
pages={7390--7425},
year={2023},
volume={202},
series={Proceedings of Machine Learning Research},
month={23--29 Jul},
url={https://proceedings.mlr.press/v202/de-sousa-ribeiro23a.html}
}
📦src # main source code directory
┣ 📂pgm # graphical models for all SCM mechanisms except the image's
┃ ┣ 📜dscm.py # deep structural causal model Pytorch module
┃ ┣ 📜flow_pgm.py # Flow mechanisms in Pyro
┃ ┣ 📜layers.py # utility modules/layers
┃ ┣ 📜resnet.py # resnet model definition
┃ ┣ 📜run.sh # example launch script for counterfactual training (slurm)
┃ ┣ 📜train_cf.py # counterfactual training code
┃ ┣ 📜train_pgm.py # SCM mechanisms training code (Pyro)
┃ ┗ 📜utils_pgm.py # graphical model utilities
┣ 📜datasets.py # dataset definitions
┣ 📜dmol.py # discretized mixture of logistics likelihood
┣ 📜hps.py # hyperparameters for all datasets
┣ 📜main.py # main file
┣ 📜run_local.sh # example launch script for HVAE causal mechanism training
┣ 📜run_slurm.sh # same as above but for slurm jobs
┣ 📜simple_vae.py # single stochastic layer VAE
┣ 📜trainer.py # training code for image x's causal mechanism
┣ 📜train_setup.py # training helpers
┣ 📜utils.py # utilities for training/plotting
┗ 📜vae.py # HVAE definition; exogenous prior and latent mediator models
Our deep structural causal models (SCMs) were designed to be modular: in all instances, the causal mechanism for the structured variable (i.e. image
We use the universal probabilistic programming language (PPL) Pyro for the following:
- Modelling and training all SCM mechanisms except for the image
$\mathbf{x}$ 's, see code insrc/pgm
; - The counterfactual inference engine, see
src/pgm/flow_pgm.py
; - Proposed constrained counterfactual training technique, see
src/pgm/train_cf.py
.
Pyro enables flexible and expressive deep probabilistic modeling, for more details refer to the official site.
Our HVAE-based causal mechanisms (src/vae.py
) are trained outside of Pyro using Pytorch, and all trained mechanisms are subsequently merged into a single Pytorch module to create a DSCM. See src/pgm/dscm.py
for an example.
To run the code you will need to install the requirements listed in the requirements.txt
file. E.g. from inside your env of choice run:
pip install -r requirements.txt
For ease of use, we provide the Morpho-MNIST dataset we used in datasets/morphomnist
. For more details on the associated SCM and data-generating process see the source code here and the original DSCM paper here.
The Colour-MNIST dataset we used was generated according to this paper.
Unfortunately, we are unable to share the UK Biobank brain data or the MIMIC-CXR chest x-ray data.
If you're interested in gaining access, we recommend you check out the specific documents provided. These resources contain all the necessary details regarding the application process, as well as the eligibility criteria. Application and eligibility criteria for gaining access are detailed here and here respectively.
To launch (local) training of the HVAE mechanism simply run the following script from inside the src
directory:
bash run_local.sh your_experiment_name
To run in the background you can append nohup
to the command: bash run_local.sh your_experiment_name nohup
. Adjust the run_command
inside the script as needed. Hyperparameters can be found in src/hps.py
.
If using Slurm Workload Manager, adjust src/run_slurm.sh
as needed and launch as bash run_slurm.sh
.
Example (loose) steps to add your own dataset and associated SCM:
- Add dataset class definition to
src/datasets
and setup the dataloader insrc/train_setup.py
- Add associated causal graph and mechanism definitions in
src/pgm/flow_pgm.py
- Adjust HVAE hyperparameters needed for your dataset (input resolution, architecture, etc) in
src/hps.py
- Train the HVAE mechanism as above, and train all other mechanisms (separately) using
src/pgm/train_pgm.py
Note: src/pgm/train_cf.py
implements the optional counterfactual training/fine-tuning procedure outlined in Section 3.4 of the paper. This step may not be necessary if the model already performs well enough at counterfactual inference.
If you'd like to make the HVAE more lightweight you can try reducing the number of blocks at each resolution and reducing the block width (hyperparameters enc_arch
, dec_arch
, and width
found in src/hps.py
). The block version == "light"
in src/vae.py
also uses half as much VRAM.
To resume training from a checkpoint simply adjust the argument: --resume=/path/to/your/checkpoint.pt
.