|
10 | 10 | import shutil
|
11 | 11 | import tempfile
|
12 | 12 | import unittest
|
| 13 | +from unittest.mock import patch |
13 | 14 |
|
14 | 15 | import torch
|
15 | 16 |
|
16 | 17 | import torch.distributed as dist
|
17 | 18 | from torch import nn
|
18 | 19 | from torchsnapshot import Snapshot
|
19 | 20 | from torchsnapshot.snapshot import SNAPSHOT_METADATA_FNAME
|
| 21 | +from torchtnt.framework._test_utils import Batch |
| 22 | +from torchtnt.framework.state import State |
| 23 | +from torchtnt.framework.unit import TrainUnit |
20 | 24 | from torchtnt.utils import get_global_rank, init_from_env
|
21 | 25 |
|
22 | 26 | from torchtnt.utils.checkpoint import (
|
|
25 | 29 | _retrieve_checkpoint_dirpaths,
|
26 | 30 | _sort_by_metric_value,
|
27 | 31 | _sort_by_recency,
|
| 32 | + BestCheckpointConfig, |
| 33 | + CheckpointManager, |
28 | 34 | CheckpointPath,
|
29 | 35 | get_best_checkpoint_path,
|
30 | 36 | get_checkpoint_dirpaths,
|
@@ -190,6 +196,349 @@ def test_pickling(self) -> None:
|
190 | 196 | self.assertEqual(unpickled, ckpt)
|
191 | 197 |
|
192 | 198 |
|
| 199 | +class CheckpointManagerTest(unittest.TestCase): |
| 200 | + def test_create_checkpoint_manager(self) -> None: |
| 201 | + with tempfile.TemporaryDirectory() as temp_dir: |
| 202 | + paths = [ |
| 203 | + f"{temp_dir}/epoch_1_step_3", |
| 204 | + f"{temp_dir}/epoch_0_step_1", |
| 205 | + f"{temp_dir}/epoch_0_step_5_loss=-0.3", |
| 206 | + f"{temp_dir}/epoch_1_step_1", |
| 207 | + f"{temp_dir}/epoch_1_step_2_loss=0.5", |
| 208 | + f"{temp_dir}/epoch_2_step_5_loss=0.3", |
| 209 | + f"{temp_dir}/epoch_0_step_2_acc=0.7", |
| 210 | + ] |
| 211 | + for path in paths: |
| 212 | + os.mkdir(path) |
| 213 | + |
| 214 | + # without last_n_checkpoints |
| 215 | + ckpt_manager = CheckpointManager(temp_dir) |
| 216 | + self.assertEqual(ckpt_manager._ckpt_paths, []) |
| 217 | + |
| 218 | + # with last_n_checkpoints but without metric |
| 219 | + ckpt_manager = CheckpointManager(temp_dir, keep_last_n_checkpoints=2) |
| 220 | + self.assertEqual( |
| 221 | + [x.path for x in ckpt_manager._ckpt_paths], |
| 222 | + [ |
| 223 | + f"{temp_dir}/epoch_0_step_1", |
| 224 | + f"{temp_dir}/epoch_0_step_2_acc=0.7", |
| 225 | + f"{temp_dir}/epoch_0_step_5_loss=-0.3", |
| 226 | + f"{temp_dir}/epoch_1_step_1", |
| 227 | + f"{temp_dir}/epoch_1_step_2_loss=0.5", |
| 228 | + f"{temp_dir}/epoch_1_step_3", |
| 229 | + f"{temp_dir}/epoch_2_step_5_loss=0.3", |
| 230 | + ], |
| 231 | + ) |
| 232 | + |
| 233 | + # with last_n_checkpoints and metric min |
| 234 | + ckpt_manager = CheckpointManager( |
| 235 | + temp_dir, |
| 236 | + keep_last_n_checkpoints=3, |
| 237 | + best_checkpoint_config=BestCheckpointConfig( |
| 238 | + monitored_metric="loss", mode="min" |
| 239 | + ), |
| 240 | + ) |
| 241 | + self.assertEqual( |
| 242 | + [x.path for x in ckpt_manager._ckpt_paths], |
| 243 | + [ |
| 244 | + f"{temp_dir}/epoch_1_step_2_loss=0.5", |
| 245 | + f"{temp_dir}/epoch_2_step_5_loss=0.3", |
| 246 | + f"{temp_dir}/epoch_0_step_5_loss=-0.3", |
| 247 | + ], |
| 248 | + ) |
| 249 | + |
| 250 | + # with last_n_checkpoints and metric max |
| 251 | + ckpt_manager = CheckpointManager( |
| 252 | + temp_dir, |
| 253 | + keep_last_n_checkpoints=3, |
| 254 | + best_checkpoint_config=BestCheckpointConfig( |
| 255 | + monitored_metric="loss", mode="max" |
| 256 | + ), |
| 257 | + ) |
| 258 | + self.assertEqual( |
| 259 | + [x.path for x in ckpt_manager._ckpt_paths], |
| 260 | + [ |
| 261 | + f"{temp_dir}/epoch_0_step_5_loss=-0.3", |
| 262 | + f"{temp_dir}/epoch_2_step_5_loss=0.3", |
| 263 | + f"{temp_dir}/epoch_1_step_2_loss=0.5", |
| 264 | + ], |
| 265 | + ) |
| 266 | + |
| 267 | + # with last_n_checkpoints and non previously tracked metric |
| 268 | + ckpt_manager = CheckpointManager( |
| 269 | + temp_dir, |
| 270 | + keep_last_n_checkpoints=3, |
| 271 | + best_checkpoint_config=BestCheckpointConfig( |
| 272 | + monitored_metric="foo", mode="max" |
| 273 | + ), |
| 274 | + ) |
| 275 | + self.assertEqual(ckpt_manager._ckpt_paths, []) |
| 276 | + |
| 277 | + @skip_if_not_distributed |
| 278 | + def test_create_checkpoint_manager_distributed(self) -> None: |
| 279 | + spawn_multi_process( |
| 280 | + 2, |
| 281 | + "gloo", |
| 282 | + self._test_create_checkpoint_manager_distributed, |
| 283 | + ) |
| 284 | + |
| 285 | + @staticmethod |
| 286 | + def _test_create_checkpoint_manager_distributed() -> None: |
| 287 | + if get_global_rank() == 0: |
| 288 | + temp_dir = tempfile.mkdtemp() |
| 289 | + paths = ["epoch_1_step_2", "epoch_0_step_1", "epoch_1_step_1"] |
| 290 | + for path in paths: |
| 291 | + os.mkdir(os.path.join(temp_dir, path)) |
| 292 | + else: |
| 293 | + temp_dir = "" |
| 294 | + |
| 295 | + tc = unittest.TestCase() |
| 296 | + |
| 297 | + # without top k config |
| 298 | + ckpt_manager = CheckpointManager(temp_dir) |
| 299 | + tc.assertNotEqual(ckpt_manager.dirpath, "") |
| 300 | + tc.assertEqual(ckpt_manager._ckpt_paths, []) |
| 301 | + |
| 302 | + # with top k config |
| 303 | + ckpt_manager = CheckpointManager(temp_dir, keep_last_n_checkpoints=1) |
| 304 | + tc.assertNotEqual(ckpt_manager.dirpath, "") |
| 305 | + tc.assertEqual( |
| 306 | + [str(x) for x in ckpt_manager._ckpt_paths], |
| 307 | + [ |
| 308 | + os.path.join(ckpt_manager.dirpath, path) |
| 309 | + for path in [ |
| 310 | + "epoch_0_step_1", |
| 311 | + "epoch_1_step_1", |
| 312 | + "epoch_1_step_2", |
| 313 | + ] |
| 314 | + ], |
| 315 | + ) |
| 316 | + |
| 317 | + def test_prune_surplus_checkpoints(self) -> None: |
| 318 | + # with checkpoints to delete |
| 319 | + with tempfile.TemporaryDirectory() as temp_dir: |
| 320 | + ckpt_manager = CheckpointManager(temp_dir, keep_last_n_checkpoints=1) |
| 321 | + paths = [ |
| 322 | + CheckpointPath(temp_dir, 0, 0), |
| 323 | + CheckpointPath(temp_dir, 0, 1), |
| 324 | + CheckpointPath(temp_dir, 1, 0), |
| 325 | + ] |
| 326 | + for path in paths: |
| 327 | + os.mkdir(path.path) |
| 328 | + |
| 329 | + ckpt_manager._ckpt_paths = list(paths) |
| 330 | + warning_messages = [] |
| 331 | + expected_warning_msg = ( |
| 332 | + f"3 checkpoints found in {temp_dir}. ", |
| 333 | + f"Deleting {2} oldest ", |
| 334 | + "checkpoints to enforce ``keep_last_n_checkpoints`` argument.", |
| 335 | + ) |
| 336 | + with patch( |
| 337 | + f"{CheckpointManager.__module__}.logging.Logger.warning", |
| 338 | + warning_messages.append, |
| 339 | + ): |
| 340 | + ckpt_manager.prune_surplus_checkpoints() |
| 341 | + |
| 342 | + self.assertEqual(warning_messages[0], expected_warning_msg) |
| 343 | + self.assertEqual(ckpt_manager._ckpt_paths, [paths[2]]) |
| 344 | + self.assertTrue(os.path.exists(paths[2].path)) |
| 345 | + self.assertFalse(os.path.exists(paths[0].path)) |
| 346 | + self.assertFalse(os.path.exists(paths[1].path)) |
| 347 | + |
| 348 | + # without checkpoints to delete |
| 349 | + with tempfile.TemporaryDirectory() as temp_dir: |
| 350 | + ckpt_manager = CheckpointManager(temp_dir) |
| 351 | + paths = [ |
| 352 | + CheckpointPath(temp_dir, 0, 0), |
| 353 | + CheckpointPath(temp_dir, 0, 1), |
| 354 | + CheckpointPath(temp_dir, 1, 0), |
| 355 | + ] |
| 356 | + ckpt_manager._ckpt_paths = list(paths) |
| 357 | + ckpt_manager.prune_surplus_checkpoints() |
| 358 | + self.assertEqual(ckpt_manager._ckpt_paths, paths) |
| 359 | + |
| 360 | + def test_generate_checkpoint_path(self) -> None: |
| 361 | + ckpt_manager = CheckpointManager("foo") |
| 362 | + |
| 363 | + self.assertEqual( |
| 364 | + ckpt_manager.generate_checkpoint_path(1, 1).path, |
| 365 | + "foo/epoch_1_step_1", |
| 366 | + ) |
| 367 | + |
| 368 | + self.assertEqual( |
| 369 | + ckpt_manager.generate_checkpoint_path(1, 3).path, |
| 370 | + "foo/epoch_1_step_3", |
| 371 | + ) |
| 372 | + |
| 373 | + ckpt_manager._best_checkpoint_config = BestCheckpointConfig( |
| 374 | + monitored_metric="val_loss", mode="min" |
| 375 | + ) |
| 376 | + self.assertEqual( |
| 377 | + ckpt_manager.generate_checkpoint_path( |
| 378 | + 1, 3, MetricData("val_loss", 0.5) |
| 379 | + ).path, |
| 380 | + "foo/epoch_1_step_3_val_loss=0.5", |
| 381 | + ) |
| 382 | + |
| 383 | + # best checkpoint config, but did not pass metric data - expect path but no metric |
| 384 | + self.assertEqual( |
| 385 | + ckpt_manager.generate_checkpoint_path(1, 2).path, |
| 386 | + "foo/epoch_1_step_2", |
| 387 | + ) |
| 388 | + |
| 389 | + # passed metric data is tracking a different metric than best checkpoint config - expect exception |
| 390 | + with self.assertRaisesRegex( |
| 391 | + AssertionError, |
| 392 | + "Attempted to get a checkpoint with metric 'mean', but best checkpoint config is for 'val_loss'", |
| 393 | + ): |
| 394 | + ckpt_manager.generate_checkpoint_path(1, 2, MetricData("mean", 3.5)) |
| 395 | + |
| 396 | + # no best checkpoint config, but passed metric data - expect exception |
| 397 | + ckpt_manager._best_checkpoint_config = None |
| 398 | + with self.assertRaisesRegex( |
| 399 | + AssertionError, |
| 400 | + "Attempted to get a checkpoint with metric but best checkpoint config is not set", |
| 401 | + ): |
| 402 | + ckpt_manager.generate_checkpoint_path(1, 2, MetricData("val_loss", 3.5)) |
| 403 | + |
| 404 | + def test_append_checkpoint_by_recency(self) -> None: |
| 405 | + ckpt_manager = CheckpointManager("foo", keep_last_n_checkpoints=2) |
| 406 | + ckpt_manager._ckpt_paths = [CheckpointPath("foo", 0, 0)] |
| 407 | + |
| 408 | + # without need to remove old by recency |
| 409 | + ckpt_manager.append_checkpoint(CheckpointPath("foo", 0, 1)) |
| 410 | + self.assertEqual( |
| 411 | + ckpt_manager._ckpt_paths, |
| 412 | + [CheckpointPath("foo", 0, 0), CheckpointPath("foo", 0, 1)], |
| 413 | + ) |
| 414 | + |
| 415 | + # removing old by recency |
| 416 | + with patch("fsspec.implementations.local.LocalFileSystem.rm") as mock_rm: |
| 417 | + ckpt_manager.append_checkpoint(CheckpointPath("foo", 0, 2)) |
| 418 | + self.assertEqual( |
| 419 | + ckpt_manager._ckpt_paths, |
| 420 | + [CheckpointPath("foo", 0, 1), CheckpointPath("foo", 0, 2)], |
| 421 | + ) |
| 422 | + mock_rm.assert_called_once_with("foo/epoch_0_step_0", recursive=True) |
| 423 | + |
| 424 | + def test_append_checkpoint_by_metric(self) -> None: |
| 425 | + ckpt_manager = CheckpointManager( |
| 426 | + "foo", |
| 427 | + keep_last_n_checkpoints=5, |
| 428 | + best_checkpoint_config=BestCheckpointConfig( |
| 429 | + monitored_metric="val_loss", mode="max" |
| 430 | + ), |
| 431 | + ) |
| 432 | + paths = [ |
| 433 | + CheckpointPath( |
| 434 | + "foo", 0, x, metric_data=MetricData(name="val_loss", value=0.01 * x) |
| 435 | + ) |
| 436 | + for x in range(1, 7, 1) |
| 437 | + ] |
| 438 | + ckpt_manager._ckpt_paths = [paths[1], paths[2], paths[4]] |
| 439 | + # without need to remove old by min metric, goes beginning |
| 440 | + ckpt_manager.append_checkpoint(paths[0]) |
| 441 | + self.assertEqual( |
| 442 | + ckpt_manager._ckpt_paths, |
| 443 | + [paths[0], paths[1], paths[2], paths[4]], |
| 444 | + ) |
| 445 | + # without need to remove old by min metric, goes end |
| 446 | + ckpt_manager.append_checkpoint(paths[5]) |
| 447 | + self.assertEqual( |
| 448 | + ckpt_manager._ckpt_paths, |
| 449 | + [paths[0], paths[1], paths[2], paths[4], paths[5]], |
| 450 | + ) |
| 451 | + # removing old max metric, goes middle |
| 452 | + with patch("fsspec.implementations.local.LocalFileSystem.rm") as mock_rm: |
| 453 | + ckpt_manager.append_checkpoint(paths[3]) |
| 454 | + self.assertEqual( |
| 455 | + ckpt_manager._ckpt_paths, |
| 456 | + [paths[1], paths[2], paths[3], paths[4], paths[5]], |
| 457 | + ) |
| 458 | + mock_rm.assert_called_once_with( |
| 459 | + "foo/epoch_0_step_1_val_loss=0.01", recursive=True |
| 460 | + ) |
| 461 | + |
| 462 | + # no metric data - noop |
| 463 | + ckpt_manager._keep_last_n_checkpoints = None |
| 464 | + ckpt_manager.append_checkpoint(CheckpointPath("foo", 0, 8)) |
| 465 | + self.assertEqual( |
| 466 | + ckpt_manager._ckpt_paths, |
| 467 | + [paths[1], paths[2], paths[3], paths[4], paths[5]], |
| 468 | + ) |
| 469 | + |
| 470 | + def test_should_save_checkpoint(self) -> None: |
| 471 | + """ |
| 472 | + Tests basic functionality of should_save_checkpoint |
| 473 | + """ |
| 474 | + ckpt_manager = CheckpointManager("foo") |
| 475 | + |
| 476 | + # test default behavior |
| 477 | + ckpt = CheckpointPath("foo", 0, 2) |
| 478 | + self.assertTrue(ckpt_manager.should_save_checkpoint(ckpt)) |
| 479 | + |
| 480 | + ckpt_manager._ckpt_paths = [CheckpointPath("foo", 0, 1)] |
| 481 | + self.assertTrue(ckpt_manager.should_save_checkpoint(ckpt)) |
| 482 | + ckpt_manager._keep_last_n_checkpoints = 1 |
| 483 | + self.assertTrue(ckpt_manager.should_save_checkpoint(ckpt)) |
| 484 | + |
| 485 | + ckpt_manager._ckpt_paths = [ |
| 486 | + CheckpointPath( |
| 487 | + "foo", 0, 1, metric_data=MetricData(name="val_loss", value=0.01) |
| 488 | + ), |
| 489 | + ] |
| 490 | + ckpt_manager._best_checkpoint_config = BestCheckpointConfig( |
| 491 | + monitored_metric="val_loss", |
| 492 | + mode="min", |
| 493 | + ) |
| 494 | + |
| 495 | + bigger_metric = CheckpointPath( |
| 496 | + "foo", 0, 1, metric_data=MetricData(name="val_loss", value=0.02) |
| 497 | + ) |
| 498 | + smaller_metric = CheckpointPath( |
| 499 | + "foo", 0, 1, metric_data=MetricData(name="val_loss", value=0.001) |
| 500 | + ) |
| 501 | + ckpt_manager._keep_last_n_checkpoints = None |
| 502 | + self.assertTrue(ckpt_manager.should_save_checkpoint(bigger_metric)) |
| 503 | + ckpt_manager._keep_last_n_checkpoints = 1 |
| 504 | + self.assertFalse(ckpt_manager.should_save_checkpoint(bigger_metric)) |
| 505 | + self.assertTrue(ckpt_manager.should_save_checkpoint(smaller_metric)) |
| 506 | + ckpt_manager._keep_last_n_checkpoints = 2 |
| 507 | + self.assertTrue(ckpt_manager.should_save_checkpoint(smaller_metric)) |
| 508 | + self.assertTrue(ckpt_manager.should_save_checkpoint(bigger_metric)) |
| 509 | + |
| 510 | + # Make sure we are actually comparing against more optimal element |
| 511 | + ckpt_manager._ckpt_paths = [ |
| 512 | + CheckpointPath( |
| 513 | + "foo", 0, 1, metric_data=MetricData(name="val_loss", value=0.01) |
| 514 | + ), |
| 515 | + CheckpointPath( |
| 516 | + "foo", 0, 1, metric_data=MetricData(name="val_loss", value=0.05) |
| 517 | + ), |
| 518 | + ] |
| 519 | + |
| 520 | + ckpt_manager._best_checkpoint_config = BestCheckpointConfig( |
| 521 | + monitored_metric="val_loss", |
| 522 | + mode="max", |
| 523 | + ) |
| 524 | + ckpt_manager._keep_last_n_checkpoints = 2 |
| 525 | + self.assertTrue(ckpt_manager.should_save_checkpoint(bigger_metric)) |
| 526 | + |
| 527 | + def test_remove_worst_checkpoint(self) -> None: |
| 528 | + with tempfile.TemporaryDirectory() as temp_dir: |
| 529 | + os.mkdir(os.path.join(temp_dir, "epoch_0_step_0")) |
| 530 | + os.mkdir(os.path.join(temp_dir, "epoch_0_step_1")) |
| 531 | + |
| 532 | + ckpt_manager = CheckpointManager(temp_dir) |
| 533 | + ckpt_manager.append_checkpoint(CheckpointPath(temp_dir, 0, 0)) |
| 534 | + ckpt_manager.append_checkpoint(CheckpointPath(temp_dir, 0, 1)) |
| 535 | + |
| 536 | + ckpt_manager.remove_checkpoint() |
| 537 | + self.assertFalse(os.path.exists(os.path.join(temp_dir, "epoch_0_step_0"))) |
| 538 | + self.assertTrue(os.path.exists(os.path.join(temp_dir, "epoch_0_step_1"))) |
| 539 | + self.assertEqual(ckpt_manager._ckpt_paths, [CheckpointPath(temp_dir, 0, 1)]) |
| 540 | + |
| 541 | + |
193 | 542 | class CheckpointUtilsTest(unittest.TestCase):
|
194 | 543 | @staticmethod
|
195 | 544 | def _create_snapshot_metadata(output_dir: str) -> None:
|
@@ -590,3 +939,12 @@ def test_metadata_exists(self) -> None:
|
590 | 939 |
|
591 | 940 | os.remove(os.path.join(dirpath, SNAPSHOT_METADATA_FNAME))
|
592 | 941 | self.assertFalse(_metadata_exists(fs, dirpath, SNAPSHOT_METADATA_FNAME))
|
| 942 | + |
| 943 | + |
| 944 | +class MyValLossUnit(TrainUnit[Batch]): |
| 945 | + def __init__(self) -> None: |
| 946 | + super().__init__() |
| 947 | + self.val_loss = 0.01 |
| 948 | + |
| 949 | + def train_step(self, state: State, data: Batch) -> None: |
| 950 | + return None |
0 commit comments