2121import logging
2222import os
2323import re
24+ import time
2425from copy import deepcopy
26+ from datetime import timedelta
2527from pathlib import Path
2628from typing import Any , Callable , Dict , Optional , Union
2729
@@ -101,12 +103,17 @@ class ModelCheckpoint(Callback):
101103 is saved (``model.save(filepath)``).
102104 every_n_train_steps: Number of training steps between checkpoints.
103105 If ``every_n_train_steps == None or every_n_train_steps == 0``, we skip saving during training
104- To disable, set ``every_n_train_steps = 0``. This value must be ``None`` non-negative.
105- This must be mutually exclusive with ``every_n_val_epochs``.
106+ To disable, set ``every_n_train_steps = 0``. This value must be ``None`` or non-negative.
107+ This must be mutually exclusive with ``train_time_interval`` and ``every_n_val_epochs``.
108+ train_time_interval: Checkpoints are monitored at the specified time interval.
109+ For all practical purposes, this cannot be smaller than the amount
110+ of time it takes to process a single training batch. This is not
111+ guaranteed to execute at the exact time specified, but should be close.
112+ This must be mutually exclusive with ``every_n_train_steps`` and ``every_n_val_epochs``.
106113 every_n_val_epochs: Number of validation epochs between checkpoints.
107114 If ``every_n_val_epochs == None or every_n_val_epochs == 0``, we skip saving on validation end
108115 To disable, set ``every_n_val_epochs = 0``. This value must be ``None`` or non-negative.
109- This must be mutually exclusive with ``every_n_train_steps``.
116+ This must be mutually exclusive with ``every_n_train_steps`` and ``train_time_interval`` .
110117 Setting both ``ModelCheckpoint(..., every_n_val_epochs=V)`` and
111118 ``Trainer(max_epochs=N, check_val_every_n_epoch=M)``
112119 will only save checkpoints at epochs 0 < E <= N
@@ -129,6 +136,9 @@ class ModelCheckpoint(Callback):
129136 For example, you can change the default last checkpoint name by doing
130137 ``checkpoint_callback.CHECKPOINT_NAME_LAST = "{epoch}-last"``
131138
139+ If you want to checkpoint every N hours, every M train batches, and/or every K val epochs,
140+ then you should create multiple ``ModelCheckpoint`` callbacks.
141+
132142 Raises:
133143 MisconfigurationException:
134144 If ``save_top_k`` is neither ``None`` nor more than or equal to ``-1``,
@@ -190,6 +200,7 @@ def __init__(
190200 mode : str = "min" ,
191201 auto_insert_metric_name : bool = True ,
192202 every_n_train_steps : Optional [int ] = None ,
203+ train_time_interval : Optional [timedelta ] = None ,
193204 every_n_val_epochs : Optional [int ] = None ,
194205 period : Optional [int ] = None ,
195206 ):
@@ -201,6 +212,7 @@ def __init__(
201212 self .save_weights_only = save_weights_only
202213 self .auto_insert_metric_name = auto_insert_metric_name
203214 self ._last_global_step_saved = - 1
215+ self ._last_time_checked : Optional [float ] = None
204216 self .current_score = None
205217 self .best_k_models = {}
206218 self .kth_best_model_path = ""
@@ -210,7 +222,7 @@ def __init__(
210222
211223 self .__init_monitor_mode (mode )
212224 self .__init_ckpt_dir (dirpath , filename , save_top_k )
213- self .__init_triggers (every_n_train_steps , every_n_val_epochs , period )
225+ self .__init_triggers (every_n_train_steps , every_n_val_epochs , train_time_interval , period )
214226 self .__validate_init_configuration ()
215227 self ._save_function = None
216228
@@ -221,6 +233,9 @@ def on_pretrain_routine_start(self, trainer: 'pl.Trainer', pl_module: 'pl.Lightn
221233 self .__resolve_ckpt_dir (trainer )
222234 self ._save_function = trainer .save_checkpoint
223235
236+ def on_train_start (self , trainer : 'pl.Trainer' , pl_module : 'pl.LightningModule' ) -> None :
237+ self ._last_time_checked = time .monotonic ()
238+
224239 def on_train_batch_end (
225240 self ,
226241 trainer : 'pl.Trainer' ,
@@ -235,8 +250,22 @@ def on_train_batch_end(
235250 return
236251 step = trainer .global_step
237252 skip_batch = self ._every_n_train_steps < 1 or ((step + 1 ) % self ._every_n_train_steps != 0 )
238- if skip_batch :
253+
254+ train_time_interval = self ._train_time_interval
255+ skip_time = True
256+ now = time .monotonic ()
257+ if train_time_interval :
258+ prev_time_check = self ._last_time_checked
259+ skip_time = (prev_time_check is None or (now - prev_time_check ) < train_time_interval .total_seconds ())
260+ # in case we have time differences across ranks
261+ # broadcast the decision on whether to checkpoint from rank 0 to avoid possible hangs
262+ skip_time = trainer .training_type_plugin .broadcast (skip_time )
263+
264+ if skip_batch and skip_time :
239265 return
266+ if not skip_time :
267+ self ._last_time_checked = now
268+
240269 self .save_checkpoint (trainer )
241270
242271 def on_validation_end (self , trainer : 'pl.Trainer' , pl_module : 'pl.LightningModule' ) -> None :
@@ -322,12 +351,17 @@ def __validate_init_configuration(self) -> None:
322351 raise MisconfigurationException (
323352 f'Invalid value for every_n_val_epochs={ self ._every_n_val_epochs } . Must be >= 0'
324353 )
325- if self ._every_n_train_steps > 0 and self ._every_n_val_epochs > 0 :
354+
355+ every_n_train_steps_triggered = self ._every_n_train_steps >= 1
356+ every_n_val_epochs_triggered = self ._every_n_val_epochs >= 1
357+ train_time_interval_triggered = self ._train_time_interval is not None
358+ if (every_n_train_steps_triggered + every_n_val_epochs_triggered + train_time_interval_triggered > 1 ):
326359 raise MisconfigurationException (
327- f'Invalid values for every_n_train_steps={ self ._every_n_train_steps } '
328- ' and every_n_val_epochs={self._every_n_val_epochs}.'
329- ' Both cannot be enabled at the same time.'
360+ f"Combination of parameters every_n_train_steps={ self ._every_n_train_steps } , "
361+ f" every_n_val_epochs={ self ._every_n_val_epochs } and train_time_interval= { self . _train_time_interval } "
362+ "should be mutually exclusive."
330363 )
364+
331365 if self .monitor is None :
332366 # None: save last epoch, -1: save all epochs, 0: nothing is saved
333367 if self .save_top_k not in (None , - 1 , 0 ):
@@ -379,19 +413,22 @@ def __init_monitor_mode(self, mode: str) -> None:
379413 self .kth_value , self .mode = mode_dict [mode ]
380414
381415 def __init_triggers (
382- self , every_n_train_steps : Optional [int ], every_n_val_epochs : Optional [int ], period : Optional [int ]
416+ self , every_n_train_steps : Optional [int ], every_n_val_epochs : Optional [int ],
417+ train_time_interval : Optional [timedelta ], period : Optional [int ]
383418 ) -> None :
384419
385420 # Default to running once after each validation epoch if neither
386421 # every_n_train_steps nor every_n_val_epochs is set
387- if every_n_train_steps is None and every_n_val_epochs is None :
422+ if every_n_train_steps is None and every_n_val_epochs is None and train_time_interval is None :
388423 self ._every_n_val_epochs = 1
389424 self ._every_n_train_steps = 0
390425 log .debug ("Both every_n_train_steps and every_n_val_epochs are not set. Setting every_n_val_epochs=1" )
391426 else :
392427 self ._every_n_val_epochs = every_n_val_epochs or 0
393428 self ._every_n_train_steps = every_n_train_steps or 0
394429
430+ self ._train_time_interval : Optional [timedelta ] = train_time_interval
431+
395432 # period takes precedence over every_n_val_epochs for backwards compatibility
396433 if period is not None :
397434 rank_zero_deprecation (
0 commit comments