Skip to content

Commit 3a70e5d

Browse files
authored
Call LightningDataModule.load_state_dict hook while restoring checkpoint using LightningDataModule.load_from_checkpoint (#14883)
1 parent 93e802a commit 3a70e5d

File tree

3 files changed

+66
-20
lines changed

3 files changed

+66
-20
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
276276
- Fixed an attribute error when running the tuner together with the `StochasticWeightAveraging` callback ([#14836](https://github.com/Lightning-AI/lightning/pull/14836))
277277

278278

279+
- Called `LightningDataModule.load_state_dict` hook while restoring checkpoint using `LightningDataModule.load_from_checkpoint` ([#14883](https://github.com/Lightning-AI/lightning/pull/14883))
280+
281+
279282
- Fixed torchscript error with containers of LightningModules ([#14904](https://github.com/Lightning-AI/lightning/pull/14904))
280283

281284

src/pytorch_lightning/core/saving.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,8 @@ def _load_state(
226226
obj.on_load_checkpoint(checkpoint)
227227

228228
if isinstance(obj, pl.LightningDataModule):
229+
if obj.__class__.__qualname__ in checkpoint:
230+
obj.load_state_dict(checkpoint[obj.__class__.__qualname__])
229231
return obj
230232

231233
# load the state_dict on the model automatically

tests/tests_pytorch/models/test_hooks.py

Lines changed: 61 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,27 @@
2525
from tests_pytorch.helpers.runif import RunIf
2626

2727

28+
class HookedDataModule(BoringDataModule):
29+
def __init__(self, called):
30+
super().__init__()
31+
32+
def call(hook, fn, *args, **kwargs):
33+
out = fn(*args, **kwargs)
34+
d = {"name": hook}
35+
if args:
36+
d["args"] = args
37+
if kwargs:
38+
d["kwargs"] = kwargs
39+
called.append(d)
40+
return out
41+
42+
for h in get_members(LightningDataModule):
43+
attr = getattr(self, h)
44+
partial_h = partial(call, h, attr)
45+
update_wrapper(partial_h, attr)
46+
setattr(self, h, partial_h)
47+
48+
2849
@pytest.mark.parametrize("max_steps", [1, 2, 3])
2950
def test_on_before_zero_grad_called(tmpdir, max_steps):
3051
class CurrentTestModel(BoringModel):
@@ -911,26 +932,6 @@ def predict_dataloader(self):
911932
def test_trainer_datamodule_hook_system(tmpdir):
912933
"""Test the LightningDataModule hook system."""
913934

914-
class HookedDataModule(BoringDataModule):
915-
def __init__(self, called):
916-
super().__init__()
917-
918-
def call(hook, fn, *args, **kwargs):
919-
out = fn(*args, **kwargs)
920-
d = {"name": hook}
921-
if args:
922-
d["args"] = args
923-
if kwargs:
924-
d["kwargs"] = kwargs
925-
called.append(d)
926-
return out
927-
928-
for h in get_members(LightningDataModule):
929-
attr = getattr(self, h)
930-
partial_h = partial(call, h, attr)
931-
update_wrapper(partial_h, attr)
932-
setattr(self, h, partial_h)
933-
934935
model = BoringModel()
935936
batches = 2
936937
trainer = Trainer(
@@ -991,3 +992,43 @@ def call(hook, fn, *args, **kwargs):
991992
dict(name="teardown", kwargs=dict(stage="predict")),
992993
]
993994
assert called == expected
995+
996+
997+
def test_load_from_checkpoint_hook_calls(tmpdir):
998+
class CustomHookedDataModule(HookedDataModule):
999+
def state_dict(self):
1000+
return {"foo": "bar"}
1001+
1002+
lm_called, ldm_called = [], []
1003+
model = HookedModel(lm_called)
1004+
datamodule = CustomHookedDataModule(ldm_called)
1005+
trainer = Trainer()
1006+
trainer.strategy.connect(model)
1007+
trainer._data_connector.attach_data(model, datamodule=datamodule)
1008+
ckpt_path = str(tmpdir / "file.ckpt")
1009+
trainer.save_checkpoint(ckpt_path)
1010+
1011+
datamodule_state_dict_key = datamodule.__class__.__qualname__
1012+
saved_ckpt = {
1013+
"callbacks": ANY,
1014+
"epoch": 0,
1015+
"global_step": 0,
1016+
"lr_schedulers": ANY,
1017+
"optimizer_states": ANY,
1018+
"pytorch-lightning_version": __version__,
1019+
"state_dict": ANY,
1020+
"loops": ANY,
1021+
datamodule_state_dict_key: {"foo": "bar"},
1022+
}
1023+
1024+
assert lm_called == [dict(name="on_save_checkpoint", args=(saved_ckpt,))]
1025+
assert ldm_called == [dict(name="state_dict"), dict(name="on_save_checkpoint", args=(saved_ckpt,))]
1026+
1027+
lm_called, ldm_called = [], []
1028+
model = HookedModel.load_from_checkpoint(ckpt_path, called=lm_called)
1029+
datamodule = CustomHookedDataModule.load_from_checkpoint(ckpt_path, called=ldm_called)
1030+
assert lm_called == [dict(name="on_load_checkpoint", args=({**saved_ckpt, "hyper_parameters": ANY},))]
1031+
assert ldm_called == [
1032+
dict(name="on_load_checkpoint", args=({**saved_ckpt, "datamodule_hyper_parameters": ANY},)),
1033+
dict(name="load_state_dict", args=(saved_ckpt[datamodule_state_dict_key],)),
1034+
]

0 commit comments

Comments
 (0)