From aa99dbc3af6925f8e77679a0366adff5d75c9b66 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Sun, 7 Feb 2021 10:39:41 -0500 Subject: [PATCH] FIX Makes sure score is passed to ReduceLROnPlateau (#738) * FIX Makes sure score is passed to ReduceLROnPlateau --- CHANGES.md | 3 +- skorch/callbacks/lr_scheduler.py | 12 ++++---- skorch/tests/callbacks/test_lr_scheduler.py | 31 +++++++++++++++------ 3 files changed, 32 insertions(+), 14 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 212743327..28fd170b3 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - CLI helper function now also supports normal (i.e. non-skorch) sklearn estimators - Disabling all callbacks is now supported (which allows reducing overhead, which is especially relevant for small models). +- `LRScheduler` now correctly passes the value being monitored to `ReduceLROnPlateau`. (#738) ### Changed @@ -35,7 +36,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed -- Removed support for schedulers with a `batch_step()` method in `LRScheduler`. +- Removed support for schedulers with a `batch_step()` method in `LRScheduler`. - Raise `FutureWarning` in `CVSplit` when `random_state` is not used. Will raise an exception in a future (#620) - The behavior of method `net.get_params` changed to make it more consistent with sklearn: it will no longer return "learned" attributes like `module_`; therefore, functions like `sklearn.base.clone`, when called with a fitted net, will no longer return a fitted net but instead an uninitialized net; if you want a copy of a fitted net, use `copy.deepcopy` instead;`net.get_params` is used under the hood by many sklearn functions and classes, such as `GridSearchCV`, whose behavior may thus be affected by the change. (#521, #527) - Raise `FutureWarning` when using `CyclicLR` scheduler, because the default behavior has changed from taking a step every batch to taking a step every epoch. (#626) diff --git a/skorch/callbacks/lr_scheduler.py b/skorch/callbacks/lr_scheduler.py index 04ec5435f..09e475e3e 100644 --- a/skorch/callbacks/lr_scheduler.py +++ b/skorch/callbacks/lr_scheduler.py @@ -149,12 +149,14 @@ def on_epoch_end(self, net, **kwargs): if callable(self.monitor): score = self.monitor(net) else: - if self.lr_scheduler_.mode == 'max': - score = -np.inf - elif self.lr_scheduler_.mode == 'min': - score = np.inf - else: + try: score = net.history[-1, self.monitor] + except KeyError as e: + raise ValueError( + f"'{self.monitor}' was not found in history. A " + f"Scoring callback with name='{self.monitor}' " + "should be placed before the LRScheduler callback" + ) from e self.lr_scheduler_.step(score) # ReduceLROnPlateau does not expose the current lr so it can't be recorded diff --git a/skorch/tests/callbacks/test_lr_scheduler.py b/skorch/tests/callbacks/test_lr_scheduler.py index 1b140488b..352c63c8c 100644 --- a/skorch/tests/callbacks/test_lr_scheduler.py +++ b/skorch/tests/callbacks/test_lr_scheduler.py @@ -290,25 +290,40 @@ def test_reduce_lr_monitor_with_callable( score = mock_step.call_args_list[0][0][0] assert score == 55 - @pytest.mark.parametrize('mode,score', [ - ('min', np.inf), - ('max', -np.inf) - ]) - def test_reduce_lr_monitor_max( - self, classifier_data, classifier_module, mode, score): + @pytest.mark.parametrize('mode', ['min', 'max']) + def test_reduce_lr_monitor_passes_monitored_loss( + self, classifier_data, classifier_module, mode): X, y = classifier_data net = NeuralNetClassifier( classifier_module, callbacks=[ ('scheduler', LRScheduler( - ReduceLROnPlateau, monitor='train_loss', mode=mode)), + ReduceLROnPlateau, monitor='valid_loss', mode=mode)), ], max_epochs=1, ) net.fit(X, y) + expected = net.history_[-1, "valid_loss"] policy = dict(net.callbacks_)['scheduler'].lr_scheduler_ - assert policy.best == score + assert policy.best == pytest.approx(expected) + + def test_reduce_lr_raise_error_when_key_does_not_exist( + self, classifier_data, classifier_module): + X, y = classifier_data + net = NeuralNetClassifier( + classifier_module, + callbacks=[ + ('scheduler', LRScheduler( + ReduceLROnPlateau, monitor='bad_key')), + ], + max_epochs=1, + ) + msg = ("'bad_key' was not found in history. A Scoring " + "callback with name='bad_key' should be placed before the " + "LRScheduler callback") + with pytest.raises(ValueError, match=msg): + net.fit(X, y) class TestWarmRestartLR():