1818Monitor a metric and stop training when it stops improving.
1919
2020"""
21- from typing import Any , Dict
21+ import logging
22+ from typing import Any , Dict , Optional , Tuple
2223
2324import numpy as np
2425import torch
2728from pytorch_lightning .utilities import rank_zero_warn
2829from pytorch_lightning .utilities .exceptions import MisconfigurationException
2930
31+ log = logging .getLogger (__name__ )
32+
3033
3134class EarlyStopping (Callback ):
3235 r"""
@@ -53,6 +56,9 @@ class EarlyStopping(Callback):
5356 monitored has stopped decreasing and in ``'max'`` mode it will stop when the quantity
5457 monitored has stopped increasing.
5558 strict: whether to crash the training if `monitor` is not found in the validation metrics.
59+ check_finite: When set ``True``, stops training when the monitor becomes NaN or infinite.
60+ stopping_threshold: Stop training immediately once the monitored quantity reaches this threshold.
61+ divergence_threshold: Stop training as soon as the monitored quantity becomes worse than this threshold.
5662
5763 Raises:
5864 MisconfigurationException:
@@ -72,6 +78,11 @@ class EarlyStopping(Callback):
7278 'max' : torch .gt ,
7379 }
7480
81+ order_dict = {
82+ 'min' : "<" ,
83+ 'max' : ">" ,
84+ }
85+
7586 def __init__ (
7687 self ,
7788 monitor : str = 'early_stop_on' ,
@@ -80,16 +91,22 @@ def __init__(
8091 verbose : bool = False ,
8192 mode : str = 'min' ,
8293 strict : bool = True ,
94+ check_finite : bool = True ,
95+ stopping_threshold : Optional [float ] = None ,
96+ divergence_threshold : Optional [float ] = None ,
8397 ):
8498 super ().__init__ ()
8599 self .monitor = monitor
100+ self .min_delta = min_delta
86101 self .patience = patience
87102 self .verbose = verbose
103+ self .mode = mode
88104 self .strict = strict
89- self .min_delta = min_delta
105+ self .check_finite = check_finite
106+ self .stopping_threshold = stopping_threshold
107+ self .divergence_threshold = divergence_threshold
90108 self .wait_count = 0
91109 self .stopped_epoch = 0
92- self .mode = mode
93110
94111 if self .mode not in self .mode_dict :
95112 raise MisconfigurationException (f"`mode` can be { ', ' .join (self .mode_dict .keys ())} , got { self .mode } " )
@@ -160,15 +177,50 @@ def _run_early_stopping_check(self, trainer):
160177 # when in dev debugging
161178 trainer .dev_debugger .track_early_stopping_history (self , current )
162179
163- if self .monitor_op (current - self .min_delta , self .best_score ):
180+ should_stop , reason = self ._evalute_stopping_criteria (current )
181+
182+ # stop every ddp process if any world process decides to stop
183+ should_stop = trainer .training_type_plugin .reduce_boolean_decision (should_stop )
184+ trainer .should_stop = trainer .should_stop or should_stop
185+ if should_stop :
186+ self .stopped_epoch = trainer .current_epoch
187+ if reason :
188+ log .info (f"[{ trainer .global_rank } ] { reason } " )
189+
190+ def _evalute_stopping_criteria (self , current : torch .Tensor ) -> Tuple [bool , str ]:
191+ should_stop = False
192+ reason = None
193+ if self .check_finite and not torch .isfinite (current ):
194+ should_stop = True
195+ reason = (
196+ f"Monitored metric { self .monitor } = { current } is not finite."
197+ f" Previous best value was { self .best_score :.3f} . Signaling Trainer to stop."
198+ )
199+ elif self .stopping_threshold is not None and self .monitor_op (current , self .stopping_threshold ):
200+ should_stop = True
201+ reason = (
202+ "Stopping threshold reached:"
203+ f" { self .monitor } = { current } { self .order_dict [self .mode ]} { self .stopping_threshold } ."
204+ " Signaling Trainer to stop."
205+ )
206+ elif self .divergence_threshold is not None and self .monitor_op (- current , - self .divergence_threshold ):
207+ should_stop = True
208+ reason = (
209+ "Divergence threshold reached:"
210+ f" { self .monitor } = { current } { self .order_dict [self .mode ]} { self .divergence_threshold } ."
211+ " Signaling Trainer to stop."
212+ )
213+ elif self .monitor_op (current - self .min_delta , self .best_score ):
214+ should_stop = False
164215 self .best_score = current
165216 self .wait_count = 0
166217 else :
167218 self .wait_count += 1
168-
169219 if self .wait_count >= self .patience :
170- self .stopped_epoch = trainer .current_epoch
171- trainer .should_stop = True
220+ should_stop = True
221+ reason = (
222+ f"Monitored metric { self .monitor } did not improve in the last { self .wait_count } epochs."
223+ f" Best score: { self .best_score :.3f} . Signaling Trainer to stop."
224+ )
172225
173- # stop every ddp process if any world process decides to stop
174- trainer .should_stop = trainer .training_type_plugin .reduce_boolean_decision (trainer .should_stop )
226+ return should_stop , reason
0 commit comments