Simple and Effective Masked Diffusion Language Models (NeurIPS 2024)
By Subham Sekhar Sahoo, Marianne Arriola, Yair Schiff, Aaron Gokaslan, Edgar Marroquin, Justin T Chiu, Alexander Rush, Volodymyr Kuleshov
We introduce MDLM, a Masked discrete Diffusion Language Model that features a novel (SUBS)titution based parameterization which simplifies the absorbing state diffusion loss to a mixture of classical masked language modeling losses. In doing so, we achieve SOTA perplexity numbers on LM1B and OpenWebText among diffusion models while achiving competitive zero-shot perplexity with SOTA AR models on numerous datasets. We provide a demo in this notebook and a video tutorial here:
In this repo, we release:
- The MDLM framework.
- SUBStitution based parameterization
- Simplified loss calculation for masked diffusion processes
- Baseline implementations [Examples]:
- Samplers
main.py
: Routines for training and evaluationnoise_schedule.py
: Noise schedulesdiffusion.py
: Forward/reverse diffusiondataloader.py
: Dataloadersutils.py
: LR scheduler, logging,fsspec
handlingmodels/
: Denoising network architectures. Supports DiT, AR transformer, and Mambaconfigs/
: Config files for datasets/denoising networks/noise schedules/LR schedulesscripts/
: Shell scripts for training/evaluation
To get started, create a conda environment containing the required dependencies.
conda env create -f requirements.yaml
conda activate mdlm
Create the following directories to store saved models and slurm logs:
mkdir outputs
mkdir watch_folder
and run the training as a batch job:
sbatch scripts/train_owt_mdlm.sh
We have uploaded MDLM model trained on OpenWebText for 1M training steps to the Huggingface hub 🤗: kuleshov-group/mdlm-owt Furthermore, we have released the checkpoints for the AR and SEDD baselines trained on OpenWebText in this Google Drive folder.
Below, we describe the steps required for reproducing the experiments in the paper.
Throughout, the main entry point for running experiments is the main.py
script.
We also provide sample slurm
scripts for launching pre-training and downstream fine-tuning experiments in the scrips/
directory.
The argument to sampling.predictor
specifies the sampler which takes one of the following values:
ddpm_cache
: our proposed sampler that's ~3-4x faster than the samplers propsed in D3PM and SEDD.ddpm
: Ancestral sampling proposed in D3PM.analytic
: Analytic sampler proposed in SEDD.
In the following table we report wall clock time to generate 64 samples on a single A5000 GPU with batch_size=1
.
SEDD | 127.1 | 229.3 |
MDLM + ddpm
|
113.8 | 206.6 |
MDLM +ddpm_cache
|
40.1 | 60.4 |
To generate samples from a pre-trained model use one of the following commands:
python main.py \
mode=sample_eval \
eval.checkpoint_path=kuleshov-group/mdlm-owt \
data=openwebtext-split \
model.length=1024 \
sampling.predictor=ddpm_cache \
sampling.steps=1000 \
loader.eval_batch_size=1 \
sampling.num_sample_batches=10 \
backbone=hf_dit
python main.py \
mode=sample_eval \
eval.checkpoint_path=/path/to/checkpoint/mdlm.ckpt \
data=openwebtext-split \
model.length=1024 \
sampling.predictor=ddpm_cache \
sampling.steps=10000 \
loader.eval_batch_size=1 \
sampling.num_sample_batches=1 \
backbone=dit
MDLM can also generate samples of arbitrary length in a semi-autoregressive (SAR) manner.
We generate 200 sequences of length 2048 tokens on a single 3090
GPU and evaluate generative perplexity under a pre-trained GPT-2 model. In the below table we find that in addition to achieving better generative perplexity, MDLM enables 25-30x faster SAR decoding relative to SSD-LM.
Gen. PPL ( |
Sec/Seq ( |
|
---|---|---|
SSD-LM | 35.43 | 2473.9 |
MDLM +ddpm_cache
|
27.18 | 89.3 |
Gen. PPL: Generation Perplexity, Sec/Seq: Seconds per Sequence
python main.py \
mode=sample_eval \
eval.checkpoint_path=kuleshov-group/mdlm-owt \
data=openwebtext-split \
parameterization=subs \
model.length=1024 \
sampling.predictor=ddpm_cache \
sampling.steps=1000 \
loader.eval_batch_size=1 \
sampling.num_sample_batches=2 \
sampling.semi_ar=True \
sampling.stride_length=512 \
sampling.num_strides=2 \
backbone=hf_dit
To train MDLM from scratch on OpenWebText use the following command:
python main.py \
model=small \
data=openwebtext-split \
wandb.name=mdlm-owt \
parameterization=subs \
model.length=1024 \
eval.compute_generative_perplexity=True \
sampling.steps=1000
The arguments loader.batch_size
and loader.eval_batch_size
allow you to control the global batch size and the batch size per GPU. If loader.batch_size * num_gpus
is less than the global batch size, PyTorch Lightning will resort to gradient accumulation. You can also launch a training job on Slurm using the command: sbatch scripts/train_owt_mdlm.sh
. The slurm scripts to train the Auto-regressive and SEDD baselines are as follows respectively: scripts/train_lm1b_ar.sh
, scripts/train_owt_sedd.sh
.
To compute test perplexity, use mode=ppl_eval
. Example scripts provided in scripts/
. An example command for perplexity evaluation on OpenWebText is:
python main.py \
mode=ppl_eval \
loader.batch_size=16 \
loader.eval_batch_size=16 \
data=openwebtext-split \
model=small \
parameterization=subs \
backbone=dit \
model.length=1024 \
eval.checkpoint_path=/path/to/checkpoint/mdlm.ckpt \
+wandb.offline=true
We release the checkpoints for the baselines: SEDD and AR trained on OpenWebText in this Google Drive folder. Download the checkpoints: ar.ckpt
, sedd.ckpt
and use the following commands to compute test perplexity:
python main.py \
mode=ppl_eval \
loader.batch_size=16 \
loader.eval_batch_size=16 \
data=openwebtext-split \
model=small-ar \
parameterization=ar \
backbone=ar \
model.length=1024 \
eval.checkpoint_path=/path/to/checkpoint/ar.ckpt \
+wandb.offline=true
python main.py \
mode=ppl_eval \
loader.batch_size=16 \
loader.eval_batch_size=16 \
data=openwebtext-split \
model=small \
parameterization=sedd \
backbone=dit \
model.length=1024 \
eval.checkpoint_path=/path/to/checkpoint/sedd.ckpt \
time_conditioning=True \
sampling.predictor=analytic \
+wandb.offline=true
This repository was built off of SEDD.
@inproceedings{
sahoo2024simple,
title={Simple and Effective Masked Diffusion Language Models},
author={Subham Sekhar Sahoo and Marianne Arriola and Aaron Gokaslan and Edgar Mariano Marroquin and Alexander M Rush and Yair Schiff and Justin T Chiu and Volodymyr Kuleshov},
booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
year={2024},
url={https://openreview.net/forum?id=L4uaAR4ArM}
}