Skip to content

Commit

Permalink
rename model axis size
Browse files Browse the repository at this point in the history
  • Loading branch information
blahBlahhhJ committed May 26, 2024
1 parent 8b7b804 commit 130c212
Show file tree
Hide file tree
Showing 27 changed files with 38 additions and 44 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ trainer:
tags: [ "openwebtext", "gpt2"]

mp: p=f32,c=bfloat16
model_ici_axis_size: 1
model_axis_size: 1
per_device_parallelism: 4

train_batch_size: 512
Expand Down
2 changes: 1 addition & 1 deletion config/backpack.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ trainer:

num_train_steps: 50000
train_batch_size: 1024
model_ici_axis_size: 1
model_axis_size: 1

optimizer:
learning_rate: 6E-4
Expand Down
2 changes: 1 addition & 1 deletion config/backpack_nano.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ trainer:

num_train_steps: 100
train_batch_size: 32
model_ici_axis_size: 1
model_axis_size: 1

optimizer:
learning_rate: 6E-4
Expand Down
2 changes: 1 addition & 1 deletion config/gpt2_1536.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ trainer:
tags: [ "openwebtext", "gpt2"]

mp: p=f32,c=bfloat16
model_ici_axis_size: 1
model_axis_size: 1
per_device_eval_parallelism: 8
optimizer:
learning_rate: 1E-4
Expand Down
2 changes: 1 addition & 1 deletion config/gpt2_1536_sophiah.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ trainer:
tags: [ "openwebtext", "gpt2"]

mp: p=f32,c=bfloat16
model_ici_axis_size: 1
model_axis_size: 1
optimizer:
type: sophia-h
learning_rate: 2E-4
Expand Down
2 changes: 1 addition & 1 deletion config/gpt2_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ trainer:

mp: p=f32,c=bfloat16

model_ici_axis_size: 1
model_axis_size: 1
per_device_parallelism: -1
per_device_eval_parallelism: -1

Expand Down
2 changes: 1 addition & 1 deletion config/gpt2_large.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ trainer:
tags: [ "openwebtext", "gpt2"]

mp: p=f32,c=bfloat16
model_ici_axis_size: 1
model_axis_size: 1
per_device_parallelism: -1
optimizer:
learning_rate: 2E-4
Expand Down
2 changes: 1 addition & 1 deletion config/gpt2_medium.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ trainer:
tags: [ "openwebtext", "gpt2"]

mp: p=f32,c=bfloat16
model_ici_axis_size: 1
model_axis_size: 1
optimizer:
learning_rate: 3E-4
weight_decay: 0.1
Expand Down
2 changes: 1 addition & 1 deletion config/gpt2_small.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ trainer:
tags: [ "openwebtext", "gpt2"]

mp: p=f32,c=bfloat16
model_ici_axis_size: 1
model_axis_size: 1
per_device_parallelism: -1

train_batch_size: 512
Expand Down
2 changes: 1 addition & 1 deletion config/gpt2_small_fast.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ trainer:
tags: [ "openwebtext", "gpt2", "itest"]

mp: p=f32,c=bfloat16
model_ici_axis_size: 1
model_axis_size: 1
per_device_parallelism: -1

train_batch_size: 256
Expand Down
2 changes: 1 addition & 1 deletion config/gpt2_small_fast_fp8.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ trainer:

mp: p=f32,c=bfloat16
fp8: true
model_ici_axis_size: 1
model_axis_size: 1
per_device_parallelism: -1

train_batch_size: 256
Expand Down
2 changes: 1 addition & 1 deletion config/gpt2_small_fast_mix.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ trainer:
tags: [ "openwebtext+wiki", "gpt2", "itest"]

mp: p=f32,c=bfloat16
model_ici_axis_size: 1
model_axis_size: 1

train_batch_size: 256
num_train_steps: 20000
Expand Down
2 changes: 1 addition & 1 deletion config/gpt2_small_fast_pile.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ trainer:
tags: [ "pile", "gpt2", "itest"]

mp: p=f32,c=bfloat16
model_ici_axis_size: 1
model_axis_size: 1

train_batch_size: 256
num_train_steps: 20000
Expand Down
2 changes: 1 addition & 1 deletion config/gpt2_small_fast_public.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ trainer:
tags: [ "openwebtext", "gpt2", "itest"]

mp: p=f32,c=bfloat16
model_ici_axis_size: 1
model_axis_size: 1
per_device_parallelism: -1

train_batch_size: 256
Expand Down
2 changes: 1 addition & 1 deletion config/gpt2_small_fast_sophia_h.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ trainer:
tags: [ "openwebtext", "gpt2", "itest", "sophia-h"]

mp: p=f32,c=bfloat16
model_ici_axis_size: 1
model_axis_size: 1

train_batch_size: 256
num_train_steps: 20000
Expand Down
2 changes: 1 addition & 1 deletion config/gpt2_small_fast_sophiah.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ trainer:
tags: [ "openwebtext", "gpt2", "itest"]

mp: p=f32,c=bfloat16
model_ici_axis_size: 1
model_axis_size: 1
per_device_parallelism: -1

train_batch_size: 256
Expand Down
2 changes: 1 addition & 1 deletion config/gpt2_small_fast_wiki.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ trainer:
tags: [ "openwebtext", "gpt2", "itest"]

mp: p=f32,c=bfloat16
model_ici_axis_size: 1
model_axis_size: 1
per_device_parallelism: -1

train_batch_size: 256
Expand Down
2 changes: 1 addition & 1 deletion config/gpt2_small_pile.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ trainer:
tags: [ "pile", "gpt2"]

mp: p=f32,c=bfloat16
model_ici_axis_size: 1
model_axis_size: 1

train_batch_size: 256
num_train_steps: 50000
Expand Down
2 changes: 1 addition & 1 deletion config/gpt2_small_pile_mixture.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ trainer:
tags: [ "pile", "gpt2"]

mp: p=f32,c=bfloat16
model_ici_axis_size: 1
model_axis_size: 1

train_batch_size: 256
num_train_steps: 50000
Expand Down
2 changes: 1 addition & 1 deletion config/gpt2_small_sophiah.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ trainer:
tags: [ "openwebtext", "gpt2", "sophia-h"]

mp: p=f32,c=bfloat16
model_ici_axis_size: 1
model_axis_size: 1

train_batch_size: 512
optimizer: !include optim/sophia-h_small.yaml
2 changes: 1 addition & 1 deletion config/llama2_7b_continued.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ trainer:

mp: p=f32,c=bfloat16

model_ici_axis_size: 1
model_axis_size: 1
per_device_eval_parallelism: 4

train_batch_size: 1024
Expand Down
2 changes: 1 addition & 1 deletion config/llama_small_fast.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ trainer:
tags: [ "openwebtext", "llama", "itest"]

mp: p=f32,c=bfloat16
model_ici_axis_size: 1
model_axis_size: 1
per_device_parallelism: -1

train_batch_size: 256
Expand Down
2 changes: 1 addition & 1 deletion config/lora/mpt_biomed.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ trainer:

mp: p=f32,c=bfloat16

model_ici_axis_size: 1
model_axis_size: 1
per_device_parallelism: 4
per_device_eval_parallelism: 4

Expand Down
2 changes: 1 addition & 1 deletion config/whisper_tiny_librispeech.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ trainer:
tags: [ "librispeech", "whisper"]

mp: p=f32,c=bf16
model_ici_axis_size: 1
model_axis_size: 1
per_device_parallelism: -1

train_batch_size: 128
Expand Down
8 changes: 4 additions & 4 deletions docs/Configuration-Guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ trainer:
tags: [ "openwebtext", "gpt2"]

mp: p=f32,c=bfloat16
model_ici_axis_size: 1
model_axis_size: 1
per_device_parallelism: 4

train_batch_size: 512
Expand Down Expand Up @@ -100,7 +100,7 @@ The following table lists some of the parameters that you might want to change.
| `seed` | The random seed | 0 |
| `num_train_steps` | The number of training steps to run | 400,000 |
| `train_batch_size` | The batch size | 32 |
| `per_device_train_parallelism` | Number of examples to process on each device during training | `train_batch_size / (num_accelerators * model_ici_axis_size)` |
| `per_device_train_parallelism` | Number of examples to process on each device during training | `train_batch_size / (num_accelerators * model_axis_size)` |
| `per_device_eval_parallelism` | Number of examples to process on each device during eval | `per_device_train_parallelism` |
| `steps_per_eval` | How often to evaluate the model during training | 1,000 |
| `max_eval_batches` | How many batches to evaluate during each evaluation | `None` (meaning all) |
Expand Down Expand Up @@ -133,7 +133,7 @@ reasonable defaults and an "advanced" mode that gives you more control.
| `batch_axis` | The axis to shard the batch over, for distributed data parallelism | `"batch"` |
| `fsdp_axis` | The axis or axes to shard the model over, for Fully Sharded Data Parallelism | `"embed"` |
| `tensor_parallel_axes` | The axis or axes to shard the model over, for Tensor Parallelism | `None` |
| `model_ici_axis_size` | How many devices for tensor parallelism | `1` |
| `model_axis_size` | How many devices for tensor parallelism | `1` |

#### Advanced Mode

Expand All @@ -142,7 +142,7 @@ reasonable defaults and an "advanced" mode that gives you more control.
| `axis_resources` | Mapping from logical axis to physical axis shared by both mappings | -- |
| `parameter_axis_resources` | Mapping from logical axis to physical axis for the parameter mapping | -- |
| `compute_axis_resources` | Mapping from logical axis to physical axis for the compute mapping | -- |
| `model_ici_axis_size` | How many devices for tensor parallelism | `1` |
| `model_axis_size` | How many devices for tensor parallelism | `1` |

### Checkpointing and Initialization

Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/Training-On-Audio-Data.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ trainer:
tags: [ "librispeech", "whisper"]
mp: p=f32,c=bf16
model_ici_axis_size: 1
model_axis_size: 1
per_device_parallelism: -1
train_batch_size: 128
Expand Down
24 changes: 9 additions & 15 deletions src/levanter/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,10 +543,9 @@ class TrainerConfig:

"""Interchip Interconnect (ICI) & Data Center Networking (DCN) shardings"""
replica_ici_axis_size: int = 1
model_ici_axis_size: int = 1
"""how many devices within each slice for sharding with DP and TP. The rest of the devices is for FSDP."""
model_axis_size: int = 1
"""how many devices within each slice for sharding with DP. Fix TP=1, the rest of the devices is for FSDP."""
replica_dcn_axis_size: int = 1
model_dcn_axis_size: int = 1
"""how many slices in the multislice scheme for sharding with DP and TP. The rest of the devices is for FSDP."""

# Config related to batch sizes
Expand Down Expand Up @@ -632,13 +631,13 @@ def device_mesh(self) -> Mesh:
is_multislice = hasattr(jax.devices()[0], "slice_index")
if is_multislice:
devices = mesh_utils.create_hybrid_device_mesh(
(self.replica_ici_axis_size, self.data_ici_axis_size, self.model_ici_axis_size),
(self.replica_dcn_axis_size, self.data_dcn_axis_size, self.model_dcn_axis_size),
(self.replica_ici_axis_size, self.data_ici_axis_size, self.model_axis_size),
(self.replica_dcn_axis_size, self.data_dcn_axis_size, 1),
allow_split_physical_axes=True,
)
else:
devices = mesh_utils.create_device_mesh(
(self.replica_ici_axis_size, self.data_ici_axis_size, self.model_ici_axis_size),
(self.replica_ici_axis_size, self.data_ici_axis_size, self.model_axis_size),
allow_split_physical_axes=True,
)
# devices = jax.devices()
Expand All @@ -664,14 +663,14 @@ def num_local_devices(self):
@property
def data_ici_axis_size(self):
"""size of the FSDP axis within slices"""
assert self.num_local_devices % (self.replica_ici_axis_size * self.model_ici_axis_size) == 0
return self.num_local_devices // (self.replica_ici_axis_size * self.model_ici_axis_size)
assert self.num_local_devices % (self.replica_ici_axis_size * self.model_axis_size) == 0
return self.num_local_devices // (self.replica_ici_axis_size * self.model_axis_size)

@property
def data_dcn_axis_size(self):
"""size of the FSDP axis across slices"""
assert self.num_slices % (self.replica_dcn_axis_size * self.model_dcn_axis_size) == 0
return self.num_slices // (self.replica_dcn_axis_size * self.model_dcn_axis_size)
assert self.num_slices % self.replica_dcn_axis_size == 0
return self.num_slices // self.replica_dcn_axis_size

@property
def data_axis_size(self):
Expand All @@ -685,11 +684,6 @@ def replica_axis_size(self):
"""size of the data parallel/batch parallel axis."""
return self.replica_dcn_axis_size * self.replica_ici_axis_size

@property
def model_axis_size(self):
"""size of the data parallel/batch parallel axis."""
return self.model_dcn_axis_size * self.model_ici_axis_size

@cached_property
def compute_axis_mapping(self) -> ResourceMapping:
"""Mapping from logical axis to physical axis for compute."""
Expand Down

0 comments on commit 130c212

Please sign in to comment.