Skip to content

Commit

Permalink
Also sum reduce in microbatched fn
Browse files Browse the repository at this point in the history
  • Loading branch information
Aphoh committed Dec 12, 2024
1 parent dc6a871 commit d8fa2d7
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions src/levanter/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from levanter.config import JsonAtom
from levanter.data import AsyncDataset, DataLoader
from levanter.distributed import DistributedConfig, RayConfig
from levanter.grad_accum import NumElementsBatch, microbatched
from levanter.grad_accum import NumElementsBatch, ReductionType, microbatched
from levanter.tracker import TrackerConfig, capture_time
from levanter.trainer_state import TrainerState, saveable_training_mask
from levanter.utils import cloud_utils, fsspec_utils
Expand Down Expand Up @@ -552,15 +552,22 @@ def obj_fun(trainable_model):
def _compute_gradients_microbatched(self, loss_fn, model: M, batch: X, **batch_kwargs) -> tuple[Scalar, M]:
grad_fn = eqx.filter_value_and_grad(loss_fn, has_aux=False)
mbs = self.config.microbatch_size
if isinstance(batch, NumElementsBatch):
batch_kwargs["batch_num_elements"] = batch.num_elements()
batch_kwargs["reduction"] = hax.sum
reduce = ReductionType.MEAN
if isinstance(batch, NumElementsBatch) and mbs != self.TrainBatch.size:
batch_kwargs[
"batch_num_elements"
] = batch.num_elements() # tell the loss function how many elements are in the batch
batch_kwargs[
"reduction"
] = hax.sum # the loss fn should sum the loss and divide by the number of elements, not average
reduce = ReductionType.SUM # we're already normalizing the loss
grad_fn = microbatched(
grad_fn,
self.TrainBatch,
mbs,
self.parameter_axis_mapping,
self.compute_axis_mapping,
reduce=reduce,
)
with hax.axis_mapping(self.compute_axis_mapping):
return grad_fn(model, batch, **batch_kwargs)
Expand Down

0 comments on commit d8fa2d7

Please sign in to comment.