diff --git a/torchtitan/train.py b/torchtitan/train.py index 807dea8bc5..f55530a083 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -575,6 +575,10 @@ def train(self): logger.warning("Ran out of data; last step was canceled.") break + self.checkpointer.save( + self.step, last_step=(self.step == job_config.training.steps) + ) + # Run validation if validator is available if ( self.job_config.validation.enabled @@ -582,10 +586,6 @@ def train(self): ): self.validator.validate(self.model_parts, self.step) - self.checkpointer.save( - self.step, last_step=(self.step == job_config.training.steps) - ) - # signal the profiler that the next profiling step has started if torch_profiler: torch_profiler.step()