@@ -75,8 +75,20 @@ def __init__(self):
7575 self .result = mock .Mock ()
7676
7777
78+ class DummyAsyncResult :
79+ """Mock object that mimics the return value of dcp.async_save with pinned memory"""
80+
81+ def __init__ (self ):
82+ self .upload_completion = DummyFuture ()
83+ self .staging_completion = DummyFuture ()
84+
85+
7886def fake_async_save (* args , ** kwargs ):
79- return DummyFuture ()
87+ # Check if this is async_with_pinned_mem mode by looking for async_stager parameter
88+ if "async_stager" in kwargs :
89+ return DummyAsyncResult ()
90+ else :
91+ return DummyFuture ()
8092
8193
8294class DummyJobConfig :
@@ -410,6 +422,62 @@ def test_last_save_model_only_and_initial_load_model_only(
410422 manager1 .close ()
411423 manager2 .close ()
412424
425+ @mock .patch ("torch.distributed.get_rank" , return_value = 0 )
426+ @mock .patch ("torch.cuda.Stream" )
427+ @mock .patch ("torchtitan.components.checkpoint.DefaultStager" )
428+ @mock .patch ("torchtitan.components.checkpoint.dist.new_group" )
429+ @mock .patch (
430+ "torchtitan.components.checkpoint.dcp.async_save" , side_effect = fake_async_save
431+ )
432+ def test_async_save_with_pinned_mem_sets_staging_flag (
433+ self ,
434+ mock_async_save ,
435+ mock_new_group ,
436+ mock_default_stager ,
437+ mock_cuda_stream ,
438+ mock_rank ,
439+ ):
440+ """
441+ Test that AsyncMode.ASYNC_WITH_PINNED_MEM correctly sets staging flag.
442+
443+ This test verifies the bug fix where self.staging was not being set to True
444+ when using ASYNC_WITH_PINNED_MEM mode, which caused maybe_wait_for_staging()
445+ to not wait properly for staging completion.
446+ """
447+ # Configure async mode with pinned memory
448+ job_config = DummyJobConfig (job = self .job_config .job )
449+ checkpoint_config = job_config .checkpoint
450+ checkpoint_config .async_mode = "async_with_pinned_mem"
451+
452+ manager = CheckpointManager (
453+ dataloader = self .data_loader ,
454+ model_parts = self .model_parts ,
455+ optimizers = self .optimizers ,
456+ lr_schedulers = self .lr_schedulers ,
457+ states = self .states ,
458+ checkpoint_config = checkpoint_config ,
459+ sd_adapter = None ,
460+ base_folder = self .job_config .job .dump_folder ,
461+ ft_manager = self .ft_manager ,
462+ )
463+
464+ # Initially staging should be False
465+ self .assertFalse (manager .staging )
466+
467+ # After save, staging should be set to True
468+ manager .save (curr_step = 1 , last_step = False )
469+ self .assertTrue (manager .staging )
470+
471+ # Verify that staging_future exists
472+ self .assertIsNotNone (manager .staging_future )
473+
474+ # Verify that maybe_wait_for_staging actually waits when staging is True
475+ manager .maybe_wait_for_staging ()
476+ # After waiting, staging should be set back to False
477+ self .assertFalse (manager .staging )
478+
479+ manager .close ()
480+
413481 @mock .patch ("torchtitan.components.checkpoint.dist.new_group" )
414482 @mock .patch (
415483 "torchtitan.components.checkpoint.dcp.async_save" , side_effect = fake_async_save
0 commit comments