Skip to content

Commit a3aa31c

Browse files
committed
Fix async_with_pinned_mem doesn't set staging correctly
#1287 refactors async_with_pinned_mem to use DCP's implementation but that PR didn't set staging correctly. This PR fixes it.
1 parent fa21894 commit a3aa31c

File tree

2 files changed

+65
-1
lines changed

2 files changed

+65
-1
lines changed

tests/unit_tests/test_checkpoint.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,20 @@ def __init__(self):
7575
self.result = mock.Mock()
7676

7777

78+
class DummyAsyncResult:
79+
"""Mock object that mimics the return value of dcp.async_save with pinned memory"""
80+
81+
def __init__(self):
82+
self.upload_completion = DummyFuture()
83+
self.staging_completion = DummyFuture()
84+
85+
7886
def fake_async_save(*args, **kwargs):
79-
return DummyFuture()
87+
# Check if this is async_with_pinned_mem mode by looking for async_stager parameter
88+
if "async_stager" in kwargs:
89+
return DummyAsyncResult()
90+
else:
91+
return DummyFuture()
8092

8193

8294
class DummyJobConfig:
@@ -410,6 +422,56 @@ def test_last_save_model_only_and_initial_load_model_only(
410422
manager1.close()
411423
manager2.close()
412424

425+
@mock.patch("torch.distributed.get_rank", return_value=0)
426+
@mock.patch("torch.cuda.Stream")
427+
@mock.patch("torchtitan.components.checkpoint.dist.new_group")
428+
@mock.patch(
429+
"torchtitan.components.checkpoint.dcp.async_save", side_effect=fake_async_save
430+
)
431+
def test_async_save_with_pinned_mem_sets_staging_flag(
432+
self, mock_async_save, mock_new_group, mock_cuda_stream, mock_rank
433+
):
434+
"""
435+
Test that AsyncMode.ASYNC_WITH_PINNED_MEM correctly sets staging flag.
436+
437+
This test verifies the bug fix where self.staging was not being set to True
438+
when using ASYNC_WITH_PINNED_MEM mode, which caused maybe_wait_for_staging()
439+
to not wait properly for staging completion.
440+
"""
441+
# Configure async mode with pinned memory
442+
job_config = DummyJobConfig(job=self.job_config.job)
443+
checkpoint_config = job_config.checkpoint
444+
checkpoint_config.async_mode = "async_with_pinned_mem"
445+
446+
manager = CheckpointManager(
447+
dataloader=self.data_loader,
448+
model_parts=self.model_parts,
449+
optimizers=self.optimizers,
450+
lr_schedulers=self.lr_schedulers,
451+
states=self.states,
452+
checkpoint_config=checkpoint_config,
453+
sd_adapter=None,
454+
base_folder=self.job_config.job.dump_folder,
455+
ft_manager=self.ft_manager,
456+
)
457+
458+
# Initially staging should be False
459+
self.assertFalse(manager.staging)
460+
461+
# After save, staging should be set to True
462+
manager.save(curr_step=1, last_step=False)
463+
self.assertTrue(manager.staging)
464+
465+
# Verify that staging_future exists
466+
self.assertIsNotNone(manager.staging_future)
467+
468+
# Verify that maybe_wait_for_staging actually waits when staging is True
469+
manager.maybe_wait_for_staging()
470+
# After waiting, staging should be set back to False
471+
self.assertFalse(manager.staging)
472+
473+
manager.close()
474+
413475
@mock.patch("torchtitan.components.checkpoint.dist.new_group")
414476
@mock.patch(
415477
"torchtitan.components.checkpoint.dcp.async_save", side_effect=fake_async_save

torchtitan/components/checkpoint.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,7 @@ def save(self, curr_step: int, last_step: bool = False) -> None:
494494
)
495495
self.save_future = result.upload_completion
496496
self.staging_future = result.staging_completion
497+
self.staging = True
497498
elif self.async_mode == AsyncMode.ASYNC:
498499
GarbageCollection.collect("GC collection invoked by checkpointer.")
499500
self.save_future = self.dcp_save(
@@ -615,6 +616,7 @@ def maybe_wait_for_staging(self) -> None:
615616
"""
616617
if self.enable_staging and self.staging:
617618
self.staging_future.result()
619+
self.staging = False
618620

619621
def _find_load_step(self, folder: str = "") -> int:
620622
"""Find the step to load the checkpoint for.

0 commit comments

Comments
 (0)