From 0bc3617a1befcd249f6a95584dba9634bd6b879c Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Mon, 24 Apr 2023 10:18:29 -0700 Subject: [PATCH 01/19] Added transformer engine support and GPU optimizations Co-authored-by: Sahil Jain Co-authored-by: Terry Kong Co-authored-by: Yu-Hang Tang Co-authored-by: Ming Huang Co-authored-by: Frederic Bastien Co-authored-by: Sharath Turuvekere Sreenivas Co-authored-by: Xiaowei Ren Co-authored-by: Ryan Jeng Co-authored-by: Reese Wang --- README.md | 5 +- docs/usage/gpu-usage.md | 144 ++++++--- t5x/contrib/gpu/Dockerfile | 6 +- t5x/contrib/gpu/README.md | 90 +----- t5x/contrib/gpu/T5X_TE_README.md | 99 ++++++ .../scripts_gpu/example_slurm_ft_frompile.sub | 46 +-- .../example_slurm_pretrain_pile.sub | 38 ++- .../scripts_gpu/multiprocess_ft_frompile.sh | 82 +++-- .../scripts_gpu/multiprocess_pretrain_pile.sh | 120 ++++---- .../gpu/scripts_gpu/singlenode_ft_frompile.sh | 17 +- .../scripts_gpu/singlenode_pretrain_pile.sh | 17 +- t5x/contrib/gpu/t5/configs/runs/finetune.gin | 6 +- .../gpu/t5/configs/runs/finetune_mnli.gin | 6 +- .../gpu/t5/configs/runs/finetune_squad1.gin | 6 +- t5x/contrib/gpu/t5/configs/runs/pretrain.gin | 4 +- t5x/contrib/gpu/t5/network.py | 121 ++++++-- .../examples/large_mnli2_finetune_adam.gin | 1 + .../examples/large_squad1_finetune_adam.gin | 1 + .../examples/small_mnli2_finetune_adam.gin | 1 + .../examples/small_squad1_finetune_adam.gin | 1 + .../examples/xl_mnli2_finetune_adam.gin | 1 + .../examples/xl_squad1_finetune_adam.gin | 1 + t5x/models.py | 161 ++++++++-- t5x/partitioning.py | 3 + t5x/te_helper.py | 284 ++++++++++++++++++ t5x/train.py | 59 ++++ t5x/train_state.py | 2 +- t5x/trainer.py | 46 +-- 28 files changed, 1044 insertions(+), 324 deletions(-) create mode 100644 t5x/contrib/gpu/T5X_TE_README.md create mode 100644 t5x/te_helper.py diff --git a/README.md b/README.md index 916b684b8..bb0146a31 100644 --- a/README.md +++ b/README.md @@ -72,8 +72,11 @@ be read by TensorBoard. ## GPU Usage Note: NVIDIA has released an updated version of this repository with H100 FP8 support and broad GPU performance improvements. Please visit the [NVIDIA Rosetta](https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/t5x) repository for more details and usage instructions. -T5X can be run easily on GPUs either in single-node configurations or multi-node configurations with a SLURM+pyxis cluster. Further instructions at [t5x/contrib/gpu](https://github.com/google-research/t5x/blob/main/t5x/contrib/gpu/README.md). The `t5x/contrib/gpu/scripts_gpu` folder contains example scripts for pretraining T5X on [The Pile](https://pile.eleuther.ai/) and for finetuning on SQuAD and MNLI. These scripts and associated `gin` configurations also contain additional GPU optimizations for better throughput. More examples and instructions can be found in the [NVIDIA Rosetta](https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/t5x) repository maintained by NVIDIA with H100 FP8 support and broad GPU performance improvements. +T5X can be run easily on GPUs either in single-node configurations or multi-node configurations with a SLURM+pyxis cluster. Further instructions at [Rosetta T5X README](https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/t5x/README.md). The `t5x/contrib/gpu/scripts_gpu` folder contains example scripts for pretraining T5X on [The Pile](https://pile.eleuther.ai/) and for finetuning on SQuAD and MNLI. These scripts and associated `gin` configurations also contain additional GPU optimizations for better throughput. More examples and instructions can be found in the [NVIDIA Rosetta](https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/t5x) repository maintained by NVIDIA with H100 FP8 support and broad GPU performance improvements. +We now have support for: +- [Transformer Engine](https://github.com/NVIDIA/TransformerEngine) FP8 +- Improved performance on H100/A100 GPUs ## Installation diff --git a/docs/usage/gpu-usage.md b/docs/usage/gpu-usage.md index dedcd88ab..a9974e1a3 100644 --- a/docs/usage/gpu-usage.md +++ b/docs/usage/gpu-usage.md @@ -1,4 +1,4 @@ -# GPU Scripts +# GPU Scripts and Usage # Warning! An updated version of T5x with optimized GPU performance (18-80% perf gains!) and new features, including FP8 with [Transformer Engine](https://github.com/NVIDIA/TransformerEngine) and H100 support can be found here: [NVIDIA Rosetta](https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/t5x). @@ -6,12 +6,14 @@ An updated version of T5x with optimized GPU performance (18-80% perf gains!) an **NVIDIA no longer recommends using this repository and won't be updating it further.** ----- -The [t5x/contrib/gpu](../../t5x/contrib/gpu) directory contains scripts optimized for GPU usage. +The [t5x/contrib/gpu/scripts_gpu](../../t5x/contrib/gpu/scripts_gpu) directory contains scripts optimized for GPU usage and includes FP8 support via [Transformer Engine](https://github.com/NVIDIA/TransformerEngine). Install with `pip install -r pile_requirements.txt` to get all pile dependencies. ## Building the container -The Dockerfile in `t5x/contrib/gpu` given will build a container with all gpu/pile dependencies. It can be built with `t5x/contrib/gpu/docker/build.sh ` +We provide a fully built and ready-to-use container here: [ghcr.io/nvidia/t5x:te-fp8-reference](ghcr.io/nvidia/t5x:te-fp8-reference) +If you'd like you build your own, +The Dockerfile in `t5x/contrib/gpu` will build a container with all gpu/pile dependencies. It can be built with `t5x/contrib/gpu/docker/build.sh ` ## Running interactively Note: this should only be done with singlenode jobs and/or for downloading the pile. Use `t5x/contrib/gpu/docker/interactive_pull_and_launch.sh`. This takes arguments for the URL to pull a container from and the location of the dataset directory to mount. For example: @@ -19,7 +21,7 @@ Note: this should only be done with singlenode jobs and/or for downloading the p `t5x/contrib/gpu/docker/interactive_pull_and_launch.sh [URL] /my/dataset/dir` ## Downloading The Pile -Run `download_the_pile.py` to download the pile. It will download to the directory set in the environment variable: `TFDS_DATA_DIR`. After that, set the `TFDS_DATA_DIR` to the same directory in your scripts to use. +We use The Pile for our pretraining experiments. If you would like to as well, run `download_the_pile.py` to download it. The download is approximately 1TB. It will download to the directory set in the environment variable: `TFDS_DATA_DIR`. After that, set the `TFDS_DATA_DIR` to the same directory in your scripts to use. ## Single Node runs Pretraining and Finetuning can be done with `singlenode_*.sh`. These will build a T5X model with the Adam optimizer and relevant parameters. These will allow multi-gpu on one host. @@ -27,61 +29,127 @@ Pretraining and Finetuning can be done with `singlenode_*.sh`. These will build ## Multi Node runs For a SLURM+pyxis cluster, `example*.sub` files provide example slurm submit files (edit with your details), which call `multiprocess*.sh` to execute training. You can add a binding script in the `.sub` file for your cluster, or remove it entirely (dropping some throughput) -## Convergence -For our Pile convergence runs, we used a Global batch size of 2304 for XXL and 2048 for all other models, where GBS is defined as #GPUs * BS/GPU / Tensor Parallel(TP). Below are example (tested) hardware topologies on NVIDIA DGX A100 (8x A100 80G) nodes. +## Convergence and performance +For our Pile convergence runs, we used a Global batch size of 2304 for XXL and 2016-2048 for all other models, where GBS is defined as #GPUs * BS/GPU / Tensor Parallel(TP). Below are example (tested) hardware topologies on NVIDIA DGX A100 (8x A100-SXM4-80G) and H100-SXM-80G nodes. -| size | #GPUs | TP | BS / GPU | Sequences/Sec | Estimated Walltime | MNLI 2.0 - matched | SQuAD v1.1 (EM/F1) | Convergence Log | -| ---- | ----- | ----- | -------- | ------------- | ------------------ | ------------------ | ------------------ | --------------- | -| small| 8 | 1 | 256 | ~3168 | 7.48 days | 83.06% | 78.33 / 86.63 | [log](https://tensorboard.dev/experiment/lWnHal7PRnOLeZuewyWVxQ/#scalars&_smoothingWeight=0) | -| large| 64 | 1 | 32 | ~3886 | 6.10 days | 90.50% | 87.31 / 94.04 | [log](https://tensorboard.dev/experiment/aOxJBIvTQBeTJ8XGXxaL6Q/#scalars&_smoothingWeight=0) | -| xl | 256 | 1 | 8 | ~3652 | 6.49 days | 91.15% | 89.36 / 95.29 | [log](https://tensorboard.dev/experiment/vuRoEYgkRgWiEtbvgxlOqw/#scalars&_smoothingWeight=0) | -| xxl | 512 | 8 | 36 | ~1346 | 19.81 days | N/A(partial run) | N/A(partial run) | N/A(partial run)| +| size | GPU | Precision | #GPUs | TP | BS / GPU | Sequences/Sec | Seq/Sec/GPU | Est. Walltime | GPU-days | MNLI 2.0 - matched | SQuAD v1.1 (EM/F1) | Convergence Log | Config | +| ---- | ------------ | --------- | ----- | ----- | -------- | ------------- | ----------- | ------------- | -------- |------------------ | ------------------ | --------------- | ---- | +| [T5-v1.1-small](../t5/t5_1_1/small.gin) | A100 80G SXM | bf16 | 8 | 1 | 256 | ~5712 | 714 | 4.2 days | 33 | 83.06% | 78.33 / 86.63 | [log](https://tensorboard.dev/experiment/lWnHal7PRnOLeZuewyWVxQ/#scalars&_smoothingWeight=0) | [pile](../t5/t5_1_1/examples/small_pile_pretrain.gin) +| [T5-v1.1-large](../t5/t5_1_1/large.gin) | A100 80G SXM | bf16 | 64 | 1 | 32 | ~4853 | 75.8 | 4.8 days | 309 | 90.50% | 87.31 / 94.04 | [log](https://tensorboard.dev/experiment/aOxJBIvTQBeTJ8XGXxaL6Q/#scalars&_smoothingWeight=0) |[pile](../t5/t5_1_1/examples/large_pile_pretrain.gin) +| [T5-v1.1-xl](../t5/t5_1_1/xl.gin) | A100 80G SXM | bf16 | 144 | 1 | 8 | ~3021 | 21.0 | 7.9 days | 1,133 | N/A(perf test) | N/A (perf test) | |[pile](../t5/t5_1_1/examples/xl_pile_pretrain.gin) +| [T5-v1.1-xl](../t5/t5_1_1/xl.gin) | A100 80G SXM | bf16 | 256 | 1 | 8 | ~4322 | 16.9 | 5.5 days | 1,408 | 91.15% | 89.36 / 95.29 | [log](https://tensorboard.dev/experiment/vuRoEYgkRgWiEtbvgxlOqw/#scalars&_smoothingWeight=0) |[pile](../t5/t5_1_1/examples/xl_pile_pretrain.gin) +| [T5-v1.1-xxl](../t5/t5_1_1/xxl.gin) | A100 80G SXM | bf16 | 512 | 8 | 36 | ~1887 | 3.69 | 12.6 days | 6,431 |N/A(partial run) | N/A(partial run) | |[pile](../t5/t5_1_1/examples/xxl_pile_pretrain.gin) +| [T5-v1.1-large](../t5/t5_1_1/large.gin) | **H100 80G SXM** | TE-fp8 | 64 | 1 | 32 | ~10156 | **158.7** | **2.3 days** | **147** | 89.1% | 86.36 / 93.5 | |[pile](../t5/t5_1_1/examples/large_pile_pretrain.gin) +| [T5-v1.1-xl](../t5/t5_1_1/xl.gin) | **H100 80G SXM** | TE-fp8 | 144 | 1 | 14 | ~7257 | **50.4** | **3.3 days** | **475** | N/A (perf test) | N/A (perf test) | |[pile](../t5/t5_1_1/examples/xl_pile_pretrain.gin) +| [T5-v1.1-xl](../t5/t5_1_1/xl.gin) | **H100 80G SXM** | TE-fp8 | 256 | 1 | 8 | ~9688 | **37.8** | **2.4 days** | **614** | N/A (perf test) | N/A (perf test) | |[pile](../t5/t5_1_1/examples/xl_pile_pretrain.gin) Note: Convergence (as shown in log) was not necessarily done with the hardware topology listed, but the listed topology is tested. Estimated Walltime is calculated assuming full throughput (seq/sec) continuously. In practice, there are compilation overheads at the beginning of each run/restart(in cluster settings) + checkpointing overheads (if any). -(More perf improvements coming soon!) - Other hyperparameters are specified in the associated pile `gin` files in the `contrib/gpu/t5/t5_1_1/examples` directory. ## Pretraining run commands -### Singlenode -small: +### Multinode +Arguments are set by environment variable as such: -`t5x/contrib/gpu/t5/scripts_gpu/singlenode_pretrain_pile.sh small bfloat16 8 256 {LOGDIR - create before running} {MODEL_DIR} {GRADIENT_ACCUMULATION (1 by default)}` +`PREC={PRECISION} T5_SIZE={SIZE} BSIZE_PER_GPU={BSIZE} ..... sbatch -N {NODE_CT} t5x/contrib/gpu/t5/scripts_gpu/example_slurm_pretrain_pile.sub {GPUS_PER_NODE}` -Finetuning: -MNLI v2: -`t5x/contrib/gpu/t5/scripts_gpu/singlenode_ft_frompile.sh mnli2 small bfloat16 8 256 {LOGDIR - create before running} {MODEL_DIR(to restore pretrained checkpoint from)} {GRADIENT_ACCUMULATION}` +All parameters can be found in the relevant script. +### Example Pretraining Commands +Assumes 8GPU 80GB A100/H100 Nodes. `ENABLE_FP8` uses transformer engine (included in container) and requires H100 -### Multinode -Arguments are as such: +* Note: To use, FP8 set `ENABLE_FP8` to `1`. This will automatically set `PREC` to `bfloat16` as is required by internals for `FP8` usage. +#### [T5-v1.1-small](../t5/t5_1_1/small.gin) (60M): +```sh +PREC=bfloat16 T5_SIZE=small BSIZE_PER_GPU=256 TRAIN_STEPS=1000000 NUM_MICROBATCHES=1 ENABLE_FP8=1 TP_SIZE=1 \ +sbatch -N1 t5x/contrib/gpu/t5/scripts_gpu/example_slurm_pretrain_pile.sub +``` -`sbatch -N {NODE_CT} t5x/contrib/gpu/t5/scripts_gpu/example_slurm_pretrain_pile.sub {MODEL_SIZE} {MODEL_PREC} {GPU/NODE} {BS/GPU} {MODEL_DIR} {GRADIENT_ACCUMULATION} {TENSOR_PARALLEL}` +#### [T5-v1.1-large](../t5/t5_1_1/large.gin) (770M): +```sh +PREC=bfloat16 T5_SIZE=large BSIZE_PER_GPU=32 TRAIN_STEPS=1000000 NUM_MICROBATCHES=1 ENABLE_FP8=1 TP_SIZE=1 \ +sbatch -N8 t5x/contrib/gpu/t5/scripts_gpu/example_slurm_pretrain_pile.sub +``` -small: +#### [T5-v1.1-xl](../t5/t5_1_1/xl.gin) (3B): +```sh +PREC=bfloat16 T5_SIZE=large BSIZE_PER_GPU=8 TRAIN_STEPS=1000000 NUM_MICROBATCHES=1 ENABLE_FP8=1 TP_SIZE=1 \ +sbatch -N 32 t5x/contrib/gpu/t5/scripts_gpu/example_slurm_pretrain_pile.sub +``` -`sbatch -N 1 t5x/contrib/gpu/t5/scripts_gpu/example_slurm_pretrain_pile.sub small bfloat16 8 256 {MODEL_DIR} 1 1` +### Example Finetuning Commands +Finetuning commands simply change the script and have an additional `{FT_TASK}` as the first argument (along with relevant hyperparameter changes). Your `MODEL_DIR` should contain the pretrained checkpoint to restore from. -large: +#### MNLI v2: +```sh +FT_TASK=mnli2 PREC=bfloat16 T5_SIZE={SIZE} BSIZE_PER_GPU={BSIZE} NUM_MICROBATCHES=1 ENABLE_FP8=1 TP_SIZE=1 \ +sbatch -N{NODE_CT} t5x/contrib/gpu/t5/scripts_gpu/example_slurm_ft_frompile.sub +``` -`sbatch -N 8 t5x/contrib/gpu/t5/scripts_gpu/example_slurm_pretrain_pile.sub large bfloat16 8 32 {MODEL_DIR} 1 1` +#### SQuAD v1.1: +```sh +FT_TASK=squad1 PREC=bfloat16 T5_SIZE={SIZE} BSIZE_PER_GPU={BSIZE} NUM_MICROBATCHES=1 ENABLE_FP8=1 TP_SIZE=1 \ +sbatch -N{NODE_CT} t5x/contrib/gpu/t5/scripts_gpu/example_slurm_ft_frompile.sub -xl: +``` -`sbatch -N 32 t5x/contrib/gpu/t5/scripts_gpu/example_slurm_pretrain_pile.sub xl bfloat16 8 8 {MODEL_DIR} 1 1` +## Performance Settings: +There are 3 major performance settings: `ENABLE_FP8`, `FUSE_QKV` and `TRANSPOSE_BS` (all of which are controllable via env var in the commands above). +We recommend always enabling `TRANSPOSE_BS` (default), but only using `FUSE_QKV` when using `ENABLE_FP8` for optimal performance. -Finetuning commands simply change the script and have an additional `{FT_TASK}` as the first argument (along with relevant hyperparameter changes). Your `MODEL_DIR` should contain the pretrained checkpoint to restore from. +On all finetuning runs, we use a Global Batch Size of 256 with bfloat16 precision + FP8. -MNLI v2: - -`sbatch -N {NODE_CT} t5x/contrib/gpu/t5/scripts_gpu/example_slurm_ft_frompile.sub mnli2 {MODEL_SIZE} {MODEL_PREC} {GPU/NODE} {BS/GPU} {MODEL_DIR} {GRADIENT_ACCUMULATION} {TENSOR_PARALLEL}` - -SQuAD v1.1 +WARNING: Finetuning is configured by default to save every checkpoint and delete none (to avoid accidentally deleting your pretrained checkpoint). Watch your disk space! This behavior can be changed in `t5x/configs/runs/finetune_{TASK}.gin`, however this puts the pretrained checkpoint at risk unless backed up. -`sbatch -N {NODE_CT} t5x/contrib/gpu/t5/scripts_gpu/example_slurm_ft_frompile.sub squad1 {MODEL_SIZE} {MODEL_PREC} {GPU/NODE} {BS/GPU} {MODEL_DIR} {GRADIENT_ACCUMULATION} {TENSOR_PARALLEL}` +### Singlenode (single process) +small: -On all finetuning runs, we use a Global Batch Size of 128 with bfloat16 precision. +```sh +t5x/contrib/gpu/scripts_gpu/singlenode_pretrain_pile.sh \ + small \ + bfloat16 \ + 8 \ + 256 \ + {LOGDIR - create before running} \ + {MODEL_DIR} \ + {GRADIENT_ACCUMULATION (1 by default)} \ + {ENABLE_FP8 (1 by default)} \ + {TRANSPOSE_BS (1 by default)} \ + {FUSE_QKV (1 by default)} \ + {PACK (0 by default)} +``` + +WARNING: Finetuning is configured by default to save every checkpoint and delete none (to avoid accidentally deleting your pretrained checkpoint). Watch your disk space! This behavior can be changed in `t5x/configs/runs/finetune_{TASK}.gin`, however this puts the pretrained checkpoint at risk unless backed up. -WARNING: Finetuning is configured by default to save every checkpoint and delete none (to avoid accidentally deleting your pretrained checkpoint). Watch your disk space! This behavior can be changed in `t5x/configs/runs/finetune_{TASK}.gin`, however this puts the pretrained checkpoint at risk unless backed up. \ No newline at end of file +Finetuning: +MNLI v2: +```sh +t5x/contrib/gpu/scripts_gpu/singlenode_ft_frompile.sh \ + mnli2 \ + small \ + bfloat16 \ + 8 \ + 256 \ + {LOGDIR - create before running} \ + {MODEL_DIR(to restore pretrained checkpoint from)} \ + {GRADIENT_ACCUMULATION (1 by default)} \ + {MAKE_FT_DIR (false by default)} + {ENABLE_FP8 (1 by default)} \ + {TRANSPOSE_BS (1 by default)} \ + {FUSE_QKV (1 by default)} \ + {PACK (0 by default)} +``` + +WARNING: Finetuning is configured by default to save every checkpoint and delete none (to avoid accidentally deleting your pretrained checkpoint). Watch your disk space! This behavior can be changed in `t5x/configs/runs/finetune_{TASK}.gin`, however this puts the pretrained checkpoint at risk unless backed up. +# Changelog +- Added Transformer Engine + FP8 support +- Added the Transposed Batch-Sequence GPU optimization +- A100 Perf gains! (BF16) + - 80% speedup - T5-small + - 23% speedup - T5-large + - 18% speedup - T5-xl + - 40% speedup - T5-xxl +- H100 FP8 support, with gains over A100 + - 2.08x faster - T5-large (FP8) + - 2.24x faster - T5-xl (FP8) diff --git a/t5x/contrib/gpu/Dockerfile b/t5x/contrib/gpu/Dockerfile index 4ab560e01..5c904350b 100644 --- a/t5x/contrib/gpu/Dockerfile +++ b/t5x/contrib/gpu/Dockerfile @@ -1,14 +1,12 @@ -ARG FROM_IMAGE_NAME=nvcr.io/nvidia/tensorflow:22.08-tf2-py3 +ARG FROM_IMAGE_NAME=ghcr.io/nvidia/jax-toolbox-internal:5061977725-te FROM ${FROM_IMAGE_NAME} -# Install the latest jax -RUN pip install jax[cuda]==0.4.1 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html - # setup directory paths for T5x ENV TFDS_DATA_DIR=/t5x_home/datasets/ ENV T5X_DIR=/t5x_home/ ENV T5X_WORKSPACE_DIR=/t5x_home/workspace ENV PYTHONPATH=/t5x_home/ + WORKDIR /t5x_home # install the requirements for T5x diff --git a/t5x/contrib/gpu/README.md b/t5x/contrib/gpu/README.md index 6e7cc57d2..7208713bc 100644 --- a/t5x/contrib/gpu/README.md +++ b/t5x/contrib/gpu/README.md @@ -1,90 +1,2 @@ # GPU Scripts - -# Warning! -An updated version of T5x with optimized GPU performance and new features, including FP8 with [Transformer Engine](https://github.com/NVIDIA/TransformerEngine) and H100 support can be found here: [NVIDIA Rosetta](https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/t5x). ------ -**NVIDIA no longer recommends using this repository and won't be updating it further.** ------ - -The [t5x/contrib/gpu/scripts_gpu](scripts_gpu) directory contains scripts optimized for GPU usage. - -To get all dependencies for the Pile dataset, install with the `gpu` extra: -```bash -pip install '.[gpu]' -``` - -## Building the container -The Dockerfile in `t5x/contrib/gpu` given will build a container with all gpu/pile dependencies. It can be built with `t5x/contrib/gpu/docker/build.sh ` - -## Running interactively -Note: this should only be done with singlenode jobs and/or for downloading the pile. Use `t5x/contrib/gpu/docker/interactive_pull_and_launch.sh`. This takes arguments for the URL to pull a container from and the location of the dataset directory to mount. For example: - -`t5x/contrib/gpu/docker/interactive_pull_and_launch.sh [URL] /my/dataset/dir` - -## Downloading The Pile -Run `download_the_pile.py` to download the pile. It will download to the directory set in the environment variable: `TFDS_DATA_DIR`. After that, set the `TFDS_DATA_DIR` to the same directory in your scripts to use. - -## Single Node runs -Pretraining and Finetuning can be done with `singlenode_*.sh`. These will build a T5X model with the Adam optimizer and relevant parameters. These will allow multi-gpu on one host. - -## Multi Node runs -For a SLURM+pyxis cluster, `example*.sub` files provide example slurm submit files (edit with your details), which call `multiprocess*.sh` to execute training. You can add a binding script in the `.sub` file for your cluster, or remove it entirely (dropping some throughput) - -## Convergence -For our Pile convergence runs, we used a Global batch size of 2304 for XXL and 2048 for all other models, where GBS is defined as #GPUs * BS/GPU / Tensor Parallel(TP). Below are example (tested) hardware topologies on NVIDIA DGX A100 (8x A100 80G) nodes. - -| size | #GPUs | TP | BS / GPU | Sequences/Sec | Estimated Walltime | MNLI 2.0 - matched | SQuAD v1.1 (EM/F1) | Convergence Log | -| ---- | ----- | ----- | -------- | ------------- | ------------------ | ------------------ | ------------------ | --------------- | -| small| 8 | 1 | 256 | ~3168 | 7.48 days | 83.06% | 78.33 / 86.63 | [log](https://tensorboard.dev/experiment/lWnHal7PRnOLeZuewyWVxQ/#scalars&_smoothingWeight=0) | -| large| 64 | 1 | 32 | ~3886 | 6.10 days | 90.50% | 87.31 / 94.04 | [log](https://tensorboard.dev/experiment/aOxJBIvTQBeTJ8XGXxaL6Q/#scalars&_smoothingWeight=0) | -| xl | 256 | 1 | 8 | ~3652 | 6.49 days | 91.15% | 89.36 / 95.29 | [log](https://tensorboard.dev/experiment/vuRoEYgkRgWiEtbvgxlOqw/#scalars&_smoothingWeight=0) | -| xxl | 512 | 8 | 36 | ~1346 | 19.81 days | N/A(partial run) | N/A(partial run) | N/A(partial run)| - -Note: Convergence (as shown in log) was not necessarily done with the hardware topology listed, but the listed topology is tested. Estimated Walltime is calculated assuming full throughput (seq/sec) continuously. In practice, there are compilation overheads at the beginning of each run/restart(in cluster settings) + checkpointing overheads (if any). - -(More perf improvements coming soon!) - -Other hyperparameters are specified in the associated pile `gin` files in the `contrib/gpu/t5/t5_1_1/examples` directory. - -## Pretraining run commands - -### Singlenode -small: - -`t5x/contrib/gpu/t5/scripts_gpu/singlenode_pretrain_pile.sh small bfloat16 8 256 {LOGDIR - create before running} {MODEL_DIR} {GRADIENT_ACCUMULATION (1 by default)}` - -Finetuning: -MNLI v2: -`t5x/contrib/gpu/t5/scripts_gpu/singlenode_ft_frompile.sh mnli2 small bfloat16 8 256 {LOGDIR - create before running} {MODEL_DIR(to restore pretrained checkpoint from)} {GRADIENT_ACCUMULATION}` - - -### Multinode -Arguments are as such: - -`sbatch -N {NODE_CT} t5x/contrib/gpu/t5/scripts_gpu/example_slurm_pretrain_pile.sub {MODEL_SIZE} {MODEL_PREC} {GPU/NODE} {BS/GPU} {MODEL_DIR} {GRADIENT_ACCUMULATION} {TENSOR_PARALLEL}` - -small: - -`sbatch -N 1 t5x/contrib/gpu/t5/scripts_gpu/example_slurm_pretrain_pile.sub small bfloat16 8 256 {MODEL_DIR} 1 1` - -large: - -`sbatch -N 8 t5x/contrib/gpu/t5/scripts_gpu/example_slurm_pretrain_pile.sub large bfloat16 8 32 {MODEL_DIR} 1 1` - -xl: - -`sbatch -N 32 t5x/contrib/gpu/t5/scripts_gpu/example_slurm_pretrain_pile.sub xl bfloat16 8 8 {MODEL_DIR} 1 1` - -Finetuning commands simply change the script and have an additional `{FT_TASK}` as the first argument (along with relevant hyperparameter changes). Your `MODEL_DIR` should contain the pretrained checkpoint to restore from. - -MNLI v2: - -`sbatch -N {NODE_CT} t5x/contrib/gpu/t5/scripts_gpu/example_slurm_ft_frompile.sub mnli2 {MODEL_SIZE} {MODEL_PREC} {GPU/NODE} {BS/GPU} {MODEL_DIR} {GRADIENT_ACCUMULATION} {TENSOR_PARALLEL}` - -SQuAD v1.1 - -`sbatch -N {NODE_CT} t5x/contrib/gpu/t5/scripts_gpu/example_slurm_ft_frompile.sub squad1 {MODEL_SIZE} {MODEL_PREC} {GPU/NODE} {BS/GPU} {MODEL_DIR} {GRADIENT_ACCUMULATION} {TENSOR_PARALLEL}` - -On all finetuning runs, we use a Global Batch Size of 128 with bfloat16 precision. - -WARNING: Finetuning is configured by default to save every checkpoint and delete none (to avoid accidentally deleting your pretrained checkpoint). Watch your disk space! This behavior can be changed in `t5x/configs/runs/finetune_{TASK}.gin`, however this puts the pretrained checkpoint at risk unless backed up. +This folder containers scripts that help run optimized T5x code on GPU with FP8 support. Please refer to [Rosetta T5X README](https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/t5x/README.md) for further guides. diff --git a/t5x/contrib/gpu/T5X_TE_README.md b/t5x/contrib/gpu/T5X_TE_README.md new file mode 100644 index 000000000..f182f9b3d --- /dev/null +++ b/t5x/contrib/gpu/T5X_TE_README.md @@ -0,0 +1,99 @@ +# T5X with Transformer Engine Summary # + +**Highlight:** +1. Add `TransformerEngineHelper` to allow users to switch with or without Transformer Engine. +2. Add the feature of transposing batch_size and sequence to accelerate performance. +2. Hide FP8 metadata in `flax_mutable`. The flax_mutable is a variable collection that originally is declared by T5X. + +## The *.gin files ## +They are configurations to set up T5X. The major change is to replace the AdaFactor optimizer with AdamW because of performance concerns. In old XLA, using AdaFactor will generate a lot of D2D copies and slow down the performance. Although the issue was resolved, we used AdamW to verify convergence and performance tests for now. + +## network.py ## +1. The `TransformerEngineHelper` is a singleton to manage ON/OFF Transformer Engine, to hide the if-else statement inside. The pseudo code is like: + ```python + class TransformerEngineHelper: + @staticmethod: + def foo(x): + if _IS_TRANSFORMER_ENGINE_INSTALLED and use_te: + y = TransformerEngine.foo(x) + else: + y = T5X.foo(x) + return y + ``` +2. The input tensor is BATCH_SEQ_HIDDEN format (i.e., batch_size, sequence, ...) by default. If `cfg.transpose_batch_sequence` is True, transpose input tensor to SEQ_BATCH_HIDDEN format because using SEQ_BATCH_HIDDEN is faster for now. It might not be necessary after integrating cuDNN MHA. And according to `output_format` to decide whether to transpose output tensor or not. It is for easy debugging. +3. The reason to rename the mask from `encoder_mask`/`decoder_mask` to `attention_mask` is to align the kwargs of TransformerLayer between T5X and Transformer Engine. The original T5X TransformerLayer has a different parameter list than the Transformer Engine. It blocks us from making a functor to switch two of them. The pseudo code is like: + ```python + if use_te: + TransformerLayer = te.TransformerLayer + else: + TransformerLayer = t5x.TransformerLayer + + y = TransformerLayer(x, attention_mask=mask) + ``` +4. The `TransformerEngineHelper.get_attn_mask(*_mask)` is used to convert the T5X mask to the format required by Transformer Engine. In T5X, `1` means keep and `0` means drop, but in Transformer Engine, the meaning is reversed. + +## utils.py ## +1. The `jax.eval_shape` has to be wrapped by `TransformerEngineHelper.eval_shape_guard()` because the `ShardingResource` must be set first. Otherwise, xmap cannot infer the shape of each layer of the model, and an exception will be thrown. +2. The `flax_mutables` is a variable collection that contains FP8 metadata and sharding information (e.g., named logical axis). It is required by FP8 training and tensor parallelism. + +## trainer.py ## +1. At the code: `grad_fn = jax.value_and_grad(model.loss_fn, argnums=(0, 3), ...)`, the number `0` refers to 1st argument of loss_fn, and the number `3` refers to 4th argument of loss_fn. The 1st argument is input tensor. The 4th argument is the `flax_mutables` which contains FP8 metadata. In order to get the updated FP8 metadata after 1 training step, we need to ask JAX to differentiate `flax_mutables`. Note that, in fact, FP8 metadata is NOT calculated by differentiation. The FP8 metadata is maintained by the Transformer Engine. It is a trick to get the updated FP8 metadata because we didn't find other interfaces or approaches to get it. +2. At the code: + ```diff + - initial_flax_mutables = train_state.flax_mutables if train_state.flax_mutables else None + + initial_flax_mutables = train_state.flax_mutables if train_state.flax_mutables else {} + ``` + The `None` should be a T5X bug. It will trigger exceptions if `flax_mutables` needs to be filled into JAX routines. Although T5X declares the `flax_mutables`, it actually doesn't use it. Thus, T5X developers weren't aware of this issue. +3. The `grad_accum` becomes a list of variable collection because two variables are differentiated. The 1st is model parameters. The 2nd is FP8 metadata. +4. At the code: + ```python + grad_accum = (grad_accum[0], + TransformerEngineHelper.update_fp8_metas( + grad_accum[1], flax_mutables, train_state.step)) + ``` + It is a workaround due to the T5X (or JAX) bug. We don't know the root-cause yet and don't have time to investigate it. The bug is that T5X always misses 1 time of accumulating gradients. For example, if the accumulation step is 10, T5X should run micro-batch 10 times and accumulate the gradient of each micro-batch but it only accumulates gradient 9 times. If the accumulation step is 1, T5X doesn't update the gradient. Thus, the workaround is to accumulate the gradient 1 time manually. + +## train_state.py ## +1. Add `flax_mutables_axes`, so xmap can know how to do the sharding for FP8 metadata. + +## train.py ## +1. Import `TransformerEngineHelper` and initialize it. + +## te_helper.py ## +1. A new file contains the `TransformerEngineHelper` implementation. Note that it uses Transformer Engine internal API - `FP8Helper.initialize` and `FP8Helper.finalize`. It is a trade off between the number of lines of code changes and the recommended way for enabling FP8 training. The recommended approach is: + ```python + with te.fp8_autocast(fp8_format, ...): + model = Net() + variable_collection = model.init(rng, inputs) + state = TrainState.create(apply_fn=model.apply, ...) + train_epoch(state, dataset) + ``` + It is equal to: + ```python + FP8Helper.initialize(fp8_format, ...) # allocate FP8 metadata and setup + model = Net() + variable_collection = model.init(rng, inputs) + state = TrainState.create(apply_fn=model.apply, ...) + train_epoch(state, dataset) + FP8Helper.finalize() # release FP8 metadata + ``` + +## partitioning.py ## +1. Append the sharding rules needed by Transformer Engine after T5X's rues + +## models.py ## +1. Add `eval_fn` because a new argument - `flax_mutable` is needed. +2. Add `predict_batch` because a new argument - `flax_mutable` is needed. +3. At the code: + ```python + module.apply( + {'params': params, **flax_mutable}, + ... + ) + ``` + The module.apply only accepts 1 variable collection, so model parameters and FP8 metadata need to be merged before filled into apply. +4. The `cache_offset` indicates which dimension is batch_size, for beam-search. Thus, it must be changed if `cfg.transpose_batch_sequence` is True. + +## run_t5x_*.sh ## +1. They are shell scripts for convenience in running experiments. + diff --git a/t5x/contrib/gpu/scripts_gpu/example_slurm_ft_frompile.sub b/t5x/contrib/gpu/scripts_gpu/example_slurm_ft_frompile.sub index 966ea6928..19ec14ee9 100755 --- a/t5x/contrib/gpu/scripts_gpu/example_slurm_ft_frompile.sub +++ b/t5x/contrib/gpu/scripts_gpu/example_slurm_ft_frompile.sub @@ -48,19 +48,28 @@ T5X_WORKSPACE_DIR=/t5x_home/workspace MOUNTS="--container-mounts=$BASE_T5X_DIR:/$T5X_DIR,$BASE_TFDS_DATA_DIR:/$TFDS_DATA_DIR,$BASE_T5X_WORKSPACE_DIR:$T5X_WORKSPACE_DIR" # Add T5x/JAX specific exports -EXPORTS="--export=ALL,TFDS_DATA_DIR=${TFDS_DATA_DIR},T5X_DIR=${T5X_DIR},T5X_WORKSPACE_DIR=${T5X_WORKSPACE_DIR}" +EXPORTS="--export=ALL,TFDS_DATA_DIR=${TFDS_DATA_DIR},T5X_DIR=${T5X_DIR},T5X_WORKSPACE_DIR=${T5X_WORKSPACE_DIR},PYTHONPATH=${T5X_DIR}" #------------------------------------------------------------------------------- -# Command line arguments needed by the underlying scripts -TASK=$1 # mnli2 or squad1, add others with corresponding gin files -T5_SIZE=$2 # small, base, large, xl, xxl -PREC="$3" # bfloat16, float32 -GPUS_PER_NODE=$4 # usually 8 -BSIZE_PER_GPU=$5 # local batch size/gpu -MODEL_DIR_LOCAL=$6 # directory to save checkpoints and config dump to -NUM_MICROBATCHES=$7 # number of gradient accumulation steps - -NUM_GPUS=$(( GPUS_PER_NODE * SLURM_JOB_NUM_NODES )) +FT_TASK=${FT_TASK:-mnli2} +PREC=${PREC:="bfloat16"} +T5_SIZE=${T5_SIZE:="large"} +BSIZE_PER_GPU=${BSIZE_PER_GPU:=32} +ENC_SL=${ENC_SL:=512} +DEC_SL=${DEC_SL:=128} +NUM_MICROBATCHES=${NUM_MICROBATCHES:=1} +ENABLE_FP8=${ENABLE_FP8:=1} +TP_SIZE=${TP_SIZE:=1} +TRANSPOSE_BS=${TRANSPOSE_BS:=1} +MODEL_DIR=${MODEL_DIR:=model_dir} +FUSE_QKV=${FUSE_QKV:=1} +PACK=${PACK:=0} + +export GPUS_PER_NODE=${1:-8} +export BASE_SCRIPT=${2:-"${T5X_DIR}/t5x/contrib/gpu/scripts_gpu/multiprocess_ft_frompile.sh"} +export WITH_MP=1 + +NUM_GPUS=$((GPUS_PER_NODE * SLURM_JOB_NUM_NODES )) # redirect both stdout and stderr in the same file for ease of analysis OUTDIR="outputs/multinode/${TASK}_t5_${T5_SIZE}-prec_${PREC}-nodes_${SLURM_JOB_NUM_NODES}-gpus_${NUM_GPUS}-bs_${BSIZE_PER_GPU}-sl_${SL}" @@ -73,16 +82,17 @@ LOGDIR="${T5X_WORKSPACE_DIR}/${OUTDIR}" # You can add binding to the command below with the following line (after nvidia-smi). Remove the '&&' on the next bash line. # && bash <>/bind.sh --cpu=exclusive --ib=single -- \ read -r -d '' cmd <>/bind.sh --cpu=exclusive --ib=single -- \ read -r -d '' cmd < \ - ${LOG_DIR}/ft_${T5_SIZE}_gpu_${NUM_GPUS}_${PREC}_gbs_${BSIZE}.log + --gin.infer_eval/utils.DatasetConfig.batch_size=${BSIZE} \ + --gin.train/utils.DatasetConfig.pack=${PACK} \ + --gin.train_eval/utils.DatasetConfig.pack=${PACK} \ + --gin.train.te_config_cls=@te_helper.TransformerEngineConfig \ + --gin.te_helper.TransformerEngineConfig.enabled=${ENABLE_FP8} \ + --gin.te_helper.TransformerEngineConfig.fp8_format=\"hybrid\" \ + --gin.network.T5Config.transpose_batch_sequence=${TRANSPOSE_BS} \ + --gin.network.T5Config.fuse_qkv_params=${FUSE_QKV} \ + &> \ + ${LOG_DIR}/ft_${T5_SIZE}_gpu_${NUM_GPUS}_${PREC}_gbs_${BSIZE}_fp8_${ENABLE_FP8}_fuseqkv_${FUSE_QKV}_transbs_${TRANSPOSE_BS}.log diff --git a/t5x/contrib/gpu/scripts_gpu/singlenode_pretrain_pile.sh b/t5x/contrib/gpu/scripts_gpu/singlenode_pretrain_pile.sh index c82322a2c..def1a1a78 100755 --- a/t5x/contrib/gpu/scripts_gpu/singlenode_pretrain_pile.sh +++ b/t5x/contrib/gpu/scripts_gpu/singlenode_pretrain_pile.sh @@ -29,6 +29,11 @@ LOG_DIR=$5 # Output log directory MODEL_DIR_LOCAL=${6:-"model_dir"} MODEL_DIR=${PWD}/${MODEL_DIR_LOCAL} NUM_MICROBATCHES=${7:-0} +ENABLE_FP8=${8:-1} +[[ $ENABLE_FP8 -eq 1 ]] && PREC='bfloat16' # Required for t5x te fp8 to work +TRANSPOSE_BS=${9:-1} +FUSE_QKV=${10:-1} +PACK=${11:-0} echo $MODEL_DIR @@ -49,5 +54,13 @@ python3 -u ${T5X_DIR}/t5x/train.py \ --gin.train/utils.DatasetConfig.batch_size=${BSIZE} \ --gin.trainer.Trainer.num_microbatches=${NUM_MICROBATCHES} \ --gin.train_eval/utils.DatasetConfig.batch_size=${BSIZE} \ - --gin.infer_eval/utils.DatasetConfig.batch_size=${BSIZE} &> \ - ${LOG_DIR}/${T5_SIZE}_gpu_${NUM_GPUS}_${PREC}_gbs_${BSIZE}.log + --gin.infer_eval/utils.DatasetConfig.batch_size=${BSIZE} \ + --gin.train/utils.DatasetConfig.pack=${PACK} \ + --gin.train_eval/utils.DatasetConfig.pack=${PACK} \ + --gin.train.te_config_cls=@te_helper.TransformerEngineConfig \ + --gin.te_helper.TransformerEngineConfig.enabled=${ENABLE_FP8} \ + --gin.te_helper.TransformerEngineConfig.fp8_format=\"hybrid\" \ + --gin.network.T5Config.transpose_batch_sequence=${TRANSPOSE_BS} \ + --gin.network.T5Config.fuse_qkv_params=${FUSE_QKV} \ + &> \ + ${LOG_DIR}/${T5_SIZE}_gpu_${NUM_GPUS}_${PREC}_gbs_${BSIZE}_fp8_${ENABLE_FP8}_fuseqkv_${FUSE_QKV}_transbs_${TRANSPOSE_BS}.log diff --git a/t5x/contrib/gpu/t5/configs/runs/finetune.gin b/t5x/contrib/gpu/t5/configs/runs/finetune.gin index a76d80957..0d68c29e1 100644 --- a/t5x/contrib/gpu/t5/configs/runs/finetune.gin +++ b/t5x/contrib/gpu/t5/configs/runs/finetune.gin @@ -41,6 +41,7 @@ TASK_FEATURE_LENGTHS = %gin.REQUIRED MIXTURE_OR_TASK_MODULE = %gin.REQUIRED TRAIN_STEPS = %gin.REQUIRED INITIAL_CHECKPOINT_PATH = %gin.REQUIRED +RESET_STATE_AFTER = None # a flag to reset optimizer and fp8 states (if exist) after a set number of sets (i.e. after pretraining) # Commonly overridden DROPOUT_RATE = 0.1 @@ -81,6 +82,7 @@ train_script.train: use_hardware_rng = %USE_HARDWARE_RNG summarize_config_fn = @gin_utils.summarize_gin_config inference_evaluator_cls = @seqio.Evaluator + reset_state_after = %RESET_STATE_AFTER partitioning.PjitPartitioner: num_partitions = 1 @@ -103,7 +105,7 @@ train/utils.DatasetConfig: shuffle = True seed = None # use a new seed each run/restart use_cached = %USE_CACHED_TASKS - pack = True + pack = False module = %MIXTURE_OR_TASK_MODULE train_eval/utils.DatasetConfig: @@ -114,7 +116,7 @@ train_eval/utils.DatasetConfig: shuffle = False seed = 42 use_cached = %USE_CACHED_TASKS - pack = True + pack = False module = %MIXTURE_OR_TASK_MODULE infer_eval/utils.DatasetConfig: diff --git a/t5x/contrib/gpu/t5/configs/runs/finetune_mnli.gin b/t5x/contrib/gpu/t5/configs/runs/finetune_mnli.gin index 7cebbb6d9..da4482710 100644 --- a/t5x/contrib/gpu/t5/configs/runs/finetune_mnli.gin +++ b/t5x/contrib/gpu/t5/configs/runs/finetune_mnli.gin @@ -40,6 +40,7 @@ MIXTURE_OR_TASK_NAME = %gin.REQUIRED TASK_FEATURE_LENGTHS = %gin.REQUIRED MIXTURE_OR_TASK_MODULE = %gin.REQUIRED TRAIN_STEPS = %gin.REQUIRED +RESET_STATE_AFTER = None # a flag to reset optimizer and fp8 states (if exist) after a set number of sets (i.e. after pretraining) INITIAL_CHECKPOINT_PATH = %gin.REQUIRED # Commonly overridden @@ -80,6 +81,7 @@ train_script.train: use_hardware_rng = %USE_HARDWARE_RNG summarize_config_fn = @gin_utils.summarize_gin_config inference_evaluator_cls = @seqio.Evaluator + reset_state_after = %RESET_STATE_AFTER partitioning.PjitPartitioner: num_partitions = 1 @@ -102,7 +104,7 @@ train/utils.DatasetConfig: shuffle = True seed = None # use a new seed each run/restart use_cached = %USE_CACHED_TASKS - pack = True + pack = False module = %MIXTURE_OR_TASK_MODULE train_eval/utils.DatasetConfig: @@ -113,7 +115,7 @@ train_eval/utils.DatasetConfig: shuffle = False seed = 42 use_cached = %USE_CACHED_TASKS - pack = True + pack = False module = %MIXTURE_OR_TASK_MODULE infer_eval/utils.DatasetConfig: diff --git a/t5x/contrib/gpu/t5/configs/runs/finetune_squad1.gin b/t5x/contrib/gpu/t5/configs/runs/finetune_squad1.gin index 4ea952c91..b1d8e7ede 100644 --- a/t5x/contrib/gpu/t5/configs/runs/finetune_squad1.gin +++ b/t5x/contrib/gpu/t5/configs/runs/finetune_squad1.gin @@ -40,6 +40,7 @@ MIXTURE_OR_TASK_NAME = %gin.REQUIRED TASK_FEATURE_LENGTHS = %gin.REQUIRED MIXTURE_OR_TASK_MODULE = %gin.REQUIRED TRAIN_STEPS = %gin.REQUIRED +RESET_STATE_AFTER = None # a flag to reset optimizer and fp8 states (if exist) after a set number of sets (i.e. after pretraining) INITIAL_CHECKPOINT_PATH = %gin.REQUIRED # Commonly overridden @@ -80,6 +81,7 @@ train_script.train: use_hardware_rng = %USE_HARDWARE_RNG summarize_config_fn = @gin_utils.summarize_gin_config inference_evaluator_cls = @seqio.Evaluator + reset_state_after = %RESET_STATE_AFTER partitioning.PjitPartitioner: num_partitions = 1 @@ -102,7 +104,7 @@ train/utils.DatasetConfig: shuffle = True seed = None # use a new seed each run/restart use_cached = %USE_CACHED_TASKS - pack = True + pack = False module = %MIXTURE_OR_TASK_MODULE train_eval/utils.DatasetConfig: @@ -113,7 +115,7 @@ train_eval/utils.DatasetConfig: shuffle = False seed = 42 use_cached = %USE_CACHED_TASKS - pack = True + pack = False module = %MIXTURE_OR_TASK_MODULE infer_eval/utils.DatasetConfig: diff --git a/t5x/contrib/gpu/t5/configs/runs/pretrain.gin b/t5x/contrib/gpu/t5/configs/runs/pretrain.gin index de1286467..e9807f901 100644 --- a/t5x/contrib/gpu/t5/configs/runs/pretrain.gin +++ b/t5x/contrib/gpu/t5/configs/runs/pretrain.gin @@ -73,7 +73,7 @@ train/utils.DatasetConfig: shuffle = %SHUFFLE_TRAIN_EXAMPLES seed = None # use a new seed each run/restart use_cached = %USE_CACHED_TASKS - pack = True + pack = False module = %MIXTURE_OR_TASK_MODULE train_eval/utils.DatasetConfig: @@ -84,7 +84,7 @@ train_eval/utils.DatasetConfig: shuffle = False seed = 42 use_cached = %USE_CACHED_TASKS - pack = True + pack = False module = %MIXTURE_OR_TASK_MODULE utils.CheckpointConfig: diff --git a/t5x/contrib/gpu/t5/network.py b/t5x/contrib/gpu/t5/network.py index dd6107869..2be008c0e 100644 --- a/t5x/contrib/gpu/t5/network.py +++ b/t5x/contrib/gpu/t5/network.py @@ -13,14 +13,18 @@ # limitations under the License. """T5.1.1 Transformer model.""" - from typing import Any, Sequence +from enum import Enum from flax import linen as nn from flax import struct import jax.numpy as jnp from t5x.contrib.gpu.t5 import layers +from t5x.te_helper import TransformerEngineHelper +class SeqDataFormat(Enum): + BATCH_SEQ_HIDDEN = 'bsh' + SEQ_BATCH_HIDDEN = 'sbh' @struct.dataclass class T5Config: @@ -43,6 +47,10 @@ class T5Config: float32_attention_logits: bool = False # Whether to scale attention logits by sqrt(d_k). Default to False for adafactor scale_attn_logits: bool = False + # Whether to transpose batch and sequence to avoid explicit transposes in MHA + transpose_batch_sequence: bool = False + # Whether to fuse the QKV proj in MHA + fuse_qkv_params: bool = False class EncoderLayer(nn.Module): @@ -51,12 +59,13 @@ class EncoderLayer(nn.Module): relative_embedding: nn.Module @nn.compact - def __call__(self, inputs, encoder_mask=None, deterministic=False): + def __call__(self, inputs, attention_mask=None, deterministic=False): cfg = self.config # Relative position embedding as attention biases. - encoder_bias = self.relative_embedding(inputs.shape[-2], inputs.shape[-2], - True) + sequence_dim = 0 if cfg.transpose_batch_sequence else 1 + encoder_bias = self.relative_embedding( + inputs.shape[sequence_dim], inputs.shape[sequence_dim], True) # Attention block. assert inputs.ndim == 3 @@ -72,7 +81,7 @@ def __call__(self, inputs, encoder_mask=None, deterministic=False): float32_logits=cfg.float32_attention_logits, name='attention', scale_attn_logits=cfg.scale_attn_logits)( - x, x, encoder_mask, encoder_bias, deterministic=deterministic) + x, x, attention_mask, encoder_bias, deterministic=deterministic) x = nn.Dropout( rate=cfg.dropout_rate, broadcast_dims=(-2,))( x, deterministic=deterministic) @@ -105,7 +114,7 @@ class DecoderLayer(nn.Module): def __call__(self, inputs, encoded, - decoder_mask=None, + attention_mask=None, encoder_decoder_mask=None, deterministic=False, decode=False, @@ -113,7 +122,8 @@ def __call__(self, cfg = self.config # Relative position embedding as attention biases. - l = max_decode_length if decode and max_decode_length else inputs.shape[-2] + sequence_dim = 0 if cfg.transpose_batch_sequence else 1 + l = max_decode_length if decode and max_decode_length else inputs.shape[sequence_dim] decoder_bias = self.relative_embedding(l, l, False) # inputs: embedded inputs to the decoder with shape [batch, length, emb_dim] @@ -132,7 +142,7 @@ def __call__(self, scale_attn_logits=cfg.scale_attn_logits)( x, x, - decoder_mask, + attention_mask, decoder_bias, deterministic=deterministic, decode=decode) @@ -185,8 +195,11 @@ class Encoder(nn.Module): def __call__(self, encoder_input_tokens, encoder_mask=None, - deterministic=False): - cfg = self.config + deterministic=False, + output_format=SeqDataFormat.BATCH_SEQ_HIDDEN): + + cfg = TransformerEngineHelper.get_t5x_config(self.config) + assert encoder_input_tokens.ndim == 2 # [batch, length] rel_emb = layers.RelativePositionBiases( num_buckets=32, @@ -199,6 +212,9 @@ def __call__(self, # [batch, length] -> [batch, length, emb_dim] x = self.shared_embedding(encoder_input_tokens.astype('int32')) + if cfg.transpose_batch_sequence: + # [batch, length, emb_dim] -> [length, batch, emb_dim] + x = x.transpose((1, 0, 2)) x = nn.Dropout( rate=cfg.dropout_rate, broadcast_dims=(-2,))( x, deterministic=deterministic) @@ -206,12 +222,20 @@ def __call__(self, for lyr in range(cfg.num_encoder_layers): # [batch, length, emb_dim] -> [batch, length, emb_dim] - x = EncoderLayer( + encoder_lyr = TransformerEngineHelper.get_encoder_layer( config=cfg, relative_embedding=rel_emb, - name=f'layers_{lyr}')(x, encoder_mask, deterministic) + name=f'layers_{lyr}', original_cls=EncoderLayer) + x = encoder_lyr(x, attention_mask=encoder_mask, deterministic=deterministic) x = layers.LayerNorm(dtype=cfg.dtype, name='encoder_norm')(x) - return nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=deterministic) + x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=deterministic) + + if (cfg.transpose_batch_sequence and output_format is SeqDataFormat.BATCH_SEQ_HIDDEN) or \ + (not cfg.transpose_batch_sequence and output_format is SeqDataFormat.SEQ_BATCH_HIDDEN): + x = x.transpose((1, 0, 2)) + + return x + class Decoder(nn.Module): @@ -228,8 +252,16 @@ def __call__(self, encoder_decoder_mask=None, deterministic=False, decode=False, - max_decode_length=None): - cfg = self.config + max_decode_length=None, + encoded_format=SeqDataFormat.BATCH_SEQ_HIDDEN, + output_format=SeqDataFormat.BATCH_SEQ_HIDDEN): + cfg = TransformerEngineHelper.get_t5x_config(self.config) + + if (cfg.transpose_batch_sequence and encoded_format is SeqDataFormat.BATCH_SEQ_HIDDEN) or \ + (not cfg.transpose_batch_sequence and encoded_format is SeqDataFormat.SEQ_BATCH_HIDDEN): + encoded = encoded.transpose((1, 0, 2)) + + assert decoder_input_tokens.ndim == 2 # [batch, len] rel_emb = layers.RelativePositionBiases( num_buckets=32, @@ -242,6 +274,10 @@ def __call__(self, # [batch, length] -> [batch, length, emb_dim] y = self.shared_embedding(decoder_input_tokens.astype('int32')) + if cfg.transpose_batch_sequence: + # [batch, length, emb_dim] -> [length, batch, emb_dim] + y = y.transpose((1, 0, 2)) + y = nn.Dropout( rate=cfg.dropout_rate, broadcast_dims=(-2,))( y, deterministic=deterministic) @@ -249,15 +285,16 @@ def __call__(self, for lyr in range(cfg.num_decoder_layers): # [batch, length, emb_dim] -> [batch, length, emb_dim] - y = DecoderLayer( - config=cfg, relative_embedding=rel_emb, name=f'layers_{lyr}')( - y, - encoded, - decoder_mask=decoder_mask, - encoder_decoder_mask=encoder_decoder_mask, - deterministic=deterministic, - decode=decode, - max_decode_length=max_decode_length) + decoder_lyr = TransformerEngineHelper.get_decoder_layer( + config=cfg, relative_embedding=rel_emb, + name=f'layers_{lyr}', original_cls=DecoderLayer) + y = decoder_lyr(y, + encoded, + attention_mask=decoder_mask, + encoder_decoder_mask=encoder_decoder_mask, + deterministic=deterministic, + decode=decode, + max_decode_length=max_decode_length) y = layers.LayerNorm(dtype=cfg.dtype, name='decoder_norm')(y) y = nn.Dropout( @@ -277,6 +314,11 @@ def __call__(self, kernel_axes=('embed', 'vocab'), name='logits_dense')( y) + + if (cfg.transpose_batch_sequence and output_format is SeqDataFormat.BATCH_SEQ_HIDDEN) or \ + (not cfg.transpose_batch_sequence and output_format is SeqDataFormat.SEQ_BATCH_HIDDEN): + # [length, batch, vocab_size] -> [batch, length, vocab_size] + logits = logits.transpose((1, 0, 2)) return logits @@ -285,7 +327,7 @@ class Transformer(nn.Module): config: T5Config def setup(self): - cfg = self.config + cfg = TransformerEngineHelper.get_t5x_config(self.config) self.shared_embedding = layers.Embed( num_embeddings=cfg.vocab_size, features=cfg.emb_dim, @@ -301,7 +343,8 @@ def setup(self): def encode(self, encoder_input_tokens, encoder_segment_ids=None, - enable_dropout=True): + enable_dropout=True, + output_format=SeqDataFormat.BATCH_SEQ_HIDDEN): """Applies Transformer encoder-branch on the inputs.""" cfg = self.config assert encoder_input_tokens.ndim == 2 # (batch, len) @@ -319,8 +362,11 @@ def encode(self, jnp.equal, dtype=cfg.dtype)) + encoder_mask = TransformerEngineHelper.get_attn_mask(encoder_mask) + return self.encoder( - encoder_input_tokens, encoder_mask, deterministic=not enable_dropout) + encoder_input_tokens, encoder_mask, deterministic=not enable_dropout, + output_format=output_format) def decode( self, @@ -333,7 +379,9 @@ def decode( decoder_positions=None, enable_dropout=True, decode=False, - max_decode_length=None): + max_decode_length=None, + encoded_format=SeqDataFormat.BATCH_SEQ_HIDDEN, + output_format=SeqDataFormat.BATCH_SEQ_HIDDEN): """Applies Transformer decoder-branch on encoded-input and target.""" cfg = self.config @@ -369,6 +417,9 @@ def decode( jnp.equal, dtype=cfg.dtype)) + decoder_mask = decoder_mask if decode else TransformerEngineHelper.get_attn_mask(decoder_mask) + encoder_decoder_mask = TransformerEngineHelper.get_attn_mask(encoder_decoder_mask) + logits = self.decoder( encoded, decoder_input_tokens=decoder_input_tokens, @@ -377,7 +428,9 @@ def decode( encoder_decoder_mask=encoder_decoder_mask, deterministic=not enable_dropout, decode=decode, - max_decode_length=max_decode_length) + max_decode_length=max_decode_length, + encoded_format=encoded_format, + output_format=output_format) return logits def __call__(self, @@ -412,10 +465,15 @@ def __call__(self, Returns: logits array from full transformer. """ + cfg = TransformerEngineHelper.get_t5x_config(self.config) + encoded_format = SeqDataFormat.BATCH_SEQ_HIDDEN + if cfg.transpose_batch_sequence: + encoded_format = SeqDataFormat.SEQ_BATCH_HIDDEN encoded = self.encode( encoder_input_tokens, encoder_segment_ids=encoder_segment_ids, - enable_dropout=enable_dropout) + enable_dropout=enable_dropout, + output_format=encoded_format) return self.decode( encoded, @@ -426,4 +484,5 @@ def __call__(self, decoder_segment_ids=decoder_segment_ids, decoder_positions=decoder_positions, enable_dropout=enable_dropout, - decode=decode) + decode=decode, + encoded_format=encoded_format) diff --git a/t5x/contrib/gpu/t5/t5_1_1/examples/large_mnli2_finetune_adam.gin b/t5x/contrib/gpu/t5/t5_1_1/examples/large_mnli2_finetune_adam.gin index 12e2ee6fa..901bd03ca 100644 --- a/t5x/contrib/gpu/t5/t5_1_1/examples/large_mnli2_finetune_adam.gin +++ b/t5x/contrib/gpu/t5/t5_1_1/examples/large_mnli2_finetune_adam.gin @@ -14,6 +14,7 @@ include "t5x/contrib/gpu/t5/configs/runs/finetune_mnli.gin" MIXTURE_OR_TASK_NAME = "glue_mnli_v2" TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 16} TRAIN_STEPS = 1_006_001 # 1000000 pre-trained steps + 6000 fine-tuning steps. +RESET_STATE_AFTER = 1_000_000 # 1000000 pre-trained steps, after which we should reset optimizer states and fp8 metas, if they exist DROPOUT_RATE = 0.1 INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/t5_1_1_large/checkpoint_1000000" # `LOSS_NORMALIZING_FACTOR`: When fine-tuning a model that was pre-trained diff --git a/t5x/contrib/gpu/t5/t5_1_1/examples/large_squad1_finetune_adam.gin b/t5x/contrib/gpu/t5/t5_1_1/examples/large_squad1_finetune_adam.gin index 87b896f58..8f49a6fb6 100644 --- a/t5x/contrib/gpu/t5/t5_1_1/examples/large_squad1_finetune_adam.gin +++ b/t5x/contrib/gpu/t5/t5_1_1/examples/large_squad1_finetune_adam.gin @@ -14,6 +14,7 @@ include "t5x/contrib/gpu/t5/configs/runs/finetune_squad1.gin" MIXTURE_OR_TASK_NAME = "squad_v010_allanswers" TASK_FEATURE_LENGTHS = {"inputs": 956, "targets": 256} TRAIN_STEPS = 1_006_001 # 1000000 pre-trained steps + 6000 fine-tuning steps. +RESET_STATE_AFTER = 1_000_000 # 1000000 pre-trained steps, after which we should reset optimizer states and fp8 metas, if they exist DROPOUT_RATE = 0.1 INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/t5_1_1_large/checkpoint_1000000" # `LOSS_NORMALIZING_FACTOR`: When fine-tuning a model that was pre-trained diff --git a/t5x/contrib/gpu/t5/t5_1_1/examples/small_mnli2_finetune_adam.gin b/t5x/contrib/gpu/t5/t5_1_1/examples/small_mnli2_finetune_adam.gin index 7f600f12b..5391f0ef1 100644 --- a/t5x/contrib/gpu/t5/t5_1_1/examples/small_mnli2_finetune_adam.gin +++ b/t5x/contrib/gpu/t5/t5_1_1/examples/small_mnli2_finetune_adam.gin @@ -14,6 +14,7 @@ include "t5x/contrib/gpu/t5/configs/runs/finetune_mnli.gin" MIXTURE_OR_TASK_NAME = "glue_mnli_v2" TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 16} TRAIN_STEPS = 1_015_001 # 1000000 pre-trained steps + 15000 fine-tuning steps. +RESET_STATE_AFTER = 1_000_000 # 1000000 pre-trained steps, after which we should reset optimizer states and fp8 metas, if they exist DROPOUT_RATE = 0.1 INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/t5_1_1_small/checkpoint_1000000" # `LOSS_NORMALIZING_FACTOR`: When fine-tuning a model that was pre-trained diff --git a/t5x/contrib/gpu/t5/t5_1_1/examples/small_squad1_finetune_adam.gin b/t5x/contrib/gpu/t5/t5_1_1/examples/small_squad1_finetune_adam.gin index ba5af03d9..c2b2797b0 100644 --- a/t5x/contrib/gpu/t5/t5_1_1/examples/small_squad1_finetune_adam.gin +++ b/t5x/contrib/gpu/t5/t5_1_1/examples/small_squad1_finetune_adam.gin @@ -14,6 +14,7 @@ include "t5x/contrib/gpu/t5/configs/runs/finetune_squad1.gin" MIXTURE_OR_TASK_NAME = "squad_v010_allanswers" TASK_FEATURE_LENGTHS = {"inputs": 956, "targets": 256} TRAIN_STEPS = 1_015_001 # 1000000 pre-trained steps + 15000 fine-tuning steps. +RESET_STATE_AFTER = 1_000_000 # 1000000 pre-trained steps, after which we should reset optimizer states and fp8 metas, if they exist DROPOUT_RATE = 0.1 INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/t5_1_1_small/checkpoint_1000000" # `LOSS_NORMALIZING_FACTOR`: When fine-tuning a model that was pre-trained diff --git a/t5x/contrib/gpu/t5/t5_1_1/examples/xl_mnli2_finetune_adam.gin b/t5x/contrib/gpu/t5/t5_1_1/examples/xl_mnli2_finetune_adam.gin index 7a5856225..75e0fe5b9 100644 --- a/t5x/contrib/gpu/t5/t5_1_1/examples/xl_mnli2_finetune_adam.gin +++ b/t5x/contrib/gpu/t5/t5_1_1/examples/xl_mnli2_finetune_adam.gin @@ -14,6 +14,7 @@ include "t5x/contrib/gpu/t5/configs/runs/finetune_mnli.gin" MIXTURE_OR_TASK_NAME = "glue_mnli_v2" TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 16} TRAIN_STEPS = 1_006_001 # 1000000 pre-trained steps + 6000 fine-tuning steps. +RESET_STATE_AFTER = 1_000_000 # 1000000 pre-trained steps, after which we should reset optimizer states and fp8 metas, if they exist DROPOUT_RATE = 0.1 INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/t5_1_1_xl/checkpoint_1000000" # `LOSS_NORMALIZING_FACTOR`: When fine-tuning a model that was pre-trained diff --git a/t5x/contrib/gpu/t5/t5_1_1/examples/xl_squad1_finetune_adam.gin b/t5x/contrib/gpu/t5/t5_1_1/examples/xl_squad1_finetune_adam.gin index 6a4c7b289..a1f50db91 100644 --- a/t5x/contrib/gpu/t5/t5_1_1/examples/xl_squad1_finetune_adam.gin +++ b/t5x/contrib/gpu/t5/t5_1_1/examples/xl_squad1_finetune_adam.gin @@ -14,6 +14,7 @@ include "t5x/contrib/gpu/t5/configs/runs/finetune_squad1.gin" MIXTURE_OR_TASK_NAME = "squad_v010_allanswers" TASK_FEATURE_LENGTHS = {"inputs": 956, "targets": 256} TRAIN_STEPS = 1_006_001 # 1000000 pre-trained steps + 6000 fine-tuning steps. +RESET_STATE_AFTER = 1_000_000 # 1000000 pre-trained steps, after which we should reset optimizer states and fp8 metas, if they exist DROPOUT_RATE = 0.1 INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/t5_1_1_xl/checkpoint_1000000" # `LOSS_NORMALIZING_FACTOR`: When fine-tuning a model that was pre-trained diff --git a/t5x/models.py b/t5x/models.py index 7433ea2e6..5891c08c1 100644 --- a/t5x/models.py +++ b/t5x/models.py @@ -43,6 +43,8 @@ import tensorflow as tf import typing_extensions +from t5x.te_helper import TransformerEngineHelper + # Remove _ShardedDeviceArray when users of t5x have their types updated _ShardedDeviceArray = Any Array = Union[np.ndarray, jnp.ndarray, _ShardedDeviceArray, tf.Tensor] @@ -135,6 +137,7 @@ def loss_fn( params: PyTree, batch: Mapping[str, jnp.ndarray], dropout_rng: Optional[jax.Array], + flax_mutables: Optional[PyTreeDef] = None, ) -> Tuple[jnp.ndarray, MetricsMap]: """Computes loss and metrics. @@ -155,6 +158,7 @@ def eval_fn( self, params: PyTree, batch: Mapping[str, jnp.ndarray], + flax_mutables: Optional[PyTreeDef] = None, ) -> Tuple[jnp.ndarray, MetricsMap]: """Computes loss and metrics during the evaluation. @@ -172,6 +176,7 @@ def eval_fn( params=params, batch=batch, dropout_rng=None, + flax_mutables=flax_mutables, ) def predict_batch( @@ -179,6 +184,7 @@ def predict_batch( params: PyTree, batch: Mapping[str, jnp.ndarray], rng: Optional[jax.Array] = None, + flax_mutables: Optional[PyTreeDef] = None, ) -> jnp.ndarray: """Predicts a batch of outputs from the model. @@ -190,7 +196,7 @@ def predict_batch( Returns: The model predictions. """ - return self.predict_batch_with_aux(params=params, batch=batch, rng=rng)[0] + return self.predict_batch_with_aux(params=params, batch=batch, rng=rng, flax_mutables=flax_mutables)[0] @abc.abstractmethod def predict_batch_with_aux( @@ -198,6 +204,7 @@ def predict_batch_with_aux( params: PyTree, batch: Mapping[str, jnp.ndarray], rng: Optional[jax.Array] = None, + flax_mutables: Optional[PyTreeDef] = None, ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: """Predict a batch from the modelwith auxiliary outputs. @@ -218,6 +225,7 @@ def score_batch( params: PyTree, batch: Mapping[str, jnp.ndarray], return_intermediates: bool = False, + flax_mutables: Optional[PyTreeDef] = None, ) -> jnp.ndarray: """Computes scores for batch.""" pass @@ -281,6 +289,7 @@ def _compute_logits( params: PyTree, batch: Mapping[str, jnp.ndarray], dropout_rng: Optional[jax.Array] = None, + flax_mutables: Optional[PyTreeDef] = None, ) -> jnp.ndarray: """Computes logits via a forward pass of the model.""" pass @@ -290,9 +299,11 @@ def loss_fn( params: PyTree, batch: Mapping[str, jnp.ndarray], dropout_rng: Optional[jax.Array], + flax_mutables: Optional[PyTreeDef] = None, ) -> Tuple[jnp.ndarray, MetricsMap]: """Loss function used for training with a cross-entropy loss.""" - logits = self._compute_logits(params, batch, dropout_rng) + logits = self._compute_logits(params, batch, dropout_rng, + flax_mutables=flax_mutables) loss_normalizing_factor: Optional[ Union[float, int, str, losses.SpecialLossNormalizingFactor] @@ -336,6 +347,31 @@ def loss_fn( ) return loss, metrics + def eval_fn( + self, + params: PyTreeDef, + batch: Mapping[str, jnp.ndarray], + flax_mutables: Optional[PyTreeDef] = None, + ) -> Tuple[jnp.ndarray, MetricsMap]: + """Computes loss and metrics during the evaluation. + + Args: + params: model parameters. + batch: a batch of inputs. + + Returns: + loss: the loss computed for the given inputs and parameters. + aux: + weight_sum: sum of the per-token weights applied to the loss. + metrics: a mapping of metrics computed for this batch. + """ + return self.loss_fn( + params=params, + batch=batch, + dropout_rng=None, + flax_mutables=flax_mutables, + ) + def _compute_metrics( self, logits: jnp.ndarray, @@ -459,15 +495,18 @@ def _compute_logits( # pytype: disable=signature-mismatch # jax-ndarray batch: Mapping[str, jnp.ndarray], dropout_rng: Optional[jax.Array] = None, mutable: flax_scope.CollectionFilter = False, - other_variables: Optional[PyTree] = None, + flax_mutables: Optional[PyTreeDef] = None, ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, flax_scope.FrozenVariableDict]]: """Computes logits via a forward pass of `self.module_cls`.""" # Dropout is provided only for the training mode. rngs = {'dropout': dropout_rng} if dropout_rng is not None else None - if other_variables is None: - other_variables = {} + if flax_mutables is None: + flax_mutables = {} return self.module.apply( - {'params': params, **other_variables}, + { + 'params': params, + **flax_mutables, + }, batch['encoder_input_tokens'], batch['decoder_input_tokens'], batch['decoder_target_tokens'], @@ -488,17 +527,25 @@ def _compute_logits_from_slice( encoded_inputs: jnp.ndarray, raw_inputs: jnp.ndarray, max_decode_length: int, + flax_mutables: Optional[PyTreeDef] = None, ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: """Token slice to logits from decoder model.""" flat_ids = decoding_state.cur_token flat_cache = decoding_state.cache + if flax_mutables is None: + flax_mutables = {} + # flat_ids: [batch * beam, seq_len=1] # cache is expanded inside beam_search to become flat_cache # flat_cache: [batch * beam, num_heads, depth_per_head, max_decode_len] # flat_logits: [batch * beam, seq_len=1, vocab] flat_logits, new_vars = self.module.apply( - {'params': params, 'cache': flat_cache}, + { + 'params': params, + 'cache': flat_cache, + **flax_mutables, + }, encoded_inputs, raw_inputs, # only needed for encoder padding mask flat_ids, @@ -521,6 +568,7 @@ def _compute_kv_cache( encoder_input_tokens: jnp.ndarray, decoder_input_tokens: jnp.ndarray, prefill_decoder_prompt: bool = False, + flax_mutables: Optional[PyTreeDef] = None, ) -> Tuple[PyTree, Optional[jnp.ndarray]]: """Initialize the key/value cache, with optional prompt. @@ -539,8 +587,14 @@ def _compute_kv_cache( initial_index: The index of the next position following prefill or None if `prefill_decoder_prompt` is False. """ + if flax_mutables is None: + flax_mutables = {} + del encoded_inputs _, initial_variables = self.module.apply( - {'params': params}, + { + 'params': params, + **flax_mutables, + }, encoder_input_tokens=jnp.ones_like(encoder_input_tokens), decoder_input_tokens=jnp.ones_like(decoder_input_tokens), decoder_target_tokens=jnp.ones_like(decoder_input_tokens), @@ -581,6 +635,24 @@ def _compute_kv_cache( return cache, inputs_lengths + def predict_batch(self, + params: PyTreeDef, + batch: Mapping[str, jnp.ndarray], + rng: Optional[jax.random.KeyArray] = None, + flax_mutables: Optional[PyTreeDef] = None) -> jnp.ndarray: + """Predicts a batch of outputs from the model. + + Args: + params: model parameters. + batch: a batch of inputs. + rng: an optional RNG to use during prediction (e.g., for decoding). + + Returns: + The model predictions. + """ + return self.predict_batch_with_aux(params=params, batch=batch, + rng=rng, flax_mutables=flax_mutables)[0] + def predict_batch_with_aux( self, params: PyTree, @@ -590,6 +662,7 @@ def predict_batch_with_aux( return_all_decodes: bool = None, num_decodes: int = None, # pytype:disable=annotation-type-mismatch prompt_with_targets: bool = False, + flax_mutables: Optional[PyTreeDef] = None, ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: """Predict with fast decoding beam search on a batch. @@ -646,11 +719,30 @@ def predict_batch_with_aux( return_all_decodes = self._default_decoder_params.return_all_decodes if num_decodes is None: num_decodes = self._default_decoder_params.num_decodes + if flax_mutables is None: + flax_mutables = {} # [batch, input_len] encoder_input_tokens = batch['encoder_input_tokens'] decoder_input_tokens = batch['decoder_input_tokens'] + # Prepare transformer fast-decoder call for beam search: for beam search, we + # need to set up our decoder model to handle a batch size equal to + # batch_size * num_decodes, where each batch item's data is expanded + # in-place rather than tiled. + # i.e. if we denote each batch element subtensor as el[n]: + # [el0, el1, el2] --> beamsize=2 --> [el0,el0,el1,el1,el2,el2] + # [batch * num_decodes, input_len, emb_dim] + encoded_inputs = decoding.flat_batch_beam_expand( + self.module.apply( + {'params': params, **flax_mutables}, + encoder_input_tokens, + enable_dropout=False, + method=self.module.encode, + ), + num_decodes, + ) + # `decoder_prompt_inputs` is initialized from the batch's # `decoder_input_tokens`. The EOS is stripped to avoid decoding to stop # after the prompt by matching to `output_vocabulary.eos_id`. @@ -691,6 +783,7 @@ def predict_batch_with_aux( encoder_input_tokens=encoder_input_tokens, decoder_input_tokens=decoder_prompt_inputs, prefill_decoder_prompt=prefill_decoder_prompt, + flax_mutables=flax_mutables, ) # Prepare transformer fast-decoder call for beam search: for beam search, we @@ -712,6 +805,7 @@ def predict_batch_with_aux( encoder_input_tokens, num_decodes ), max_decode_length=decoder_input_tokens.shape[1], + flax_mutables=flax_mutables, ) if decoder_params is None: @@ -737,6 +831,7 @@ def predict_batch_with_aux( # decodes: [batch, num_decodes, max_decode_len + 1] # scores: [batch, num_decodes] scanned = hasattr(self.module, 'scan_layers') and self.module.scan_layers + cfg = TransformerEngineHelper.get_t5x_config(self.module.config) if 'eos_id' not in decoder_params: decoder_params['eos_id'] = self.output_vocabulary.eos_id @@ -745,8 +840,8 @@ def predict_batch_with_aux( cache=cache, tokens_to_logits=tokens_ids_to_logits, num_decodes=num_decodes, - cache_offset=1 if scanned else 0, - **decoder_params, + cache_offset=1 if (scanned or cfg.transpose_batch_sequence) else 0, + **decoder_params ) # Beam search returns [n_batch, n_beam, n_length] with beam dimension sorted @@ -762,6 +857,7 @@ def score_batch( # pytype: disable=signature-mismatch # jax-ndarray params: PyTree, batch: Mapping[str, jnp.ndarray], return_intermediates: bool = False, + flax_mutables: Optional[PyTreeDef] = None, ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, Mapping[str, Any]]]: """Compute log likelihood score on a batch.""" weights = batch['decoder_loss_weights'] @@ -769,7 +865,8 @@ def score_batch( # pytype: disable=signature-mismatch # jax-ndarray if return_intermediates: logits, modified_variables = self._compute_logits( - params=params, batch=batch, mutable=['intermediates'] + params=params, batch=batch, mutable=['intermediates'], + flax_mutables=flax_mutables, ) # Inside self.module, we called nn.Module.sow to track various @@ -787,7 +884,7 @@ def score_batch( # pytype: disable=signature-mismatch # jax-ndarray # `intermediates` should be tuples tracking all instantiations of a value. # These values each have just one instantiation, hence singletons. else: - logits = self._compute_logits(params, batch) # type: jnp.ndarray # pytype: disable=annotation-type-mismatch # jax-ndarray + logits = self._compute_logits(params, batch, flax_mutables=flax_mutables) # type: jnp.ndarray # pytype: disable=annotation-type-mismatch # jax-ndarray # Purposefully don't use config.z_loss because that term is for training # stability and shouldn't affect our reported scores. @@ -896,16 +993,19 @@ def _compute_logits( batch: Mapping[str, jnp.ndarray], dropout_rng: Optional[jax.Array] = None, mutable: flax_scope.CollectionFilter = False, - other_variables: Optional[PyTree] = None, + flax_mutables: Optional[PyTreeDef] = None, ) -> jnp.ndarray: """Computes logits via a forward pass of `self.module`.""" rngs = {'dropout': dropout_rng} if dropout_rng is not None else None decoder_causal_attention = self._get_decoder_causal_attention(batch) - if other_variables is None: - other_variables = {} + if flax_mutables is None: + flax_mutables = {} return self.module.apply( - {'params': params, **other_variables}, + { + 'params': params, + **flax_mutables, + }, batch['decoder_input_tokens'], batch['decoder_target_tokens'], decoder_segment_ids=batch.get('decoder_segment_ids', None), @@ -922,6 +1022,7 @@ def _compute_logits_from_slice( decoding_state: decoding.DecodingState, params: PyTree, max_decode_length: int, + flax_mutables: Optional[PyTreeDef] = None, ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: """Token slice to logits from decoder model.""" flat_ids = decoding_state.cur_token @@ -932,7 +1033,11 @@ def _compute_logits_from_slice( # flat_cache['cache_index']: [batch] # flat_logits: [batch, seq_len=1, vocab] flat_logits, new_vars = self.module.apply( - {'params': params, 'cache': flat_cache}, + { + 'params': params, + 'cache': flat_cache, + **flax_mutables, + }, flat_ids, flat_ids, enable_dropout=False, @@ -950,6 +1055,7 @@ def score_batch( params: PyTree, batch: Mapping[str, jnp.ndarray], return_intermediates: bool = False, + flax_mutables: Optional[PyTreeDef] = None, ) -> jnp.ndarray: """Compute log likelihood score on a batch.""" @@ -962,6 +1068,7 @@ def score_batch( batch=batch, dropout_rng=None, mutable=['intermediates'], + flax_mutables=flax_mutables, ) # Inside self.module, we called nn.Module.sow to track various @@ -980,7 +1087,7 @@ def score_batch( # These values each have just one instantiation, hence singletons. else: logits = self._compute_logits( - params=params, batch=batch, dropout_rng=None + params=params, batch=batch, dropout_rng=None, flax_mutables=flax_mutables, ) token_scores = ( @@ -1008,6 +1115,7 @@ def _compute_kv_cache( params: PyTree, inputs: jnp.ndarray, causal_attention_mask: jnp.ndarray, + flax_mutables: Optional[PyTreeDef] = None, ) -> Tuple[PyTree, jnp.ndarray]: """Compute the key/value cache on the input prompt. @@ -1021,12 +1129,17 @@ def _compute_kv_cache( cache: The prefilled cache. initial_index: The index of the next position following prefill. """ + if flax_mutables is None: + flax_mutables = {} # The lengths of the inputs match the number of non-padding positions, # excluding the initial BOS. inputs_lengths = jnp.sum(inputs[:, 1:] != 0, axis=-1) _, initial_variables = self.module.apply( - {'params': params}, + { + 'params': params, + **flax_mutables, + }, jnp.ones_like(inputs), jnp.ones_like(inputs), enable_dropout=False, @@ -1064,7 +1177,11 @@ def _compute_kv_cache( ) _, variables_with_cache = self.module.apply( - {'params': params, 'cache': cache}, + { + 'params': params, + 'cache': cache, + **flax_mutables, + }, decoder_input_tokens=inputs, # Use the `decoder_causal_attention`, which has 1 for all input # positions, including the BOS token, as the targets so when the @@ -1091,6 +1208,7 @@ def predict_batch_with_aux( return_all_decodes: bool = False, num_decodes: int = 1, decoder_params: Optional[MutableMapping[str, Any]] = None, + flax_mutables: Optional[PyTreeDef] = None, ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: """Predict with prefix. @@ -1189,7 +1307,7 @@ def predict_batch_with_aux( inputs = batch['decoder_input_tokens'] * batch['decoder_causal_attention'] prefilled_cache, initial_index = self._compute_kv_cache( - params, inputs, batch['decoder_causal_attention'] + params, inputs, batch['decoder_causal_attention'], flax_mutables=flax_mutables, ) target_shape = batch['decoder_input_tokens'].shape @@ -1199,6 +1317,7 @@ def predict_batch_with_aux( self._compute_logits_from_slice, params=params, max_decode_length=max_decode_length, + flax_mutables=flax_mutables, ) if decoder_params is None: diff --git a/t5x/partitioning.py b/t5x/partitioning.py index 0b57e67fd..847dc24c5 100644 --- a/t5x/partitioning.py +++ b/t5x/partitioning.py @@ -35,6 +35,7 @@ from jax.sharding import PartitionSpec import numpy as np from t5x import train_state as train_state_lib +from t5x.te_helper import TransformerEngineHelper JaxDevice = jax.Device TpuMesh = Tuple[int, int, int, int] # (x, y, z, num_cores). @@ -621,6 +622,8 @@ def standard_logical_axis_rules( if additional_rules: rules.extend(additional_rules) + rules = TransformerEngineHelper.extend_logical_axis_rules(rules) + return rules diff --git a/t5x/te_helper.py b/t5x/te_helper.py new file mode 100644 index 000000000..fb5f48f08 --- /dev/null +++ b/t5x/te_helper.py @@ -0,0 +1,284 @@ +# Copyright 2023 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from absl import logging +from contextlib import contextmanager +import gin +import jax + +try: + from transformer_engine.common.recipe import DelayedScaling + from transformer_engine.common.recipe import Format as FP8Format + import transformer_engine.jax as te + _IS_TRANSFORMER_ENGINE_INSTALLED = True + +except ModuleNotFoundError as e: + _IS_TRANSFORMER_ENGINE_INSTALLED = False + + +def _canonicalize_fp8_format(fp8_format): + if not _IS_TRANSFORMER_ENGINE_INSTALLED: + return None + + fp8_format = fp8_format.lower() + if fp8_format in ['fp8_e4m3', 'fp8e4m3', 'e4m3']: + return FP8Format.E4M3 + if fp8_format in ['fp8_e5m2', 'fp8e5m2', 'e5m2']: + return FP8Format.E5M2 + if fp8_format in ['fp8_hybrid', 'fp8hybrid', 'hybrid']: + return FP8Format.HYBRID + raise ValueError('fp8_format must be one of [fp8_e4m3, fp8_e5m2, fp8_hybrid]' + f'but the value is {fp8_format}') + +@gin.configurable +class TransformerEngineConfig: + def __init__(self, enabled=False, fp8_format='fp8_hybrid', margin=0., amax_history_len=1024): + assert (_IS_TRANSFORMER_ENGINE_INSTALLED or (not enabled)), \ + 'Attempt to run transformer engine FP8 without installing transformer_engine.' + + self.enabled = enabled + self.fp8_format = _canonicalize_fp8_format(fp8_format) + self.margin = margin + self.amax_history_len = amax_history_len + + def __str__(self): + return f"TransformerEngineConfig: enabled:{self.enabled}," \ + f" fp8_format: {self.fp8_format}, margin: {self.margin}," \ + f" amax_history_len: {self.amax_history_len}." + + +class TransformerEngineHelperBase: + + @staticmethod + def is_fp8_enabled(): + raise NotImplementedError + + @staticmethod + @contextmanager + def fp8_autocast(te_config, dp_mesh_axis=None, tp_mesh_axis=None): + raise NotImplementedError + + @staticmethod + def extend_logical_axis_rules(rules): + raise NotImplementedError + + @staticmethod + def update_fp8_metas(grad_accum, flax_mutables): + raise NotImplementedError + + @staticmethod + def check_dataset_cfg(config): + raise NotImplementedError + + @staticmethod + def get_t5x_config(config): + raise NotImplementedError + + @staticmethod + def get_attn_mask(mask): + raise NotImplementedError + + @staticmethod + def get_encoder_layer(config, relative_embedding, name, original_cls): + raise NotImplementedError + + @staticmethod + def get_decoder_layer(config, relative_embedding, name, original_cls): + raise NotImplementedError + + +class TENotInstalledHelper(TransformerEngineHelperBase): + + @staticmethod + def is_fp8_enabled(): + return False + + @staticmethod + @contextmanager + def fp8_autocast(te_config, dp_mesh_axis=None, tp_mesh_axis=None): + try: + yield + finally: + pass + + @staticmethod + def extend_logical_axis_rules(rules): + return rules + + @staticmethod + def update_fp8_metas(grad_accum, flax_mutables): + return flax_mutables + + @staticmethod + def check_dataset_cfg(config): + pass + + @staticmethod + def get_t5x_config(config): + assert not config.transpose_batch_sequence, \ + "Only allow transpose_batch_sequence when Transformer Engine is installed." + return config + + @staticmethod + def get_attn_mask(mask): + return mask + + @staticmethod + def get_encoder_layer(config, relative_embedding, name, original_cls): + return original_cls(config=config, + relative_embedding=relative_embedding, name=name) + + @staticmethod + def get_decoder_layer(config, relative_embedding, name, original_cls): + return original_cls(config=config, + relative_embedding=relative_embedding, name=name) + + +class TEInstalledHelper(TransformerEngineHelperBase): + + @staticmethod + def is_fp8_enabled(): + return te.fp8.FP8Helper.is_fp8_enabled() + + @staticmethod + @contextmanager + def fp8_autocast(te_config, dp_mesh_axis="data", tp_mesh_axis="model"): + delay_scaling = DelayedScaling(margin=te_config.margin, + fp8_format=te_config.fp8_format, + amax_history_len=te_config.amax_history_len, + amax_compute_algo="max") + try: + with te.fp8_autocast(enabled=te_config.enabled, fp8_recipe=delay_scaling, + sharding_resource=te.ShardingResource(dp_mesh_axis, tp_mesh_axis)): + yield + finally: + pass + + @staticmethod + def extend_logical_axis_rules(rules): + # Apply fp8_autocast to correctly set sharding_resource up. + with TEInstalledHelper.fp8_autocast(TransformerEngineConfig()): + return te.extend_logical_axis_rules(rules) + + @staticmethod + def update_fp8_metas(grad_accum, flax_mutables): + update_coll = te.update_collections(grad_accum, flax_mutables) + # As the suggestion of FP8 training, updating FP8 scales as frequent as possible. + update_coll = te.update_fp8_metas(update_coll) + return update_coll + + @staticmethod + def check_dataset_cfg(config): + assert not config.pack, \ + "Transformer Engine does not support dataset.packing, please turn it off." + + @staticmethod + def get_t5x_config(config): + return config + + @staticmethod + def get_attn_mask(mask): + # Invert T5X's mask by 0->1, and 1->0 + mask_ = mask + mask_ = 1 - mask_.astype(jax.numpy.uint8) + return mask_ + + @staticmethod + def get_encoder_layer(config, relative_embedding, name, original_cls): + hidden_dropout_dims = (-3,) if config.transpose_batch_sequence else(-2,) + return te.TransformerLayer( + hidden_size=config.num_heads*config.head_dim, + mlp_hidden_size=config.mlp_dim, + layernorm_type="rmsnorm", + num_attention_heads=config.num_heads, + hidden_dropout=config.dropout_rate, + hidden_dropout_dims = hidden_dropout_dims, + attention_dropout=config.dropout_rate, + mlp_activations=config.mlp_activations, + transpose_batch_sequence=config.transpose_batch_sequence, + float32_attention_logits=config.float32_attention_logits, + scale_attn_logits=config.scale_attn_logits, + scaled_query_init=True, + fuse_qkv_params=config.fuse_qkv_params, + relative_embedding=relative_embedding, + dtype=config.dtype, layer_type=te.TransformerLayerType.ENCODER, name=name) + + @staticmethod + def get_decoder_layer(config, relative_embedding, name, original_cls): + hidden_dropout_dims = (-3,) if config.transpose_batch_sequence else(-2,) + return te.TransformerLayer( + hidden_size=config.num_heads*config.head_dim, + mlp_hidden_size=config.mlp_dim, + layernorm_type="rmsnorm", + num_attention_heads=config.num_heads, + hidden_dropout=config.dropout_rate, + hidden_dropout_dims = hidden_dropout_dims, + attention_dropout=config.dropout_rate, + mlp_activations=config.mlp_activations, + transpose_batch_sequence=config.transpose_batch_sequence, + float32_attention_logits=config.float32_attention_logits, + scale_attn_logits=config.scale_attn_logits, + scaled_query_init=True, + fuse_qkv_params=config.fuse_qkv_params, + relative_embedding=relative_embedding, + dtype=config.dtype, layer_type=te.TransformerLayerType.DECODER, name=name) + + +class TransformerEngineHelper(TransformerEngineHelperBase): + + @staticmethod + def get_helper(): + if _IS_TRANSFORMER_ENGINE_INSTALLED: + return TEInstalledHelper + return TENotInstalledHelper + + @staticmethod + def is_fp8_enabled(): + return TransformerEngineHelper.get_helper().is_fp8_enabled() + + @staticmethod + @contextmanager + def fp8_autocast(te_config, dp_mesh_axis="data", tp_mesh_axis="model"): + try: + with TransformerEngineHelper.get_helper().fp8_autocast(te_config, dp_mesh_axis, tp_mesh_axis): + yield + finally: + pass + + @staticmethod + def extend_logical_axis_rules(rules): + return TransformerEngineHelper.get_helper().extend_logical_axis_rules(rules) + + @staticmethod + def update_fp8_metas(grad_accum, flax_mutables): + return TransformerEngineHelper.get_helper().update_fp8_metas(grad_accum, flax_mutables) + + @staticmethod + def check_dataset_cfg(config): + return TransformerEngineHelper.get_helper().check_dataset_cfg(config) + + @staticmethod + def get_t5x_config(config): + return TransformerEngineHelper.get_helper().get_t5x_config(config) + + @staticmethod + def get_attn_mask(mask): + return TransformerEngineHelper.get_helper().get_attn_mask(mask) + + @staticmethod + def get_encoder_layer(config, relative_embedding, name, original_cls): + return TransformerEngineHelper.get_helper().get_encoder_layer(config, relative_embedding, name, original_cls) + + @staticmethod + def get_decoder_layer(config, relative_embedding, name, original_cls): + return TransformerEngineHelper.get_helper().get_decoder_layer(config, relative_embedding, name, original_cls) diff --git a/t5x/train.py b/t5x/train.py index 303ad2f13..e70bcfb5b 100644 --- a/t5x/train.py +++ b/t5x/train.py @@ -46,6 +46,8 @@ from t5x import train_state as train_state_lib from t5x import trainer as trainer_lib from t5x import utils +from t5x.te_helper import TransformerEngineConfig, TransformerEngineHelper +import atexit import tensorflow as tf # pylint:enable=g-import-not-at-top @@ -109,6 +111,7 @@ def train( eval_steps: int, eval_period: int, relative_steps: Optional[int] = None, + te_config_cls: Type[TransformerEngineConfig] = TransformerEngineConfig, stats_period: Optional[int] = None, random_seed: Optional[int], use_hardware_rng: bool = False, @@ -129,6 +132,7 @@ def train( Callable[[utils.DatasetConfig, models.BaseTransformerModel], None] ] = utils.verify_matching_vocabs, gc_period: int = 0, + reset_state_after: Optional[int] = None, ) -> Tuple[int, train_state_lib.TrainState]: """Train function. @@ -196,6 +200,12 @@ def train( for training_eval (e.g., {'task_average/accuracy': ['task_a', 'task_b']}). infer_eval_average_metrics: Averages the metric over the list of tasks for infer_eval (e.g., {'task_average/accuracy': ['task_a', 'task_b']}). # END + reset_state_after: Optional number of steps after which to reset the + optimizer states and fp8 metadata. In finetuning, you may not want to + keep fp8 metas due to the distribution change messing up `amax` + statistics. Ex: to reset to finetuning, set this to the number of + pretraining steps. Only triggered if the train step on restore equals + reset_state_after. Otherwise ignored Returns: The tuple of (last_step, last_train_state). @@ -206,6 +216,18 @@ def train( if use_orbax: logging.info('Checkpointing with Orbax enabled.') + te_config = te_config_cls() + logging.info(te_config) + + # Note(terry): The proper usage of the TE API is to use fp8_autocast as a + # contextmanager using the "with" statement. The reason it is done this + # way here is to avoid indenting the code. Please refer to TE documentation + # for more details. + te_ctx_mgr = TransformerEngineHelper.fp8_autocast(te_config) + # Register a hook with atexit in case exception raised + atexit.register(lambda: te_ctx_mgr.__exit__(None, None, None)) + te_ctx_mgr.__enter__() + # Each "epoch" of the training loop should be the min of the eval period, # checkpoint period or the full training. # We compute here to ensure that the eval period and checkpoint period are @@ -245,6 +267,8 @@ def train( # Initialize datasets # --------------------------------------------------------------------------- + TransformerEngineHelper.check_dataset_cfg(train_dataset_cfg) + if train_dataset_cfg.seed and not ( checkpoint_cfg.save and checkpoint_cfg.save.save_dataset ): @@ -283,6 +307,7 @@ def _verify_matching_vocabs(cfg: utils.DatasetConfig): input_types = jax.tree_map(lambda x: x.dtype, train_iter.element_spec) if train_eval_dataset_cfg: + TransformerEngineHelper.check_dataset_cfg(train_eval_dataset_cfg) _verify_matching_vocabs(train_eval_dataset_cfg) train_eval_datasets = train_eval_get_dataset_fn( train_eval_dataset_cfg, @@ -381,6 +406,38 @@ def _verify_matching_vocabs(cfg: utils.DatasetConfig): ) ) + if train_state is not None: + # Only triggered if the train step is equal to reset_state_after right after + # restore + host_step = int(utils.get_local_data(train_state.step)) # pytype: disable=attribute-error + if reset_state_after and reset_state_after == host_step: + logging.info('Resetting optimizer and fp8 states. Preserving only optimizer targets') + assert use_orbax == False, "resetting the states in the train loop is not\ + supported with orbax. If you need to do this,\ + please delete the optimizer state and fp8 metas\ + from the checkpoint folder directly " + old_step = train_state.step + train_state_initializer = train_state_initializer_cls( + optimizer_def=model.optimizer_def, + init_fn=model.get_initial_variables, + input_shapes=input_shapes, + input_types=input_types, + partitioner=partitioner) + from_scratch_state = train_state_initializer.from_scratch(init_rng) + _optimizer = from_scratch_state._optimizer.replace( + target=train_state.params) + train_state = from_scratch_state.replace(_optimizer=_optimizer) + train_state = train_state.replace_step(old_step) + + checkpoint_manager = utils.LegacyCheckpointManager( + save_cfg=checkpoint_cfg.save, + restore_cfg=valid_restore_cfg, + train_state_shape=train_state_initializer.global_train_state_shape, + partitioner=partitioner, + ds_iter=train_iter, + model_dir=model_dir, + use_gda=use_gda) + # Start warming up the input pipeline in the background. This must happen # after input pipeline checkpoints were restored. first_batch_ready = train_iter.peek_async() @@ -454,6 +511,7 @@ def _verify_matching_vocabs(cfg: utils.DatasetConfig): # Init evaluator to set up cached datasets evaluator = None if infer_eval_dataset_cfg is not None: + TransformerEngineHelper.check_dataset_cfg(infer_eval_dataset_cfg) evaluator = eval_lib.InferenceEvaluator( infer_eval_dataset_cfg=infer_eval_dataset_cfg, inference_evaluator_cls=inference_evaluator_cls, @@ -802,6 +860,7 @@ def _as_gda(spec): # the same interpreter. gc.enable() + te_ctx_mgr.__exit__(None, None, None) return host_step, trainer.train_state diff --git a/t5x/train_state.py b/t5x/train_state.py index 04a86288e..d3f091ee1 100644 --- a/t5x/train_state.py +++ b/t5x/train_state.py @@ -228,7 +228,7 @@ def as_logical_axes(self) -> 'FlaxOptimTrainState': self._optimizer, flax_partitioning.get_axis_names(self.params_axes) ), flax_mutables=flax_partitioning.get_axis_names(flax_mutables_axes), - ) + flax_mutables_axes=self.flax_mutables_axes) class InferenceState(flax.struct.PyTreeNode): diff --git a/t5x/trainer.py b/t5x/trainer.py index 965bd09a1..cebb815b7 100644 --- a/t5x/trainer.py +++ b/t5x/trainer.py @@ -44,6 +44,7 @@ from t5x import partitioning from t5x import train_state as train_state_lib from t5x import utils +from t5x.te_helper import TransformerEngineHelper import typing_extensions @@ -731,7 +732,7 @@ def accumulate_grads_microbatched( """ batch_size = next(iter(batch.values())).shape[0] - grad_fn = jax.value_and_grad(model.loss_fn, has_aux=True) + grad_fn = jax.value_and_grad(model.loss_fn, argnums=(0, 3), has_aux=True) # We assume that the model loss_fn supports flax mutables if and only if # the train state includes non-empty flax mutables. @@ -739,18 +740,13 @@ def accumulate_grads_microbatched( # them and return flax_mutables from `get_initial_variables` and `loss_fn`. initial_flax_mutables = ( - train_state.flax_mutables if train_state.flax_mutables else None + train_state.flax_mutables if train_state.flax_mutables else {} ) if num_microbatches is None or num_microbatches <= 1: - - if initial_flax_mutables is None: - (_, metrics), grad_accum = grad_fn(train_state.params, batch, dropout_rng) - flax_mutables = None - else: - (_, (metrics, flax_mutables)), grad_accum = grad_fn( - train_state.params, batch, dropout_rng, initial_flax_mutables - ) + (_, metrics), grad_accum = grad_fn(train_state.params, batch, + dropout_rng, initial_flax_mutables) + flax_mutables=initial_flax_mutables else: assert ( batch_size % num_microbatches == 0 @@ -783,14 +779,9 @@ def metrics_and_grad(loop_cnt, dropout_rng, flax_mutables=None): ), mbatch, ) - if flax_mutables is None: - (_, metrics), grad = grad_fn( - train_state.params, mbatch, sub_dropout_rng - ) - else: - (_, (metrics, flax_mutables)), grad = grad_fn( - train_state.params, mbatch, sub_dropout_rng, flax_mutables - ) + (_, metrics), grad = grad_fn(train_state.params, mbatch, + sub_dropout_rng, flax_mutables) + return metrics, grad, flax_mutables def per_microbatch_train_step( @@ -812,7 +803,9 @@ def per_microbatch_train_step( loop_cnt, dropout_rng, flax_mutables ) - grad_accum = jax.tree_util.tree_map(jnp.add, grad_accum, grad) + grad_accum[0] = jax.tree_map(jnp.add, grad_accum[0], grad[0]) + grad_accum[1] = TransformerEngineHelper.update_fp8_metas(grad[1], flax_mutables) + flax_mutables = grad_accum[1] metrics = jax.lax.cond( loop_cnt == 0, lambda _: metrics, @@ -823,9 +816,10 @@ def per_microbatch_train_step( # Initialize gradient accumulation loop state. accum_dtype = jnp.float32 - grad_accum_init = jax.tree_util.tree_map( - lambda x: jnp.zeros(x.shape, accum_dtype), train_state.params - ) + grad_accum_init = [jax.tree_util.tree_map(lambda x: jnp.zeros(x.shape, accum_dtype), + train_state.params), + jax.tree_util.tree_map(lambda x: jnp.zeros(x.shape, accum_dtype), + train_state.flax_mutables)] initial_metrics_shape, _, _ = jax.eval_shape( metrics_and_grad, loop_cnt=0, @@ -849,6 +843,9 @@ def per_microbatch_train_step( del new_dropout_rng + grad_accum = (grad_accum[0], + TransformerEngineHelper.update_fp8_metas(grad_accum[1], flax_mutables)) + return grad_accum, metrics, flax_mutables @@ -877,9 +874,12 @@ def apply_grads( """ if other_state_variables is None: other_state_variables = {} + + other_state_variables["flax_mutables"] = FrozenDict(grad_accum[1]) + # Update optimizer using accumulated gradient. new_train_state = train_state.apply_gradient( - grad_accum, learning_rate=learning_rate, **other_state_variables + grad_accum[0], learning_rate=learning_rate, **other_state_variables ) metrics["learning_rate"] = clu.metrics.Average.from_model_output( jnp.asarray([learning_rate]) From dcbbb37dea8cef5b0c2d798b57cdfd2d8600013d Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Tue, 11 Jul 2023 12:10:33 -0700 Subject: [PATCH 02/19] UNINSTALL_TE in fine-tuning scripts now defaults to no-action --- .../gpu/scripts_gpu/multiprocess_ft_frompile.sh | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/t5x/contrib/gpu/scripts_gpu/multiprocess_ft_frompile.sh b/t5x/contrib/gpu/scripts_gpu/multiprocess_ft_frompile.sh index 388d2ec46..135ecf68e 100755 --- a/t5x/contrib/gpu/scripts_gpu/multiprocess_ft_frompile.sh +++ b/t5x/contrib/gpu/scripts_gpu/multiprocess_ft_frompile.sh @@ -72,13 +72,9 @@ case $FT_TASK in ;; esac -case $UNINSTALL_TE in - 0) - ;; - *) - pip uninstall -y transformer_engine - ;; -esac +if [[ -n "${UNINSTALL_TE:-}" && ${UNINSTALL_TE:-} -ne 0 ]]; then + pip uninstall -y transformer_engine +fi # Global batch size BSIZE=$(( GPUS_PER_NODE * BSIZE_PER_GPU * SLURM_JOB_NUM_NODES / TP_SIZE)) From db6fc5519e240eb74ce39668047327da323f323a Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Wed, 12 Jul 2023 20:26:28 -0700 Subject: [PATCH 03/19] remove use_gda from LegacyCheckpointManager in train.py for fp8 --- t5x/train.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/t5x/train.py b/t5x/train.py index e70bcfb5b..0e11482d3 100644 --- a/t5x/train.py +++ b/t5x/train.py @@ -435,8 +435,7 @@ def _verify_matching_vocabs(cfg: utils.DatasetConfig): train_state_shape=train_state_initializer.global_train_state_shape, partitioner=partitioner, ds_iter=train_iter, - model_dir=model_dir, - use_gda=use_gda) + model_dir=model_dir) # Start warming up the input pipeline in the background. This must happen # after input pipeline checkpoints were restored. From a39a08ec58e47976facda6ff8e340d5d1cec1f97 Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Tue, 18 Jul 2023 15:55:01 -0700 Subject: [PATCH 04/19] Allow singlenode scripts to tee to stdout for better indication of training status --- t5x/contrib/gpu/scripts_gpu/singlenode_ft_frompile.sh | 2 +- t5x/contrib/gpu/scripts_gpu/singlenode_pretrain_pile.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/t5x/contrib/gpu/scripts_gpu/singlenode_ft_frompile.sh b/t5x/contrib/gpu/scripts_gpu/singlenode_ft_frompile.sh index c27c4a068..73e113973 100755 --- a/t5x/contrib/gpu/scripts_gpu/singlenode_ft_frompile.sh +++ b/t5x/contrib/gpu/scripts_gpu/singlenode_ft_frompile.sh @@ -90,5 +90,5 @@ python3 -u ${T5X_DIR}/t5x/train.py \ --gin.te_helper.TransformerEngineConfig.fp8_format=\"hybrid\" \ --gin.network.T5Config.transpose_batch_sequence=${TRANSPOSE_BS} \ --gin.network.T5Config.fuse_qkv_params=${FUSE_QKV} \ - &> \ + 2>&1 | tee \ ${LOG_DIR}/ft_${T5_SIZE}_gpu_${NUM_GPUS}_${PREC}_gbs_${BSIZE}_fp8_${ENABLE_FP8}_fuseqkv_${FUSE_QKV}_transbs_${TRANSPOSE_BS}.log diff --git a/t5x/contrib/gpu/scripts_gpu/singlenode_pretrain_pile.sh b/t5x/contrib/gpu/scripts_gpu/singlenode_pretrain_pile.sh index def1a1a78..0d12f305c 100755 --- a/t5x/contrib/gpu/scripts_gpu/singlenode_pretrain_pile.sh +++ b/t5x/contrib/gpu/scripts_gpu/singlenode_pretrain_pile.sh @@ -62,5 +62,5 @@ python3 -u ${T5X_DIR}/t5x/train.py \ --gin.te_helper.TransformerEngineConfig.fp8_format=\"hybrid\" \ --gin.network.T5Config.transpose_batch_sequence=${TRANSPOSE_BS} \ --gin.network.T5Config.fuse_qkv_params=${FUSE_QKV} \ - &> \ + 2>&1 | tee \ ${LOG_DIR}/${T5_SIZE}_gpu_${NUM_GPUS}_${PREC}_gbs_${BSIZE}_fp8_${ENABLE_FP8}_fuseqkv_${FUSE_QKV}_transbs_${TRANSPOSE_BS}.log From 39e637f3bb12dbe75dada4383f1d30805b946dc4 Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Fri, 14 Jul 2023 05:00:58 -0700 Subject: [PATCH 05/19] Explicit specify self_attn_mask_type --- t5x/te_helper.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/t5x/te_helper.py b/t5x/te_helper.py index fb5f48f08..f3750ca88 100644 --- a/t5x/te_helper.py +++ b/t5x/te_helper.py @@ -211,7 +211,10 @@ def get_encoder_layer(config, relative_embedding, name, original_cls): scaled_query_init=True, fuse_qkv_params=config.fuse_qkv_params, relative_embedding=relative_embedding, - dtype=config.dtype, layer_type=te.TransformerLayerType.ENCODER, name=name) + dtype=config.dtype, + layer_type=te.TransformerLayerType.ENCODER, + self_attn_mask_type='padding', + name=name) @staticmethod def get_decoder_layer(config, relative_embedding, name, original_cls): @@ -231,7 +234,10 @@ def get_decoder_layer(config, relative_embedding, name, original_cls): scaled_query_init=True, fuse_qkv_params=config.fuse_qkv_params, relative_embedding=relative_embedding, - dtype=config.dtype, layer_type=te.TransformerLayerType.DECODER, name=name) + dtype=config.dtype, + layer_type=te.TransformerLayerType.DECODER, + self_attn_mask_type='causal', + name=name) class TransformerEngineHelper(TransformerEngineHelperBase): From d016f83835aabd2966a50f37126138c885b5a3b0 Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Thu, 3 Aug 2023 14:25:22 -0700 Subject: [PATCH 06/19] Disables check for packing by the te_helper util since not all dataset configs use packing (CV/Multimodal) --- t5x/te_helper.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/t5x/te_helper.py b/t5x/te_helper.py index f3750ca88..f5857525d 100644 --- a/t5x/te_helper.py +++ b/t5x/te_helper.py @@ -179,6 +179,8 @@ def update_fp8_metas(grad_accum, flax_mutables): @staticmethod def check_dataset_cfg(config): + if not hasattr(config, 'pack'): + return assert not config.pack, \ "Transformer Engine does not support dataset.packing, please turn it off." From 83a2b20dfc6e6df59887cd44000c5f84608d0400 Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Sat, 26 Aug 2023 16:13:11 -0700 Subject: [PATCH 07/19] Corrected T5x large baselines Updated T5x-large MNLI and SQUAD baselines --- docs/usage/gpu-usage.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/usage/gpu-usage.md b/docs/usage/gpu-usage.md index a9974e1a3..660df3a84 100644 --- a/docs/usage/gpu-usage.md +++ b/docs/usage/gpu-usage.md @@ -35,7 +35,7 @@ For our Pile convergence runs, we used a Global batch size of 2304 for XXL and 2 | size | GPU | Precision | #GPUs | TP | BS / GPU | Sequences/Sec | Seq/Sec/GPU | Est. Walltime | GPU-days | MNLI 2.0 - matched | SQuAD v1.1 (EM/F1) | Convergence Log | Config | | ---- | ------------ | --------- | ----- | ----- | -------- | ------------- | ----------- | ------------- | -------- |------------------ | ------------------ | --------------- | ---- | | [T5-v1.1-small](../t5/t5_1_1/small.gin) | A100 80G SXM | bf16 | 8 | 1 | 256 | ~5712 | 714 | 4.2 days | 33 | 83.06% | 78.33 / 86.63 | [log](https://tensorboard.dev/experiment/lWnHal7PRnOLeZuewyWVxQ/#scalars&_smoothingWeight=0) | [pile](../t5/t5_1_1/examples/small_pile_pretrain.gin) -| [T5-v1.1-large](../t5/t5_1_1/large.gin) | A100 80G SXM | bf16 | 64 | 1 | 32 | ~4853 | 75.8 | 4.8 days | 309 | 90.50% | 87.31 / 94.04 | [log](https://tensorboard.dev/experiment/aOxJBIvTQBeTJ8XGXxaL6Q/#scalars&_smoothingWeight=0) |[pile](../t5/t5_1_1/examples/large_pile_pretrain.gin) +| [T5-v1.1-large](../t5/t5_1_1/large.gin) | A100 80G SXM | bf16 | 64 | 1 | 32 | ~4853 | 75.8 | 4.8 days | 309 | 89.23% | 86.12 / 93.21 | [log](https://tensorboard.dev/experiment/aOxJBIvTQBeTJ8XGXxaL6Q/#scalars&_smoothingWeight=0) |[pile](../t5/t5_1_1/examples/large_pile_pretrain.gin) | [T5-v1.1-xl](../t5/t5_1_1/xl.gin) | A100 80G SXM | bf16 | 144 | 1 | 8 | ~3021 | 21.0 | 7.9 days | 1,133 | N/A(perf test) | N/A (perf test) | |[pile](../t5/t5_1_1/examples/xl_pile_pretrain.gin) | [T5-v1.1-xl](../t5/t5_1_1/xl.gin) | A100 80G SXM | bf16 | 256 | 1 | 8 | ~4322 | 16.9 | 5.5 days | 1,408 | 91.15% | 89.36 / 95.29 | [log](https://tensorboard.dev/experiment/vuRoEYgkRgWiEtbvgxlOqw/#scalars&_smoothingWeight=0) |[pile](../t5/t5_1_1/examples/xl_pile_pretrain.gin) | [T5-v1.1-xxl](../t5/t5_1_1/xxl.gin) | A100 80G SXM | bf16 | 512 | 8 | 36 | ~1887 | 3.69 | 12.6 days | 6,431 |N/A(partial run) | N/A(partial run) | |[pile](../t5/t5_1_1/examples/xxl_pile_pretrain.gin) From 5944f07c1924783a90e1a61915a9d8f2ea7b2215 Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Fri, 8 Sep 2023 15:09:08 -0700 Subject: [PATCH 08/19] Add t5-large FP8 logs --- docs/usage/gpu-usage.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/usage/gpu-usage.md b/docs/usage/gpu-usage.md index 660df3a84..c31094da5 100644 --- a/docs/usage/gpu-usage.md +++ b/docs/usage/gpu-usage.md @@ -39,7 +39,7 @@ For our Pile convergence runs, we used a Global batch size of 2304 for XXL and 2 | [T5-v1.1-xl](../t5/t5_1_1/xl.gin) | A100 80G SXM | bf16 | 144 | 1 | 8 | ~3021 | 21.0 | 7.9 days | 1,133 | N/A(perf test) | N/A (perf test) | |[pile](../t5/t5_1_1/examples/xl_pile_pretrain.gin) | [T5-v1.1-xl](../t5/t5_1_1/xl.gin) | A100 80G SXM | bf16 | 256 | 1 | 8 | ~4322 | 16.9 | 5.5 days | 1,408 | 91.15% | 89.36 / 95.29 | [log](https://tensorboard.dev/experiment/vuRoEYgkRgWiEtbvgxlOqw/#scalars&_smoothingWeight=0) |[pile](../t5/t5_1_1/examples/xl_pile_pretrain.gin) | [T5-v1.1-xxl](../t5/t5_1_1/xxl.gin) | A100 80G SXM | bf16 | 512 | 8 | 36 | ~1887 | 3.69 | 12.6 days | 6,431 |N/A(partial run) | N/A(partial run) | |[pile](../t5/t5_1_1/examples/xxl_pile_pretrain.gin) -| [T5-v1.1-large](../t5/t5_1_1/large.gin) | **H100 80G SXM** | TE-fp8 | 64 | 1 | 32 | ~10156 | **158.7** | **2.3 days** | **147** | 89.1% | 86.36 / 93.5 | |[pile](../t5/t5_1_1/examples/large_pile_pretrain.gin) +| [T5-v1.1-large](../t5/t5_1_1/large.gin) | **H100 80G SXM** | TE-fp8 | 64 | 1 | 32 | ~10156 | **158.7** | **2.3 days** | **147** | 89.1% | 86.36 / 93.5 | [log](https://tensorboard.dev/experiment/QJYnDaaBSeuZtYPXXtAG3Q/#scalars&_smoothingWeight=0) |[pile](../t5/t5_1_1/examples/large_pile_pretrain.gin) | [T5-v1.1-xl](../t5/t5_1_1/xl.gin) | **H100 80G SXM** | TE-fp8 | 144 | 1 | 14 | ~7257 | **50.4** | **3.3 days** | **475** | N/A (perf test) | N/A (perf test) | |[pile](../t5/t5_1_1/examples/xl_pile_pretrain.gin) | [T5-v1.1-xl](../t5/t5_1_1/xl.gin) | **H100 80G SXM** | TE-fp8 | 256 | 1 | 8 | ~9688 | **37.8** | **2.4 days** | **614** | N/A (perf test) | N/A (perf test) | |[pile](../t5/t5_1_1/examples/xl_pile_pretrain.gin) From 2d2fbe8e857990a336f9d949ad7e8d82351b1633 Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Fri, 20 Oct 2023 14:26:09 +0800 Subject: [PATCH 09/19] Fix missing fp8_meta_collection in the eval stage. --- t5x/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/t5x/models.py b/t5x/models.py index 5891c08c1..e13b81091 100644 --- a/t5x/models.py +++ b/t5x/models.py @@ -756,7 +756,7 @@ def predict_batch_with_aux( decoder_prompt_inputs = jnp.zeros_like(decoder_input_tokens) encoded_inputs = self.module.apply( - {'params': params}, + {'params': params, **flax_mutables}, encoder_input_tokens, enable_dropout=False, method=self.module.encode, From 4a86f76013173545d1299feaeb0e3383e890bee9 Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Fri, 20 Oct 2023 14:48:40 +0800 Subject: [PATCH 10/19] Remove redundant code. --- t5x/models.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/t5x/models.py b/t5x/models.py index e13b81091..cc3348fad 100644 --- a/t5x/models.py +++ b/t5x/models.py @@ -726,23 +726,6 @@ def predict_batch_with_aux( encoder_input_tokens = batch['encoder_input_tokens'] decoder_input_tokens = batch['decoder_input_tokens'] - # Prepare transformer fast-decoder call for beam search: for beam search, we - # need to set up our decoder model to handle a batch size equal to - # batch_size * num_decodes, where each batch item's data is expanded - # in-place rather than tiled. - # i.e. if we denote each batch element subtensor as el[n]: - # [el0, el1, el2] --> beamsize=2 --> [el0,el0,el1,el1,el2,el2] - # [batch * num_decodes, input_len, emb_dim] - encoded_inputs = decoding.flat_batch_beam_expand( - self.module.apply( - {'params': params, **flax_mutables}, - encoder_input_tokens, - enable_dropout=False, - method=self.module.encode, - ), - num_decodes, - ) - # `decoder_prompt_inputs` is initialized from the batch's # `decoder_input_tokens`. The EOS is stripped to avoid decoding to stop # after the prompt by matching to `output_vocabulary.eos_id`. From 7b878db04ad1e424bf6ca330c3e6fc8f9e0c1c3b Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Fri, 20 Oct 2023 15:05:40 +0800 Subject: [PATCH 11/19] Fix deprecating warning about TE. --- t5x/te_helper.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/t5x/te_helper.py b/t5x/te_helper.py index f5857525d..568f59698 100644 --- a/t5x/te_helper.py +++ b/t5x/te_helper.py @@ -198,7 +198,7 @@ def get_attn_mask(mask): @staticmethod def get_encoder_layer(config, relative_embedding, name, original_cls): hidden_dropout_dims = (-3,) if config.transpose_batch_sequence else(-2,) - return te.TransformerLayer( + return te.flax.TransformerLayer( hidden_size=config.num_heads*config.head_dim, mlp_hidden_size=config.mlp_dim, layernorm_type="rmsnorm", @@ -214,14 +214,14 @@ def get_encoder_layer(config, relative_embedding, name, original_cls): fuse_qkv_params=config.fuse_qkv_params, relative_embedding=relative_embedding, dtype=config.dtype, - layer_type=te.TransformerLayerType.ENCODER, + layer_type=te.flax.TransformerLayerType.ENCODER, self_attn_mask_type='padding', name=name) @staticmethod def get_decoder_layer(config, relative_embedding, name, original_cls): hidden_dropout_dims = (-3,) if config.transpose_batch_sequence else(-2,) - return te.TransformerLayer( + return te.flax.TransformerLayer( hidden_size=config.num_heads*config.head_dim, mlp_hidden_size=config.mlp_dim, layernorm_type="rmsnorm", @@ -237,7 +237,7 @@ def get_decoder_layer(config, relative_embedding, name, original_cls): fuse_qkv_params=config.fuse_qkv_params, relative_embedding=relative_embedding, dtype=config.dtype, - layer_type=te.TransformerLayerType.DECODER, + layer_type=te.flax.TransformerLayerType.DECODER, self_attn_mask_type='causal', name=name) From 4c604770bf02d3ff5fab6b1e795c5b6074032c7e Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Fri, 27 Oct 2023 09:08:10 -0700 Subject: [PATCH 12/19] Updates TE api from te.extend_* to te.flax.extend_* (#7) Co-authored-by: NVIDIA --- t5x/te_helper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/t5x/te_helper.py b/t5x/te_helper.py index 568f59698..05c5f6b99 100644 --- a/t5x/te_helper.py +++ b/t5x/te_helper.py @@ -168,7 +168,7 @@ def fp8_autocast(te_config, dp_mesh_axis="data", tp_mesh_axis="model"): def extend_logical_axis_rules(rules): # Apply fp8_autocast to correctly set sharding_resource up. with TEInstalledHelper.fp8_autocast(TransformerEngineConfig()): - return te.extend_logical_axis_rules(rules) + return te.flax.extend_logical_axis_rules(rules) @staticmethod def update_fp8_metas(grad_accum, flax_mutables): From a3f2ab9fdf0681e10ffbcb2588274941e52700ba Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Tue, 31 Oct 2023 21:52:22 -0700 Subject: [PATCH 13/19] Adds ENABLE_TE env var and renames TEConfig.enabled -> TEConfig.enable_fp8 (#8) * Allows ENABLE_TE env var to control whether TE code path is invoked * Changes enabled -> enable_fp8 to be more consistent with PAX and avoid confusion with ENABLE_TE * Remove UNINSTALL_TE logic since it is no longer required --------- Co-authored-by: NVIDIA --- .../scripts_gpu/multiprocess_ft_frompile.sh | 6 +---- .../scripts_gpu/multiprocess_pretrain_pile.sh | 2 +- .../gpu/scripts_gpu/singlenode_ft_frompile.sh | 2 +- .../scripts_gpu/singlenode_pretrain_pile.sh | 2 +- t5x/te_helper.py | 22 +++++++++++++------ 5 files changed, 19 insertions(+), 15 deletions(-) diff --git a/t5x/contrib/gpu/scripts_gpu/multiprocess_ft_frompile.sh b/t5x/contrib/gpu/scripts_gpu/multiprocess_ft_frompile.sh index 135ecf68e..cd563ec4c 100755 --- a/t5x/contrib/gpu/scripts_gpu/multiprocess_ft_frompile.sh +++ b/t5x/contrib/gpu/scripts_gpu/multiprocess_ft_frompile.sh @@ -72,10 +72,6 @@ case $FT_TASK in ;; esac -if [[ -n "${UNINSTALL_TE:-}" && ${UNINSTALL_TE:-} -ne 0 ]]; then - pip uninstall -y transformer_engine -fi - # Global batch size BSIZE=$(( GPUS_PER_NODE * BSIZE_PER_GPU * SLURM_JOB_NUM_NODES / TP_SIZE)) export GPU_DEVICES=$(seq -s, 0 $((GPUS_PER_NODE - 1)) ) @@ -105,7 +101,7 @@ python3 -u ${T5X_DIR}/t5x/train.py \ --gin.train.eval_period=1000 \ --gin.train.gc_period=2000 \ --gin.train.te_config_cls=@te_helper.TransformerEngineConfig \ - --gin.te_helper.TransformerEngineConfig.enabled=${ENABLE_FP8} \ + --gin.te_helper.TransformerEngineConfig.enable_fp8=${ENABLE_FP8} \ --gin.te_helper.TransformerEngineConfig.fp8_format=\"hybrid\" \ --gin.network.T5Config.transpose_batch_sequence=${TRANSPOSE_BS} \ --gin.network.T5Config.fuse_qkv_params=${FUSE_QKV} \ diff --git a/t5x/contrib/gpu/scripts_gpu/multiprocess_pretrain_pile.sh b/t5x/contrib/gpu/scripts_gpu/multiprocess_pretrain_pile.sh index f807105a5..d0835403a 100755 --- a/t5x/contrib/gpu/scripts_gpu/multiprocess_pretrain_pile.sh +++ b/t5x/contrib/gpu/scripts_gpu/multiprocess_pretrain_pile.sh @@ -95,7 +95,7 @@ python3 ${T5X_DIR}/t5x/train.py \ --gin.train.eval_period=1000 \ --gin.train.gc_period=${TRAIN_STEPS} \ --gin.train.te_config_cls=@te_helper.TransformerEngineConfig \ - --gin.te_helper.TransformerEngineConfig.enabled=${ENABLE_FP8} \ + --gin.te_helper.TransformerEngineConfig.enable_fp8=${ENABLE_FP8} \ --gin.te_helper.TransformerEngineConfig.fp8_format=\"hybrid\" \ --gin.network.T5Config.transpose_batch_sequence=${TRANSPOSE_BS} \ --gin.network.T5Config.fuse_qkv_params=${FUSE_QKV} \ diff --git a/t5x/contrib/gpu/scripts_gpu/singlenode_ft_frompile.sh b/t5x/contrib/gpu/scripts_gpu/singlenode_ft_frompile.sh index 73e113973..981fb216c 100755 --- a/t5x/contrib/gpu/scripts_gpu/singlenode_ft_frompile.sh +++ b/t5x/contrib/gpu/scripts_gpu/singlenode_ft_frompile.sh @@ -86,7 +86,7 @@ python3 -u ${T5X_DIR}/t5x/train.py \ --gin.train/utils.DatasetConfig.pack=${PACK} \ --gin.train_eval/utils.DatasetConfig.pack=${PACK} \ --gin.train.te_config_cls=@te_helper.TransformerEngineConfig \ - --gin.te_helper.TransformerEngineConfig.enabled=${ENABLE_FP8} \ + --gin.te_helper.TransformerEngineConfig.enable_fp8=${ENABLE_FP8} \ --gin.te_helper.TransformerEngineConfig.fp8_format=\"hybrid\" \ --gin.network.T5Config.transpose_batch_sequence=${TRANSPOSE_BS} \ --gin.network.T5Config.fuse_qkv_params=${FUSE_QKV} \ diff --git a/t5x/contrib/gpu/scripts_gpu/singlenode_pretrain_pile.sh b/t5x/contrib/gpu/scripts_gpu/singlenode_pretrain_pile.sh index 0d12f305c..8ca54b0e8 100755 --- a/t5x/contrib/gpu/scripts_gpu/singlenode_pretrain_pile.sh +++ b/t5x/contrib/gpu/scripts_gpu/singlenode_pretrain_pile.sh @@ -58,7 +58,7 @@ python3 -u ${T5X_DIR}/t5x/train.py \ --gin.train/utils.DatasetConfig.pack=${PACK} \ --gin.train_eval/utils.DatasetConfig.pack=${PACK} \ --gin.train.te_config_cls=@te_helper.TransformerEngineConfig \ - --gin.te_helper.TransformerEngineConfig.enabled=${ENABLE_FP8} \ + --gin.te_helper.TransformerEngineConfig.enable_fp8=${ENABLE_FP8} \ --gin.te_helper.TransformerEngineConfig.fp8_format=\"hybrid\" \ --gin.network.T5Config.transpose_batch_sequence=${TRANSPOSE_BS} \ --gin.network.T5Config.fuse_qkv_params=${FUSE_QKV} \ diff --git a/t5x/te_helper.py b/t5x/te_helper.py index 05c5f6b99..7657c52be 100644 --- a/t5x/te_helper.py +++ b/t5x/te_helper.py @@ -15,16 +15,20 @@ from contextlib import contextmanager import gin import jax +import os + +logging.set_verbosity(logging.INFO) try: from transformer_engine.common.recipe import DelayedScaling from transformer_engine.common.recipe import Format as FP8Format import transformer_engine.jax as te _IS_TRANSFORMER_ENGINE_INSTALLED = True + logging.info('Transformer Engine is installed') except ModuleNotFoundError as e: _IS_TRANSFORMER_ENGINE_INSTALLED = False - + logging.info('Transformer Engine is not installed') def _canonicalize_fp8_format(fp8_format): if not _IS_TRANSFORMER_ENGINE_INSTALLED: @@ -42,17 +46,17 @@ def _canonicalize_fp8_format(fp8_format): @gin.configurable class TransformerEngineConfig: - def __init__(self, enabled=False, fp8_format='fp8_hybrid', margin=0., amax_history_len=1024): - assert (_IS_TRANSFORMER_ENGINE_INSTALLED or (not enabled)), \ + def __init__(self, enable_fp8=False, fp8_format='fp8_hybrid', margin=0., amax_history_len=1024): + assert (_IS_TRANSFORMER_ENGINE_INSTALLED or (not enable_fp8)), \ 'Attempt to run transformer engine FP8 without installing transformer_engine.' - self.enabled = enabled + self.enable_fp8 = enable_fp8 self.fp8_format = _canonicalize_fp8_format(fp8_format) self.margin = margin self.amax_history_len = amax_history_len def __str__(self): - return f"TransformerEngineConfig: enabled:{self.enabled}," \ + return f"TransformerEngineConfig: enable_fp8:{self.enable_fp8}," \ f" fp8_format: {self.fp8_format}, margin: {self.margin}," \ f" amax_history_len: {self.amax_history_len}." @@ -158,7 +162,7 @@ def fp8_autocast(te_config, dp_mesh_axis="data", tp_mesh_axis="model"): amax_history_len=te_config.amax_history_len, amax_compute_algo="max") try: - with te.fp8_autocast(enabled=te_config.enabled, fp8_recipe=delay_scaling, + with te.fp8_autocast(enabled=te_config.enable_fp8, fp8_recipe=delay_scaling, sharding_resource=te.ShardingResource(dp_mesh_axis, tp_mesh_axis)): yield finally: @@ -243,10 +247,14 @@ def get_decoder_layer(config, relative_embedding, name, original_cls): class TransformerEngineHelper(TransformerEngineHelperBase): + @staticmethod + def is_enabled_te(): + enable_te = bool(int((os.environ.get("ENABLE_TE", False)))) + return (_IS_TRANSFORMER_ENGINE_INSTALLED and enable_te) @staticmethod def get_helper(): - if _IS_TRANSFORMER_ENGINE_INSTALLED: + if TransformerEngineHelper.is_enabled_te(): return TEInstalledHelper return TENotInstalledHelper From 4abe3e591574b0b2e5bbcdcd5e96df4f8d4367de Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Tue, 7 Nov 2023 10:57:34 +0800 Subject: [PATCH 14/19] Adapting to TE/JAX/Custom_partitioning. --- t5x/te_helper.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/t5x/te_helper.py b/t5x/te_helper.py index 7657c52be..b064d2b25 100644 --- a/t5x/te_helper.py +++ b/t5x/te_helper.py @@ -163,7 +163,8 @@ def fp8_autocast(te_config, dp_mesh_axis="data", tp_mesh_axis="model"): amax_compute_algo="max") try: with te.fp8_autocast(enabled=te_config.enable_fp8, fp8_recipe=delay_scaling, - sharding_resource=te.ShardingResource(dp_mesh_axis, tp_mesh_axis)): + mesh_resource=te.MeshResource(dp_resource=dp_mesh_axis, + tp_resource=tp_mesh_axis)): yield finally: pass @@ -177,8 +178,6 @@ def extend_logical_axis_rules(rules): @staticmethod def update_fp8_metas(grad_accum, flax_mutables): update_coll = te.update_collections(grad_accum, flax_mutables) - # As the suggestion of FP8 training, updating FP8 scales as frequent as possible. - update_coll = te.update_fp8_metas(update_coll) return update_coll @staticmethod From bfa6313ce0bc9b67f6d1b48fca460e1bfea16670 Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Wed, 22 Nov 2023 13:53:49 +0800 Subject: [PATCH 15/19] Running Partitioner.compile within Mesh context-manager --- t5x/partitioning.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/t5x/partitioning.py b/t5x/partitioning.py index 847dc24c5..63873cf7a 100644 --- a/t5x/partitioning.py +++ b/t5x/partitioning.py @@ -960,6 +960,13 @@ def lower(self, *args, **kwargs): self._logical_axis_rules): return self._pjitted_fn.lower(*args, **kwargs) + def lower_and_compile(self, *args, **kwargs): + with Mesh(self._mesh.devices, + self._mesh.axis_names), flax_partitioning.axis_rules( + self._logical_axis_rules): + return self._pjitted_fn.lower(*args, **kwargs).compile() + + class BasePjitPartitioner(BasePartitioner): """Partitioner that uses T5X version of jax.pjit.""" @@ -998,7 +1005,7 @@ def partition( def compile(self, partitioned_fn: PjittedFnWithContext, *args) -> CompiledPartitionedCallable: - return partitioned_fn.lower(*args).compile() + return partitioned_fn.lower_and_compile(*args) class PjitPartitioner(BasePjitPartitioner): From b4dbfdea8313b0f098d7583d80dee4eea6180cdb Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Tue, 14 Nov 2023 23:42:35 -0800 Subject: [PATCH 16/19] Updates multiprocessing scripts to use SLURM output variables instead of input variables (#9) * Update multiprocess scripts * No longer need UNINSTALL_TE * Removes slurm scripts as the source of truth has moved to rosetta * Adds "Finished" message to multiprocess scripts * Remove BENCHMARK_ARGS which is no longer used * Fix typo in BENCHMARK_MODE and straggling if keyword * Address comments --- .../scripts_gpu/example_slurm_ft_frompile.sub | 98 ------------------- .../example_slurm_pretrain_pile.sub | 92 ----------------- .../scripts_gpu/multiprocess_ft_frompile.sh | 67 +++++++++---- .../scripts_gpu/multiprocess_pretrain_pile.sh | 70 ++++++++----- 4 files changed, 91 insertions(+), 236 deletions(-) delete mode 100755 t5x/contrib/gpu/scripts_gpu/example_slurm_ft_frompile.sub delete mode 100755 t5x/contrib/gpu/scripts_gpu/example_slurm_pretrain_pile.sub diff --git a/t5x/contrib/gpu/scripts_gpu/example_slurm_ft_frompile.sub b/t5x/contrib/gpu/scripts_gpu/example_slurm_ft_frompile.sub deleted file mode 100755 index 19ec14ee9..000000000 --- a/t5x/contrib/gpu/scripts_gpu/example_slurm_ft_frompile.sub +++ /dev/null @@ -1,98 +0,0 @@ -#!/bin/bash -#SBATCH -A example # slurm account -#SBATCH -p partition # slurm partition name -#SBATCH -N 1 # number of nodes -#SBATCH -t 04:00:00 # wall time -#SBATCH -J "t5x:train" # slurm job name -#SBATCH --exclusive # exclusive node access -#SBATCH --mem=0 # all mem avail -#SBATCH --mail-type=FAIL # only send email on failure -#SBATCH --overcommit -#SBATCH --dependency=singleton # tells slurm to run only one job with the same job name at a time -set -x - -# Copyright 2022 The T5X Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# File system and volume glue code -#------------------------------------------------------------------------------- -# << CHANGE ! >> -SLURM_ACCOUNT='example' -USERID='exampleperson' - -# << CHANGE ! >> -CONTAINER="" # Add link to your built container - -# << CHANGE ! >> -BASE_T5X_DIR="...../t5x_git" # path to your clone of the repo -BASE_TFDS_DATA_DIR="" # path to tfds data directory -BASE_T5X_WORKSPACE_DIR="${BASE_T5X_DIR}/workspace" # path to where outputs will be dumped - -# Default env variables for paths required by t5x training scripts -TFDS_DATA_DIR=/t5x_home/datasets/ -T5X_DIR=/t5x_home/ -T5X_WORKSPACE_DIR=/t5x_home/workspace - -# Add the T5x/JAX specific mounts -MOUNTS="--container-mounts=$BASE_T5X_DIR:/$T5X_DIR,$BASE_TFDS_DATA_DIR:/$TFDS_DATA_DIR,$BASE_T5X_WORKSPACE_DIR:$T5X_WORKSPACE_DIR" - -# Add T5x/JAX specific exports -EXPORTS="--export=ALL,TFDS_DATA_DIR=${TFDS_DATA_DIR},T5X_DIR=${T5X_DIR},T5X_WORKSPACE_DIR=${T5X_WORKSPACE_DIR},PYTHONPATH=${T5X_DIR}" -#------------------------------------------------------------------------------- - -FT_TASK=${FT_TASK:-mnli2} -PREC=${PREC:="bfloat16"} -T5_SIZE=${T5_SIZE:="large"} -BSIZE_PER_GPU=${BSIZE_PER_GPU:=32} -ENC_SL=${ENC_SL:=512} -DEC_SL=${DEC_SL:=128} -NUM_MICROBATCHES=${NUM_MICROBATCHES:=1} -ENABLE_FP8=${ENABLE_FP8:=1} -TP_SIZE=${TP_SIZE:=1} -TRANSPOSE_BS=${TRANSPOSE_BS:=1} -MODEL_DIR=${MODEL_DIR:=model_dir} -FUSE_QKV=${FUSE_QKV:=1} -PACK=${PACK:=0} - -export GPUS_PER_NODE=${1:-8} -export BASE_SCRIPT=${2:-"${T5X_DIR}/t5x/contrib/gpu/scripts_gpu/multiprocess_ft_frompile.sh"} -export WITH_MP=1 - -NUM_GPUS=$((GPUS_PER_NODE * SLURM_JOB_NUM_NODES )) - -# redirect both stdout and stderr in the same file for ease of analysis -OUTDIR="outputs/multinode/${TASK}_t5_${T5_SIZE}-prec_${PREC}-nodes_${SLURM_JOB_NUM_NODES}-gpus_${NUM_GPUS}-bs_${BSIZE_PER_GPU}-sl_${SL}" - -OUTFILE="${BASE_T5X_WORKSPACE_DIR}/${OUTDIR}/output-%j-%n.txt" - -LOGDIR="${T5X_WORKSPACE_DIR}/${OUTDIR}" - -# << CHANGE ! >> -# You can add binding to the command below with the following line (after nvidia-smi). Remove the '&&' on the next bash line. -# && bash <>/bind.sh --cpu=exclusive --ib=single -- \ -read -r -d '' cmd <> -SLURM_ACCOUNT='example' -USERID='exampleperson' - -# << CHANGE ! >> -CONTAINER="" # Add link to your built container - -# << CHANGE ! >> -BASE_T5X_DIR="...../t5x_git" # path to your clone of the repo -BASE_TFDS_DATA_DIR="" # path to tfds data directory -BASE_T5X_WORKSPACE_DIR="${BASE_T5X_DIR}/workspace" # path to where outputs will be dumped - -# Default env variables for paths required by t5x training scripts -TFDS_DATA_DIR=/t5x_home/datasets/ -T5X_DIR=/t5x_home/ -T5X_WORKSPACE_DIR=/t5x_home/workspace - -# Add the T5x/JAX specific mounts -MOUNTS="--container-mounts=$BASE_T5X_DIR:/$T5X_DIR,$BASE_TFDS_DATA_DIR:/$TFDS_DATA_DIR,$BASE_T5X_WORKSPACE_DIR:$T5X_WORKSPACE_DIR" - -# Add T5x/JAX specific exports -EXPORTS="--export=ALL,TFDS_DATA_DIR=${TFDS_DATA_DIR},T5X_DIR=${T5X_DIR},T5X_WORKSPACE_DIR=${T5X_WORKSPACE_DIR},PYTHONPATH=${T5X_DIR}" -#------------------------------------------------------------------------------- - -# Command line arguments needed by the underlying scripts -PREC=${PREC:="bfloat16"} -T5_SIZE=${T5_SIZE:="large"} -BSIZE_PER_GPU=${BSIZE_PER_GPU:=32} -ENC_SL=${ENC_SL:=512} -DEC_SL=${DEC_SL:=128} -TRAIN_STEPS=${TRAIN_STEPS:=500} -NUM_MICROBATCHES=${NUM_MICROBATCHES:=1} -ENABLE_FP8=${ENABLE_FP8:=1} # Uses TransformerEngine FP8 -TP_SIZE=${TP_SIZE:=1} -TRANSPOSE_BS=${TRANSPOSE_BS:=1} # An optimization for GPUs -MODEL_DIR=${MODEL_DIR} -FUSE_QKV=${FUSE_QKV:=1} # Used with TransformerEngine -PACK=${PACK:=0} # Not supported with TransformerEngine - -export GPUS_PER_NODE=${1:-8} -export BASE_SCRIPT=${2:-"${T5X_DIR}/t5x/contrib/gpu/scripts_gpu/multiprocess_pretrain_pile.sh"} -export WITH_MP=1 - -NUM_GPUS=$(( GPUS_PER_NODE * SLURM_JOB_NUM_NODES )) - -# << CHANGE ! >> -# You can add binding to the command below with the following line (after nvidia-smi). Remove the '&&' on the next bash line. -# && bash <>/bind.sh --cpu=exclusive --ib=single -- \ -read -r -d '' cmd < Date: Wed, 6 Dec 2023 14:56:45 -0800 Subject: [PATCH 17/19] Force initial flax mutables to be a frozen dict (#11) --- t5x/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/t5x/trainer.py b/t5x/trainer.py index cebb815b7..8910ce81a 100644 --- a/t5x/trainer.py +++ b/t5x/trainer.py @@ -740,7 +740,7 @@ def accumulate_grads_microbatched( # them and return flax_mutables from `get_initial_variables` and `loss_fn`. initial_flax_mutables = ( - train_state.flax_mutables if train_state.flax_mutables else {} + train_state.flax_mutables if train_state.flax_mutables else FrozenDict({}) ) if num_microbatches is None or num_microbatches <= 1: From 06be7c2e50535630bebe023f6bca922ccfe93448 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 28 Dec 2023 07:24:33 -0800 Subject: [PATCH 18/19] update rng dtype in predict_batch --- t5x/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/t5x/models.py b/t5x/models.py index cc3348fad..eb7bd376f 100644 --- a/t5x/models.py +++ b/t5x/models.py @@ -638,7 +638,7 @@ def _compute_kv_cache( def predict_batch(self, params: PyTreeDef, batch: Mapping[str, jnp.ndarray], - rng: Optional[jax.random.KeyArray] = None, + rng: Optional[jax.Array] = None, flax_mutables: Optional[PyTreeDef] = None) -> jnp.ndarray: """Predicts a batch of outputs from the model. From 339b03461fce0caca5e838139289888086c46d15 Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Fri, 22 Mar 2024 23:49:37 -0700 Subject: [PATCH 19/19] Change decoder attn mask type to padding_causal Signed-off-by: Reese Wang --- t5x/te_helper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/t5x/te_helper.py b/t5x/te_helper.py index b064d2b25..9410aa479 100644 --- a/t5x/te_helper.py +++ b/t5x/te_helper.py @@ -241,7 +241,7 @@ def get_decoder_layer(config, relative_embedding, name, original_cls): relative_embedding=relative_embedding, dtype=config.dtype, layer_type=te.flax.TransformerLayerType.DECODER, - self_attn_mask_type='causal', + self_attn_mask_type='padding_causal', name=name)