Skip to content

Commit

Permalink
Re-enable more of test_early_stopping. The test logic wasn't correct …
Browse files Browse the repository at this point in the history
…when min_steps was set.

PiperOrigin-RevId: 696206571
  • Loading branch information
georgedahl authored and copybara-github committed Nov 13, 2024
1 parent 4293990 commit d008bcf
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions init2winit/trainer_lib/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,7 +918,7 @@ def mock_evaluate_batch(params, batch_stats, batch):
@parameterized.parameters(False, True)
def test_early_stopping(self, min_steps):
"""Test training early stopping on MNIST with a small model."""
rng = jax.random.PRNGKey(0)
rng = jax.random.PRNGKey(5)

# Set the numpy seed to make the fake data deterministc. mocking.mock_data
# ultimately calls numpy.random.
Expand Down Expand Up @@ -1022,14 +1022,14 @@ def as_dataset(self, *args, **kwargs):
# With min steps, we should've run an extra 10 steps.
self.assertLen(epoch_reports, 4)
epoch_reports.pop()
# TODO(b/373692442)
# self.assertLen(epoch_reports, 3)
# self.assertGreater(
# epoch_reports[-2][early_stopping_target_name],
# early_stopping_target_value)
# self.assertLess(
# epoch_reports[-1][early_stopping_target_name],
# early_stopping_target_value)
self.assertGreaterEqual(len(epoch_reports), 2)
if not min_steps:
self.assertGreater(
epoch_reports[-2][early_stopping_target_name],
early_stopping_target_value)
self.assertLess(
epoch_reports[-1][early_stopping_target_name],
early_stopping_target_value)


if __name__ == '__main__':
Expand Down

0 comments on commit d008bcf

Please sign in to comment.