Skip to content

Commit 646cf2f

Browse files
authored
[refactor] Move save_function to accelerator 1/n [DeepSpeed] (#6689)
* move save_checkpoint responsability to accelerator * update
1 parent 3a4c424 commit 646cf2f

File tree

4 files changed

+27
-24
lines changed

4 files changed

+27
-24
lines changed

pytorch_lightning/accelerators/accelerator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,3 +466,6 @@ def connect_precision_plugin(self, plugin: PrecisionPlugin) -> None:
466466
' It will be removed in v1.5.'
467467
)
468468
self.setup_precision_plugin(plugin)
469+
470+
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath) -> None:
471+
self.training_type_plugin.save_checkpoint(checkpoint, filepath)

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,6 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:
106106
trainer.accelerator.setup_optimizers(trainer)
107107
trainer.precision_plugin.connect(self._model, None, None)
108108

109-
# replace trainer save_checkpoint to use `xm.save`
110-
trainer.save_checkpoint = self.save_checkpoint
111109
self.barrier("pre-run-stage")
112110

113111
results = trainer.run_stage()
@@ -298,14 +296,13 @@ def test_step(self, *args, **kwargs):
298296
def predict_step(self, *args, **kwargs):
299297
return self.lightning_module.predict_step(*args, **kwargs)
300298

301-
def save_checkpoint(self, filepath, weights_only: bool = False):
299+
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
302300
"""Save model/training states as a checkpoint file through state-dump and file-write.
303301
304302
Args:
303+
trainer: PyTorch Lightning Trainer
305304
filepath: write-target file's path
306305
weights_only: saving model weights only
307306
"""
308-
# dump states as a checkpoint dictionary object
309-
_checkpoint = self.lightning_module.trainer.checkpoint_connector.dump_checkpoint(weights_only)
310307
# Todo: TypeError: 'mappingproxy' object does not support item assignment
311-
self.save({k: v for k, v in _checkpoint.items() if k != "callbacks"}, filepath)
308+
self.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, filepath)

pytorch_lightning/plugins/training_type/training_type_plugin.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from pytorch_lightning.core.lightning import LightningModule
2323
from pytorch_lightning.overrides.base import unwrap_lightning_module
2424
from pytorch_lightning.plugins.base_plugin import Plugin
25+
from pytorch_lightning.utilities import rank_zero_warn
26+
from pytorch_lightning.utilities.cloud_io import atomic_save
2527

2628
if TYPE_CHECKING:
2729
from pytorch_lightning.trainer.trainer import Trainer
@@ -192,3 +194,19 @@ def setup_optimizers_in_pre_dispatch(self) -> bool:
192194
Returns: If True, delay setup optimizers till pre_dispatch, else call within setup.
193195
"""
194196
return False
197+
198+
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
199+
# dump states as a checkpoint dictionary object
200+
if self.is_global_zero:
201+
checkpoint = self.on_save(checkpoint)
202+
try:
203+
# write the checkpoint dictionary on the file
204+
atomic_save(checkpoint, filepath)
205+
except AttributeError as err:
206+
if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
207+
del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
208+
rank_zero_warn(
209+
'Warning, `hyper_parameters` dropped from checkpoint.'
210+
f' An attribute is not picklable {err}'
211+
)
212+
atomic_save(checkpoint, filepath)

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -386,27 +386,12 @@ def get_max_ckpt_path_from_folder(self, folder_path: Union[str, Path]) -> str:
386386
ckpt_number = max_suffix if max_suffix is not None else 0
387387
return f'{folder_path}/hpc_ckpt_{ckpt_number}.ckpt'
388388

389-
def save_checkpoint(self, filepath, weights_only: bool = False):
389+
def save_checkpoint(self, filepath, weights_only: bool = False) -> None:
390390
"""Save model/training states as a checkpoint file through state-dump and file-write.
391391
392392
Args:
393393
filepath: write-target file's path
394394
weights_only: saving model weights only
395395
"""
396-
# dump states as a checkpoint dictionary object
397-
checkpoint = self.dump_checkpoint(weights_only)
398-
if self.trainer.is_global_zero:
399-
# write the checkpoint dictionary on the file
400-
401-
if self.trainer.training_type_plugin:
402-
checkpoint = self.trainer.training_type_plugin.on_save(checkpoint)
403-
try:
404-
atomic_save(checkpoint, filepath)
405-
except AttributeError as err:
406-
if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
407-
del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
408-
rank_zero_warn(
409-
'Warning, `hyper_parameters` dropped from checkpoint.'
410-
f' An attribute is not picklable {err}'
411-
)
412-
atomic_save(checkpoint, filepath)
396+
_checkpoint = self.dump_checkpoint(weights_only)
397+
self.trainer.accelerator.save_checkpoint(_checkpoint, filepath)

0 commit comments

Comments
 (0)