Skip to content

Commit c930843

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Add CheckpointManager abstraction in utils/checkpoint.py
Differential Revision: D56427226
1 parent a16ea66 commit c930843

File tree

3 files changed

+616
-9
lines changed

3 files changed

+616
-9
lines changed

tests/utils/test_checkpoint.py

Lines changed: 358 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,17 @@
1010
import shutil
1111
import tempfile
1212
import unittest
13+
from unittest.mock import patch
1314

1415
import torch
1516

1617
import torch.distributed as dist
1718
from torch import nn
1819
from torchsnapshot import Snapshot
1920
from torchsnapshot.snapshot import SNAPSHOT_METADATA_FNAME
21+
from torchtnt.framework._test_utils import Batch
22+
from torchtnt.framework.state import State
23+
from torchtnt.framework.unit import TrainUnit
2024
from torchtnt.utils import get_global_rank, init_from_env
2125

2226
from torchtnt.utils.checkpoint import (
@@ -25,6 +29,8 @@
2529
_retrieve_checkpoint_dirpaths,
2630
_sort_by_metric_value,
2731
_sort_by_recency,
32+
BestCheckpointConfig,
33+
CheckpointManager,
2834
CheckpointPath,
2935
get_best_checkpoint_path,
3036
get_checkpoint_dirpaths,
@@ -190,6 +196,349 @@ def test_pickling(self) -> None:
190196
self.assertEqual(unpickled, ckpt)
191197

192198

199+
class CheckpointManagerTest(unittest.TestCase):
200+
def test_create_checkpoint_manager(self) -> None:
201+
with tempfile.TemporaryDirectory() as temp_dir:
202+
paths = [
203+
f"{temp_dir}/epoch_1_step_3",
204+
f"{temp_dir}/epoch_0_step_1",
205+
f"{temp_dir}/epoch_0_step_5_loss=-0.3",
206+
f"{temp_dir}/epoch_1_step_1",
207+
f"{temp_dir}/epoch_1_step_2_loss=0.5",
208+
f"{temp_dir}/epoch_2_step_5_loss=0.3",
209+
f"{temp_dir}/epoch_0_step_2_acc=0.7",
210+
]
211+
for path in paths:
212+
os.mkdir(path)
213+
214+
# without last_n_checkpoints
215+
ckpt_manager = CheckpointManager(temp_dir)
216+
self.assertEqual(ckpt_manager._ckpt_paths, [])
217+
218+
# with last_n_checkpoints but without metric
219+
ckpt_manager = CheckpointManager(temp_dir, keep_last_n_checkpoints=2)
220+
self.assertEqual(
221+
[x.path for x in ckpt_manager._ckpt_paths],
222+
[
223+
f"{temp_dir}/epoch_0_step_1",
224+
f"{temp_dir}/epoch_0_step_2_acc=0.7",
225+
f"{temp_dir}/epoch_0_step_5_loss=-0.3",
226+
f"{temp_dir}/epoch_1_step_1",
227+
f"{temp_dir}/epoch_1_step_2_loss=0.5",
228+
f"{temp_dir}/epoch_1_step_3",
229+
f"{temp_dir}/epoch_2_step_5_loss=0.3",
230+
],
231+
)
232+
233+
# with last_n_checkpoints and metric min
234+
ckpt_manager = CheckpointManager(
235+
temp_dir,
236+
keep_last_n_checkpoints=3,
237+
best_checkpoint_config=BestCheckpointConfig(
238+
monitored_metric="loss", mode="min"
239+
),
240+
)
241+
self.assertEqual(
242+
[x.path for x in ckpt_manager._ckpt_paths],
243+
[
244+
f"{temp_dir}/epoch_1_step_2_loss=0.5",
245+
f"{temp_dir}/epoch_2_step_5_loss=0.3",
246+
f"{temp_dir}/epoch_0_step_5_loss=-0.3",
247+
],
248+
)
249+
250+
# with last_n_checkpoints and metric max
251+
ckpt_manager = CheckpointManager(
252+
temp_dir,
253+
keep_last_n_checkpoints=3,
254+
best_checkpoint_config=BestCheckpointConfig(
255+
monitored_metric="loss", mode="max"
256+
),
257+
)
258+
self.assertEqual(
259+
[x.path for x in ckpt_manager._ckpt_paths],
260+
[
261+
f"{temp_dir}/epoch_0_step_5_loss=-0.3",
262+
f"{temp_dir}/epoch_2_step_5_loss=0.3",
263+
f"{temp_dir}/epoch_1_step_2_loss=0.5",
264+
],
265+
)
266+
267+
# with last_n_checkpoints and non previously tracked metric
268+
ckpt_manager = CheckpointManager(
269+
temp_dir,
270+
keep_last_n_checkpoints=3,
271+
best_checkpoint_config=BestCheckpointConfig(
272+
monitored_metric="foo", mode="max"
273+
),
274+
)
275+
self.assertEqual(ckpt_manager._ckpt_paths, [])
276+
277+
@skip_if_not_distributed
278+
def test_create_checkpoint_manager_distributed(self) -> None:
279+
spawn_multi_process(
280+
2,
281+
"gloo",
282+
self._test_create_checkpoint_manager_distributed,
283+
)
284+
285+
@staticmethod
286+
def _test_create_checkpoint_manager_distributed() -> None:
287+
if get_global_rank() == 0:
288+
temp_dir = tempfile.mkdtemp()
289+
paths = ["epoch_1_step_2", "epoch_0_step_1", "epoch_1_step_1"]
290+
for path in paths:
291+
os.mkdir(os.path.join(temp_dir, path))
292+
else:
293+
temp_dir = ""
294+
295+
tc = unittest.TestCase()
296+
297+
# without top k config
298+
ckpt_manager = CheckpointManager(temp_dir)
299+
tc.assertNotEqual(ckpt_manager.dirpath, "")
300+
tc.assertEqual(ckpt_manager._ckpt_paths, [])
301+
302+
# with top k config
303+
ckpt_manager = CheckpointManager(temp_dir, keep_last_n_checkpoints=1)
304+
tc.assertNotEqual(ckpt_manager.dirpath, "")
305+
tc.assertEqual(
306+
[str(x) for x in ckpt_manager._ckpt_paths],
307+
[
308+
os.path.join(ckpt_manager.dirpath, path)
309+
for path in [
310+
"epoch_0_step_1",
311+
"epoch_1_step_1",
312+
"epoch_1_step_2",
313+
]
314+
],
315+
)
316+
317+
def test_prune_surplus_checkpoints(self) -> None:
318+
# with checkpoints to delete
319+
with tempfile.TemporaryDirectory() as temp_dir:
320+
ckpt_manager = CheckpointManager(temp_dir, keep_last_n_checkpoints=1)
321+
paths = [
322+
CheckpointPath(temp_dir, 0, 0),
323+
CheckpointPath(temp_dir, 0, 1),
324+
CheckpointPath(temp_dir, 1, 0),
325+
]
326+
for path in paths:
327+
os.mkdir(path.path)
328+
329+
ckpt_manager._ckpt_paths = list(paths)
330+
warning_messages = []
331+
expected_warning_msg = (
332+
f"3 checkpoints found in {temp_dir}. ",
333+
f"Deleting {2} oldest ",
334+
"checkpoints to enforce ``keep_last_n_checkpoints`` argument.",
335+
)
336+
with patch(
337+
f"{CheckpointManager.__module__}.logging.Logger.warning",
338+
warning_messages.append,
339+
):
340+
ckpt_manager.prune_surplus_checkpoints()
341+
342+
self.assertEqual(warning_messages[0], expected_warning_msg)
343+
self.assertEqual(ckpt_manager._ckpt_paths, [paths[2]])
344+
self.assertTrue(os.path.exists(paths[2].path))
345+
self.assertFalse(os.path.exists(paths[0].path))
346+
self.assertFalse(os.path.exists(paths[1].path))
347+
348+
# without checkpoints to delete
349+
with tempfile.TemporaryDirectory() as temp_dir:
350+
ckpt_manager = CheckpointManager(temp_dir)
351+
paths = [
352+
CheckpointPath(temp_dir, 0, 0),
353+
CheckpointPath(temp_dir, 0, 1),
354+
CheckpointPath(temp_dir, 1, 0),
355+
]
356+
ckpt_manager._ckpt_paths = list(paths)
357+
ckpt_manager.prune_surplus_checkpoints()
358+
self.assertEqual(ckpt_manager._ckpt_paths, paths)
359+
360+
def test_generate_checkpoint_path(self) -> None:
361+
ckpt_manager = CheckpointManager("foo")
362+
363+
self.assertEqual(
364+
ckpt_manager.generate_checkpoint_path(1, 1).path,
365+
"foo/epoch_1_step_1",
366+
)
367+
368+
self.assertEqual(
369+
ckpt_manager.generate_checkpoint_path(1, 3).path,
370+
"foo/epoch_1_step_3",
371+
)
372+
373+
ckpt_manager._best_checkpoint_config = BestCheckpointConfig(
374+
monitored_metric="val_loss", mode="min"
375+
)
376+
self.assertEqual(
377+
ckpt_manager.generate_checkpoint_path(
378+
1, 3, MetricData("val_loss", 0.5)
379+
).path,
380+
"foo/epoch_1_step_3_val_loss=0.5",
381+
)
382+
383+
# best checkpoint config, but did not pass metric data - expect path but no metric
384+
self.assertEqual(
385+
ckpt_manager.generate_checkpoint_path(1, 2).path,
386+
"foo/epoch_1_step_2",
387+
)
388+
389+
# passed metric data is tracking a different metric than best checkpoint config - expect exception
390+
with self.assertRaisesRegex(
391+
AssertionError,
392+
"Attempted to get a checkpoint with metric 'mean', but best checkpoint config is for 'val_loss'",
393+
):
394+
ckpt_manager.generate_checkpoint_path(1, 2, MetricData("mean", 3.5))
395+
396+
# no best checkpoint config, but passed metric data - expect exception
397+
ckpt_manager._best_checkpoint_config = None
398+
with self.assertRaisesRegex(
399+
AssertionError,
400+
"Attempted to get a checkpoint with metric but best checkpoint config is not set",
401+
):
402+
ckpt_manager.generate_checkpoint_path(1, 2, MetricData("val_loss", 3.5))
403+
404+
def test_append_checkpoint_by_recency(self) -> None:
405+
ckpt_manager = CheckpointManager("foo", keep_last_n_checkpoints=2)
406+
ckpt_manager._ckpt_paths = [CheckpointPath("foo", 0, 0)]
407+
408+
# without need to remove old by recency
409+
ckpt_manager.append_checkpoint(CheckpointPath("foo", 0, 1))
410+
self.assertEqual(
411+
ckpt_manager._ckpt_paths,
412+
[CheckpointPath("foo", 0, 0), CheckpointPath("foo", 0, 1)],
413+
)
414+
415+
# removing old by recency
416+
with patch("fsspec.implementations.local.LocalFileSystem.rm") as mock_rm:
417+
ckpt_manager.append_checkpoint(CheckpointPath("foo", 0, 2))
418+
self.assertEqual(
419+
ckpt_manager._ckpt_paths,
420+
[CheckpointPath("foo", 0, 1), CheckpointPath("foo", 0, 2)],
421+
)
422+
mock_rm.assert_called_once_with("foo/epoch_0_step_0", recursive=True)
423+
424+
def test_append_checkpoint_by_metric(self) -> None:
425+
ckpt_manager = CheckpointManager(
426+
"foo",
427+
keep_last_n_checkpoints=5,
428+
best_checkpoint_config=BestCheckpointConfig(
429+
monitored_metric="val_loss", mode="max"
430+
),
431+
)
432+
paths = [
433+
CheckpointPath(
434+
"foo", 0, x, metric_data=MetricData(name="val_loss", value=0.01 * x)
435+
)
436+
for x in range(1, 7, 1)
437+
]
438+
ckpt_manager._ckpt_paths = [paths[1], paths[2], paths[4]]
439+
# without need to remove old by min metric, goes beginning
440+
ckpt_manager.append_checkpoint(paths[0])
441+
self.assertEqual(
442+
ckpt_manager._ckpt_paths,
443+
[paths[0], paths[1], paths[2], paths[4]],
444+
)
445+
# without need to remove old by min metric, goes end
446+
ckpt_manager.append_checkpoint(paths[5])
447+
self.assertEqual(
448+
ckpt_manager._ckpt_paths,
449+
[paths[0], paths[1], paths[2], paths[4], paths[5]],
450+
)
451+
# removing old max metric, goes middle
452+
with patch("fsspec.implementations.local.LocalFileSystem.rm") as mock_rm:
453+
ckpt_manager.append_checkpoint(paths[3])
454+
self.assertEqual(
455+
ckpt_manager._ckpt_paths,
456+
[paths[1], paths[2], paths[3], paths[4], paths[5]],
457+
)
458+
mock_rm.assert_called_once_with(
459+
"foo/epoch_0_step_1_val_loss=0.01", recursive=True
460+
)
461+
462+
# no metric data - noop
463+
ckpt_manager._keep_last_n_checkpoints = None
464+
ckpt_manager.append_checkpoint(CheckpointPath("foo", 0, 8))
465+
self.assertEqual(
466+
ckpt_manager._ckpt_paths,
467+
[paths[1], paths[2], paths[3], paths[4], paths[5]],
468+
)
469+
470+
def test_should_save_checkpoint(self) -> None:
471+
"""
472+
Tests basic functionality of should_save_checkpoint
473+
"""
474+
ckpt_manager = CheckpointManager("foo")
475+
476+
# test default behavior
477+
ckpt = CheckpointPath("foo", 0, 2)
478+
self.assertTrue(ckpt_manager.should_save_checkpoint(ckpt))
479+
480+
ckpt_manager._ckpt_paths = [CheckpointPath("foo", 0, 1)]
481+
self.assertTrue(ckpt_manager.should_save_checkpoint(ckpt))
482+
ckpt_manager._keep_last_n_checkpoints = 1
483+
self.assertTrue(ckpt_manager.should_save_checkpoint(ckpt))
484+
485+
ckpt_manager._ckpt_paths = [
486+
CheckpointPath(
487+
"foo", 0, 1, metric_data=MetricData(name="val_loss", value=0.01)
488+
),
489+
]
490+
ckpt_manager._best_checkpoint_config = BestCheckpointConfig(
491+
monitored_metric="val_loss",
492+
mode="min",
493+
)
494+
495+
bigger_metric = CheckpointPath(
496+
"foo", 0, 1, metric_data=MetricData(name="val_loss", value=0.02)
497+
)
498+
smaller_metric = CheckpointPath(
499+
"foo", 0, 1, metric_data=MetricData(name="val_loss", value=0.001)
500+
)
501+
ckpt_manager._keep_last_n_checkpoints = None
502+
self.assertTrue(ckpt_manager.should_save_checkpoint(bigger_metric))
503+
ckpt_manager._keep_last_n_checkpoints = 1
504+
self.assertFalse(ckpt_manager.should_save_checkpoint(bigger_metric))
505+
self.assertTrue(ckpt_manager.should_save_checkpoint(smaller_metric))
506+
ckpt_manager._keep_last_n_checkpoints = 2
507+
self.assertTrue(ckpt_manager.should_save_checkpoint(smaller_metric))
508+
self.assertTrue(ckpt_manager.should_save_checkpoint(bigger_metric))
509+
510+
# Make sure we are actually comparing against more optimal element
511+
ckpt_manager._ckpt_paths = [
512+
CheckpointPath(
513+
"foo", 0, 1, metric_data=MetricData(name="val_loss", value=0.01)
514+
),
515+
CheckpointPath(
516+
"foo", 0, 1, metric_data=MetricData(name="val_loss", value=0.05)
517+
),
518+
]
519+
520+
ckpt_manager._best_checkpoint_config = BestCheckpointConfig(
521+
monitored_metric="val_loss",
522+
mode="max",
523+
)
524+
ckpt_manager._keep_last_n_checkpoints = 2
525+
self.assertTrue(ckpt_manager.should_save_checkpoint(bigger_metric))
526+
527+
def test_remove_worst_checkpoint(self) -> None:
528+
with tempfile.TemporaryDirectory() as temp_dir:
529+
os.mkdir(os.path.join(temp_dir, "epoch_0_step_0"))
530+
os.mkdir(os.path.join(temp_dir, "epoch_0_step_1"))
531+
532+
ckpt_manager = CheckpointManager(temp_dir)
533+
ckpt_manager.append_checkpoint(CheckpointPath(temp_dir, 0, 0))
534+
ckpt_manager.append_checkpoint(CheckpointPath(temp_dir, 0, 1))
535+
536+
ckpt_manager.remove_checkpoint()
537+
self.assertFalse(os.path.exists(os.path.join(temp_dir, "epoch_0_step_0")))
538+
self.assertTrue(os.path.exists(os.path.join(temp_dir, "epoch_0_step_1")))
539+
self.assertEqual(ckpt_manager._ckpt_paths, [CheckpointPath(temp_dir, 0, 1)])
540+
541+
193542
class CheckpointUtilsTest(unittest.TestCase):
194543
@staticmethod
195544
def _create_snapshot_metadata(output_dir: str) -> None:
@@ -590,3 +939,12 @@ def test_metadata_exists(self) -> None:
590939

591940
os.remove(os.path.join(dirpath, SNAPSHOT_METADATA_FNAME))
592941
self.assertFalse(_metadata_exists(fs, dirpath, SNAPSHOT_METADATA_FNAME))
942+
943+
944+
class MyValLossUnit(TrainUnit[Batch]):
945+
def __init__(self) -> None:
946+
super().__init__()
947+
self.val_loss = 0.01
948+
949+
def train_step(self, state: State, data: Batch) -> None:
950+
return None

0 commit comments

Comments
 (0)