Skip to content

Commit

Permalink
Merge pull request #23 from skorch-dev/master
Browse files Browse the repository at this point in the history
FIX Makes sure score is passed to ReduceLROnPlateau (skorch-dev#738)
  • Loading branch information
sthagen authored Feb 15, 2021
2 parents 7602ce2 + aa99dbc commit cdb92fb
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 14 deletions.
3 changes: 2 additions & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- CLI helper function now also supports normal (i.e. non-skorch) sklearn estimators
- Disabling all callbacks is now supported (which allows reducing overhead,
which is especially relevant for small models).
- `LRScheduler` now correctly passes the value being monitored to `ReduceLROnPlateau`. (#738)

### Changed

Expand All @@ -35,7 +36,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Removed support for schedulers with a `batch_step()` method in `LRScheduler`.
- Removed support for schedulers with a `batch_step()` method in `LRScheduler`.
- Raise `FutureWarning` in `CVSplit` when `random_state` is not used. Will raise an exception in a future (#620)
- The behavior of method `net.get_params` changed to make it more consistent with sklearn: it will no longer return "learned" attributes like `module_`; therefore, functions like `sklearn.base.clone`, when called with a fitted net, will no longer return a fitted net but instead an uninitialized net; if you want a copy of a fitted net, use `copy.deepcopy` instead;`net.get_params` is used under the hood by many sklearn functions and classes, such as `GridSearchCV`, whose behavior may thus be affected by the change. (#521, #527)
- Raise `FutureWarning` when using `CyclicLR` scheduler, because the default behavior has changed from taking a step every batch to taking a step every epoch. (#626)
Expand Down
12 changes: 7 additions & 5 deletions skorch/callbacks/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,14 @@ def on_epoch_end(self, net, **kwargs):
if callable(self.monitor):
score = self.monitor(net)
else:
if self.lr_scheduler_.mode == 'max':
score = -np.inf
elif self.lr_scheduler_.mode == 'min':
score = np.inf
else:
try:
score = net.history[-1, self.monitor]
except KeyError as e:
raise ValueError(
f"'{self.monitor}' was not found in history. A "
f"Scoring callback with name='{self.monitor}' "
"should be placed before the LRScheduler callback"
) from e

self.lr_scheduler_.step(score)
# ReduceLROnPlateau does not expose the current lr so it can't be recorded
Expand Down
31 changes: 23 additions & 8 deletions skorch/tests/callbacks/test_lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,25 +290,40 @@ def test_reduce_lr_monitor_with_callable(
score = mock_step.call_args_list[0][0][0]
assert score == 55

@pytest.mark.parametrize('mode,score', [
('min', np.inf),
('max', -np.inf)
])
def test_reduce_lr_monitor_max(
self, classifier_data, classifier_module, mode, score):
@pytest.mark.parametrize('mode', ['min', 'max'])
def test_reduce_lr_monitor_passes_monitored_loss(
self, classifier_data, classifier_module, mode):
X, y = classifier_data
net = NeuralNetClassifier(
classifier_module,
callbacks=[
('scheduler', LRScheduler(
ReduceLROnPlateau, monitor='train_loss', mode=mode)),
ReduceLROnPlateau, monitor='valid_loss', mode=mode)),
],
max_epochs=1,
)
net.fit(X, y)

expected = net.history_[-1, "valid_loss"]
policy = dict(net.callbacks_)['scheduler'].lr_scheduler_
assert policy.best == score
assert policy.best == pytest.approx(expected)

def test_reduce_lr_raise_error_when_key_does_not_exist(
self, classifier_data, classifier_module):
X, y = classifier_data
net = NeuralNetClassifier(
classifier_module,
callbacks=[
('scheduler', LRScheduler(
ReduceLROnPlateau, monitor='bad_key')),
],
max_epochs=1,
)
msg = ("'bad_key' was not found in history. A Scoring "
"callback with name='bad_key' should be placed before the "
"LRScheduler callback")
with pytest.raises(ValueError, match=msg):
net.fit(X, y)


class TestWarmRestartLR():
Expand Down

0 comments on commit cdb92fb

Please sign in to comment.