diff --git a/config/backpack.yaml b/config/backpack.yaml index 735d40c01..0fe93b539 100644 --- a/config/backpack.yaml +++ b/config/backpack.yaml @@ -12,7 +12,7 @@ model: trainer: tracker: project: "levanter" - tags: [ "openwebtext", "backpack" ] + tags: ["openwebtext", "backpack"] mp: p=f32,c=bfloat16 @@ -21,5 +21,5 @@ trainer: model_axis_size: 1 optimizer: - learning_rate: 6E-4 + learning_rate: 6e-4 weight_decay: 0.1 diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py index 135d10dd5..9ff544bd8 100644 --- a/src/levanter/callbacks.py +++ b/src/levanter/callbacks.py @@ -17,6 +17,7 @@ from tqdm_loggable import tqdm_logging from tqdm_loggable.auto import tqdm +import haliax as hax import haliax.nn from haliax import NamedArray, is_named_array from haliax.jax_utils import is_jax_array_like @@ -30,6 +31,8 @@ from levanter.utils import flop_utils, jax_utils from levanter.utils.jax_utils import barrier_sync, jnp_to_python from levanter.utils.logging import save_xla_dumps_to_wandb +from levanter.utils.stat_utils import RunningMean +from levanter.utils.types import Extras from levanter.visualization import compute_and_visualize_log_probs as viz_probs @@ -145,10 +148,8 @@ async def compute_length(): def eval_loss_loop(loss_fn, model, dataset, max_batches: Optional[int] = None, name: Optional[str] = None): - total_loss = 0.0 - total_load_time = 0.0 - total_loss_time = 0.0 - n = 0 + loss = RunningMean(jnp.zeros(()), jnp.zeros(())) + extras: Extras = {} if name is not None: desc = f"eval {name}" @@ -159,28 +160,27 @@ def eval_loss_loop(loss_fn, model, dataset, max_batches: Optional[int] = None, n pbar = tqdm(dataset, desc=desc, position=1, leave=False, total=max_batches) iter_ = iter(pbar) + n = 0 while True: - time_in = time.time() + n += 1 batch = next(iter_, None) if batch is None: break - load_time = time.time() - time_in - total_load_time += load_time - loss = loss_fn(model, batch) - total_loss += loss.item() - n += 1 - loss_time = time.time() - time_in - load_time - total_loss_time += loss_time + losses, where, extras = loss_fn(model, batch) + mean_loss = hax.mean(losses, where=where) + loss += RunningMean(mean_loss, where.sum()) + for k, v in extras.items(): + if k not in extras: + extras[k] = v + else: + extras[k] += v - pbar.set_postfix(loss=total_loss / n) + pbar.set_postfix(loss=loss.mean.item()) if max_batches is not None and n >= max_batches: break - if n > 0: - total_loss /= n - - return total_loss + return loss.item(), {k: v.item() for k, v in extras.items()} def compute_validation_loss( @@ -190,12 +190,14 @@ def compute_validation_loss( name: Optional[str] = None, ): def compute_loss(info: StepInfo): - loss = eval_loss_loop(loss_fn, info.model, dataset, max_batches=max_batches, name=name) + loss, extras = eval_loss_loop(loss_fn, info.model, dataset, max_batches=max_batches, name=name) prefix = "eval" if name: prefix += "/" + name - levanter.tracker.log({f"{prefix}/loss": loss}, step=info.step) + levanter.tracker.log( + {f"{prefix}/loss": loss} | {f"{prefix}/{k}": v for k, v in extras.items()}, step=info.step + ) if name: logger.info(f"{name} validation loss: {loss:.3f}") diff --git a/src/levanter/doremi.py b/src/levanter/doremi.py index cdcfe68cd..6cd14d496 100644 --- a/src/levanter/doremi.py +++ b/src/levanter/doremi.py @@ -134,12 +134,17 @@ def doremi_step(state: DoremiState, ref, batch, domains): proxy = inference_mode(state.model, False) with hax.axis_mapping(trainer.compute_axis_mapping): # calculate per-token losses for proxy and ref - proxy_losses, proxy_loss_bwd = eqx.filter_vjp(lambda p: loss_fn(p, batch, reduction_axis=()), proxy) - ref_losses = loss_fn(ref, batch, reduction_axis=()) + def scalar_loss_fn(p, batch): + ret, _, _ = loss_fn(p, batch) + return ret + + proxy_losses, proxy_loss_bwd = eqx.filter_vjp(lambda p: scalar_loss_fn(p, batch), proxy) + ref_losses = scalar_loss_fn(ref, batch) # calculate excess losses, aggregate per-domain losses excess_losses = proxy_losses - ref_losses clipped_losses = hax.maximum(excess_losses, 0) + print(clipped_losses.shape) per_domain_losses = _compute_per_domain_losses(clipped_losses, Domain, domains) # Update domain weights diff --git a/src/levanter/grad_accum.py b/src/levanter/grad_accum.py index 8dbdc2c30..3b77a75bd 100644 --- a/src/levanter/grad_accum.py +++ b/src/levanter/grad_accum.py @@ -1,11 +1,9 @@ -import enum import functools from typing import Callable, Optional, ParamSpec, TypeVar import equinox as eqx import jax import jax.numpy as jnp -import jax.tree as jtu from jax.lax import with_sharding_constraint from jax.sharding import PartitionSpec @@ -25,34 +23,6 @@ X = TypeVar("X", contravariant=True) # Input -class ReductionType(enum.Enum): - SUM = enum.auto() - MEAN = enum.auto() - # TODO: add MAX? - - -def apply_updates_running(acc, r, updates, overwrites): - def _running_sum_updates(u, p): - if u is None: - return p - else: - return p * (1 - r) + u * r - - def _is_none(x): - return x is None - - def _apply_update(tree, update, overwrite): - if overwrite is not None: - return overwrite - - return jtu.map(_running_sum_updates, update, tree, is_leaf=_is_none) - - def is_leaf(x): - return x is None or isinstance(x, hq.OverwriteWithGradient) - - return jtu.map(_apply_update, acc, updates, overwrites, is_leaf=is_leaf) - - # TODO: should we use a custom_jvp on microbatched? # cf https://github.com/google-research/t5x/blob/main/t5x/trainer.py#L617 @@ -108,6 +78,8 @@ def microbatched( @functools.wraps(loss_fn) def no_accum_loss_fn(*args, **kwargs): losses, where, extras = loss_fn(*args, **kwargs) + seen_tokens = where.sum().scalar() + extras["seen_tokens"] = seen_tokens return hax.mean(losses, where=where).scalar(), extras return eqx.filter_value_and_grad(no_accum_loss_fn, has_aux=True) @@ -119,7 +91,7 @@ def no_accum_loss_fn(*args, **kwargs): @functools.wraps(loss_fn) def accum_loss_fn(*args, **kwargs): losses, where, extras = loss_fn(*args, **kwargs) - return hax.mean(losses, where=where).scalar(), (where.sum(), extras) + return hax.sum(losses, where=where).scalar(), (where.sum(), extras) grad_fn = eqx.filter_value_and_grad(accum_loss_fn, has_aux=True) @@ -154,17 +126,20 @@ def loop(acc, microbatch_and_key): # TODO: this uses the latest value for the scale for fp8, which seems not ideal but probably ok? overwrites, updates = hq.partition_for_grad_overwrite(grads_mb) - r = n_mb / (total + n_mb) - loss = loss + (loss_mb - loss) * r - grads = apply_updates_running(grads, r, updates, overwrites) + grads = hq.apply_updates(grads, updates, overwrites) grads = hax.shard_with_axis_mapping(grads, accum_axis_mapping) - print(loss, loss_mb, r) + loss += loss_mb + total += n_mb + return (loss, (total, {k: v + extras_mb[k] for k, v in extras.items()})), grads with jax.named_scope("microbatched"): - (loss, (_, extras)), grads, = hax.fold( + (loss, (total, extras)), grads, = hax.fold( loop, AccumStep )(acc, (args, kwargs, key)) + grads = jax.tree_util.tree_map(lambda x: x / total, grads) + loss /= total + extras["seen_tokens"] = total return (loss, extras), grads diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 2a454362c..cbb272dcd 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -43,7 +43,7 @@ from levanter.utils import cloud_utils, fsspec_utils from levanter.utils.jax_utils import create_fsdp_mesh, zeros_like_tree from levanter.utils.tree_utils import inference_mode -from levanter.utils.types import ComputeLossFunction, FilterSpec +from levanter.utils.types import ComputeLossFunction, Extras, FilterSpec logger = pylogging.getLogger(__name__) @@ -391,10 +391,10 @@ def train_step(self, state: S, batch: X, **batch_kwargs) -> StepInfo[S]: with capture_time() as step_time: if hooks_this_time: - loss, new_state, cb_states = self._jit_train_step_fn(state, batch, batch_kwargs) + loss, new_state, extras, cb_states = self._jit_train_step_fn(state, batch, batch_kwargs) # force the loss so timing numbers are accurate. laziness isn't going to help here (i think?) else: - loss, new_state = self._jit_train_step_fn_no_hook(state, batch, batch_kwargs) + loss, new_state, extras = self._jit_train_step_fn_no_hook(state, batch, batch_kwargs) loss = loss.item() # type: ignore info = StepInfo(new_state, loss, step_time()) @@ -404,7 +404,8 @@ def train_step(self, state: S, batch: X, **batch_kwargs) -> StepInfo[S]: if hooks_this_time: self.hooks.run_jit_hooks_outside_step(info, cb_states) - levanter.tracker.log({"throughput/hook_time": hook_time()}, step=info.step) + log_items = {k: v.item() for k, v in extras.items()} | {"throughput/hook_time": hook_time()} + levanter.tracker.log(log_items, step=info.step) return info @@ -525,11 +526,13 @@ def _jit_train_step_fn_no_hook(self): def _train_step( self, state: S, batch, batch_kwargs, _no_hooks=False - ) -> tuple[Scalar, S, Sequence[CBInfo]] | tuple[Scalar, S]: + ) -> tuple[Scalar, S, Extras, Sequence[CBInfo]] | tuple[Scalar, S, Extras]: key, new_key = jax.random.split(state.training_key) model = inference_mode(state.model, False) - loss, grads = self._compute_gradients_microbatched(self.loss_fn, model, batch, **batch_kwargs, key=key) + (loss, extras), grads = self._compute_gradients_microbatched( + self.loss_fn, model, batch, **batch_kwargs, key=key + ) with hax.axis_mapping(self.parameter_axis_mapping): if not _no_hooks: @@ -545,11 +548,13 @@ def obj_fun(trainable_model): new_state = state.take_step(grads, obj_fun=obj_fun) new_state = hax.shard(new_state, self.parameter_axis_mapping) if _no_hooks: - return loss, new_state + return loss, new_state, extras else: - return loss, new_state, hook_infos + return loss, new_state, extras, hook_infos - def _compute_gradients_microbatched(self, loss_fn, model: M, batch: X, **batch_kwargs) -> tuple[Scalar, M]: + def _compute_gradients_microbatched( + self, loss_fn, model: M, batch: X, **batch_kwargs + ) -> tuple[tuple[Scalar, Extras], M]: mbs = self.config.microbatch_size grad_fn = microbatched( loss_fn, diff --git a/src/levanter/utils/stat_utils.py b/src/levanter/utils/stat_utils.py index 6111be42e..895003cf4 100644 --- a/src/levanter/utils/stat_utils.py +++ b/src/levanter/utils/stat_utils.py @@ -1,13 +1,26 @@ -import typing +from typing import TypeAlias import equinox as eqx import jax.numpy as jnp import numpy as np +from typing_extensions import Self import haliax as hax +from levanter.utils.types import Accumulatable -Arrayish: typing.TypeAlias = hax.NamedArray | np.ndarray | jnp.ndarray + +Arrayish: TypeAlias = hax.NamedArray | np.ndarray | jnp.ndarray + + +class SumScalar(Accumulatable): + value: jnp.ndarray + + def item(self) -> float: + return self.value.item() + + def __add__(self, other: Self) -> Self: + return SumScalar(self.value + other.value) class RunningMean(eqx.Module): @@ -27,6 +40,9 @@ def add(self, x: Arrayish, total: Arrayish) -> "RunningMean": new_total = self.total + total return RunningMean(new_mean, new_total) + def item(self) -> float: + return self.mean.item() + def __add__(self, other: "RunningMean"): return self.add(other.mean, other.total) diff --git a/src/levanter/utils/types.py b/src/levanter/utils/types.py index 900b23985..dafc0791d 100644 --- a/src/levanter/utils/types.py +++ b/src/levanter/utils/types.py @@ -1,6 +1,10 @@ -from typing import Any, Callable, Protocol, Tuple, TypeVar, Union +import abc +from typing import Any, Callable, Dict, Protocol, Tuple, TypeAlias, TypeVar, Union +import equinox as eqx +import jax from jaxtyping import PyTree +from typing_extensions import Self import haliax as hax from haliax.types import Scalar @@ -10,6 +14,19 @@ M_con = TypeVar("M_con", contravariant=True) # Model X = TypeVar("X", contravariant=True) # Input + +class Accumulatable(abc.ABC, eqx.Module): + @abc.abstractmethod + def item(self) -> float: + pass + + @abc.abstractmethod + def __add__(self, other: Self) -> Self: + pass + + +Extras: TypeAlias = Dict[str, jax.Array | Accumulatable] + try: from haliax.nn.scan import BlockFoldable except ImportError: @@ -53,5 +70,5 @@ def __call__( model: M_con, input: X, **kwargs, - ) -> tuple[hax.NamedArray, hax.NamedArray, dict]: + ) -> tuple[hax.NamedArray, hax.NamedArray, Extras]: ... diff --git a/tests/test_doremi.py b/tests/test_doremi.py index 3ad4aa9ab..e60554480 100644 --- a/tests/test_doremi.py +++ b/tests/test_doremi.py @@ -128,10 +128,11 @@ def test_estimate_mixture_weights(): ds3 = LogitDataset(W3, 0.05, x3_mask, x3_bias, key=next(keys)) # TODO: remove key as a requirement for models - def compute_loss_fn(model, example, reduction=hax.mean, reduction_axis=None, key=None): + def compute_loss_fn(model, example, key=None): del key y_pred = model(example.x) - return hax.nn.binary_cross_entropy_loss(y_pred, example.y, reduction=reduction, reduction_axis=reduction_axis) + losses = hax.nn.binary_cross_entropy_loss(y_pred, example.y, reduction=None) + return losses, hax.ones_like(losses), {} tiny_trainer_config = TrainerConfig( num_train_steps=300, diff --git a/tests/test_grad_accum.py b/tests/test_grad_accum.py index aab55c2c2..77ef4df7e 100644 --- a/tests/test_grad_accum.py +++ b/tests/test_grad_accum.py @@ -61,7 +61,7 @@ def scalar_loss_fn(mlp, x): mesh = Mesh(jax.devices(), ("data",)) - # @hax.partitioning.named_jit(axis_resources=axis_mapping) + @hax.partitioning.named_jit(axis_resources=axis_mapping) def jit_grad_accum(mlp, x): grad_fn = microbatched(loss_fn, Batch, parallelism, axis_mapping, axis_mapping) return grad_fn(mlp, x) diff --git a/tests/test_hf_gpt2_serialize.py b/tests/test_hf_gpt2_serialize.py index 9cc46ca0d..29074504c 100644 --- a/tests/test_hf_gpt2_serialize.py +++ b/tests/test_hf_gpt2_serialize.py @@ -135,7 +135,8 @@ def torch_loss(model, input_ids) -> torch.Tensor: def compute_loss(model: LmHeadModel, input_ids): example = LmExample.causal(input_ids, eos_id=converter.tokenizer.eos_token_id) - return compute_next_token_loss(model, example, key=None).scalar() + loss, where, _ = compute_next_token_loss(model, example, key=None) + return hax.mean(loss, where=where).scalar() jax_compute_grad = equinox.filter_value_and_grad(compute_loss, has_aux=False) jax_grad: Gpt2LMHeadModel