diff --git a/README.md b/README.md index 13097d7dd..fcf99e4ff 100644 --- a/README.md +++ b/README.md @@ -157,7 +157,7 @@ trainer: tags: [ "openwebtext", "gpt2"] mp: p=f32,c=bfloat16 - model_axis_size: 1 + model_ici_axis_size: 1 per_device_parallelism: 4 train_batch_size: 512 diff --git a/config/backpack.yaml b/config/backpack.yaml index 735d40c01..2300c1e44 100644 --- a/config/backpack.yaml +++ b/config/backpack.yaml @@ -18,7 +18,7 @@ trainer: num_train_steps: 50000 train_batch_size: 1024 - model_axis_size: 1 + model_ici_axis_size: 1 optimizer: learning_rate: 6E-4 diff --git a/config/backpack_nano.yaml b/config/backpack_nano.yaml index 7bcc8ab6f..e068d94cb 100644 --- a/config/backpack_nano.yaml +++ b/config/backpack_nano.yaml @@ -15,7 +15,7 @@ trainer: num_train_steps: 100 train_batch_size: 32 - model_axis_size: 1 + model_ici_axis_size: 1 optimizer: learning_rate: 6E-4 diff --git a/config/config-markweb-web_comparison-owt_1b.yaml b/config/config-markweb-web_comparison-owt_1b.yaml new file mode 100644 index 000000000..376861edf --- /dev/null +++ b/config/config-markweb-web_comparison-owt_1b.yaml @@ -0,0 +1,106 @@ +data: + cache_dir: "gs://levanter-data/tokenized/markweb_llama/" + tokenizer: "meta-llama/Llama-2-7b-hf" + configs: + openwebtext: + train_urls: + - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz" + validation_urls: + - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz" + # these are just for eval + "paloma/4chan": + validation_urls: + - gs://levanter-data/paloma/4chan_meta_sep/val/val*.jsonl.gz + "paloma/c4_100_domains": + validation_urls: + - gs://levanter-data/paloma/c4_100_domains/val/val*.jsonl.gz + "paloma/c4_en": + validation_urls: + - gs://levanter-data/paloma/c4_en/val/val*.jsonl.gz + "paloma/dolma-v1_5": + validation_urls: + - gs://levanter-data/paloma/dolma-v1_5/val/val*.jsonl.gz + "paloma/dolma_100_programing_languages": + validation_urls: + - gs://levanter-data/paloma/dolma_100_programing_languages/val/val*.jsonl.gz + "paloma/dolma_100_subreddits": + validation_urls: + - gs://levanter-data/paloma/dolma_100_subreddits/val/val*.jsonl.gz + "paloma/falcon-refinedweb": + validation_urls: + - gs://levanter-data/paloma/falcon-refinedweb/val/val*.jsonl.gz + "paloma/gab": + validation_urls: + - gs://levanter-data/paloma/gab/val/val*.jsonl.gz + "paloma/m2d2_s2orc_unsplit": + validation_urls: + - gs://levanter-data/paloma/m2d2_s2orc_unsplit/val/val*.jsonl.gz + "paloma/m2d2_wikipedia_unsplit": + validation_urls: + - gs://levanter-data/paloma/m2d2_wikipedia_unsplit/val/val*.jsonl.gz + "paloma/manosphere_meta_sep": + validation_urls: + - gs://levanter-data/paloma/manosphere_meta_sep/val/val*.jsonl.gz + "paloma/mc4": + validation_urls: + - gs://levanter-data/paloma/mc4/val/val*.jsonl.gz + "paloma/ptb": + validation_urls: + - gs://levanter-data/paloma/ptb/val/val*.jsonl.gz + "paloma/redpajama": + validation_urls: + - gs://levanter-data/paloma/redpajama/val/val*.jsonl.gz + "paloma/twitterAAE_HELM_fixed": + validation_urls: + - gs://levanter-data/paloma/twitterAAE_HELM_fixed/val/val*.jsonl.gz + "paloma/wikitext_103": + validation_urls: + - gs://levanter-data/paloma/wikitext_103/val/val*.jsonl.gz + + train_weights: + openwebtext: 1.0 + paloma/4chan: 0.0 + paloma/c4_100_domains: 0.0 + paloma/c4_en: 0.0 + paloma/dolma-v1_5: 0.0 + paloma/dolma_100_programing_languages: 0.0 + paloma/dolma_100_subreddits: 0.0 + paloma/falcon-refinedweb: 0.0 + paloma/gab: 0.0 + paloma/m2d2_s2orc_unsplit: 0.0 + paloma/m2d2_wikipedia_unsplit: 0.0 + paloma/manosphere_meta_sep: 0.0 + paloma/mc4: 0.0 + paloma/ptb: 0.0 + paloma/redpajama: 0.0 + paloma/twitterAAE_HELM_fixed: 0.0 + paloma/wikitext_103: 0.0 +model: + # 1B class model + type: llama + seq_len: 2048 + hidden_dim: 2048 + intermediate_dim: 4096 + num_layers: 24 + num_heads: 32 + num_kv_heads: 32 + use_flash_attention: True + flash_attention_block_size: 2048 +trainer: + tracker: + type: wandb + project: "markweb" + tags: ["owt", "llama", "web_comparison"] + + mp: p=f32,c=bfloat16 + train_batch_size: 512 + num_train_steps: 50000 + steps_per_eval: 1000 + per_device_eval_parallelism: 64 + tensor_parallel_axes: ["mlp", "heads"] + fsdp_axis: "embed" + batch_axis: "batch" +optimizer: + learning_rate: 2E-4 + weight_decay: 0.1 + min_lr_ratio: 0.1 diff --git a/config/gpt2_1536.yaml b/config/gpt2_1536.yaml index bbce6e1f6..bc0fcf0cb 100644 --- a/config/gpt2_1536.yaml +++ b/config/gpt2_1536.yaml @@ -13,7 +13,7 @@ trainer: tags: [ "openwebtext", "gpt2"] mp: p=f32,c=bfloat16 - model_axis_size: 1 + model_ici_axis_size: 1 per_device_eval_parallelism: 8 optimizer: learning_rate: 1E-4 diff --git a/config/gpt2_1536_sophiah.yaml b/config/gpt2_1536_sophiah.yaml index 4c6ad2b18..37042954a 100644 --- a/config/gpt2_1536_sophiah.yaml +++ b/config/gpt2_1536_sophiah.yaml @@ -19,7 +19,7 @@ trainer: tags: [ "openwebtext", "gpt2"] mp: p=f32,c=bfloat16 - model_axis_size: 1 + model_ici_axis_size: 1 optimizer: type: sophia-h learning_rate: 2E-4 diff --git a/config/gpt2_7b.yaml b/config/gpt2_7b.yaml index 36a3d4fd2..1f90efde3 100644 --- a/config/gpt2_7b.yaml +++ b/config/gpt2_7b.yaml @@ -17,7 +17,7 @@ trainer: mp: p=f32,c=bfloat16 - model_axis_size: 1 + model_ici_axis_size: 1 per_device_parallelism: -1 per_device_eval_parallelism: -1 diff --git a/config/gpt2_large.yaml b/config/gpt2_large.yaml index 8a8aea8d7..69ce01852 100644 --- a/config/gpt2_large.yaml +++ b/config/gpt2_large.yaml @@ -13,7 +13,7 @@ trainer: tags: [ "openwebtext", "gpt2"] mp: p=f32,c=bfloat16 - model_axis_size: 1 + model_ici_axis_size: 1 per_device_parallelism: -1 optimizer: learning_rate: 2E-4 diff --git a/config/gpt2_medium.yaml b/config/gpt2_medium.yaml index 2451153ac..f43891fec 100644 --- a/config/gpt2_medium.yaml +++ b/config/gpt2_medium.yaml @@ -13,7 +13,7 @@ trainer: tags: [ "openwebtext", "gpt2"] mp: p=f32,c=bfloat16 - model_axis_size: 1 + model_ici_axis_size: 1 optimizer: learning_rate: 3E-4 weight_decay: 0.1 diff --git a/config/gpt2_small.yaml b/config/gpt2_small.yaml index b3e0295af..c83f0806e 100644 --- a/config/gpt2_small.yaml +++ b/config/gpt2_small.yaml @@ -13,7 +13,7 @@ trainer: tags: [ "openwebtext", "gpt2"] mp: p=f32,c=bfloat16 - model_axis_size: 1 + model_ici_axis_size: 1 per_device_parallelism: -1 train_batch_size: 512 diff --git a/config/gpt2_small_fast.yaml b/config/gpt2_small_fast.yaml index 6242a37bc..56356478d 100644 --- a/config/gpt2_small_fast.yaml +++ b/config/gpt2_small_fast.yaml @@ -14,7 +14,7 @@ trainer: tags: [ "openwebtext", "gpt2", "itest"] mp: p=f32,c=bfloat16 - model_axis_size: 1 + model_ici_axis_size: 1 per_device_parallelism: -1 train_batch_size: 256 diff --git a/config/gpt2_small_fast_fp8.yaml b/config/gpt2_small_fast_fp8.yaml index 2d195a05b..5905be6d5 100644 --- a/config/gpt2_small_fast_fp8.yaml +++ b/config/gpt2_small_fast_fp8.yaml @@ -15,7 +15,7 @@ trainer: mp: p=f32,c=bfloat16 fp8: true - model_axis_size: 1 + model_ici_axis_size: 1 per_device_parallelism: -1 train_batch_size: 256 diff --git a/config/gpt2_small_fast_mix.yaml b/config/gpt2_small_fast_mix.yaml index deb2fd7c0..47f2ad52e 100644 --- a/config/gpt2_small_fast_mix.yaml +++ b/config/gpt2_small_fast_mix.yaml @@ -26,7 +26,7 @@ trainer: tags: [ "openwebtext+wiki", "gpt2", "itest"] mp: p=f32,c=bfloat16 - model_axis_size: 1 + model_ici_axis_size: 1 train_batch_size: 256 num_train_steps: 20000 diff --git a/config/gpt2_small_fast_pile.yaml b/config/gpt2_small_fast_pile.yaml index 3a21732a7..be037662b 100644 --- a/config/gpt2_small_fast_pile.yaml +++ b/config/gpt2_small_fast_pile.yaml @@ -13,7 +13,7 @@ trainer: tags: [ "pile", "gpt2", "itest"] mp: p=f32,c=bfloat16 - model_axis_size: 1 + model_ici_axis_size: 1 train_batch_size: 256 num_train_steps: 20000 diff --git a/config/gpt2_small_fast_public.yaml b/config/gpt2_small_fast_public.yaml index 1b466ef91..79a351400 100644 --- a/config/gpt2_small_fast_public.yaml +++ b/config/gpt2_small_fast_public.yaml @@ -20,7 +20,7 @@ trainer: tags: [ "openwebtext", "gpt2", "itest"] mp: p=f32,c=bfloat16 - model_axis_size: 1 + model_ici_axis_size: 1 per_device_parallelism: -1 train_batch_size: 256 diff --git a/config/gpt2_small_fast_sophia_h.yaml b/config/gpt2_small_fast_sophia_h.yaml index 0037664f1..68e329da8 100644 --- a/config/gpt2_small_fast_sophia_h.yaml +++ b/config/gpt2_small_fast_sophia_h.yaml @@ -13,7 +13,7 @@ trainer: tags: [ "openwebtext", "gpt2", "itest", "sophia-h"] mp: p=f32,c=bfloat16 - model_axis_size: 1 + model_ici_axis_size: 1 train_batch_size: 256 num_train_steps: 20000 diff --git a/config/gpt2_small_fast_sophiah.yaml b/config/gpt2_small_fast_sophiah.yaml index 71675312c..17aa1c70b 100644 --- a/config/gpt2_small_fast_sophiah.yaml +++ b/config/gpt2_small_fast_sophiah.yaml @@ -13,7 +13,7 @@ trainer: tags: [ "openwebtext", "gpt2", "itest"] mp: p=f32,c=bfloat16 - model_axis_size: 1 + model_ici_axis_size: 1 per_device_parallelism: -1 train_batch_size: 256 diff --git a/config/gpt2_small_fast_wiki.yaml b/config/gpt2_small_fast_wiki.yaml index a25736434..a786ad48b 100644 --- a/config/gpt2_small_fast_wiki.yaml +++ b/config/gpt2_small_fast_wiki.yaml @@ -14,7 +14,7 @@ trainer: tags: [ "openwebtext", "gpt2", "itest"] mp: p=f32,c=bfloat16 - model_axis_size: 1 + model_ici_axis_size: 1 per_device_parallelism: -1 train_batch_size: 256 diff --git a/config/gpt2_small_pile.yaml b/config/gpt2_small_pile.yaml index 07aeb24ee..9638ad07c 100644 --- a/config/gpt2_small_pile.yaml +++ b/config/gpt2_small_pile.yaml @@ -13,7 +13,7 @@ trainer: tags: [ "pile", "gpt2"] mp: p=f32,c=bfloat16 - model_axis_size: 1 + model_ici_axis_size: 1 train_batch_size: 256 num_train_steps: 50000 diff --git a/config/gpt2_small_pile_mixture.yaml b/config/gpt2_small_pile_mixture.yaml index c6c5338cd..207fff75c 100644 --- a/config/gpt2_small_pile_mixture.yaml +++ b/config/gpt2_small_pile_mixture.yaml @@ -13,7 +13,7 @@ trainer: tags: [ "pile", "gpt2"] mp: p=f32,c=bfloat16 - model_axis_size: 1 + model_ici_axis_size: 1 train_batch_size: 256 num_train_steps: 50000 diff --git a/config/gpt2_small_sophiah.yaml b/config/gpt2_small_sophiah.yaml index fd82ab226..c77bace75 100644 --- a/config/gpt2_small_sophiah.yaml +++ b/config/gpt2_small_sophiah.yaml @@ -13,7 +13,7 @@ trainer: tags: [ "openwebtext", "gpt2", "sophia-h"] mp: p=f32,c=bfloat16 - model_axis_size: 1 + model_ici_axis_size: 1 train_batch_size: 512 optimizer: !include optim/sophia-h_small.yaml diff --git a/config/llama2_7b_continued.yaml b/config/llama2_7b_continued.yaml index 1c16a2f16..35ae720d0 100644 --- a/config/llama2_7b_continued.yaml +++ b/config/llama2_7b_continued.yaml @@ -13,7 +13,7 @@ trainer: mp: p=f32,c=bfloat16 - model_axis_size: 1 + model_ici_axis_size: 1 per_device_eval_parallelism: 4 train_batch_size: 1024 diff --git a/config/llama_small_fast.yaml b/config/llama_small_fast.yaml index 5fb6d911c..96c2142cc 100644 --- a/config/llama_small_fast.yaml +++ b/config/llama_small_fast.yaml @@ -21,7 +21,7 @@ trainer: tags: [ "openwebtext", "llama", "itest"] mp: p=f32,c=bfloat16 - model_axis_size: 1 + model_ici_axis_size: 1 per_device_parallelism: -1 train_batch_size: 256 diff --git a/config/lora/mpt_biomed.yaml b/config/lora/mpt_biomed.yaml index 6b19d0ab5..e51156683 100644 --- a/config/lora/mpt_biomed.yaml +++ b/config/lora/mpt_biomed.yaml @@ -18,7 +18,7 @@ trainer: mp: p=f32,c=bfloat16 - model_axis_size: 1 + model_ici_axis_size: 1 per_device_parallelism: 4 per_device_eval_parallelism: 4 diff --git a/config/whisper_tiny_librispeech.yaml b/config/whisper_tiny_librispeech.yaml index 0b13491ae..850fa93b6 100644 --- a/config/whisper_tiny_librispeech.yaml +++ b/config/whisper_tiny_librispeech.yaml @@ -15,7 +15,7 @@ trainer: tags: [ "librispeech", "whisper"] mp: p=f32,c=bf16 - model_axis_size: 1 + model_ici_axis_size: 1 per_device_parallelism: -1 train_batch_size: 128 diff --git a/docs/Configuration-Guide.md b/docs/Configuration-Guide.md index bdb09e4f1..79ffda4c1 100644 --- a/docs/Configuration-Guide.md +++ b/docs/Configuration-Guide.md @@ -41,7 +41,7 @@ trainer: tags: [ "openwebtext", "gpt2"] mp: p=f32,c=bfloat16 - model_axis_size: 1 + model_ici_axis_size: 1 per_device_parallelism: 4 train_batch_size: 512 @@ -100,7 +100,7 @@ The following table lists some of the parameters that you might want to change. | `seed` | The random seed | 0 | | `num_train_steps` | The number of training steps to run | 400,000 | | `train_batch_size` | The batch size | 32 | -| `per_device_train_parallelism` | Number of examples to process on each device during training | `train_batch_size / (num_accelerators * model_axis_size)` | +| `per_device_train_parallelism` | Number of examples to process on each device during training | `train_batch_size / (num_accelerators * model_ici_axis_size)` | | `per_device_eval_parallelism` | Number of examples to process on each device during eval | `per_device_train_parallelism` | | `steps_per_eval` | How often to evaluate the model during training | 1,000 | | `max_eval_batches` | How many batches to evaluate during each evaluation | `None` (meaning all) | @@ -133,7 +133,7 @@ reasonable defaults and an "advanced" mode that gives you more control. | `batch_axis` | The axis to shard the batch over, for distributed data parallelism | `"batch"` | | `fsdp_axis` | The axis or axes to shard the model over, for Fully Sharded Data Parallelism | `"embed"` | | `tensor_parallel_axes` | The axis or axes to shard the model over, for Tensor Parallelism | `None` | -| `model_axis_size` | How many devices for tensor parallelism | `1` | +| `model_ici_axis_size` | How many devices for tensor parallelism | `1` | #### Advanced Mode @@ -142,7 +142,7 @@ reasonable defaults and an "advanced" mode that gives you more control. | `axis_resources` | Mapping from logical axis to physical axis shared by both mappings | -- | | `parameter_axis_resources` | Mapping from logical axis to physical axis for the parameter mapping | -- | | `compute_axis_resources` | Mapping from logical axis to physical axis for the compute mapping | -- | -| `model_axis_size` | How many devices for tensor parallelism | `1` | +| `model_ici_axis_size` | How many devices for tensor parallelism | `1` | ### Checkpointing and Initialization diff --git a/docs/tutorials/Training-On-Audio-Data.md b/docs/tutorials/Training-On-Audio-Data.md index c378fda08..379ee2659 100644 --- a/docs/tutorials/Training-On-Audio-Data.md +++ b/docs/tutorials/Training-On-Audio-Data.md @@ -117,7 +117,7 @@ trainer: tags: [ "librispeech", "whisper"] mp: p=f32,c=bf16 - model_axis_size: 1 + model_ici_axis_size: 1 per_device_parallelism: -1 train_batch_size: 128 diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 00e7b280d..7a22a0eb8 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -9,28 +9,14 @@ from dataclasses import dataclass from functools import cached_property from pathlib import Path -from typing import ( - Any, - Callable, - Dict, - Generic, - Iterable, - List, - Mapping, - Optional, - Protocol, - Sequence, - Tuple, - TypeVar, - Union, -) +from typing import Any, Callable, Dict, Generic, Iterable, List, Optional, Protocol, Sequence, Tuple, TypeVar, Union import equinox as eqx import jax import jmp import numpy as np from draccus import field -from jax.experimental import multihost_utils +from jax.experimental import create_device_mesh, create_hybrid_device_mesh, multihost_utils from jax.sharding import Mesh from jaxtyping import PRNGKeyArray, PyTree from optax import GradientTransformation @@ -550,11 +536,18 @@ class TrainerConfig: tensor_parallel_axes: Optional[List[str]] = None # Axes, if any, to use for tensor parallelism # TODO: in theory we can support tuples of physical axis names, but I don't think anyone actually uses that. - axis_resources: Mapping[str, str] = field(default_factory=dict) + axis_resources: ResourceMapping = field(default_factory=dict) """mapping from logical axis to physical axis. batch_axis, fsdp_axis, and tensor_parallel_axes are preferred""" - parameter_axis_resources: Mapping[str, str] = field(default_factory=dict) # overrides axis_mapping for parameter + parameter_axis_resources: ResourceMapping = field(default_factory=dict) # overrides axis_mapping for parameter """logical->physical mapping for parameter/optimizer sharding. fsdp_axis and tensor_parallel_axes are preferred""" - model_axis_size: int = 1 # how many devices to shard each model over. Data axis is the other axis + + """Interchip Interconnect (ICI) & Data Center Networking (DCN) shardings""" + replica_ici_axis_size: int = 1 + model_ici_axis_size: int = 1 + """how many devices within each slice for sharding with DP and TP. The rest of the devices is for FSDP.""" + replica_dcn_axis_size: int = 1 + model_dcn_axis_size: int = 1 + """how many slices in the multislice scheme for sharding with DP and TP. The rest of the devices is for FSDP.""" # Config related to batch sizes train_batch_size: int = 512 @@ -636,19 +629,58 @@ def initialize(self): @cached_property def device_mesh(self) -> Mesh: - devices = jax.devices() - devices = np.array(devices).reshape(self.data_axis_size, self.model_axis_size) - return Mesh(devices, (ResourceAxis.DATA, ResourceAxis.MODEL)) + is_multislice = hasattr(jax.devices()[0], "slice_index") + if is_multislice: + devices = create_hybrid_device_mesh( + (self.replica_ici_axis_size, self.data_ici_axis_size, self.model_ici_axis_size), + (self.replica_dcn_axis_size, self.data_dcn_axis_size, self.model_dcn_axis_size), + ) + else: + devices = create_device_mesh( + (self.replica_ici_axis_size, self.data_ici_axis_size, self.model_ici_axis_size) + ) + return Mesh(devices, ("replica", ResourceAxis.DATA, ResourceAxis.MODEL)) @property def eval_batch_size(self): return self.per_device_eval_parallelism * self.data_axis_size + @property + def num_slices(self): + """number of nodes""" + return getattr(jax.device_count()[-1], "slice_index", 1) + + @property + def num_local_devices(self): + """number of devices within a slice""" + return jax.device_count() // self.num_slices + + @property + def data_ici_axis_size(self): + """size of the FSDP axis within slices""" + assert self.num_local_devices % (self.replica_ici_axis_size * self.model_ici_axis_size) == 0 + return self.num_local_devices // (self.replica_ici_axis_size * self.model_ici_axis_size) + + @property + def data_dcn_axis_size(self): + """size of the FSDP axis across slices""" + assert self.num_slices % (self.replica_dcn_axis_size * self.model_dcn_axis_size) == 0 + return self.num_slices // (self.replica_dcn_axis_size * self.model_dcn_axis_size) + @property def data_axis_size(self): """size of the data parallel/batch parallel axis.""" - assert jax.device_count() % self.model_axis_size == 0 - return jax.device_count() // self.model_axis_size + return self.data_dcn_axis_size * self.data_ici_axis_size + + @property + def replica_axis_size(self): + """size of the data parallel/batch parallel axis.""" + return self.replica_dcn_axis_size * self.replica_ici_axis_size + + @property + def model_axis_size(self): + """size of the data parallel/batch parallel axis.""" + return self.model_dcn_axis_size * self.model_ici_axis_size @cached_property def compute_axis_mapping(self) -> ResourceMapping: @@ -662,7 +694,7 @@ def compute_axis_mapping(self) -> ResourceMapping: axes_to_return[axis] = ResourceAxis.MODEL if self.batch_axis is not None: - axes_to_return[self.batch_axis] = ResourceAxis.DATA + axes_to_return[self.batch_axis] = ("replica", ResourceAxis.DATA) return axes_to_return