Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
eegli committed Feb 3, 2025
1 parent 2efb696 commit 974a741
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions tests/e2e/trainer/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,23 @@ def assert_on_model_run_output(output_dir: Path) -> None:

run_config = load_yml(yml_config_file, parse_to=TrainOutputConfig)

assert len(checkpoints) == run_config.io.num_models_to_save, (
f"Expected {run_config.io.num_models_to_save} model checkpoints"
)
assert (
len(checkpoints) == run_config.io.num_models_to_save
), f"Expected {run_config.io.num_models_to_save} model checkpoints"

# assert that we get the specified number of training loss logs
csv_log = pl.read_csv(csv_loss_file)

num_train_loss_entries = csv_log.filter(pl.col("kind") == "train").select(pl.len()).item()
assert num_train_loss_entries == run_config.io.log_train_loss_amount, (
f"Expected {run_config.io.log_train_loss_amount} train loss entries, received {num_train_loss_entries}"
)
assert (
num_train_loss_entries == run_config.io.log_train_loss_amount
), f"Expected {run_config.io.log_train_loss_amount} train loss entries, received {num_train_loss_entries}"

# assert that we get the specified number of validation loss logs
num_valid_loss_entries = csv_log.filter(pl.col("kind") == "valid").select(pl.len()).item()
assert num_valid_loss_entries == run_config.io.validate_amount, (
f"Expected {run_config.io.validate_amount} validation loss entries, received {num_valid_loss_entries}"
)
assert (
num_valid_loss_entries == run_config.io.validate_amount
), f"Expected {run_config.io.validate_amount} validation loss entries, received {num_valid_loss_entries}"

# assert that we the test loss is logged
num_test_loss_entries = csv_log.filter(pl.col("kind") == "test").select(pl.len()).item()
Expand Down

0 comments on commit 974a741

Please sign in to comment.