diff --git a/config/gpt2_nano_harness.yaml b/config/gpt2_nano_harness.yaml index 5e0a0a36f..8a241a058 100644 --- a/config/gpt2_nano_harness.yaml +++ b/config/gpt2_nano_harness.yaml @@ -19,7 +19,7 @@ trainer: save_interval: 5m per_device_parallelism: -1 - train_batch_size: 32 + train_batch_size: 4 tensor_parallel_axes: ["mlp", "heads"] fsdp_axis: "embed" diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py index 6f998deac..7c6c29f32 100644 --- a/src/levanter/models/lm_model.py +++ b/src/levanter/models/lm_model.py @@ -57,7 +57,7 @@ def from_prompt_and_completion( all_causal: bool = True, ) -> "LmExample": # mask out the prompt tokens - loss_mask = hax.arange(Pos) >= prompt_length + loss_mask = hax.arange(Pos) >= prompt_length - 1 # also mask out the last token loss_mask *= 1 - hax.nn.one_hot(-1, Pos, dtype=jnp.float32)