Skip to content

Commit

Permalink
Try moving all the logic into the loss fn
Browse files Browse the repository at this point in the history
  • Loading branch information
Aphoh committed Dec 17, 2024
1 parent d8fa2d7 commit fcf15ab
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 138 deletions.
3 changes: 1 addition & 2 deletions src/levanter/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,7 @@ def accum_for_batch(m: LmHeadModel, state: _EvalRunningMeans, batch: LmExample,
m = self.mp.cast_to_compute(m)

with hax.axis_mapping(axis_mapping):
losses = compute_next_token_loss(m, batch, reduction=None, reduction_axis=())
mask = batch.loss_mask # [Batch, Pos]
losses, mask, _extras = compute_next_token_loss(m, batch)
this_tokens = hax.sum(mask)
this_loss = hax.einsum("->", losses, mask) # to scalar

Expand Down
13 changes: 2 additions & 11 deletions src/levanter/eval_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
import levanter.tracker
from levanter.compat.hf_checkpoints import HFCheckpointConverter, load_tokenizer
from levanter.models.gpt2 import Gpt2Config
from levanter.models.loss import next_token_loss
from levanter.utils.hf_utils import HfTokenizer


Expand All @@ -58,7 +57,7 @@
import levanter.config
from levanter.checkpoint import load_checkpoint
from levanter.data import AsyncDataset, DataLoader
from levanter.models.lm_model import LmConfig, LmExample, LmHeadModel
from levanter.models.lm_model import LmConfig, LmExample, LmHeadModel, compute_next_token_loss
from levanter.trainer import StepInfo, TrainerConfig
from levanter.utils.jax_utils import use_cpu_device
from levanter.utils.tree_utils import inference_mode
Expand Down Expand Up @@ -157,15 +156,7 @@ def _eval_loglikelihood(model: LmHeadModel, example: LmExample) -> tuple[NamedAr
logits = logits.astype(jnp.float32)
Pos = logits.resolve_axis(self.EvalPos.name)

loss = next_token_loss(
Pos=Pos,
Vocab=model.Vocab,
logits=logits,
true_ids=example.tokens,
loss_mask=example.loss_mask,
reduction=hax.sum,
reduction_axis=Pos,
)
loss, _, _ = compute_next_token_loss(model, example)

not_last_loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=bool)
pred_targets = hax.argmax(logits, axis=model.Vocab)
Expand Down
82 changes: 56 additions & 26 deletions src/levanter/grad_accum.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,28 @@
import abc
import enum
import functools
from typing import Callable, Optional, ParamSpec, TypeVar

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.tree as jtu
from jax.lax import with_sharding_constraint
from jax.sharding import PartitionSpec

import haliax as hax
import haliax.quantization as hq
from haliax import Axis
from haliax.partitioning import ResourceAxis
from haliax.util import is_named_array

from levanter.utils.jax_utils import zeros_like_tree
from levanter.utils.types import ComputeLossFunction


Args = ParamSpec("Args")
R = TypeVar("R")


class NumElementsBatch(abc.ABC):
@abc.abstractmethod
def num_elements(self) -> int:
pass
M_con = TypeVar("M_con", contravariant=True) # Model
X = TypeVar("X", contravariant=True) # Input


class ReductionType(enum.Enum):
Expand All @@ -33,17 +31,38 @@ class ReductionType(enum.Enum):
# TODO: add MAX?


def apply_updates_running(acc, r, updates, overwrites):
def _running_sum_updates(u, p):
if u is None:
return p
else:
return p * (1 - r) + u * r

def _is_none(x):
return x is None

def _apply_update(tree, update, overwrite):
if overwrite is not None:
return overwrite

return jtu.map(_running_sum_updates, update, tree, is_leaf=_is_none)

def is_leaf(x):
return x is None or isinstance(x, hq.OverwriteWithGradient)

return jtu.map(_apply_update, acc, updates, overwrites, is_leaf=is_leaf)


# TODO: should we use a custom_jvp on microbatched?

# cf https://github.com/google-research/t5x/blob/main/t5x/trainer.py#L617
def microbatched(
fn: Callable[Args, R],
loss_fn: ComputeLossFunction[M_con, X],
Batch: Axis,
microbatch_size: int,
accum_axis_mapping,
compute_axis_mapping,
patch_in_rng_key: Optional[str] = "key",
reduce: ReductionType = ReductionType.MEAN,
accum_dtype: Optional[jnp.dtype] = None,
) -> Callable[Args, R]:
"""
Expand Down Expand Up @@ -85,20 +104,30 @@ def microbatched(
num_micro_steps = batch_size // microbatch_size

if num_micro_steps == 1:
return fn

@functools.wraps(loss_fn)
def no_accum_loss_fn(*args, **kwargs):
losses, where, extras = loss_fn(*args, **kwargs)
return hax.mean(losses, where=where).scalar(), extras

return eqx.filter_value_and_grad(no_accum_loss_fn, has_aux=True)

Microbatch = Batch.resize(microbatch_size)
AccumStep = Axis("accum_step", num_micro_steps)
assert num_micro_steps * microbatch_size == batch_size

if reduce not in ReductionType:
raise ValueError(f"accum_type must be one of {ReductionType}")
@functools.wraps(loss_fn)
def accum_loss_fn(*args, **kwargs):
losses, where, extras = loss_fn(*args, **kwargs)
return hax.mean(losses, where=where).scalar(), (where.sum(), extras)

@functools.wraps(fn)
grad_fn = eqx.filter_value_and_grad(accum_loss_fn, has_aux=True)

@functools.wraps(grad_fn)
def wrapped_fn(*args, **kwargs):

# first, determine the shape and make accumulator arrays
r_shape = eqx.filter_eval_shape(fn, *args, **kwargs)
r_shape = eqx.filter_eval_shape(grad_fn, *args, **kwargs)
acc = zeros_like_tree(r_shape, accum_axis_mapping, accum_dtype)

# then, reshape the inputs from (Batch, ...) to (AccumStep, Microbatch, ...)
Expand All @@ -113,30 +142,31 @@ def wrapped_fn(*args, **kwargs):
args = _reshape_for_microbatch(Batch, Microbatch, AccumStep, args, compute_axis_mapping)

def loop(acc, microbatch_and_key):
(loss, (total, extras)), grads = acc
microbatch, microbatch_kwargs, key = microbatch_and_key
with jax.named_scope("compute"):
microbatch_kwargs = microbatch_kwargs.copy()
if key is not None:
microbatch_kwargs[patch_in_rng_key] = key
this_r = fn(*microbatch, **microbatch_kwargs)
(loss_mb, (n_mb, extras_mb)), grads_mb = grad_fn(*microbatch, **microbatch_kwargs)

with jax.named_scope("accum"):
import haliax.quantization as hq

# TODO: this uses the latest value for the scale for fp8, which seems not ideal but probably ok?
overwrites, updates = hq.partition_for_grad_overwrite(this_r)
acc = hq.apply_updates(acc, updates, overwrites)
acc = hax.shard_with_axis_mapping(acc, accum_axis_mapping)

return acc
overwrites, updates = hq.partition_for_grad_overwrite(grads_mb)
r = n_mb / (total + n_mb)
loss = loss + (loss_mb - loss) * r
grads = apply_updates_running(grads, r, updates, overwrites)
grads = hax.shard_with_axis_mapping(grads, accum_axis_mapping)
print(loss, loss_mb, r)
return (loss, (total, {k: v + extras_mb[k] for k, v in extras.items()})), grads

with jax.named_scope("microbatched"):
acc = hax.fold(loop, AccumStep)(acc, (args, kwargs, key))

if reduce == ReductionType.MEAN:
acc = jax.tree_util.tree_map(lambda x: x / num_micro_steps, acc)
(loss, (_, extras)), grads, = hax.fold(
loop, AccumStep
)(acc, (args, kwargs, key))

return acc
return (loss, extras), grads

return wrapped_fn

Expand Down
19 changes: 8 additions & 11 deletions src/levanter/models/lm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import haliax as hax
from haliax import Axis, NamedArray, NamedOrNumeric

from levanter.grad_accum import NumElementsBatch
from levanter.models.attention import AttentionMask
from levanter.models.loss import maybe_fused_next_token_loss

Expand All @@ -20,7 +19,7 @@
LmT = TypeVar("LmT", bound="LmHeadModel")


class LmExample(eqx.Module, NumElementsBatch):
class LmExample(eqx.Module):
tokens: hax.NamedArray
loss_mask: hax.NamedArray
attn_mask: AttentionMask | NamedArray = AttentionMask.causal()
Expand Down Expand Up @@ -223,33 +222,31 @@ def compute_next_token_loss(
example: LmExample,
*,
key=None,
reduction: Optional[hax.ReductionFunction] = hax.mean,
reduction_axis: Optional[hax.AxisSelection] = None,
batch_num_elements: Optional[int] = None,
logsumexp_weight: Optional[float] = None,
loss_dtype: Optional[Type[jnp.dtype]] = jnp.float32,
) -> jnp.ndarray | NamedArray:
) -> tuple[NamedArray, NamedArray, dict]:
"""
Computes the cross-entropy loss for a language modeling example. If reduction is not None, the loss is reduced
across the reduction axis (with reduction_axis=None meaning all axes). If reduction is None, the loss is not
reduced, and the result is a named array with axes (*batch axes, sequence_length).
"""
activations = model.activations(example.tokens, example.attn_mask, key=key)
if isinstance(activations, tuple):
activations, extras = activations
else:
extras = {}

loss = maybe_fused_next_token_loss(
loss, where = maybe_fused_next_token_loss(
model.Pos,
model.Embed,
model.Vocab,
activations,
model.get_lm_head(),
example.tokens,
loss_mask=example.loss_mask,
reduction=reduction,
reduction_axis=reduction_axis,
batch_num_elements=batch_num_elements,
logsumexp_weight=logsumexp_weight,
dtype=loss_dtype,
block_size=model.config.cross_entropy_block_size,
)

return loss
return loss, where, extras
Loading

0 comments on commit fcf15ab

Please sign in to comment.