diff --git a/kauldron/evals/evaluators.py b/kauldron/evals/evaluators.py index b1c5744b..fe865e6b 100644 --- a/kauldron/evals/evaluators.py +++ b/kauldron/evals/evaluators.py @@ -251,7 +251,6 @@ def evaluate( step=step, aux=merged_aux, schedules={}, - model_with_aux=self.model_with_aux, log_summaries=True, ) return merged_aux diff --git a/kauldron/train/__init__.py b/kauldron/train/__init__.py index 80eecf4e..93babd7e 100644 --- a/kauldron/train/__init__.py +++ b/kauldron/train/__init__.py @@ -21,7 +21,8 @@ from kauldron.train.setup_utils import Setup from kauldron.train.setup_utils import TqdmInfo from kauldron.train.train_step import Auxiliaries -from kauldron.train.train_step import ModelWithAux +from kauldron.train.train_step import forward +from kauldron.train.train_step import forward_with_loss from kauldron.train.train_step import TrainState from kauldron.train.train_step import TrainStep from kauldron.train.trainer_lib import Trainer diff --git a/kauldron/train/metric_writer.py b/kauldron/train/metric_writer.py index e506e3bb..a48c4190 100644 --- a/kauldron/train/metric_writer.py +++ b/kauldron/train/metric_writer.py @@ -165,7 +165,6 @@ def write_step_metrics( *, step: int, aux: train_step.Auxiliaries, - model_with_aux: train_step.ModelWithAux, schedules: Mapping[str, optax.Schedule], log_summaries: bool, timer: Optional[chrono_utils.Chrono] = None, @@ -201,15 +200,7 @@ def write_step_metrics( if log_summaries: with jax.transfer_guard("allow"): - # TODO(klausg): remove once all summaries are migrated to new protocol - # image summaries - image_summaries_old = { - name: summary.get_images(**aux.summary_kwargs[name]) - for name, summary in model_with_aux.summaries.items() - if isinstance(summary, summaries.ImageSummary) - } - - image_summaries = image_summaries_old | { + image_summaries = { name: value for name, value in aux_result.summary_values.items() if isinstance(value, Float["n h w #3"]) @@ -586,7 +577,6 @@ def write_step_metrics( *, step: int, aux: train_step.Auxiliaries, - model_with_aux: train_step.ModelWithAux, schedules: Mapping[str, optax.Schedule], log_summaries: bool, timer: Optional[chrono_utils.Chrono] = None, diff --git a/kauldron/train/train_lib.py b/kauldron/train/train_lib.py index dda08504..f7ac85db 100644 --- a/kauldron/train/train_lib.py +++ b/kauldron/train/train_lib.py @@ -136,7 +136,6 @@ def train_impl( step=i, aux=aux, schedules=trainer.schedules, - model_with_aux=trainstep.model_with_aux, timer=chrono, log_summaries=log_summaries, ) diff --git a/kauldron/train/train_step.py b/kauldron/train/train_step.py index d42debf4..1b888ee9 100644 --- a/kauldron/train/train_step.py +++ b/kauldron/train/train_step.py @@ -41,7 +41,6 @@ from kauldron.utils import train_property # pylint: disable=unused-import from kauldron.utils.kdash import dashboard_utils from kauldron.utils.sharding_utils import sharding as sharding_lib # pylint: disable=g-importing-member -from kauldron.utils.status_utils import status # pylint: disable=g-importing-member import optax # Do not import `trainer_lib` at runtime to avoid circular imports @@ -77,6 +76,10 @@ def replace(self, **changes: Any) -> TrainState: return dataclasses.replace(self, **changes) +# TODO(epot): Move auxiliaries to a separate file (`Auxiliaries`, +# `AuxiliariesOutput`, `AuxiliariesRef`) + + @flax.struct.dataclass class Auxiliaries: """Auxiliaries (intermediate states to be accumulated).""" @@ -213,11 +216,11 @@ def _gather_kwargs_with_reraise(k, summary, context): return summary.gather_kwargs(context) +# TODO(epot): Not sure about the name. @klausg any ideas ? @dataclasses.dataclass(kw_only=True, eq=True, frozen=True) -class ModelWithAux(config_util.UpdateFromRootCfg): - """Wrapper around model which also compute the summaries and metrics.""" +class AuxiliariesRef(config_util.UpdateFromRootCfg): + """Wrapper around the losses, summaries and metrics.""" - model: nn.Module = config_util.ROOT_CFG_REF.model losses: Mapping[str, kd_losses.Loss] = config_util.ROOT_CFG_REF.train_losses metrics: Mapping[str, kd_metrics.Metric] = ( config_util.ROOT_CFG_REF.train_metrics @@ -226,117 +229,8 @@ class ModelWithAux(config_util.UpdateFromRootCfg): config_util.ROOT_CFG_REF.train_summaries ) - def init( # pylint:disable=missing-function-docstring - self, - init_rngs: rngs_lib.Rngs, - batch: PyTree[jax.Array], - model_method: Optional[str] = None, - ) -> tuple[_Params, _Collections]: - self._assert_root_cfg_resolved() - args, kwargs = data_utils.get_model_inputs_from_batch(self.model, batch) - collections = self.model.init( - init_rngs, - *args, - method=model_method, - is_training_property=True, - capture_intermediates=True, - **kwargs, - ) - collections = flax.core.unfreeze(collections) - params = collections.pop("params", {}) - collections.pop("intermediates", None) # Remove intermediates - - return params, collections - - @jax.named_call - def forward( - self, - context: context_lib.Context | None = None, - *, - rngs: rngs_lib.Rngs, - is_training: bool, - # DEPRECATED variables: Should be passed through `context` instead. - params=None, - batch=None, - step: int | None = None, - collections: _Collections | None = None, - ) -> tuple[float, context_lib.Context]: - """Forward pass of the model including losses. - - Arguments: - context: Context to use for the forward pass. Should contain `params`, - `batch`, `step`, and `collections` (and optionally `opt_state`). - rngs: Random numbers to use for the forward pass. - is_training: Whether to run the model in training or eval mode. - params: DEPRECATED: Should be passed through `context` instead. - batch: DEPRECATED: Should be passed through `context` instead. - step: DEPRECATED: Should be passed through `context` instead. - collections: DEPRECATED: Should be passed through `context` instead. - - Returns: - loss_total: Total loss. - context: Context with the updated `loss_total`, `loss_states`, - `interms`, and `collections`. - """ - # New API: pass everything through `context` - if isinstance(context, context_lib.Context): - if any(v is not None for v in (params, batch, step, collections)): - raise ValueError( - "When calling `model_with_aux.forward(context)`, you should not" - " pass `params`, `batch`,... through kwargs, but rather through the" - " context." - ) - # Should check that params, batch,... are correctly set in the context ? - else: # Legacy API (deprecated) - status.log( - "Warning: Calling `model_with_aux.forward(params)` is deprecated and" - " will be removed soon. Instead, all inputs should be passed" - " through context directly: `model_with_aux.forward(context)`." - ) - # Params can be passed either as positional or keyword arguments: - if context is None: # `forward(params=params)` - assert params is not None, "Cannot pass both `params` and `context`" - else: # `forward(params)` - assert params is None, "Cannot pass both `params` and `context`" - params = context - - context = context_lib.Context( - step=step, - batch=batch, - params=params, - collections=collections, - ) - del params, batch, step, collections - args, kwargs = data_utils.get_model_inputs(self.model, context) - preds, collections = self.model.apply( - {"params": context.params} | context.collections, - *args, - rngs=rngs, - mutable=True, - capture_intermediates=True, # TODO(klausg): check if need a filter here - is_training_property=is_training, - **kwargs, - ) - # Note the params can be mutable if the model call the same sub-model - # internally but with different params. However, the updates are never - # propagated - collections.pop("params", None) - interms = collections.pop("intermediates") - context = context.replace( - preds=preds, - interms=interms, - collections=collections, - ) - loss_total, loss_states = kd_losses.compute_losses( - losses=self.losses, context=context - ) - return loss_total, context.replace( - loss_states=loss_states, - loss_total=loss_total, - ) - @jax.named_call - def get_aux( + def compute( self, context: context_lib.Context, *, @@ -348,6 +242,13 @@ def get_aux( """Get auxilaries.""" aux = Auxiliaries() if return_losses: + # TODO(epot): Cleanup loss-states: + # * Re-compute the states here if `context.loss_states` is None (e.g. + # if in eval) + # * Split `kd/losses/base:compute_losses` into `get_state` and + # `compute_losses(loss_states) -> float` + # * Unify all the `m.get_state_from_context` patterns for metrics, + # summaries, and losses. aux = aux.replace(loss_states=context.loss_states) if return_metrics: @@ -381,17 +282,22 @@ def _get_summary_state(summary): @dataclasses.dataclass(kw_only=True, eq=True, frozen=True) class TrainStep(config_util.UpdateFromRootCfg): - """Training Step.""" + """Base Training Step. + + Subclasses can overwrite the `_step` method to implement custom training + steps. + """ - model_with_aux: ModelWithAux = dataclasses.field(default_factory=ModelWithAux) + model: nn.Module = config_util.ROOT_CFG_REF.model optimizer: optax.GradientTransformation = config_util.ROOT_CFG_REF.optimizer rng_streams: rngs_lib.RngStreams = config_util.ROOT_CFG_REF.rng_streams sharding: sharding_lib.ShardingStrategy = config_util.ROOT_CFG_REF.sharding init_transforms: Mapping[str, partial_loader.AbstractPartialLoader] = ( config_util.ROOT_CFG_REF.init_transforms ) + aux: AuxiliariesRef = dataclasses.field(default_factory=AuxiliariesRef) - __root_cfg_fields_to_recurse__ = ("model_with_aux",) + __root_cfg_fields_to_recurse__ = ("aux",) def init( self, @@ -437,11 +343,19 @@ def _init_model( ) -> TrainState: """Initialize the model and return the initial TrainState.""" batch = data_utils.mock_batch_from_elem_spec(elem_spec, self.sharding.ds) - params, collections = self.model_with_aux.init( + args, kwargs = data_utils.get_model_inputs_from_batch(self.model, batch) + collections = self.model.init( self.rng_streams.init_rngs(), - batch, - model_method=model_method, + *args, + method=model_method, + is_training_property=True, + capture_intermediates=True, + **kwargs, ) + collections = flax.core.unfreeze(collections) + params = collections.pop("params", {}) + collections.pop("intermediates", None) # Remove intermediates + state = TrainState( # pytype: disable=wrong-arg-types step=jnp.asarray(0), params=params, @@ -532,7 +446,7 @@ def _step( # NOTE: ensure that evaluation metrics are computed from the OLD model state # *before* backprop gradients are applied. grad_fn = jax.grad( - self.model_with_aux.forward, + forward_with_loss, argnums=0, has_aux=True, allow_int=True, @@ -542,6 +456,8 @@ def _step( context = context_lib.Context.from_state_and_batch(state=state, batch=batch) context_grads, context = grad_fn( context, + model=self.model, + losses=self.aux.losses, rngs=self.rng_streams.train_rngs(state.step), is_training=True, ) @@ -566,7 +482,7 @@ def _step( opt_state=state.opt_state, ) - aux = self.model_with_aux.get_aux( + aux = self.aux.compute( context, return_losses=return_losses, return_metrics=return_metrics, @@ -577,3 +493,108 @@ def _step( (next_state, aux), (self.sharding.state, self.sharding.aux), ) + + +def forward( + context: context_lib.Context, + *, + model: nn.Module, + rngs: rngs_lib.Rngs, + is_training: bool, +) -> context_lib.Context: + """Forward pass of the model. + + Arguments: + context: Context to use for the forward pass. Should contain `params`, + `batch`, `step`, and `collections` (and optionally `opt_state`). + model: Model to use for the forward pass. + rngs: Random numbers to use for the forward pass. + is_training: Whether to run the model in training or eval mode. + + Returns: + loss_total: Total loss. + context: Context with the updated `loss_total`, `loss_states`, + `interms`, and `collections`. + """ + args, kwargs = data_utils.get_model_inputs(model, context) + preds, collections = model.apply( + {"params": context.params} | context.collections, + *args, + rngs=rngs, + mutable=True, + capture_intermediates=True, # TODO(klausg): check if need a filter here + is_training_property=is_training, + **kwargs, + ) + # Note the params can be mutable if the model call the same sub-model + # internally but with different params. However, the updates are never + # propagated + collections.pop("params", None) + interms = collections.pop("intermediates") + context = context.replace( + preds=preds, + interms=interms, + collections=collections, + ) + return context + + +def forward_with_loss( + context: context_lib.Context, + *, + model: nn.Module, + losses: Mapping[str, kd_losses.Loss], + rngs: rngs_lib.Rngs, + is_training: bool, +) -> tuple[float, context_lib.Context]: + """Forward pass of the model, including losses. + + Arguments: + context: Context to use for the forward pass. Should contain `params`, + `batch`, `step`, and `collections` (and optionally `opt_state`). + model: Model to use for the forward pass. + losses: Losses to compute. + rngs: Random numbers to use for the forward pass. + is_training: Whether to run the model in training or eval mode. + + Returns: + loss_total: Total loss. + context: Context with the updated `loss_total`, `loss_states`, + `interms`, and `collections`. + """ + context = forward( + context=context, + model=model, + rngs=rngs, + is_training=is_training, + ) + loss_total, loss_states = kd_losses.compute_losses( + losses=losses, context=context + ) + return loss_total, context.replace( + loss_states=loss_states, + loss_total=loss_total, + ) + + +@dataclasses.dataclass(kw_only=True, eq=True, frozen=True) +class ModelWithAux(AuxiliariesRef): + """Model with aux. + + DEPRECATED: Do not use. + """ + + # TODO(epot): Deprecate this class in eval. + + model: nn.Module + + def forward(self, context, **kwargs): + return forward_with_loss( + context=context, + model=self.model, + losses=self.losses, + **kwargs, + ) + + def get_aux(self, context, **kwargs): + return self.compute(context, **kwargs) diff --git a/kauldron/train/trainer_lib.py b/kauldron/train/trainer_lib.py index 342d17f4..7f88069c 100644 --- a/kauldron/train/trainer_lib.py +++ b/kauldron/train/trainer_lib.py @@ -396,7 +396,6 @@ def context_specs(self) -> context_lib.Context: elem_spec = self.train_ds.element_spec elem_sharding = self.sharding.ds rngs = self.rng_streams.init_rngs() - mwa = self.trainstep.model_with_aux state_specs = self.state_specs @@ -407,7 +406,17 @@ def context_specs(self) -> context_lib.Context: batch=m_batch, ) _, context = jax.eval_shape( - functools.partial(mwa.forward, is_training=True), + # TODO(epot): Instead add an option for `trainer.trainstep.step` to + # return the `context`. For example, could simplify the `._step` to + # always compute all summaries and return `context`, and only select + # the subset of metrics to write in `step` (so `return_summaries=`,... + # would not be propagated to `._step()` but only in `.step()`) + functools.partial( + train_step.forward_with_loss, + is_training=True, + model=self.model, + losses=self.train_losses, + ), context, rngs=rngs, )