From d8fa2d7d2c887724338e1757423a1e8530004db9 Mon Sep 17 00:00:00 2001 From: William Arnold Date: Thu, 12 Dec 2024 15:19:44 -0800 Subject: [PATCH] Also sum reduce in microbatched fn --- src/levanter/trainer.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 83d49ebdf..8f0902447 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -37,7 +37,7 @@ from levanter.config import JsonAtom from levanter.data import AsyncDataset, DataLoader from levanter.distributed import DistributedConfig, RayConfig -from levanter.grad_accum import NumElementsBatch, microbatched +from levanter.grad_accum import NumElementsBatch, ReductionType, microbatched from levanter.tracker import TrackerConfig, capture_time from levanter.trainer_state import TrainerState, saveable_training_mask from levanter.utils import cloud_utils, fsspec_utils @@ -552,15 +552,22 @@ def obj_fun(trainable_model): def _compute_gradients_microbatched(self, loss_fn, model: M, batch: X, **batch_kwargs) -> tuple[Scalar, M]: grad_fn = eqx.filter_value_and_grad(loss_fn, has_aux=False) mbs = self.config.microbatch_size - if isinstance(batch, NumElementsBatch): - batch_kwargs["batch_num_elements"] = batch.num_elements() - batch_kwargs["reduction"] = hax.sum + reduce = ReductionType.MEAN + if isinstance(batch, NumElementsBatch) and mbs != self.TrainBatch.size: + batch_kwargs[ + "batch_num_elements" + ] = batch.num_elements() # tell the loss function how many elements are in the batch + batch_kwargs[ + "reduction" + ] = hax.sum # the loss fn should sum the loss and divide by the number of elements, not average + reduce = ReductionType.SUM # we're already normalizing the loss grad_fn = microbatched( grad_fn, self.TrainBatch, mbs, self.parameter_axis_mapping, self.compute_axis_mapping, + reduce=reduce, ) with hax.axis_mapping(self.compute_axis_mapping): return grad_fn(model, batch, **batch_kwargs)