This is the code repository for the paper Building Universal Foundation Models for Medical Image Analysis with Spatially Adaptive Networks (arxiv, former name: Pre-trained Universal Medical Image Transformer).
Pre-trained model weights can be found in the release page.
The code release is for reproducing results of the paper and also used for double-blinded review, which may not look nice. You may switch to the dev
branch to see a cleaner version of code (but is also further developed, so it may not be consistent).
Mamba is recommended to manage virtual environments.
Download datasets and put them under the ./datasets
folder, e.g.:
├── AbdomenCT-1K
├── ACDC
├── AMOS22
├── BCV
├── BrainPTM-2021
├── BraTS2023
Then run pre-proecssing script: python scripts/data/ <dataset name>
. More details can be found in the script.
python scripts/tokenizer/ -c conf/tokenizer/simple/main.yaml --data.dl_conf.train_batch_size 8 --data.dl_conf.num_workers 10 --training.benchmark true --model.quantize.mode soft --model.quantize.num_embeddings 1024 --loss.entropy_weight 1 --loss.quant_weight 0.03
Note that you may need to adjust the batch size according to the number of GPUs.
scripts/model/mim-b.zsh --data.dl_conf.train_batch_size 14 --data.dl_conf.num_workers 10 --model.tokenizer.path <tokenizer checkpoit path>
Assume that the pre-trained checkpoint is placed at ./pre-trained/pumit.ckpt
Execute scripts under scripts/downstream/medmnistv2
for training and evaluation for each model.
Download the BTCV data from the official challenge, and download the train/validation split file from SMIT's repository, organize the files as following:
├── BTCV
│ ├── smit.json
│ ├── Testing
│ └── Training
Then run fine-tuning and inference:
scripts/downstream/btcv/pumit-b.zsh --data.num_workers 10 --data.ratio 1 pumit-b --data.train_batch_size 4
scripts/downstream/btcv/test-b.zsh --data.num_workers 10 --ckpt_path <output checkpoint path> pumit-b
First, run the pre-processing script to convert the DICOM series into NIFTI format: python scripts/downstream/chaos/
Then run fine-tuning and inference:
scripts/downstream/chaos/pumit-b.zsh --data.num_workers 10 --data.ratio 1 pumit-b --data.train_batch_size 8
scripts/downstream/chaos/predict-b.zsh --data.num_workers 10 --ckpt_path <output checkpoint path> pumit-b