Skip to content

Commit 0d583e4

Browse files
committed
fix checkpoint UT
1 parent 436b35a commit 0d583e4

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

tests/unit_tests/test_checkpoint.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -420,15 +420,15 @@ def test_async_save_calls_async_wait(self, mock_async_save, mock_new_group):
420420

421421
# First save schedules async
422422
manager.save(curr_step=10, last_step=False)
423-
future = manager.async_future
423+
future = manager.save_future
424424
future.result.assert_not_called()
425425

426426
# Second save should wait
427427
manager.save(curr_step=20, last_step=False)
428428
future.result.assert_called_once()
429429

430430
# New future created
431-
new_future = manager.async_future
431+
new_future = manager.save_future
432432
new_future.result.assert_not_called()
433433

434434
@mock.patch("torch.cuda.Stream")
@@ -461,16 +461,16 @@ def test_ft_async_save_calls_async_wait(
461461
)
462462

463463
# Initially no future
464-
self.assertIsNone(manager.async_future)
464+
self.assertIsNone(manager.save_future)
465465
manager.save(curr_step=5, last_step=False)
466-
self.assertIsNotNone(manager.async_future)
466+
self.assertIsNotNone(manager.save_future)
467467

468-
manager.async_future.result.assert_not_called()
469-
prev_future = manager.async_future
468+
manager.save_future.result.assert_not_called()
469+
prev_future = manager.save_future
470470
manager.save(curr_step=6, last_step=False)
471471
prev_future.result.assert_called_once()
472-
self.assertIsNotNone(manager.async_future)
473-
manager.async_future.result.assert_not_called()
472+
self.assertIsNotNone(manager.save_future)
473+
manager.save_future.result.assert_not_called()
474474

475475
@mock.patch("torch.distributed.get_rank", return_value=0)
476476
@mock.patch("torchtitan.components.checkpoint.dcp.save")

0 commit comments

Comments
 (0)