A framework for training any-to-any multimodal foundation models.
Scalable. Open-sourced. Across tens of modalities and tasks.
EPFL - Apple
Official implementation and pre-trained models for :
4M: Massively Multimodal Masked Modeling, NeurIPS 2023 (Spotlight)
David Mizrahi*, Roman Bachmann*, Oğuzhan Fatih Kar, Teresa Yeo, Mingfei Gao, Afshin Dehghan, Amir Zamir
4M-21: An Any-to-Any Vision Model for Tens of Tasks and Modalities, arXiv 2024
Roman Bachmann*, Oğuzhan Fatih Kar*, David Mizrahi*, Ali Garjani, Mingfei Gao, David Griffiths, Jiaming Hu, Afshin Dehghan, Amir Zamir
4M is a framework for training "any-to-any" foundation models, using tokenization and masking to scale to many diverse modalities. Models trained using 4M can perform a wide range of vision tasks, transfer well to unseen tasks and modalities, and are flexible and steerable multimodal generative models. We are releasing code and models for "4M: Massively Multimodal Masked Modeling" (here denoted 4M-7), as well as "4M-21: An Any-to-Any Vision Model for Tens of Tasks and Modalities" (here denoted 4M-21).
- Clone this repository and navigate to the root directory:
git clone https://github.com/apple/ml-4m
cd ml-4m
- Create a new conda environment, then install the package and its dependencies:
conda create -n fourm python=3.9 -y
conda activate fourm
pip install --upgrade pip # enable PEP 660 support
pip install -e .
- Verify that CUDA is available in PyTorch by running the following in a Python shell:
# Run in Python shell
import torch
print(torch.cuda.is_available()) # Should return True
If CUDA is not available, consider re-installing PyTorch following the official installation instructions. Likewise, if you want to install xFormers (optional, for faster tokenizers), follow their README to ensure that the CUDA version is correct.
We provide a demo wrapper to quickly get started with using 4M models for RGB-to-all or {caption, bounding boxes}-to-all generation tasks. For example, to generate all modalities from a given RGB input, call:
from fourm.demo_4M_sampler import Demo4MSampler, img_from_url
sampler = Demo4MSampler(fm='EPFL-VILAB/4M-21_XL').cuda()
img = img_from_url('https://storage.googleapis.com/four_m_site/images/demo_rgb.png') # 1x3x224x224 ImageNet-standardized PyTorch Tensor
preds = sampler({'rgb@224': img.cuda()}, seed=None)
sampler.plot_modalities(preds, save_path=None)
You should expect to see an output like the following:
For performing caption-to-all generation, you can replace the sampler input by: preds = sampler({'caption': 'A lake house with a boat in front [S_1]'})
.
For a list of available 4M models, please see the model zoo below, and see README_GENERATION.md for more instructions on generation.
See README_DATA.md for instructions on how to prepare aligned multimodal datasets.
See README_TOKENIZATION.md for instructions on how to train modality-specific tokenizers.
See README_TRAINING.md for instructions on how to train 4M models.
See README_GENERATION.md for instructions on how to use 4M models for inference / generation. We also provide a generation notebook that contains examples for 4M inference, specifically performing conditional image generation and common vision tasks (i.e. RGB-to-All).
We provide 4M and tokenizer checkpoints as safetensors, and also offer easy loading via Hugging Face Hub.
Model | # Mod. | Datasets | # Params | Config | Weights |
---|---|---|---|---|---|
4M-B | 7 | CC12M | 198M | Config | Checkpoint / HF Hub |
4M-B | 7 | COYO700M | 198M | Config | Checkpoint / HF Hub |
4M-B | 21 | CC12M+COYO700M+C4 | 198M | Config | Checkpoint / HF Hub |
4M-L | 7 | CC12M | 705M | Config | Checkpoint / HF Hub |
4M-L | 7 | COYO700M | 705M | Config | Checkpoint / HF Hub |
4M-L | 21 | CC12M+COYO700M+C4 | 705M | Config | Checkpoint / HF Hub |
4M-XL | 7 | CC12M | 2.8B | Config | Checkpoint / HF Hub |
4M-XL | 7 | COYO700M | 2.8B | Config | Checkpoint / HF Hub |
4M-XL | 21 | CC12M+COYO700M+C4 | 2.8B | Config | Checkpoint / HF Hub |
To load models from Hugging Face Hub:
from fourm.models.fm import FM
fm7b_cc12m = FM.from_pretrained('EPFL-VILAB/4M-7_B_CC12M')
fm7b_coyo = FM.from_pretrained('EPFL-VILAB/4M-7_B_COYO700M')
fm21b = FM.from_pretrained('EPFL-VILAB/4M-21_B')
fm7l_cc12m = FM.from_pretrained('EPFL-VILAB/4M-7_L_CC12M')
fm7l_coyo = FM.from_pretrained('EPFL-VILAB/4M-7_L_COYO700M')
fm21l = FM.from_pretrained('EPFL-VILAB/4M-21_L')
fm7xl_cc12m = FM.from_pretrained('EPFL-VILAB/4M-7_XL_CC12M')
fm7xl_coyo = FM.from_pretrained('EPFL-VILAB/4M-7_XL_COYO700M')
fm21xl = FM.from_pretrained('EPFL-VILAB/4M-21_XL')
To load the checkpoints manually, first download the safetensors files from the above links and call:
from fourm.utils import load_safetensors
from fourm.models.fm import FM
ckpt, config = load_safetensors('/path/to/checkpoint.safetensors')
fm = FM(config=config)
fm.load_state_dict(ckpt)
These models were initialized with the standard 4M-7 CC12M models, but continued training with a modality mixture heavily biased towards text inputs. They are still able to perform all other tasks, but perform better at text-to-image generation compared to the non-finetuned models.
Model | # Mod. | Datasets | # Params | Config | Weights |
---|---|---|---|---|---|
4M-T2I-B | 7 | CC12M | 198M | Config | Checkpoint / HF Hub |
4M-T2I-L | 7 | CC12M | 705M | Config | Checkpoint / HF Hub |
4M-T2I-XL | 7 | CC12M | 2.8B | Config | Checkpoint / HF Hub |
To load models from Hugging Face Hub:
from fourm.models.fm import FM
fm7b_t2i_cc12m = FM.from_pretrained('EPFL-VILAB/4M-7-T2I_B_CC12M')
fm7l_t2i_cc12m = FM.from_pretrained('EPFL-VILAB/4M-7-T2I_L_CC12M')
fm7xl_t2i_cc12m = FM.from_pretrained('EPFL-VILAB/4M-7-T2I_XL_CC12M')
Loading manually from checkpoints is performed in the same way as above for the base 4M models.
Model | # Mod. | Datasets | # Params | Config | Weights |
---|---|---|---|---|---|
4M-SR-L | 7 | CC12M | 198M | Config | Checkpoint / HF Hub |
To load models from Hugging Face Hub:
from fourm.models.fm import FM
fm7l_sr_cc12m = FM.from_pretrained('EPFL-VILAB/4M-7-SR_L_CC12M')
Loading manually from checkpoints is performed in the same way as above for the base 4M models.
Modality | Resolution | Number of tokens | Codebook size | Diffusion decoder | Weights |
---|---|---|---|---|---|
RGB | 224-448 | 196-784 | 16k | ✓ | Checkpoint / HF Hub |
Depth | 224-448 | 196-784 | 8k | ✓ | Checkpoint / HF Hub |
Normals | 224-448 | 196-784 | 8k | ✓ | Checkpoint / HF Hub |
Edges (Canny, SAM) | 224-512 | 196-1024 | 8k | ✓ | Checkpoint / HF Hub |
COCO semantic segmentation | 224-448 | 196-784 | 4k | ✗ | Checkpoint / HF Hub |
CLIP-B/16 | 224-448 | 196-784 | 8k | ✗ | Checkpoint / HF Hub |
DINOv2-B/14 | 224-448 | 256-1024 | 8k | ✗ | Checkpoint / HF Hub |
DINOv2-B/14 (global) | 224 | 16 | 8k | ✗ | Checkpoint / HF Hub |
ImageBind-H/14 | 224-448 | 256-1024 | 8k | ✗ | Checkpoint / HF Hub |
ImageBind-H/14 (global) | 224 | 16 | 8k | ✗ | Checkpoint / HF Hub |
SAM instances | - | 64 | 1k | ✗ | Checkpoint / HF Hub |
3D Human poses | - | 8 | 1k | ✗ | Checkpoint / HF Hub |
To load models from Hugging Face Hub:
from fourm.vq.vqvae import VQVAE, DiVAE
# 4M-7 modalities
tok_rgb = DiVAE.from_pretrained('EPFL-VILAB/4M_tokenizers_rgb_16k_224-448')
tok_depth = DiVAE.from_pretrained('EPFL-VILAB/4M_tokenizers_depth_8k_224-448')
tok_normal = DiVAE.from_pretrained('EPFL-VILAB/4M_tokenizers_normal_8k_224-448')
tok_semseg = VQVAE.from_pretrained('EPFL-VILAB/4M_tokenizers_semseg_4k_224-448')
tok_clip = VQVAE.from_pretrained('EPFL-VILAB/4M_tokenizers_CLIP-B16_8k_224-448')
# 4M-21 modalities
tok_edge = DiVAE.from_pretrained('EPFL-VILAB/4M_tokenizers_edge_8k_224-512')
tok_dinov2 = VQVAE.from_pretrained('EPFL-VILAB/4M_tokenizers_DINOv2-B14_8k_224-448')
tok_dinov2_global = VQVAE.from_pretrained('EPFL-VILAB/4M_tokenizers_DINOv2-B14-global_8k_16_224')
tok_imagebind = VQVAE.from_pretrained('EPFL-VILAB/4M_tokenizers_ImageBind-H14_8k_224-448')
tok_imagebind_global = VQVAE.from_pretrained('EPFL-VILAB/4M_tokenizers_ImageBind-H14-global_8k_16_224')
sam_instance = VQVAE.from_pretrained('EPFL-VILAB/4M_tokenizers_sam-instance_1k_64')
human_poses = VQVAE.from_pretrained('EPFL-VILAB/4M_tokenizers_human-poses_1k_8')
To load the checkpoints manually, first download the safetensors files from the above links and call:
from fourm.utils import load_safetensors
from fourm.vq.vqvae import VQVAE, DiVAE
ckpt, config = load_safetensors('/path/to/checkpoint.safetensors')
tok = VQVAE(config=config) # Or DiVAE for models with a diffusion decoder
tok.load_state_dict(ckpt)
The code in this repository is released under the Apache 2.0 license as found in the LICENSE file.
The model weights in this repository are released under the Sample Code license as found in the LICENSE_WEIGHTS file.
If you find this repository helpful, please consider citing our work:
@inproceedings{4m,
title={{4M}: Massively Multimodal Masked Modeling},
author={David Mizrahi and Roman Bachmann and O{\u{g}}uzhan Fatih Kar and Teresa Yeo and Mingfei Gao and Afshin Dehghan and Amir Zamir},
booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
year={2023},
}
@article{4m21,
title={{4M-21}: An Any-to-Any Vision Model for Tens of Tasks and Modalities},
author={Roman Bachmann and O{\u{g}}uzhan Fatih Kar and David Mizrahi and Ali Garjani and Mingfei Gao and David Griffiths and Jiaming Hu and Afshin Dehghan and Amir Zamir},
journal={arXiv 2024},
year={2024},
}