From 20113b0bbfc5590ed508a2b670514be612ade8e9 Mon Sep 17 00:00:00 2001 From: Etienne Pot Date: Sat, 23 Nov 2024 03:20:37 -0800 Subject: [PATCH] Have `._step` returns the full `Context` PiperOrigin-RevId: 699444877 --- kauldron/evals/evaluators.py | 1 - kauldron/train/__init__.py | 4 +- kauldron/train/context.py | 21 ++ kauldron/train/metric_writer.py | 12 +- kauldron/train/train_lib.py | 1 - kauldron/train/train_step.py | 364 ++++++++++++++++---------------- kauldron/train/trainer_lib.py | 13 +- 7 files changed, 213 insertions(+), 203 deletions(-) diff --git a/kauldron/evals/evaluators.py b/kauldron/evals/evaluators.py index b1c5744b..fe865e6b 100644 --- a/kauldron/evals/evaluators.py +++ b/kauldron/evals/evaluators.py @@ -251,7 +251,6 @@ def evaluate( step=step, aux=merged_aux, schedules={}, - model_with_aux=self.model_with_aux, log_summaries=True, ) return merged_aux diff --git a/kauldron/train/__init__.py b/kauldron/train/__init__.py index 80eecf4e..d53fbc09 100644 --- a/kauldron/train/__init__.py +++ b/kauldron/train/__init__.py @@ -21,7 +21,9 @@ from kauldron.train.setup_utils import Setup from kauldron.train.setup_utils import TqdmInfo from kauldron.train.train_step import Auxiliaries -from kauldron.train.train_step import ModelWithAux +from kauldron.train.train_step import AuxiliariesRef +from kauldron.train.train_step import forward +from kauldron.train.train_step import forward_with_loss from kauldron.train.train_step import TrainState from kauldron.train.train_step import TrainStep from kauldron.train.trainer_lib import Trainer diff --git a/kauldron/train/context.py b/kauldron/train/context.py index 767f8c41..7ca53211 100644 --- a/kauldron/train/context.py +++ b/kauldron/train/context.py @@ -62,6 +62,8 @@ class Context: opt_state: The state of the optimizer prior to the update. (available after the backward pass, e.g. for metrics). The old state is chosen to be consistent with parameters which are also pre-update. + metric_states: The states of the metrics (after the backward pass) + summary_states: The states of the summaries (after the backward pass) """ # These are always available: @@ -80,6 +82,9 @@ class Context: grads: Any = None updates: Any = None opt_state: Any = None + # Become available after the metrics computation + metric_states: Any = None + summary_states: Any = None replace = dataclasses.replace @@ -100,3 +105,19 @@ def from_state_and_batch( def flatten(self) -> dict[str, Any]: return kontext.flatten_with_path(self) + + def get_aux( + self, + *, + return_losses: bool = False, + return_metrics: bool = False, + return_summaries: bool = False, + ) -> train_step.Auxiliaries: + """Returns the auxiliaries for the step.""" + from kauldron.train import train_step # pylint: disable=g-import-not-at-top + + return train_step.Auxiliaries( + loss_states=self.loss_states if return_losses else None, + metric_states=self.metric_states if return_metrics else None, + summary_states=self.summary_states if return_summaries else None, + ) diff --git a/kauldron/train/metric_writer.py b/kauldron/train/metric_writer.py index e506e3bb..a48c4190 100644 --- a/kauldron/train/metric_writer.py +++ b/kauldron/train/metric_writer.py @@ -165,7 +165,6 @@ def write_step_metrics( *, step: int, aux: train_step.Auxiliaries, - model_with_aux: train_step.ModelWithAux, schedules: Mapping[str, optax.Schedule], log_summaries: bool, timer: Optional[chrono_utils.Chrono] = None, @@ -201,15 +200,7 @@ def write_step_metrics( if log_summaries: with jax.transfer_guard("allow"): - # TODO(klausg): remove once all summaries are migrated to new protocol - # image summaries - image_summaries_old = { - name: summary.get_images(**aux.summary_kwargs[name]) - for name, summary in model_with_aux.summaries.items() - if isinstance(summary, summaries.ImageSummary) - } - - image_summaries = image_summaries_old | { + image_summaries = { name: value for name, value in aux_result.summary_values.items() if isinstance(value, Float["n h w #3"]) @@ -586,7 +577,6 @@ def write_step_metrics( *, step: int, aux: train_step.Auxiliaries, - model_with_aux: train_step.ModelWithAux, schedules: Mapping[str, optax.Schedule], log_summaries: bool, timer: Optional[chrono_utils.Chrono] = None, diff --git a/kauldron/train/train_lib.py b/kauldron/train/train_lib.py index dda08504..f7ac85db 100644 --- a/kauldron/train/train_lib.py +++ b/kauldron/train/train_lib.py @@ -136,7 +136,6 @@ def train_impl( step=i, aux=aux, schedules=trainer.schedules, - model_with_aux=trainstep.model_with_aux, timer=chrono, log_summaries=log_summaries, ) diff --git a/kauldron/train/train_step.py b/kauldron/train/train_step.py index d42debf4..4534d97c 100644 --- a/kauldron/train/train_step.py +++ b/kauldron/train/train_step.py @@ -41,7 +41,6 @@ from kauldron.utils import train_property # pylint: disable=unused-import from kauldron.utils.kdash import dashboard_utils from kauldron.utils.sharding_utils import sharding as sharding_lib # pylint: disable=g-importing-member -from kauldron.utils.status_utils import status # pylint: disable=g-importing-member import optax # Do not import `trainer_lib` at runtime to avoid circular imports @@ -77,6 +76,10 @@ def replace(self, **changes: Any) -> TrainState: return dataclasses.replace(self, **changes) +# TODO(epot): Move auxiliaries to a separate file (`Auxiliaries`, +# `AuxiliariesOutput`, `AuxiliariesRef`) + + @flax.struct.dataclass class Auxiliaries: """Auxiliaries (intermediate states to be accumulated).""" @@ -213,11 +216,11 @@ def _gather_kwargs_with_reraise(k, summary, context): return summary.gather_kwargs(context) +# TODO(epot): Not sure about the name. @klausg any ideas ? @dataclasses.dataclass(kw_only=True, eq=True, frozen=True) -class ModelWithAux(config_util.UpdateFromRootCfg): - """Wrapper around model which also compute the summaries and metrics.""" +class AuxiliariesRef(config_util.UpdateFromRootCfg): + """Wrapper around the losses, summaries and metrics.""" - model: nn.Module = config_util.ROOT_CFG_REF.model losses: Mapping[str, kd_losses.Loss] = config_util.ROOT_CFG_REF.train_losses metrics: Mapping[str, kd_metrics.Metric] = ( config_util.ROOT_CFG_REF.train_metrics @@ -226,172 +229,50 @@ class ModelWithAux(config_util.UpdateFromRootCfg): config_util.ROOT_CFG_REF.train_summaries ) - def init( # pylint:disable=missing-function-docstring - self, - init_rngs: rngs_lib.Rngs, - batch: PyTree[jax.Array], - model_method: Optional[str] = None, - ) -> tuple[_Params, _Collections]: - self._assert_root_cfg_resolved() - args, kwargs = data_utils.get_model_inputs_from_batch(self.model, batch) - collections = self.model.init( - init_rngs, - *args, - method=model_method, - is_training_property=True, - capture_intermediates=True, - **kwargs, - ) - collections = flax.core.unfreeze(collections) - params = collections.pop("params", {}) - collections.pop("intermediates", None) # Remove intermediates - - return params, collections - @jax.named_call - def forward( - self, - context: context_lib.Context | None = None, - *, - rngs: rngs_lib.Rngs, - is_training: bool, - # DEPRECATED variables: Should be passed through `context` instead. - params=None, - batch=None, - step: int | None = None, - collections: _Collections | None = None, - ) -> tuple[float, context_lib.Context]: - """Forward pass of the model including losses. - - Arguments: - context: Context to use for the forward pass. Should contain `params`, - `batch`, `step`, and `collections` (and optionally `opt_state`). - rngs: Random numbers to use for the forward pass. - is_training: Whether to run the model in training or eval mode. - params: DEPRECATED: Should be passed through `context` instead. - batch: DEPRECATED: Should be passed through `context` instead. - step: DEPRECATED: Should be passed through `context` instead. - collections: DEPRECATED: Should be passed through `context` instead. + def compute(self, context: context_lib.Context) -> context_lib.Context: + """Get auxilaries.""" - Returns: - loss_total: Total loss. - context: Context with the updated `loss_total`, `loss_states`, - `interms`, and `collections`. - """ - # New API: pass everything through `context` - if isinstance(context, context_lib.Context): - if any(v is not None for v in (params, batch, step, collections)): - raise ValueError( - "When calling `model_with_aux.forward(context)`, you should not" - " pass `params`, `batch`,... through kwargs, but rather through the" - " context." - ) - # Should check that params, batch,... are correctly set in the context ? - else: # Legacy API (deprecated) - status.log( - "Warning: Calling `model_with_aux.forward(params)` is deprecated and" - " will be removed soon. Instead, all inputs should be passed" - " through context directly: `model_with_aux.forward(context)`." - ) - # Params can be passed either as positional or keyword arguments: - if context is None: # `forward(params=params)` - assert params is not None, "Cannot pass both `params` and `context`" - else: # `forward(params)` - assert params is None, "Cannot pass both `params` and `context`" - params = context - - context = context_lib.Context( - step=step, - batch=batch, - params=params, - collections=collections, - ) - del params, batch, step, collections - args, kwargs = data_utils.get_model_inputs(self.model, context) - preds, collections = self.model.apply( - {"params": context.params} | context.collections, - *args, - rngs=rngs, - mutable=True, - capture_intermediates=True, # TODO(klausg): check if need a filter here - is_training_property=is_training, - **kwargs, - ) - # Note the params can be mutable if the model call the same sub-model - # internally but with different params. However, the updates are never - # propagated - collections.pop("params", None) - interms = collections.pop("intermediates") - context = context.replace( - preds=preds, - interms=interms, - collections=collections, - ) - loss_total, loss_states = kd_losses.compute_losses( - losses=self.losses, context=context + # TODO(epot): Cleanup loss-states: + # * Re-compute the states here if `context.loss_states` is None (e.g. + # if in eval) + # * Split `kd/losses/base:compute_losses` into `get_state` and + # `compute_losses(loss_states) -> float` + # * Unify all the `m.get_state_from_context` patterns for metrics, + # summaries, and losses. + + metric_states = jax.tree.map( + lambda m: m.get_state_from_context(context), self.metrics ) - return loss_total, context.replace( - loss_states=loss_states, - loss_total=loss_total, + summary_states = jax.tree.map( + lambda m: m.get_state_from_context(context), self.summaries ) - @jax.named_call - def get_aux( - self, - context: context_lib.Context, - *, - # TODO(epot): Better signature - return_losses: bool = False, - return_metrics: bool = False, - return_summaries: bool = False, - ) -> Auxiliaries: - """Get auxilaries.""" - aux = Auxiliaries() - if return_losses: - aux = aux.replace(loss_states=context.loss_states) - - if return_metrics: - aux = aux.replace( - metric_states=jax.tree.map( - lambda m: m.get_state_from_context(context), self.metrics - ) - ) - - if return_summaries: - # TODO(klausg): remove legacy summaries protocol once all are migrated - # legacy summaries protocol: - aux = aux.replace( - summary_kwargs={ - k: _gather_kwargs_with_reraise(k, summary, context) - for k, summary in self.summaries.items() - } - ) - # new summaries as metrics protocol: - def _get_summary_state(summary): - if isinstance(summary, kd_metrics.Metric): - return summary.get_state_from_context(context) - else: - return kd_metrics.EmptyState() - - aux = aux.replace( - summary_states=jax.tree.map(_get_summary_state, self.summaries) - ) - return aux + return dataclasses.replace( + context, + metric_states=metric_states, + summary_states=summary_states, + ) @dataclasses.dataclass(kw_only=True, eq=True, frozen=True) class TrainStep(config_util.UpdateFromRootCfg): - """Training Step.""" + """Base Training Step. + + Subclasses can overwrite the `_step` method to implement custom training + steps. + """ - model_with_aux: ModelWithAux = dataclasses.field(default_factory=ModelWithAux) + model: nn.Module = config_util.ROOT_CFG_REF.model optimizer: optax.GradientTransformation = config_util.ROOT_CFG_REF.optimizer rng_streams: rngs_lib.RngStreams = config_util.ROOT_CFG_REF.rng_streams sharding: sharding_lib.ShardingStrategy = config_util.ROOT_CFG_REF.sharding init_transforms: Mapping[str, partial_loader.AbstractPartialLoader] = ( config_util.ROOT_CFG_REF.init_transforms ) + aux: AuxiliariesRef = dataclasses.field(default_factory=AuxiliariesRef) - __root_cfg_fields_to_recurse__ = ("model_with_aux",) + __root_cfg_fields_to_recurse__ = ("aux",) def init( self, @@ -437,11 +318,19 @@ def _init_model( ) -> TrainState: """Initialize the model and return the initial TrainState.""" batch = data_utils.mock_batch_from_elem_spec(elem_spec, self.sharding.ds) - params, collections = self.model_with_aux.init( + args, kwargs = data_utils.get_model_inputs_from_batch(self.model, batch) + collections = self.model.init( self.rng_streams.init_rngs(), - batch, - model_method=model_method, + *args, + method=model_method, + is_training_property=True, + capture_intermediates=True, + **kwargs, ) + collections = flax.core.unfreeze(collections) + params = collections.pop("params", {}) + collections.pop("intermediates", None) # Remove intermediates + state = TrainState( # pytype: disable=wrong-arg-types step=jnp.asarray(0), params=params, @@ -497,42 +386,44 @@ def step( ] = frozenset(), ) -> tuple[TrainState, Auxiliaries]: """Training step: forward, losses, gradients, update, and metrics.""" + # This function is just a small wrapper around `_step` for: + # * Checkify errors handling + # * Select which auxiliaries metrics to return. + # * Sharding + # If reading the code, you can likely skip this function and go directly + # to `_step`. + if checkify_error_categories: step_fn = checkify.checkify(self._step, errors=checkify_error_categories) - error, (state, aux) = step_fn( - state, - batch, - return_losses=return_losses, - return_metrics=return_metrics, - return_summaries=return_summaries, - ) - aux = aux.replace(error=error) + error, (state, ctx) = step_fn(state, batch) else: - state, aux = self._step( - state, - batch, - return_losses=return_losses, - return_metrics=return_metrics, - return_summaries=return_summaries, - ) + error = None + state, ctx = self._step(state, batch) - return state, aux + # TODO(epot): More flexible way to select the subset of context to return. + # Should also have a way to return the full context. + aux = ctx.get_aux( + return_losses=return_losses, + return_metrics=return_metrics, + return_summaries=return_summaries, + ) + aux = aux.replace(error=error) + return sharding_lib.with_sharding_constraint( + (state, aux), + (self.sharding.state, self.sharding.aux), + ) def _step( self, state: TrainState, batch: PyTree[Any], - *, - return_losses: bool = False, - return_metrics: bool = False, - return_summaries: bool = False - ) -> tuple[TrainState, Auxiliaries]: + ) -> tuple[TrainState, context_lib.Context]: """Training step to be wrapped by checkify and called by `step`.""" # TODO(epot): Should `jax.named_call` be moved downstream directly in optax? # NOTE: ensure that evaluation metrics are computed from the OLD model state # *before* backprop gradients are applied. grad_fn = jax.grad( - self.model_with_aux.forward, + forward_with_loss, argnums=0, has_aux=True, allow_int=True, @@ -542,6 +433,8 @@ def _step( context = context_lib.Context.from_state_and_batch(state=state, batch=batch) context_grads, context = grad_fn( context, + model=self.model, + losses=self.aux.losses, rngs=self.rng_streams.train_rngs(state.step), is_training=True, ) @@ -566,14 +459,111 @@ def _step( opt_state=state.opt_state, ) - aux = self.model_with_aux.get_aux( - context, - return_losses=return_losses, - return_metrics=return_metrics, - return_summaries=return_summaries, - ) + context = self.aux.compute(context) - return sharding_lib.with_sharding_constraint( - (next_state, aux), - (self.sharding.state, self.sharding.aux), + return next_state, context + + +def forward( + context: context_lib.Context, + *, + model: nn.Module, + rngs: rngs_lib.Rngs, + is_training: bool, +) -> context_lib.Context: + """Forward pass of the model. + + Arguments: + context: Context to use for the forward pass. Should contain `params`, + `batch`, `step`, and `collections` (and optionally `opt_state`). + model: Model to use for the forward pass. + rngs: Random numbers to use for the forward pass. + is_training: Whether to run the model in training or eval mode. + + Returns: + loss_total: Total loss. + context: Context with the updated `loss_total`, `loss_states`, + `interms`, and `collections`. + """ + args, kwargs = data_utils.get_model_inputs(model, context) + preds, collections = model.apply( + {"params": context.params} | context.collections, + *args, + rngs=rngs, + mutable=True, + capture_intermediates=True, # TODO(klausg): check if need a filter here + is_training_property=is_training, + **kwargs, + ) + # Note the params can be mutable if the model call the same sub-model + # internally but with different params. However, the updates are never + # propagated + collections.pop("params", None) + interms = collections.pop("intermediates") + context = context.replace( + preds=preds, + interms=interms, + collections=collections, + ) + return context + + +def forward_with_loss( + context: context_lib.Context, + *, + model: nn.Module, + losses: Mapping[str, kd_losses.Loss], + rngs: rngs_lib.Rngs, + is_training: bool, +) -> tuple[float, context_lib.Context]: + """Forward pass of the model, including losses. + + Arguments: + context: Context to use for the forward pass. Should contain `params`, + `batch`, `step`, and `collections` (and optionally `opt_state`). + model: Model to use for the forward pass. + losses: Losses to compute. + rngs: Random numbers to use for the forward pass. + is_training: Whether to run the model in training or eval mode. + + Returns: + loss_total: Total loss. + context: Context with the updated `loss_total`, `loss_states`, + `interms`, and `collections`. + """ + context = forward( + context=context, + model=model, + rngs=rngs, + is_training=is_training, + ) + loss_total, loss_states = kd_losses.compute_losses( + losses=losses, context=context + ) + return loss_total, context.replace( + loss_states=loss_states, + loss_total=loss_total, + ) + + +@dataclasses.dataclass(kw_only=True, eq=True, frozen=True) +class ModelWithAux(AuxiliariesRef): + """Model with aux. + + DEPRECATED: Do not use. + """ + + # TODO(epot): Deprecate this class in eval. + + model: nn.Module + + def forward(self, context, **kwargs): + return forward_with_loss( + context=context, + model=self.model, + losses=self.losses, + **kwargs, ) + + def get_aux(self, context, **kwargs): + return self.compute(context).get_aux(**kwargs) diff --git a/kauldron/train/trainer_lib.py b/kauldron/train/trainer_lib.py index 342d17f4..7f88069c 100644 --- a/kauldron/train/trainer_lib.py +++ b/kauldron/train/trainer_lib.py @@ -396,7 +396,6 @@ def context_specs(self) -> context_lib.Context: elem_spec = self.train_ds.element_spec elem_sharding = self.sharding.ds rngs = self.rng_streams.init_rngs() - mwa = self.trainstep.model_with_aux state_specs = self.state_specs @@ -407,7 +406,17 @@ def context_specs(self) -> context_lib.Context: batch=m_batch, ) _, context = jax.eval_shape( - functools.partial(mwa.forward, is_training=True), + # TODO(epot): Instead add an option for `trainer.trainstep.step` to + # return the `context`. For example, could simplify the `._step` to + # always compute all summaries and return `context`, and only select + # the subset of metrics to write in `step` (so `return_summaries=`,... + # would not be propagated to `._step()` but only in `.step()`) + functools.partial( + train_step.forward_with_loss, + is_training=True, + model=self.model, + losses=self.train_losses, + ), context, rngs=rngs, )