Skip to content

Commit

Permalink
Fix grad accum
Browse files Browse the repository at this point in the history
  • Loading branch information
Aphoh committed Dec 12, 2024
1 parent 1d63849 commit dc6a871
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 31 deletions.
7 changes: 7 additions & 0 deletions src/levanter/grad_accum.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import abc
import enum
import functools
from typing import Callable, Optional, ParamSpec, TypeVar
Expand All @@ -20,6 +21,12 @@
R = TypeVar("R")


class NumElementsBatch(abc.ABC):
@abc.abstractmethod
def num_elements(self) -> int:
pass


class ReductionType(enum.Enum):
SUM = enum.auto()
MEAN = enum.auto()
Expand Down
8 changes: 7 additions & 1 deletion src/levanter/models/lm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import haliax as hax
from haliax import Axis, NamedArray, NamedOrNumeric

from levanter.grad_accum import NumElementsBatch
from levanter.models.attention import AttentionMask
from levanter.models.loss import maybe_fused_next_token_loss

Expand All @@ -19,7 +20,7 @@
LmT = TypeVar("LmT", bound="LmHeadModel")


class LmExample(eqx.Module):
class LmExample(eqx.Module, NumElementsBatch):
tokens: hax.NamedArray
loss_mask: hax.NamedArray
attn_mask: AttentionMask | NamedArray = AttentionMask.causal()
Expand Down Expand Up @@ -88,6 +89,9 @@ def from_prompt_and_completion(

return LmExample(tokens=tokens, loss_mask=loss_mask, attn_mask=attn_mask)

def num_elements(self):
return self.loss_mask.sum()


# TODO: for some reason, mypy doesn't like the discover_packages_path argument?
@dataclass(frozen=True)
Expand Down Expand Up @@ -221,6 +225,7 @@ def compute_next_token_loss(
key=None,
reduction: Optional[hax.ReductionFunction] = hax.mean,
reduction_axis: Optional[hax.AxisSelection] = None,
batch_num_elements: Optional[int] = None,
logsumexp_weight: Optional[float] = None,
loss_dtype: Optional[Type[jnp.dtype]] = jnp.float32,
) -> jnp.ndarray | NamedArray:
Expand All @@ -241,6 +246,7 @@ def compute_next_token_loss(
loss_mask=example.loss_mask,
reduction=reduction,
reduction_axis=reduction_axis,
batch_num_elements=batch_num_elements,
logsumexp_weight=logsumexp_weight,
dtype=loss_dtype,
block_size=model.config.cross_entropy_block_size,
Expand Down
61 changes: 37 additions & 24 deletions src/levanter/models/loss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import logging
from typing import Optional

import equinox
Expand All @@ -10,6 +11,9 @@
from haliax.nn import cross_entropy_loss_and_log_normalizers


logger = logging.getLogger(__name__)


def maybe_fused_next_token_loss(
Pos: hax.AxisSelector,
Embed: hax.AxisSelector,
Expand All @@ -20,6 +24,7 @@ def maybe_fused_next_token_loss(
loss_mask: Optional[NamedArray] = None,
reduction: Optional[hax.ReductionFunction] = hax.mean,
reduction_axis: Optional[hax.AxisSelection] = None,
batch_num_elements: Optional[int] = None,
logsumexp_weight: Optional[float] = None,
block_size: Optional[int] = None,
dtype: Optional[jnp.dtype] = jnp.float32,
Expand All @@ -36,6 +41,7 @@ def maybe_fused_next_token_loss(
loss_mask (Optional[NamedArray]): Mask to apply to the loss.
reduction (Optional[hax.ReductionFunction]): Reduction function.
reduction_axis (Optional[hax.AxisSelection]): Axis to apply reduction.
batch_num_elements (Optional[int]): The number of elements in the batch. When passed, it is used to reduce the loss.
logsumexp_weight (Optional[float]): Weight for logsumexp penalty.
block_size (Optional[int]): Size of each block for processing.
Expand All @@ -45,6 +51,9 @@ def maybe_fused_next_token_loss(
# Resolve axes
Pos = pred_embeddings.resolve_axis(Pos)
Vocab = pred_lm_head.resolve_axis(Vocab)
if batch_num_elements is not None:
if reduction is not hax.sum:
logger.warning("batch_num_elements given when reduction is not hax.sum, make sure this is intended")

if block_size is None:
# Full softmax computation
Expand All @@ -53,32 +62,36 @@ def maybe_fused_next_token_loss(
logits = logits.astype(dtype)

# Shift target tokens to predict the next token
return next_token_loss(Pos, Vocab, logits, true_ids, loss_mask, reduction, reduction_axis, logsumexp_weight)

# Shift target tokens to predict the next token
target_y = hax.roll(true_ids, -1, Pos)

# Create a mask that excludes the last token
not_last_loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jnp.float32) # type: ignore
if loss_mask is not None:
loss_mask = loss_mask * not_last_loss_mask
loss = next_token_loss(Pos, Vocab, logits, true_ids, loss_mask, reduction, reduction_axis, logsumexp_weight)
else:
loss_mask = not_last_loss_mask
# Shift target tokens to predict the next token
target_y = hax.roll(true_ids, -1, Pos)

# Compute the loss with optional block-wise processing
return fused_cross_entropy_loss_and_logsumexp_penalty(
pred_embeddings,
pred_lm_head,
Contract=Embed,
Label=Vocab,
target_y=target_y,
reduction=reduction,
reduction_axis=reduction_axis,
where=loss_mask,
logsumexp_weight=logsumexp_weight,
block_size=block_size,
dtype=dtype,
)
# Create a mask that excludes the last token
not_last_loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jnp.float32) # type: ignore
if loss_mask is not None:
loss_mask = loss_mask * not_last_loss_mask
else:
loss_mask = not_last_loss_mask

# Compute the loss with optional block-wise processing
loss = fused_cross_entropy_loss_and_logsumexp_penalty(
pred_embeddings,
pred_lm_head,
Contract=Embed,
Label=Vocab,
target_y=target_y,
reduction=reduction,
reduction_axis=reduction_axis,
where=loss_mask,
logsumexp_weight=logsumexp_weight,
block_size=block_size,
dtype=dtype,
)

if batch_num_elements is not None:
return loss / batch_num_elements
return loss


def next_token_loss(
Expand Down
13 changes: 8 additions & 5 deletions src/levanter/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from levanter.config import JsonAtom
from levanter.data import AsyncDataset, DataLoader
from levanter.distributed import DistributedConfig, RayConfig
from levanter.grad_accum import microbatched
from levanter.grad_accum import NumElementsBatch, microbatched
from levanter.tracker import TrackerConfig, capture_time
from levanter.trainer_state import TrainerState, saveable_training_mask
from levanter.utils import cloud_utils, fsspec_utils
Expand Down Expand Up @@ -380,7 +380,7 @@ def checkpoint_path(self) -> str:
checkpoint_path = self.config.checkpointer.expanded_path(self.run_id)
return checkpoint_path

def train_step(self, state: S, *batch: X, **batch_kwargs) -> StepInfo[S]:
def train_step(self, state: S, batch: X, **batch_kwargs) -> StepInfo[S]:
"""
Performs a single training step.
"""
Expand Down Expand Up @@ -529,7 +529,7 @@ def _train_step(
key, new_key = jax.random.split(state.training_key)
model = inference_mode(state.model, False)

loss, grads = self._compute_gradients_microbatched(self.loss_fn, model, *batch, **batch_kwargs, key=key)
loss, grads = self._compute_gradients_microbatched(self.loss_fn, model, batch, **batch_kwargs, key=key)

with hax.axis_mapping(self.parameter_axis_mapping):
if not _no_hooks:
Expand All @@ -549,9 +549,12 @@ def obj_fun(trainable_model):
else:
return loss, new_state, hook_infos

def _compute_gradients_microbatched(self, loss_fn, model: M, *batch, **batch_kwargs) -> tuple[Scalar, M]:
def _compute_gradients_microbatched(self, loss_fn, model: M, batch: X, **batch_kwargs) -> tuple[Scalar, M]:
grad_fn = eqx.filter_value_and_grad(loss_fn, has_aux=False)
mbs = self.config.microbatch_size
if isinstance(batch, NumElementsBatch):
batch_kwargs["batch_num_elements"] = batch.num_elements()
batch_kwargs["reduction"] = hax.sum
grad_fn = microbatched(
grad_fn,
self.TrainBatch,
Expand All @@ -560,7 +563,7 @@ def _compute_gradients_microbatched(self, loss_fn, model: M, *batch, **batch_kwa
self.compute_axis_mapping,
)
with hax.axis_mapping(self.compute_axis_mapping):
return grad_fn(model, *batch, **batch_kwargs)
return grad_fn(model, batch, **batch_kwargs)


def _initialize_global_tracker(config, run_id):
Expand Down
3 changes: 2 additions & 1 deletion src/levanter/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ class ComputeLossFunction(Protocol[M_con, X]):
def __call__(
self,
model: M_con,
*inputs: X,
input: X,
reduction: Optional[hax.ReductionFunction] = hax.mean,
reduction_axis: Optional[hax.AxisSelection] = None,
batch_num_elements: Optional[int] = None,
**kwargs,
) -> Scalar | hax.NamedArray:
...

0 comments on commit dc6a871

Please sign in to comment.