Go to pdf
This is the code associated to the paper DRIM: Learning Disentangled Representations from Incomplete Multimodal Healthcare Data (accepted to MICCAI2024).
Navigate to the data
folder to deal with the download and the preprocessing of the data.
The maximum file size limit of the supplemental does not allow the WSI pre-trained models to be supplied. So only the MRI pre-trained model is provided in data/models/
.
N.B: The data used is from the TCGA-GBMLGG and must have the following structure:
├── GBMLGG
│ ├── MRI
│ ├── WSI
│ ├── DNAm
│ ├── RNA
Also in the data
folder you must have, for each sample a files/train_brain.csv
and a files/test_brain.csv
containing for each modality, the path through the corresponding file. But again, check the README.md
from the data
folder for further explanations.
Start by creating a dedicated conda environment and installing all the required packages
$ conda create python=3.10.10 --name drim
$ conda activate drim
$ pip install -r requirements.txt
N.B: This repository relies on Hydra to run experiments and benefits from all the advantages it offers (multiruns, sweeps, flexibility...)
Here are a number of scripts for training and cross-validation.
Unimodal
For example, to launch a unimodal training session on each of the modalities, using Hydra's features, you can enter the following in the CLI:
$ python train_unimodal.py -m general.modalities=DNAm,WSI,RNA,MRI
Simple Multimodal
To use a simple fusion architecture, having the choice for each of
$ python train_vanilla_multimodal.py fusion.name=maf
But once again, to use all the possible functions
$ python train_vanilla_multimodal.py -m fusion.name=mean,concat,masked_mean,tensor,sum,max,maf
Auxiliary Multimodal
In this repository, in addition to the fusion, you can add a cost function to the
$ python train_aux_multimodal.py aux_loss.name=contrastive
$ python train_aux_multimodal.py aux_loss.name=mmo aux_loss.alpha=0.5
The DRIM method is very flexible and can be adapted to any type of task. For example, we offer two alternatives DRIMSurv and DRIMU.
DRIMSurv
DRIMSurv enables end-to-end training of a survival model by training the encoders at the same time as the higher attention blocks needed to merge the various representations. (Fig 1c.)
$ python train_drimsurv.py
DRIMU
The other alternative DRIMU does not require a label and is trained using a reconstruction scheme. The code provided allows the backbone encoders to be pre-trained and then frozen for finetuning over 10 epochs on a survival task with a fusion scheme similar to DRIMSurv. (Fig 1b.)
$ python train_drimu.py
Once each model has been trained on the 5 splits, it is possible to test how they perform and how robust they are to different combinations of input modalities. Simply run the robustness.py
script, which is based on the configs/robustness.yaml
configuration file, and look at the output CS score performance.
To check all available methods:
$ python robustness -m method=drim,max,tensor,concat
All the code related to this stratisfication as well as the results of the logrank test can be found in the notebook clustering.ipynb
All the training sessions have been logged on Weights and Biases. You can also use Weights and Biases to track the progress of your runs. To do this, add your project and entity to the configuration files configs/unimodal.yaml
and configs/multimodal.yaml
:
wandb:
entity: MyAwesomeNickname
project: MyAwesomeProject
Otherwise, training progression and evaluation will be displayed directly in the console.
@InProceedings{robinet_drim_2024,
author = { Robinet, Lucas and Berjaoui, Ahmad and Kheil, Ziad and Cohen-Jonathan Moyal, Elizabeth},
title = { { DRIM: Learning Disentangled Representations from Incomplete Multimodal Healthcare Data } },
booktitle = {proceedings of Medical Image Computing and Computer Assisted Intervention -- MICCAI 2024},
year = {2024},
publisher = {Springer Nature Switzerland},
}