Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use public aliases to define MultiTrainStep #953

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion kauldron/evals/evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,6 @@ def evaluate(
step=step,
aux=merged_aux,
schedules={},
model_with_aux=self.model_with_aux,
log_summaries=True,
)
return merged_aux
Expand Down
3 changes: 2 additions & 1 deletion kauldron/train/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from kauldron.train.setup_utils import Setup
from kauldron.train.setup_utils import TqdmInfo
from kauldron.train.train_step import Auxiliaries
from kauldron.train.train_step import ModelWithAux
from kauldron.train.train_step import forward
from kauldron.train.train_step import forward_with_loss
from kauldron.train.train_step import TrainState
from kauldron.train.train_step import TrainStep
from kauldron.train.trainer_lib import Trainer
Expand Down
5 changes: 5 additions & 0 deletions kauldron/train/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ class Context:
opt_state: The state of the optimizer prior to the update. (available after
the backward pass, e.g. for metrics). The old state is chosen to be
consistent with parameters which are also pre-update.
metric_states: The states of the metrics (after the backward pass)
summary_states: The states of the summaries (after the backward pass)
"""

# These are always available:
Expand All @@ -80,6 +82,9 @@ class Context:
grads: Any = None
updates: Any = None
opt_state: Any = None
# Become available after the metrics computation
metric_states: Any = None
summary_states: Any = None

replace = dataclasses.replace

Expand Down
12 changes: 1 addition & 11 deletions kauldron/train/metric_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ def write_step_metrics(
*,
step: int,
aux: train_step.Auxiliaries,
model_with_aux: train_step.ModelWithAux,
schedules: Mapping[str, optax.Schedule],
log_summaries: bool,
timer: Optional[chrono_utils.Chrono] = None,
Expand Down Expand Up @@ -201,15 +200,7 @@ def write_step_metrics(

if log_summaries:
with jax.transfer_guard("allow"):
# TODO(klausg): remove once all summaries are migrated to new protocol
# image summaries
image_summaries_old = {
name: summary.get_images(**aux.summary_kwargs[name])
for name, summary in model_with_aux.summaries.items()
if isinstance(summary, summaries.ImageSummary)
}

image_summaries = image_summaries_old | {
image_summaries = {
name: value
for name, value in aux_result.summary_values.items()
if isinstance(value, Float["n h w #3"])
Expand Down Expand Up @@ -586,7 +577,6 @@ def write_step_metrics(
*,
step: int,
aux: train_step.Auxiliaries,
model_with_aux: train_step.ModelWithAux,
schedules: Mapping[str, optax.Schedule],
log_summaries: bool,
timer: Optional[chrono_utils.Chrono] = None,
Expand Down
1 change: 0 additions & 1 deletion kauldron/train/train_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ def train_impl(
step=i,
aux=aux,
schedules=trainer.schedules,
model_with_aux=trainstep.model_with_aux,
timer=chrono,
log_summaries=log_summaries,
)
Expand Down
Loading
Loading