|
25 | 25 | from tests_pytorch.helpers.runif import RunIf
|
26 | 26 |
|
27 | 27 |
|
| 28 | +class HookedDataModule(BoringDataModule): |
| 29 | + def __init__(self, called): |
| 30 | + super().__init__() |
| 31 | + |
| 32 | + def call(hook, fn, *args, **kwargs): |
| 33 | + out = fn(*args, **kwargs) |
| 34 | + d = {"name": hook} |
| 35 | + if args: |
| 36 | + d["args"] = args |
| 37 | + if kwargs: |
| 38 | + d["kwargs"] = kwargs |
| 39 | + called.append(d) |
| 40 | + return out |
| 41 | + |
| 42 | + for h in get_members(LightningDataModule): |
| 43 | + attr = getattr(self, h) |
| 44 | + partial_h = partial(call, h, attr) |
| 45 | + update_wrapper(partial_h, attr) |
| 46 | + setattr(self, h, partial_h) |
| 47 | + |
| 48 | + |
28 | 49 | @pytest.mark.parametrize("max_steps", [1, 2, 3])
|
29 | 50 | def test_on_before_zero_grad_called(tmpdir, max_steps):
|
30 | 51 | class CurrentTestModel(BoringModel):
|
@@ -911,26 +932,6 @@ def predict_dataloader(self):
|
911 | 932 | def test_trainer_datamodule_hook_system(tmpdir):
|
912 | 933 | """Test the LightningDataModule hook system."""
|
913 | 934 |
|
914 |
| - class HookedDataModule(BoringDataModule): |
915 |
| - def __init__(self, called): |
916 |
| - super().__init__() |
917 |
| - |
918 |
| - def call(hook, fn, *args, **kwargs): |
919 |
| - out = fn(*args, **kwargs) |
920 |
| - d = {"name": hook} |
921 |
| - if args: |
922 |
| - d["args"] = args |
923 |
| - if kwargs: |
924 |
| - d["kwargs"] = kwargs |
925 |
| - called.append(d) |
926 |
| - return out |
927 |
| - |
928 |
| - for h in get_members(LightningDataModule): |
929 |
| - attr = getattr(self, h) |
930 |
| - partial_h = partial(call, h, attr) |
931 |
| - update_wrapper(partial_h, attr) |
932 |
| - setattr(self, h, partial_h) |
933 |
| - |
934 | 935 | model = BoringModel()
|
935 | 936 | batches = 2
|
936 | 937 | trainer = Trainer(
|
@@ -991,3 +992,43 @@ def call(hook, fn, *args, **kwargs):
|
991 | 992 | dict(name="teardown", kwargs=dict(stage="predict")),
|
992 | 993 | ]
|
993 | 994 | assert called == expected
|
| 995 | + |
| 996 | + |
| 997 | +def test_load_from_checkpoint_hook_calls(tmpdir): |
| 998 | + class CustomHookedDataModule(HookedDataModule): |
| 999 | + def state_dict(self): |
| 1000 | + return {"foo": "bar"} |
| 1001 | + |
| 1002 | + lm_called, ldm_called = [], [] |
| 1003 | + model = HookedModel(lm_called) |
| 1004 | + datamodule = CustomHookedDataModule(ldm_called) |
| 1005 | + trainer = Trainer() |
| 1006 | + trainer.strategy.connect(model) |
| 1007 | + trainer._data_connector.attach_data(model, datamodule=datamodule) |
| 1008 | + ckpt_path = str(tmpdir / "file.ckpt") |
| 1009 | + trainer.save_checkpoint(ckpt_path) |
| 1010 | + |
| 1011 | + datamodule_state_dict_key = datamodule.__class__.__qualname__ |
| 1012 | + saved_ckpt = { |
| 1013 | + "callbacks": ANY, |
| 1014 | + "epoch": 0, |
| 1015 | + "global_step": 0, |
| 1016 | + "lr_schedulers": ANY, |
| 1017 | + "optimizer_states": ANY, |
| 1018 | + "pytorch-lightning_version": __version__, |
| 1019 | + "state_dict": ANY, |
| 1020 | + "loops": ANY, |
| 1021 | + datamodule_state_dict_key: {"foo": "bar"}, |
| 1022 | + } |
| 1023 | + |
| 1024 | + assert lm_called == [dict(name="on_save_checkpoint", args=(saved_ckpt,))] |
| 1025 | + assert ldm_called == [dict(name="state_dict"), dict(name="on_save_checkpoint", args=(saved_ckpt,))] |
| 1026 | + |
| 1027 | + lm_called, ldm_called = [], [] |
| 1028 | + model = HookedModel.load_from_checkpoint(ckpt_path, called=lm_called) |
| 1029 | + datamodule = CustomHookedDataModule.load_from_checkpoint(ckpt_path, called=ldm_called) |
| 1030 | + assert lm_called == [dict(name="on_load_checkpoint", args=({**saved_ckpt, "hyper_parameters": ANY},))] |
| 1031 | + assert ldm_called == [ |
| 1032 | + dict(name="on_load_checkpoint", args=({**saved_ckpt, "datamodule_hyper_parameters": ANY},)), |
| 1033 | + dict(name="load_state_dict", args=(saved_ckpt[datamodule_state_dict_key],)), |
| 1034 | + ] |
0 commit comments