Skip to content

Commit 8266b14

Browse files
ananthsubBordaawaelchli
authored
[feat] Support time-based checkpointing during training (#7515)
Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 485554c commit 8266b14

File tree

3 files changed

+101
-15
lines changed

3 files changed

+101
-15
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1111

1212
- Added `KubeflowEnvironment` for use with the `PyTorchJob` operator in Kubeflow
1313

14+
1415
- Added LightningCLI support for config files on object stores ([#7521](https://github.com/PyTorchLightning/pytorch-lightning/pull/7521))
1516

1617

18+
- Added support for checkpointing based on a provided time interval during training ([#7515](https://github.com/PyTorchLightning/pytorch-lightning/pull/7515))
19+
20+
1721
- Added dataclasses for progress tracking ([#6603](https://github.com/PyTorchLightning/pytorch-lightning/pull/6603))
1822

1923

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
import logging
2222
import os
2323
import re
24+
import time
2425
from copy import deepcopy
26+
from datetime import timedelta
2527
from pathlib import Path
2628
from typing import Any, Callable, Dict, Optional, Union
2729

@@ -101,12 +103,17 @@ class ModelCheckpoint(Callback):
101103
is saved (``model.save(filepath)``).
102104
every_n_train_steps: Number of training steps between checkpoints.
103105
If ``every_n_train_steps == None or every_n_train_steps == 0``, we skip saving during training
104-
To disable, set ``every_n_train_steps = 0``. This value must be ``None`` non-negative.
105-
This must be mutually exclusive with ``every_n_val_epochs``.
106+
To disable, set ``every_n_train_steps = 0``. This value must be ``None`` or non-negative.
107+
This must be mutually exclusive with ``train_time_interval`` and ``every_n_val_epochs``.
108+
train_time_interval: Checkpoints are monitored at the specified time interval.
109+
For all practical purposes, this cannot be smaller than the amount
110+
of time it takes to process a single training batch. This is not
111+
guaranteed to execute at the exact time specified, but should be close.
112+
This must be mutually exclusive with ``every_n_train_steps`` and ``every_n_val_epochs``.
106113
every_n_val_epochs: Number of validation epochs between checkpoints.
107114
If ``every_n_val_epochs == None or every_n_val_epochs == 0``, we skip saving on validation end
108115
To disable, set ``every_n_val_epochs = 0``. This value must be ``None`` or non-negative.
109-
This must be mutually exclusive with ``every_n_train_steps``.
116+
This must be mutually exclusive with ``every_n_train_steps`` and ``train_time_interval``.
110117
Setting both ``ModelCheckpoint(..., every_n_val_epochs=V)`` and
111118
``Trainer(max_epochs=N, check_val_every_n_epoch=M)``
112119
will only save checkpoints at epochs 0 < E <= N
@@ -129,6 +136,9 @@ class ModelCheckpoint(Callback):
129136
For example, you can change the default last checkpoint name by doing
130137
``checkpoint_callback.CHECKPOINT_NAME_LAST = "{epoch}-last"``
131138
139+
If you want to checkpoint every N hours, every M train batches, and/or every K val epochs,
140+
then you should create multiple ``ModelCheckpoint`` callbacks.
141+
132142
Raises:
133143
MisconfigurationException:
134144
If ``save_top_k`` is neither ``None`` nor more than or equal to ``-1``,
@@ -190,6 +200,7 @@ def __init__(
190200
mode: str = "min",
191201
auto_insert_metric_name: bool = True,
192202
every_n_train_steps: Optional[int] = None,
203+
train_time_interval: Optional[timedelta] = None,
193204
every_n_val_epochs: Optional[int] = None,
194205
period: Optional[int] = None,
195206
):
@@ -201,6 +212,7 @@ def __init__(
201212
self.save_weights_only = save_weights_only
202213
self.auto_insert_metric_name = auto_insert_metric_name
203214
self._last_global_step_saved = -1
215+
self._last_time_checked: Optional[float] = None
204216
self.current_score = None
205217
self.best_k_models = {}
206218
self.kth_best_model_path = ""
@@ -210,7 +222,7 @@ def __init__(
210222

211223
self.__init_monitor_mode(mode)
212224
self.__init_ckpt_dir(dirpath, filename, save_top_k)
213-
self.__init_triggers(every_n_train_steps, every_n_val_epochs, period)
225+
self.__init_triggers(every_n_train_steps, every_n_val_epochs, train_time_interval, period)
214226
self.__validate_init_configuration()
215227
self._save_function = None
216228

@@ -221,6 +233,9 @@ def on_pretrain_routine_start(self, trainer: 'pl.Trainer', pl_module: 'pl.Lightn
221233
self.__resolve_ckpt_dir(trainer)
222234
self._save_function = trainer.save_checkpoint
223235

236+
def on_train_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
237+
self._last_time_checked = time.monotonic()
238+
224239
def on_train_batch_end(
225240
self,
226241
trainer: 'pl.Trainer',
@@ -235,8 +250,22 @@ def on_train_batch_end(
235250
return
236251
step = trainer.global_step
237252
skip_batch = self._every_n_train_steps < 1 or ((step + 1) % self._every_n_train_steps != 0)
238-
if skip_batch:
253+
254+
train_time_interval = self._train_time_interval
255+
skip_time = True
256+
now = time.monotonic()
257+
if train_time_interval:
258+
prev_time_check = self._last_time_checked
259+
skip_time = (prev_time_check is None or (now - prev_time_check) < train_time_interval.total_seconds())
260+
# in case we have time differences across ranks
261+
# broadcast the decision on whether to checkpoint from rank 0 to avoid possible hangs
262+
skip_time = trainer.training_type_plugin.broadcast(skip_time)
263+
264+
if skip_batch and skip_time:
239265
return
266+
if not skip_time:
267+
self._last_time_checked = now
268+
240269
self.save_checkpoint(trainer)
241270

242271
def on_validation_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
@@ -322,12 +351,17 @@ def __validate_init_configuration(self) -> None:
322351
raise MisconfigurationException(
323352
f'Invalid value for every_n_val_epochs={self._every_n_val_epochs}. Must be >= 0'
324353
)
325-
if self._every_n_train_steps > 0 and self._every_n_val_epochs > 0:
354+
355+
every_n_train_steps_triggered = self._every_n_train_steps >= 1
356+
every_n_val_epochs_triggered = self._every_n_val_epochs >= 1
357+
train_time_interval_triggered = self._train_time_interval is not None
358+
if (every_n_train_steps_triggered + every_n_val_epochs_triggered + train_time_interval_triggered > 1):
326359
raise MisconfigurationException(
327-
f'Invalid values for every_n_train_steps={self._every_n_train_steps}'
328-
' and every_n_val_epochs={self._every_n_val_epochs}.'
329-
' Both cannot be enabled at the same time.'
360+
f"Combination of parameters every_n_train_steps={self._every_n_train_steps}, "
361+
f"every_n_val_epochs={self._every_n_val_epochs} and train_time_interval={self._train_time_interval} "
362+
"should be mutually exclusive."
330363
)
364+
331365
if self.monitor is None:
332366
# None: save last epoch, -1: save all epochs, 0: nothing is saved
333367
if self.save_top_k not in (None, -1, 0):
@@ -379,19 +413,22 @@ def __init_monitor_mode(self, mode: str) -> None:
379413
self.kth_value, self.mode = mode_dict[mode]
380414

381415
def __init_triggers(
382-
self, every_n_train_steps: Optional[int], every_n_val_epochs: Optional[int], period: Optional[int]
416+
self, every_n_train_steps: Optional[int], every_n_val_epochs: Optional[int],
417+
train_time_interval: Optional[timedelta], period: Optional[int]
383418
) -> None:
384419

385420
# Default to running once after each validation epoch if neither
386421
# every_n_train_steps nor every_n_val_epochs is set
387-
if every_n_train_steps is None and every_n_val_epochs is None:
422+
if every_n_train_steps is None and every_n_val_epochs is None and train_time_interval is None:
388423
self._every_n_val_epochs = 1
389424
self._every_n_train_steps = 0
390425
log.debug("Both every_n_train_steps and every_n_val_epochs are not set. Setting every_n_val_epochs=1")
391426
else:
392427
self._every_n_val_epochs = every_n_val_epochs or 0
393428
self._every_n_train_steps = every_n_train_steps or 0
394429

430+
self._train_time_interval: Optional[timedelta] = train_time_interval
431+
395432
# period takes precedence over every_n_val_epochs for backwards compatibility
396433
if period is not None:
397434
rank_zero_deprecation(

tests/checkpointing/test_model_checkpoint.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
import os
1717
import pickle
1818
import re
19+
import time
1920
from argparse import Namespace
21+
from datetime import timedelta
2022
from logging import INFO
2123
from pathlib import Path
2224
from typing import Union
@@ -564,16 +566,24 @@ def test_invalid_every_n_train_steps(tmpdir):
564566
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=2)
565567

566568

567-
def test_invalid_every_n_train_steps_val_epochs_combination(tmpdir):
569+
def test_invalid_trigger_combination(tmpdir):
568570
"""
569-
Test that a MisconfigurationException is raised if both
570-
every_n_val_epochs and every_n_train_steps are enabled together.
571+
Test that a MisconfigurationException is raised if more than one of
572+
every_n_val_epochs, every_n_train_steps, and train_time_interval are enabled together.
571573
"""
572-
with pytest.raises(MisconfigurationException, match=r'.*Both cannot be enabled at the same time'):
574+
with pytest.raises(MisconfigurationException, match=r'.*Combination of parameters every_n_train_steps'):
573575
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=1, every_n_val_epochs=2)
576+
with pytest.raises(MisconfigurationException, match=r'.*Combination of parameters every_n_train_steps'):
577+
ModelCheckpoint(train_time_interval=timedelta(minutes=1), every_n_val_epochs=2)
578+
with pytest.raises(MisconfigurationException, match=r'.*Combination of parameters every_n_train_steps'):
579+
ModelCheckpoint(train_time_interval=timedelta(minutes=1), every_n_train_steps=2)
580+
574581
# These should not fail
575582
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0, every_n_val_epochs=3)
576583
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=4, every_n_val_epochs=0)
584+
ModelCheckpoint(
585+
dirpath=tmpdir, every_n_train_steps=0, every_n_val_epochs=0, train_time_interval=timedelta(minutes=1)
586+
)
577587

578588

579589
def test_none_every_n_train_steps_val_epochs(tmpdir):
@@ -718,6 +728,41 @@ def test_ckpt_every_n_train_steps(tmpdir):
718728
assert set(os.listdir(tmpdir)) == set(expected)
719729

720730

731+
@mock.patch("pytorch_lightning.callbacks.model_checkpoint.time")
732+
def test_model_checkpoint_train_time_interval(mock_datetime, tmpdir) -> None:
733+
"""Tests that the checkpoints are saved at the specified time interval."""
734+
seconds_per_batch = 7
735+
start_time = time.monotonic()
736+
batches_per_epoch = 64
737+
num_epochs = 2
738+
max_batches = batches_per_epoch * num_epochs + 1
739+
mock_datetime.monotonic.side_effect = [start_time + seconds_per_batch * i for i in range(max_batches)]
740+
741+
model = BoringModel()
742+
trainer = Trainer(
743+
default_root_dir=tmpdir,
744+
min_epochs=num_epochs,
745+
max_epochs=num_epochs,
746+
progress_bar_refresh_rate=0,
747+
callbacks=[
748+
ModelCheckpoint(
749+
filename="{epoch}-{step}",
750+
dirpath=tmpdir,
751+
train_time_interval=timedelta(minutes=1),
752+
save_top_k=-1,
753+
save_last=False,
754+
)
755+
],
756+
logger=False,
757+
)
758+
759+
trainer.fit(model)
760+
# Each batch takes 7 sec and we checkpoint every minute. There are 64
761+
# batches per epoch, so total time to run is 7*64*2 = 896 sec < 14.96 minutes,
762+
# so we should have 14 checkpoints.
763+
assert len(os.listdir(tmpdir)) == 14
764+
765+
721766
def test_model_checkpoint_topk_zero(tmpdir):
722767
""" Test that no checkpoints are saved when save_top_k=0. """
723768
model = LogInTwoMethods()

0 commit comments

Comments
 (0)