diff --git a/kauldron/evals/evaluators.py b/kauldron/evals/evaluators.py index b1c5744b..fe865e6b 100644 --- a/kauldron/evals/evaluators.py +++ b/kauldron/evals/evaluators.py @@ -251,7 +251,6 @@ def evaluate( step=step, aux=merged_aux, schedules={}, - model_with_aux=self.model_with_aux, log_summaries=True, ) return merged_aux diff --git a/kauldron/train/metric_writer.py b/kauldron/train/metric_writer.py index e506e3bb..a48c4190 100644 --- a/kauldron/train/metric_writer.py +++ b/kauldron/train/metric_writer.py @@ -165,7 +165,6 @@ def write_step_metrics( *, step: int, aux: train_step.Auxiliaries, - model_with_aux: train_step.ModelWithAux, schedules: Mapping[str, optax.Schedule], log_summaries: bool, timer: Optional[chrono_utils.Chrono] = None, @@ -201,15 +200,7 @@ def write_step_metrics( if log_summaries: with jax.transfer_guard("allow"): - # TODO(klausg): remove once all summaries are migrated to new protocol - # image summaries - image_summaries_old = { - name: summary.get_images(**aux.summary_kwargs[name]) - for name, summary in model_with_aux.summaries.items() - if isinstance(summary, summaries.ImageSummary) - } - - image_summaries = image_summaries_old | { + image_summaries = { name: value for name, value in aux_result.summary_values.items() if isinstance(value, Float["n h w #3"]) @@ -586,7 +577,6 @@ def write_step_metrics( *, step: int, aux: train_step.Auxiliaries, - model_with_aux: train_step.ModelWithAux, schedules: Mapping[str, optax.Schedule], log_summaries: bool, timer: Optional[chrono_utils.Chrono] = None, diff --git a/kauldron/train/train_lib.py b/kauldron/train/train_lib.py index dda08504..f7ac85db 100644 --- a/kauldron/train/train_lib.py +++ b/kauldron/train/train_lib.py @@ -136,7 +136,6 @@ def train_impl( step=i, aux=aux, schedules=trainer.schedules, - model_with_aux=trainstep.model_with_aux, timer=chrono, log_summaries=log_summaries, ) diff --git a/kauldron/train/train_step.py b/kauldron/train/train_step.py index d42debf4..21b61dd2 100644 --- a/kauldron/train/train_step.py +++ b/kauldron/train/train_step.py @@ -251,15 +251,10 @@ def init( # pylint:disable=missing-function-docstring @jax.named_call def forward( self, - context: context_lib.Context | None = None, + context: context_lib.Context, *, rngs: rngs_lib.Rngs, is_training: bool, - # DEPRECATED variables: Should be passed through `context` instead. - params=None, - batch=None, - step: int | None = None, - collections: _Collections | None = None, ) -> tuple[float, context_lib.Context]: """Forward pass of the model including losses. @@ -268,45 +263,12 @@ def forward( `batch`, `step`, and `collections` (and optionally `opt_state`). rngs: Random numbers to use for the forward pass. is_training: Whether to run the model in training or eval mode. - params: DEPRECATED: Should be passed through `context` instead. - batch: DEPRECATED: Should be passed through `context` instead. - step: DEPRECATED: Should be passed through `context` instead. - collections: DEPRECATED: Should be passed through `context` instead. Returns: loss_total: Total loss. context: Context with the updated `loss_total`, `loss_states`, `interms`, and `collections`. """ - # New API: pass everything through `context` - if isinstance(context, context_lib.Context): - if any(v is not None for v in (params, batch, step, collections)): - raise ValueError( - "When calling `model_with_aux.forward(context)`, you should not" - " pass `params`, `batch`,... through kwargs, but rather through the" - " context." - ) - # Should check that params, batch,... are correctly set in the context ? - else: # Legacy API (deprecated) - status.log( - "Warning: Calling `model_with_aux.forward(params)` is deprecated and" - " will be removed soon. Instead, all inputs should be passed" - " through context directly: `model_with_aux.forward(context)`." - ) - # Params can be passed either as positional or keyword arguments: - if context is None: # `forward(params=params)` - assert params is not None, "Cannot pass both `params` and `context`" - else: # `forward(params)` - assert params is None, "Cannot pass both `params` and `context`" - params = context - - context = context_lib.Context( - step=step, - batch=batch, - params=params, - collections=collections, - ) - del params, batch, step, collections args, kwargs = data_utils.get_model_inputs(self.model, context) preds, collections = self.model.apply( {"params": context.params} | context.collections,