Skip to content

Commit 42c7f27

Browse files
awaelchlipre-commit-ci[bot]carmocca
authored
refactor checkpoint loading for training type plugins (#7928)
* plugin loading logic * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * integrate loading for test * fix * fix * unused iport * Update pytorch_lightning/trainer/connectors/checkpoint_connector.py Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent ac4eb0a commit 42c7f27

File tree

4 files changed

+58
-72
lines changed

4 files changed

+58
-72
lines changed

pytorch_lightning/plugins/training_type/deepspeed.py

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from collections import OrderedDict
1919
from pathlib import Path
2020
from types import SimpleNamespace
21-
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
21+
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, Union
2222

2323
import torch
2424

@@ -524,37 +524,34 @@ def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None:
524524
else:
525525
super().save_checkpoint(checkpoint, filepath)
526526

527-
def restore_model_state_from_ckpt_path(
528-
self,
529-
ckpt_path: str,
530-
map_location: Callable = lambda storage, loc: storage,
531-
) -> Tuple[Dict, bool]:
532-
if not self.save_full_weights and self.world_size > 1:
533-
# Rely on deepspeed to load the checkpoint and necessary information
534-
from pytorch_lightning.trainer.states import TrainerFn
535-
is_fitting = self.lightning_module.trainer.state.fn == TrainerFn.FITTING
536-
save_dir = self._filepath_to_dir(ckpt_path)
537-
538-
if self.zero_stage_3:
539-
# TODO: Currently required as this call is missing within the deepspeed engine.
540-
self.deepspeed_engine.optimizer._partition_all_parameters()
541-
542-
_, client_state = self.deepspeed_engine.load_checkpoint(
543-
save_dir, load_optimizer_states=is_fitting, load_lr_scheduler_states=is_fitting
544-
)
527+
def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
528+
if self.save_full_weights or self.world_size == 1:
529+
# Broadcast to ensure we load from the rank 0 checkpoint
530+
# This doesn't have to be the case when using deepspeed sharded checkpointing
531+
checkpoint_path = self.broadcast(checkpoint_path)
532+
return super().load_checkpoint_file(checkpoint_path)
533+
534+
# Rely on deepspeed to load the checkpoint and necessary information
535+
from pytorch_lightning.trainer.states import TrainerFn
536+
is_fitting = self.lightning_module.trainer.state.fn == TrainerFn.FITTING
537+
save_dir = self._filepath_to_dir(checkpoint_path)
545538

546-
# restore datamodule states
547-
if self.lightning_module.trainer.datamodule is not None:
548-
self.lightning_module.trainer.datamodule.on_load_checkpoint(client_state)
539+
if self.zero_stage_3:
540+
# TODO: Currently required as this call is missing within the deepspeed engine.
541+
self.deepspeed_engine.optimizer._partition_all_parameters()
542+
543+
_, client_state = self.deepspeed_engine.load_checkpoint(
544+
save_dir, load_optimizer_states=is_fitting, load_lr_scheduler_states=is_fitting
545+
)
546+
return client_state
549547

550-
# hook: give user access to checkpoint if needed.
551-
self.lightning_module.on_load_checkpoint(client_state)
552-
return client_state, False
548+
def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
549+
# override to do nothing, deepspeed engine already loaded the weights in `load_checkpoint_file()`
550+
pass
553551

554-
# Broadcast to ensure we load from the rank 0 checkpoint
555-
# This doesn't have to be the case when using deepspeed sharded checkpointing
556-
ckpt_path = self.broadcast(ckpt_path)
557-
return super().restore_model_state_from_ckpt_path(ckpt_path, map_location=map_location)
552+
def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
553+
# override to do nothing, deepspeed engine already loaded the states in `load_checkpoint_file()`
554+
pass
558555

559556
def update_global_step(self, total_batch_idx: int, current_global_step: int) -> int:
560557
if self._original_accumulate_grad_batches is None:

pytorch_lightning/plugins/training_type/training_type_plugin.py

Lines changed: 13 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
# limitations under the License.
1414
import contextlib
1515
from abc import ABC, abstractmethod
16-
from typing import Any, Callable, Dict, Generator, Iterable, Optional, Tuple, TypeVar, Union
16+
from pathlib import Path
17+
from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, TypeVar, Union
1718

1819
import torch
1920
from torch import Tensor
@@ -148,6 +149,17 @@ def results(self) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
148149
def rpc_enabled(self) -> bool:
149150
return False
150151

152+
def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
153+
return pl_load(checkpoint_path, map_location=(lambda storage, loc: storage))
154+
155+
def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
156+
self.lightning_module.load_state_dict(checkpoint["state_dict"])
157+
158+
def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
159+
optimizer_states = checkpoint["optimizer_states"]
160+
for optimizer, opt_state in zip(self.lightning_module.trainer.accelerator.optimizers, optimizer_states):
161+
optimizer.load_state_dict(opt_state)
162+
151163
def start_training(self, trainer: 'pl.Trainer') -> None:
152164
# double dispatch to initiate the training loop
153165
self._results = trainer.run_stage()
@@ -227,33 +239,6 @@ def setup_optimizers_in_pre_dispatch(self) -> bool:
227239
"""
228240
return False
229241

230-
def restore_model_state_from_ckpt_path(
231-
self,
232-
ckpt_path: str,
233-
map_location: Callable = lambda storage, loc: storage,
234-
) -> Tuple[Dict, bool]:
235-
"""
236-
This function is used to load and restore the model state.
237-
238-
Args:
239-
ckpt_path: Path to a checkpoint
240-
map_location: lambda function to map checkpoint location
241-
242-
Return
243-
checkpoint: Return loaded checkpoint
244-
bool: Wether to load optimizer / lr_schedulers states from checkpoint
245-
246-
"""
247-
ckpt = pl_load(ckpt_path, map_location=map_location)
248-
# restore datamodule states
249-
if self.lightning_module.trainer.datamodule is not None:
250-
self.lightning_module.trainer.datamodule.on_load_checkpoint(ckpt)
251-
252-
# hook: give user access to checkpoint if needed.
253-
self.lightning_module.on_load_checkpoint(ckpt)
254-
self.lightning_module.load_state_dict(ckpt['state_dict'])
255-
return ckpt, True
256-
257242
def update_global_step(self, total_batch_idx: int, current_global_step: int) -> int:
258243
"""
259244
Provide a hook to count optimizer step calls.

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,7 @@ def __init__(self, trainer, resume_from_checkpoint: Optional[Union[str, Path]] =
3838
self.resume_checkpoint_path = resume_from_checkpoint
3939
# used to validate checkpointing logic
4040
self.has_trained = False
41-
4241
self._loaded_checkpoint = dict()
43-
# FIXME: remove in https://github.com/PyTorchLightning/pytorch-lightning/pull/7652
44-
self._load_optimizer_states = True
4542

4643
@property
4744
def hpc_resume_path(self) -> Optional[str]:
@@ -76,11 +73,7 @@ def resume_start(self) -> None:
7673
raise FileNotFoundError(f"Checkpoint at {checkpoint_path} not found. Aborting training.")
7774

7875
rank_zero_info(f"Restoring states from the checkpoint file at {checkpoint_path}")
79-
checkpoint, load_optimizer_states = self.trainer.training_type_plugin.restore_model_state_from_ckpt_path(
80-
checkpoint_path, map_location=lambda storage, loc: storage
81-
)
82-
self._loaded_checkpoint = checkpoint
83-
self._load_optimizer_states = load_optimizer_states
76+
self._loaded_checkpoint = self.trainer.training_type_plugin.load_checkpoint_file(checkpoint_path)
8477

8578
def resume_end(self) -> None:
8679
""" Signal the connector that all states have resumed and memory for the checkpoint object can be released. """
@@ -110,6 +103,8 @@ def restore(self, checkpoint_path: Optional[Union[Path, str]] = None) -> bool:
110103
self.resume_start()
111104
model = self.trainer.lightning_module
112105

106+
self.restore_model_state(model, self._loaded_checkpoint)
107+
113108
if self.trainer._device_type == DeviceType.GPU:
114109
model.cuda(self.trainer.root_gpu)
115110

@@ -124,6 +119,8 @@ def restore_model_state(self, model: LightningModule, checkpoint) -> None:
124119
"""
125120
Restore model states from a 'PyTorch-Lightning checkpoint' dictionary object
126121
"""
122+
if not checkpoint:
123+
return
127124

128125
# restore datamodule states
129126
if self.trainer.datamodule is not None:
@@ -133,7 +130,16 @@ def restore_model_state(self, model: LightningModule, checkpoint) -> None:
133130
model.on_load_checkpoint(checkpoint)
134131

135132
# restore model state_dict
136-
model.load_state_dict(checkpoint['state_dict'])
133+
self.trainer.training_type_plugin.load_model_state_dict(checkpoint)
134+
135+
def restore_model_weights(self, checkpoint_path: Optional[Union[str, Path]]) -> None:
136+
""" Restore only the model weights. """
137+
checkpoint = self._loaded_checkpoint
138+
if checkpoint_path is not None:
139+
checkpoint = self.trainer.training_type_plugin.load_checkpoint_file(checkpoint_path)
140+
141+
self.trainer.lightning_module.on_load_checkpoint(checkpoint)
142+
self.trainer.training_type_plugin.load_model_state_dict(checkpoint)
137143

138144
def restore_training_state(self, checkpoint: Dict[str, Any]) -> None:
139145
"""
@@ -199,7 +205,7 @@ def restore_progress(self) -> None:
199205

200206
def restore_optimizers_and_schedulers(self) -> None:
201207
""" Restores the optimizers and learning rate scheduler states from the pre-loaded checkpoint. """
202-
if not self._load_optimizer_states or not self._loaded_checkpoint:
208+
if not self._loaded_checkpoint:
203209
return
204210

205211
# validation
@@ -213,7 +219,7 @@ def restore_optimizers_and_schedulers(self) -> None:
213219

214220
def restore_optimizers(self) -> None:
215221
""" Restores the optimizer states from the pre-loaded checkpoint. """
216-
if not self._load_optimizer_states or not self._loaded_checkpoint:
222+
if not self._loaded_checkpoint:
217223
return
218224

219225
# restore the optimizers
@@ -231,7 +237,7 @@ def restore_optimizers(self) -> None:
231237

232238
def restore_lr_schedulers(self) -> None:
233239
""" Restores the learning rate scheduler states from the pre-loaded checkpoint. """
234-
if not self._load_optimizer_states or not self._loaded_checkpoint:
240+
if not self._loaded_checkpoint:
235241
return
236242

237243
# restore the lr schedulers

pytorch_lightning/trainer/trainer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,9 +1154,7 @@ def __load_ckpt_weights(self, ckpt_path: Optional[str]) -> Optional[str]:
11541154
if not self._device_type == DeviceType.TPU:
11551155
self.training_type_plugin.barrier()
11561156

1157-
self.training_type_plugin.restore_model_state_from_ckpt_path(
1158-
ckpt_path, map_location=lambda storage, loc: storage
1159-
)
1157+
self.checkpoint_connector.restore_model_weights(ckpt_path)
11601158
return ckpt_path
11611159

11621160
def _call_setup_hook(self, model: LightningModule) -> None:

0 commit comments

Comments
 (0)