Skip to content

Commit 8ec37d2

Browse files
authored
Fix async_with_pinned_mem doesn't set staging correctly (#1783)
#1287 refactors async_with_pinned_mem to use DCP's implementation but that PR didn't set staging correctly. This PR fixes it. Fixes #1773
1 parent 177b050 commit 8ec37d2

File tree

2 files changed

+71
-1
lines changed

2 files changed

+71
-1
lines changed

tests/unit_tests/test_checkpoint.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,20 @@ def __init__(self):
7575
self.result = mock.Mock()
7676

7777

78+
class DummyAsyncResult:
79+
"""Mock object that mimics the return value of dcp.async_save with pinned memory"""
80+
81+
def __init__(self):
82+
self.upload_completion = DummyFuture()
83+
self.staging_completion = DummyFuture()
84+
85+
7886
def fake_async_save(*args, **kwargs):
79-
return DummyFuture()
87+
# Check if this is async_with_pinned_mem mode by looking for async_stager parameter
88+
if "async_stager" in kwargs:
89+
return DummyAsyncResult()
90+
else:
91+
return DummyFuture()
8092

8193

8294
class DummyJobConfig:
@@ -410,6 +422,62 @@ def test_last_save_model_only_and_initial_load_model_only(
410422
manager1.close()
411423
manager2.close()
412424

425+
@mock.patch("torch.distributed.get_rank", return_value=0)
426+
@mock.patch("torch.cuda.Stream")
427+
@mock.patch("torchtitan.components.checkpoint.DefaultStager")
428+
@mock.patch("torchtitan.components.checkpoint.dist.new_group")
429+
@mock.patch(
430+
"torchtitan.components.checkpoint.dcp.async_save", side_effect=fake_async_save
431+
)
432+
def test_async_save_with_pinned_mem_sets_staging_flag(
433+
self,
434+
mock_async_save,
435+
mock_new_group,
436+
mock_default_stager,
437+
mock_cuda_stream,
438+
mock_rank,
439+
):
440+
"""
441+
Test that AsyncMode.ASYNC_WITH_PINNED_MEM correctly sets staging flag.
442+
443+
This test verifies the bug fix where self.staging was not being set to True
444+
when using ASYNC_WITH_PINNED_MEM mode, which caused maybe_wait_for_staging()
445+
to not wait properly for staging completion.
446+
"""
447+
# Configure async mode with pinned memory
448+
job_config = DummyJobConfig(job=self.job_config.job)
449+
checkpoint_config = job_config.checkpoint
450+
checkpoint_config.async_mode = "async_with_pinned_mem"
451+
452+
manager = CheckpointManager(
453+
dataloader=self.data_loader,
454+
model_parts=self.model_parts,
455+
optimizers=self.optimizers,
456+
lr_schedulers=self.lr_schedulers,
457+
states=self.states,
458+
checkpoint_config=checkpoint_config,
459+
sd_adapter=None,
460+
base_folder=self.job_config.job.dump_folder,
461+
ft_manager=self.ft_manager,
462+
)
463+
464+
# Initially staging should be False
465+
self.assertFalse(manager.staging)
466+
467+
# After save, staging should be set to True
468+
manager.save(curr_step=1, last_step=False)
469+
self.assertTrue(manager.staging)
470+
471+
# Verify that staging_future exists
472+
self.assertIsNotNone(manager.staging_future)
473+
474+
# Verify that maybe_wait_for_staging actually waits when staging is True
475+
manager.maybe_wait_for_staging()
476+
# After waiting, staging should be set back to False
477+
self.assertFalse(manager.staging)
478+
479+
manager.close()
480+
413481
@mock.patch("torchtitan.components.checkpoint.dist.new_group")
414482
@mock.patch(
415483
"torchtitan.components.checkpoint.dcp.async_save", side_effect=fake_async_save

torchtitan/components/checkpoint.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,7 @@ def save(self, curr_step: int, last_step: bool = False) -> None:
494494
)
495495
self.save_future = result.upload_completion
496496
self.staging_future = result.staging_completion
497+
self.staging = True
497498
elif self.async_mode == AsyncMode.ASYNC:
498499
GarbageCollection.collect("GC collection invoked by checkpointer.")
499500
self.save_future = self.dcp_save(
@@ -615,6 +616,7 @@ def maybe_wait_for_staging(self) -> None:
615616
"""
616617
if self.enable_staging and self.staging:
617618
self.staging_future.result()
619+
self.staging = False
618620

619621
def _find_load_step(self, folder: str = "") -> int:
620622
"""Find the step to load the checkpoint for.

0 commit comments

Comments
 (0)