-
Notifications
You must be signed in to change notification settings - Fork 297
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Reorganize documentation with existing content
- Loading branch information
Showing
16 changed files
with
153 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# Per-device batch size | ||
|
||
The value of the `per_device_batch_size` parameter dictates the amount of | ||
training data fed into the chip. This can be of decimal value between 0 and 1. | ||
Changing the value of per_device_batch_size can improve the MFU for your | ||
training run. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Checkpointing | ||
|
||
Maxtext provides the ability to run training with following checkpointing options: | ||
|
||
- enabled/disabled | ||
- asynchronous - true/false | ||
- checkpointing frequency | ||
|
||
They are dictated by the following parameters: | ||
|
||
- `Enable_checkpointing` (`True`/`False`) | ||
- `Checkpoint_period` (integer value) | ||
- `Async_checkpointing` (`True`/`False`) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Codebase Walkthrough | ||
|
||
MaxText is purely written in JAX and python. Below are some folders and files | ||
that show a high-level organization of the code and some key files. | ||
|
||
File/Folder | Description | ||
---------|--------------------------------- | ||
`configs` | Folder contains all the config file, including model configs (llama2, mistral etc) , and pre-optimized configs for different model size on different TPUs | ||
`input_pipelines` | Input training data related code | ||
`layers` | Model layer implementation | ||
`end_to_end` | Example scripts to run Maxtext | ||
`Maxtext/train.py` | The main training script you will run directly | ||
`Maxtext/config/base.yaml` | The base configuration file containing all the related info: checkpointing, model arch, sharding schema, data input, learning rate, profile, compilation, decode | ||
`Maxtext/decode.py` | This is a script to run offline inference with a sample prompt | ||
`setup.sh`| Bash script used to install all needed library dependencies. | ||
|
||
## Training configuration | ||
|
||
The [MaxText/configs/base.yaml](https://github.com/AI-Hypercomputer/maxtext/blob/main/MaxText/configs/base.yml) | ||
has a set of default configurations. These can be overridden directly via CLI | ||
when invoking the MaxText train scripts. The command line parameters overwrite | ||
the default values. A few of the key parameters are described below: | ||
|
||
- `load_parameters_path`: maxtext checkpoint path. | ||
- `base_output_directory`: Base path to save the outputs (logs and data). | ||
- [`dataset_type`](https://github.com/AI-Hypercomputer/maxtext/blob/main/MaxText/configs/base.yml#L273): | ||
synthetic, tfds, grain or hf (hugging face) | ||
- `dataset_path`: for `dataset_type=tfds`, path to the dataset. | ||
- `tokenizer_path`: Path to a tokenizer for the model. The tokenizers are | ||
present in ... | ||
- `quantization`: Whether to use quantized training with AQT. Valid values are ['int8'] | ||
- `per_device_batch_size`: How many batches each TPU/device receives. To improve | ||
the MFU, you can increase this value. This can also be a fraction. For this | ||
tutorial, we will use the default value of 1. | ||
- `enable_checkpointing`: Boolean value. Whether we want to generate a checkpoint. | ||
- `checkpoint_period`: After how many steps should checkpointing be performed. | ||
- `async_checkpointing`: Accepts a boolean value to set whether to use | ||
asynchronous checkpointing. Here, we set it to False. | ||
- `attention`: On TPUv3 and earlier, we need to set the attention to | ||
`dot_product`. Newer versions support the flash attention value. On GPU use | ||
`cudnn_flash_te`. | ||
- `steps`: Number of steps to train. For this tutorial, we will use a small | ||
value of 10 steps. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,11 @@ | ||
# Data Loading | ||
# How to load the data | ||
|
||
Maxtext supports input data pipelines in the following ways: | ||
Tf.data* | ||
Grain | ||
Hugging Face Datasets | ||
|
||
*Tf.data is the most performant way of loading large scale datasets. | ||
- Tf.data[^1] | ||
- Grain | ||
- Hugging Face Datasets | ||
|
||
You can read more about the pipelines in [](getting_started/Data_Input_Pipeline.md). | ||
[^1]: Tf.data is the most performant way of loading large scale datasets. | ||
|
||
You can read more about the pipelines in [](getting_started/Data_Input_Pipeline.md). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Full Finetuninhg LLama2/LLama3 Optimized configuration | ||
|
||
## Parameters to achieve high MFU | ||
|
||
This page is in progress. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Getting started with GCE/GKE+XPK | ||
|
||
This page is in progress. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Inference (JetStream) | ||
|
||
This page is in progress. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Profiling and Pre-training: Xplane and Tensorboard | ||
|
||
This page is in progress. |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Remat Policy and Host Offloading | ||
|
||
For large-scale model training, accelerator memory is a limited resource and we | ||
often make trade-offs such as activation re-materialization to trade off compute | ||
cycles for accelerator memory resources. Host offload is another technique we | ||
recently introduced in the XLA compiler to leverage host DRAM to offload | ||
activations computed during the forward pass and reuse them during the backward | ||
pass for gradient computation; this saves activation recomputation cycles. | ||
|
||
Maxtext provides a parameter called `remat_policy`. This parameter allows | ||
offloading activation memory to host, HBM or recomputing on backward pass. | ||
|
||
Activations in the forward pass are also needed in the backward pass. There are | ||
three options for where in memory these activations are accessible for the | ||
backward pass: | ||
|
||
1. In HBM (MaxText remat policy "minimal") | ||
2. On host (MaxText remat policy "minimal_offloaded") | ||
3. Activations are re-computed during the backward pass (MaxText remat policy "full") | ||
|
||
We can choose different remat policies for different activations (e.g. the FF | ||
activations versus the QKV proj activations), which allows us to optimize memory | ||
usage vs compute trade-offs: Generally we want to use all of our HBM. Both host | ||
offloading (option 2) and re-computing (Aka remat, option 3), use as little HBM | ||
as possible - which is faster depends on model sizes, device compute speed and | ||
host to device speed. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# Sharding | ||
|
||
Maxtext supports the following sharding mechanisms: | ||
|
||
- Distributed Data Parallelism | ||
- Tensor Parallelism | ||
- Fully Sharded Data Parallel | ||
- Sequence Parallel | ||
|
||
They are covered in the following parameters. These are the default values from base.yml. Use the following sharding parameters for setting on a single TPU Slice or a GPU Slice. | ||
|
||
``` | ||
ici_data_parallelism: 1 | ||
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded | ||
ici_fsdp_transpose_parallelism: 1 | ||
ici_sequence_parallelism: 1 | ||
ici_tensor_parallelism: 1 | ||
``` | ||
|
||
Following sharding values dictate how training will happen across multiple TPU Pods. | ||
|
||
``` | ||
dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded | ||
dcn_fsdp_parallelism: 1 | ||
dcn_fsdp_transpose_parallelism: 1 | ||
dcn_sequence_parallelism: 1 # never recommended | ||
dcn_tensor_parallelism: 1 # never recommended | ||
dcn_pipeline_parallelism: 1 | ||
dcn_expert_parallelism: 1 | ||
dcn_autoregressive_parallelism: 1 # never recommended | ||
``` |