Skip to content

Commit d12c6cf

Browse files
awaelchlicarmocca
andauthored
more early stopping options (convergence and divergence threshold) (#6868)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 60c1c8f commit d12c6cf

File tree

3 files changed

+109
-14
lines changed

3 files changed

+109
-14
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
102102
- Added `max_time` Trainer argument to limit training time ([#6823](https://github.com/PyTorchLightning/pytorch-lightning/pull/6823))
103103

104104

105+
- Added new `EarlyStopping` parameters `stopping_threshold` and `divergence_threshold` ([#6868](https://github.com/PyTorchLightning/pytorch-lightning/pull/6868))
106+
107+
108+
105109
### Changed
106110

107111
- Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259))

pytorch_lightning/callbacks/early_stopping.py

Lines changed: 61 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
Monitor a metric and stop training when it stops improving.
1919
2020
"""
21-
from typing import Any, Dict
21+
import logging
22+
from typing import Any, Dict, Optional, Tuple
2223

2324
import numpy as np
2425
import torch
@@ -27,6 +28,8 @@
2728
from pytorch_lightning.utilities import rank_zero_warn
2829
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2930

31+
log = logging.getLogger(__name__)
32+
3033

3134
class EarlyStopping(Callback):
3235
r"""
@@ -53,6 +56,9 @@ class EarlyStopping(Callback):
5356
monitored has stopped decreasing and in ``'max'`` mode it will stop when the quantity
5457
monitored has stopped increasing.
5558
strict: whether to crash the training if `monitor` is not found in the validation metrics.
59+
check_finite: When set ``True``, stops training when the monitor becomes NaN or infinite.
60+
stopping_threshold: Stop training immediately once the monitored quantity reaches this threshold.
61+
divergence_threshold: Stop training as soon as the monitored quantity becomes worse than this threshold.
5662
5763
Raises:
5864
MisconfigurationException:
@@ -72,6 +78,11 @@ class EarlyStopping(Callback):
7278
'max': torch.gt,
7379
}
7480

81+
order_dict = {
82+
'min': "<",
83+
'max': ">",
84+
}
85+
7586
def __init__(
7687
self,
7788
monitor: str = 'early_stop_on',
@@ -80,16 +91,22 @@ def __init__(
8091
verbose: bool = False,
8192
mode: str = 'min',
8293
strict: bool = True,
94+
check_finite: bool = True,
95+
stopping_threshold: Optional[float] = None,
96+
divergence_threshold: Optional[float] = None,
8397
):
8498
super().__init__()
8599
self.monitor = monitor
100+
self.min_delta = min_delta
86101
self.patience = patience
87102
self.verbose = verbose
103+
self.mode = mode
88104
self.strict = strict
89-
self.min_delta = min_delta
105+
self.check_finite = check_finite
106+
self.stopping_threshold = stopping_threshold
107+
self.divergence_threshold = divergence_threshold
90108
self.wait_count = 0
91109
self.stopped_epoch = 0
92-
self.mode = mode
93110

94111
if self.mode not in self.mode_dict:
95112
raise MisconfigurationException(f"`mode` can be {', '.join(self.mode_dict.keys())}, got {self.mode}")
@@ -160,15 +177,50 @@ def _run_early_stopping_check(self, trainer):
160177
# when in dev debugging
161178
trainer.dev_debugger.track_early_stopping_history(self, current)
162179

163-
if self.monitor_op(current - self.min_delta, self.best_score):
180+
should_stop, reason = self._evalute_stopping_criteria(current)
181+
182+
# stop every ddp process if any world process decides to stop
183+
should_stop = trainer.training_type_plugin.reduce_boolean_decision(should_stop)
184+
trainer.should_stop = trainer.should_stop or should_stop
185+
if should_stop:
186+
self.stopped_epoch = trainer.current_epoch
187+
if reason:
188+
log.info(f"[{trainer.global_rank}] {reason}")
189+
190+
def _evalute_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]:
191+
should_stop = False
192+
reason = None
193+
if self.check_finite and not torch.isfinite(current):
194+
should_stop = True
195+
reason = (
196+
f"Monitored metric {self.monitor} = {current} is not finite."
197+
f" Previous best value was {self.best_score:.3f}. Signaling Trainer to stop."
198+
)
199+
elif self.stopping_threshold is not None and self.monitor_op(current, self.stopping_threshold):
200+
should_stop = True
201+
reason = (
202+
"Stopping threshold reached:"
203+
f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.stopping_threshold}."
204+
" Signaling Trainer to stop."
205+
)
206+
elif self.divergence_threshold is not None and self.monitor_op(-current, -self.divergence_threshold):
207+
should_stop = True
208+
reason = (
209+
"Divergence threshold reached:"
210+
f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.divergence_threshold}."
211+
" Signaling Trainer to stop."
212+
)
213+
elif self.monitor_op(current - self.min_delta, self.best_score):
214+
should_stop = False
164215
self.best_score = current
165216
self.wait_count = 0
166217
else:
167218
self.wait_count += 1
168-
169219
if self.wait_count >= self.patience:
170-
self.stopped_epoch = trainer.current_epoch
171-
trainer.should_stop = True
220+
should_stop = True
221+
reason = (
222+
f"Monitored metric {self.monitor} did not improve in the last {self.wait_count} epochs."
223+
f" Best score: {self.best_score:.3f}. Signaling Trainer to stop."
224+
)
172225

173-
# stop every ddp process if any world process decides to stop
174-
trainer.should_stop = trainer.training_type_plugin.reduce_boolean_decision(trainer.should_stop)
226+
return should_stop, reason

tests/callbacks/test_early_stopping.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -213,25 +213,64 @@ def test_early_stopping_no_val_step(tmpdir):
213213
assert trainer.current_epoch < trainer.max_epochs - 1
214214

215215

216-
def test_early_stopping_functionality(tmpdir):
216+
@pytest.mark.parametrize("stopping_threshold,divergence_theshold,losses,expected_epoch", [
217+
(None, None, [8, 4, 2, 3, 4, 5, 8, 10], 5),
218+
(2.9, None, [9, 8, 7, 6, 5, 6, 4, 3, 2, 1], 8),
219+
(None, 15.9, [9, 4, 2, 16, 32, 64], 3),
220+
])
221+
def test_early_stopping_thresholds(tmpdir, stopping_threshold, divergence_theshold, losses, expected_epoch):
217222

218223
class CurrentModel(BoringModel):
219224

220225
def validation_epoch_end(self, outputs):
221-
losses = [8, 4, 2, 3, 4, 5, 8, 10]
222226
val_loss = losses[self.current_epoch]
223227
self.log('abc', val_loss)
224228

225229
model = CurrentModel()
226-
230+
early_stopping = EarlyStopping(
231+
monitor='abc',
232+
stopping_threshold=stopping_threshold,
233+
divergence_threshold=divergence_theshold,
234+
)
227235
trainer = Trainer(
228236
default_root_dir=tmpdir,
229-
callbacks=[EarlyStopping(monitor='abc')],
237+
callbacks=[early_stopping],
230238
overfit_batches=0.20,
231239
max_epochs=20,
232240
)
233241
trainer.fit(model)
234-
assert trainer.current_epoch == 5, 'early_stopping failed'
242+
assert trainer.current_epoch == expected_epoch, 'early_stopping failed'
243+
244+
245+
@pytest.mark.parametrize("stop_value", [
246+
torch.tensor(np.inf),
247+
torch.tensor(np.nan),
248+
])
249+
def test_early_stopping_on_non_finite_monitor(tmpdir, stop_value):
250+
251+
losses = [4, 3, stop_value, 2, 1]
252+
expected_stop_epoch = 2
253+
254+
class CurrentModel(BoringModel):
255+
256+
def validation_epoch_end(self, outputs):
257+
val_loss = losses[self.current_epoch]
258+
self.log('val_loss', val_loss)
259+
260+
model = CurrentModel()
261+
early_stopping = EarlyStopping(
262+
monitor='val_loss',
263+
check_finite=True,
264+
)
265+
trainer = Trainer(
266+
default_root_dir=tmpdir,
267+
callbacks=[early_stopping],
268+
overfit_batches=0.20,
269+
max_epochs=10,
270+
)
271+
trainer.fit(model)
272+
assert trainer.current_epoch == expected_stop_epoch
273+
assert early_stopping.stopped_epoch == expected_stop_epoch
235274

236275

237276
@pytest.mark.parametrize('step_freeze, min_steps, min_epochs', [(5, 1, 1), (5, 1, 3), (3, 15, 1)])

0 commit comments

Comments
 (0)