Skip to content

Latest commit

 

History

History
123 lines (100 loc) · 4.13 KB

README.md

File metadata and controls

123 lines (100 loc) · 4.13 KB

Text example

This example implements training of a discrete flow matching model on text data. This repository provides the necessary tools and scripts to train and evaluate these models.

Note: this example was tested only using PyTorch 2.5 and on a single node of H100 (8 gpus). With this setup, we achieved approximately 380k training steps in 24 hours.

Installation

To get started with this project, follow these steps to set up your environment:

conda env create -f environment.yml
conda activate discrete_flow_matching

Usage

Specify the data cache and checkpoint directories. Data will automatically be downloaded into the cache directory.

CACHE_DIR=...
HYDRA_RUN_DIR=...

To train a discrete flow matching model on fine-web-edu, run:

python run_train.py data.cache_dir=${CACHE_DIR}

To use slurm, modify the slurm config according to the cluster you are working on, and run:

python run_train.py data.cache_dir=${CACHE_DIR} hydra_dir=${HYDRA_RUN_DIR} -m &

Results

We trained models with linear scheduler (PolynomialConvexScheduler(n=1.0)) for one million steps on FineWeb-EDU.

PYTHONPATH="." python scripts/run_eval.py --work_dir "/path/to/exp/folder" --ngpus 8 --eval_elbo --eval_perplexity
Scheduler Source distribution Loss Generative perplexity ELBO
Linear Mask Cross-entropy 128.9 53.2
Generalized KL 132.2 47.9
Uniform Cross-entropy 90.9 71.7
Generalized KL 82.1 71.3

Folder structure

.
├── configs        # Train configs
│   └── ...
├── data           # Data loading and preprocessing
│   └── ...
├── logic          # Logic components, such as flow related classes
│   └── ...
├── model          # Transformer implementation
│   └── ...
├── scripts        # Evaluation script
│   └── ...
├── utils          # Utility functions
│    └── ...
├── README.md
├── environment.yml
├── train.py
└── run_train.py   # Run training script

Implemented methods

This repository implements the following papers:

Acknowledgements

This example partially use code from:

License

The majority of the code in this example is licensed under CC-BY-NC, however portions of the project are available under separate license terms:

  • flash attention and TorchData are under BSD 3 license.
  • Discrete Diffusion Modeling by Estimating the Ratios of the Data Distribution and GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models are under MIT license.