From d008bcf29d75b79be3c304461c00355919b33f9c Mon Sep 17 00:00:00 2001 From: "George E. Dahl" Date: Wed, 13 Nov 2024 11:20:53 -0800 Subject: [PATCH] Re-enable more of test_early_stopping. The test logic wasn't correct when min_steps was set. PiperOrigin-RevId: 696206571 --- init2winit/trainer_lib/test_trainer.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/init2winit/trainer_lib/test_trainer.py b/init2winit/trainer_lib/test_trainer.py index a4b4fda..7ba594c 100644 --- a/init2winit/trainer_lib/test_trainer.py +++ b/init2winit/trainer_lib/test_trainer.py @@ -918,7 +918,7 @@ def mock_evaluate_batch(params, batch_stats, batch): @parameterized.parameters(False, True) def test_early_stopping(self, min_steps): """Test training early stopping on MNIST with a small model.""" - rng = jax.random.PRNGKey(0) + rng = jax.random.PRNGKey(5) # Set the numpy seed to make the fake data deterministc. mocking.mock_data # ultimately calls numpy.random. @@ -1022,14 +1022,14 @@ def as_dataset(self, *args, **kwargs): # With min steps, we should've run an extra 10 steps. self.assertLen(epoch_reports, 4) epoch_reports.pop() - # TODO(b/373692442) - # self.assertLen(epoch_reports, 3) - # self.assertGreater( - # epoch_reports[-2][early_stopping_target_name], - # early_stopping_target_value) - # self.assertLess( - # epoch_reports[-1][early_stopping_target_name], - # early_stopping_target_value) + self.assertGreaterEqual(len(epoch_reports), 2) + if not min_steps: + self.assertGreater( + epoch_reports[-2][early_stopping_target_name], + early_stopping_target_value) + self.assertLess( + epoch_reports[-1][early_stopping_target_name], + early_stopping_target_value) if __name__ == '__main__':