Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove deprecated arguments #949

Merged
merged 1 commit into from
Nov 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 1 addition & 40 deletions kauldron/train/train_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from kauldron.utils import train_property # pylint: disable=unused-import
from kauldron.utils.kdash import dashboard_utils
from kauldron.utils.sharding_utils import sharding as sharding_lib # pylint: disable=g-importing-member
from kauldron.utils.status_utils import status # pylint: disable=g-importing-member
import optax

# Do not import `trainer_lib` at runtime to avoid circular imports
Expand Down Expand Up @@ -251,15 +250,10 @@ def init( # pylint:disable=missing-function-docstring
@jax.named_call
def forward(
self,
context: context_lib.Context | None = None,
context: context_lib.Context,
*,
rngs: rngs_lib.Rngs,
is_training: bool,
# DEPRECATED variables: Should be passed through `context` instead.
params=None,
batch=None,
step: int | None = None,
collections: _Collections | None = None,
) -> tuple[float, context_lib.Context]:
"""Forward pass of the model including losses.

Expand All @@ -268,45 +262,12 @@ def forward(
`batch`, `step`, and `collections` (and optionally `opt_state`).
rngs: Random numbers to use for the forward pass.
is_training: Whether to run the model in training or eval mode.
params: DEPRECATED: Should be passed through `context` instead.
batch: DEPRECATED: Should be passed through `context` instead.
step: DEPRECATED: Should be passed through `context` instead.
collections: DEPRECATED: Should be passed through `context` instead.

Returns:
loss_total: Total loss.
context: Context with the updated `loss_total`, `loss_states`,
`interms`, and `collections`.
"""
# New API: pass everything through `context`
if isinstance(context, context_lib.Context):
if any(v is not None for v in (params, batch, step, collections)):
raise ValueError(
"When calling `model_with_aux.forward(context)`, you should not"
" pass `params`, `batch`,... through kwargs, but rather through the"
" context."
)
# Should check that params, batch,... are correctly set in the context ?
else: # Legacy API (deprecated)
status.log(
"Warning: Calling `model_with_aux.forward(params)` is deprecated and"
" will be removed soon. Instead, all inputs should be passed"
" through context directly: `model_with_aux.forward(context)`."
)
# Params can be passed either as positional or keyword arguments:
if context is None: # `forward(params=params)`
assert params is not None, "Cannot pass both `params` and `context`"
else: # `forward(params)`
assert params is None, "Cannot pass both `params` and `context`"
params = context

context = context_lib.Context(
step=step,
batch=batch,
params=params,
collections=collections,
)
del params, batch, step, collections
args, kwargs = data_utils.get_model_inputs(self.model, context)
preds, collections = self.model.apply(
{"params": context.params} | context.collections,
Expand Down