Skip to content

Commit ac4eb0a

Browse files
carmoccaawaelchli
andauthored
is_overridden improvements (#7918)
Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 9e932f4 commit ac4eb0a

File tree

8 files changed

+195
-47
lines changed

8 files changed

+195
-47
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3535
- Added `clip_grad_by_value` support for TPUs ([#7025](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))
3636

3737

38+
- Added support for passing any class to `is_overridden` ([#7918](https://github.com/PyTorchLightning/pytorch-lightning/pull/7918))
39+
40+
3841
- Added `sub_dir` parameter to `TensorBoardLogger` ([#6195](https://github.com/PyTorchLightning/pytorch-lightning/pull/6195))
3942

4043

@@ -172,6 +175,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
172175
- Deprecated `self.log(sync_dist_op)` in favor of `self.log(reduce_fx)`. ([#7891](https://github.com/PyTorchLightning/pytorch-lightning/pull/7891))
173176

174177

178+
- Deprecated `is_overridden(model=...)` in favor of `is_overridden(instance=...)` ([#7918](https://github.com/PyTorchLightning/pytorch-lightning/pull/7918))
179+
180+
175181
- Deprecated default value of `monitor` argument in EarlyStopping callback to enforce `monitor` as a required argument ([#7907](https://github.com/PyTorchLightning/pytorch-lightning/pull/7907))
176182

177183

pytorch_lightning/trainer/evaluation_loop.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,9 +201,9 @@ def evaluation_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT
201201
def _should_track_batch_outputs_for_epoch_end(self) -> bool:
202202
model = self.trainer.lightning_module
203203
if self.trainer.testing:
204-
return is_overridden('test_epoch_end', model=model)
204+
return is_overridden('test_epoch_end', model)
205205
else:
206-
return is_overridden('validation_epoch_end', model=model)
206+
return is_overridden('validation_epoch_end', model)
207207

208208
def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
209209
# inform logger the batch loop has finished
@@ -216,12 +216,12 @@ def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
216216
model._current_dataloader_idx = None
217217

218218
if self.trainer.testing:
219-
if is_overridden('test_epoch_end', model=model):
219+
if is_overridden('test_epoch_end', model):
220220
model._current_fx_name = 'test_epoch_end'
221221
model.test_epoch_end(outputs)
222222

223223
else:
224-
if is_overridden('validation_epoch_end', model=model):
224+
if is_overridden('validation_epoch_end', model):
225225
model._current_fx_name = 'validation_epoch_end'
226226
model.validation_epoch_end(outputs)
227227

pytorch_lightning/trainer/training_loop.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -216,10 +216,10 @@ def _should_add_batch_output_to_epoch_output(self) -> bool:
216216
# 2. The model overrides on_train_epoch_end which has `outputs` in the signature
217217
# TODO: in v1.5 this only needs to check if training_epoch_end is overridden
218218
lightning_module = self.trainer.lightning_module
219-
if is_overridden("training_epoch_end", model=lightning_module):
219+
if is_overridden("training_epoch_end", lightning_module):
220220
return True
221221

222-
if is_overridden("on_train_epoch_end", model=lightning_module):
222+
if is_overridden("on_train_epoch_end", lightning_module):
223223
model_hook_fx = getattr(lightning_module, "on_train_epoch_end")
224224
if is_param_in_hook_signature(model_hook_fx, "outputs"):
225225
return True
@@ -540,7 +540,7 @@ def on_train_epoch_end(self, epoch_output: List[List[List['ResultCollection']]])
540540
# get the model and call model.training_epoch_end
541541
model = self.trainer.lightning_module
542542

543-
if is_overridden('training_epoch_end', model=model):
543+
if is_overridden('training_epoch_end', model):
544544
# run training_epoch_end
545545
# refresh the result for custom logging at the epoch level
546546
model._current_fx_name = 'training_epoch_end'

pytorch_lightning/utilities/model_helpers.py

Lines changed: 47 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,33 +11,59 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
15-
from typing import Union
14+
from functools import partial
15+
from typing import Optional, Type, Union
16+
from unittest.mock import Mock
1617

1718
from pytorch_lightning.core.datamodule import LightningDataModule
1819
from pytorch_lightning.core.lightning import LightningModule
20+
from pytorch_lightning.utilities import rank_zero_deprecation
1921

2022

21-
def is_overridden(method_name: str, model: Union[LightningModule, LightningDataModule]) -> bool:
22-
# if you pass DataModule instead of None or a LightningModule, we use LightningDataModule as super
23-
# TODO - refector this function to accept model_name, instance, parent so it makes more sense
24-
super_object = LightningModule if not isinstance(model, LightningDataModule) else LightningDataModule
23+
def is_overridden(
24+
method_name: str,
25+
instance: Optional[object] = None,
26+
parent: Optional[Type[object]] = None,
27+
model: Optional[Union[LightningModule, LightningDataModule]] = None,
28+
) -> bool:
29+
if model is not None and instance is None:
30+
rank_zero_deprecation(
31+
'`is_overriden(model=...)` has been deprecated and will be removed in v1.6.'
32+
'Please use `is_overriden(instance=...)`'
33+
)
34+
instance = model
2535

26-
if not hasattr(model, method_name) or not hasattr(super_object, method_name):
27-
# in case of calling deprecated method
36+
if instance is None:
37+
# if `self.lightning_module` was passed as instance, it can be `None`
2838
return False
2939

30-
instance_attr = getattr(model, method_name)
31-
if not instance_attr:
40+
if parent is None:
41+
if isinstance(instance, LightningModule):
42+
parent = LightningModule
43+
elif isinstance(instance, LightningDataModule):
44+
parent = LightningDataModule
45+
if parent is None:
46+
raise ValueError("Expected a parent")
47+
48+
instance_attr = getattr(instance, method_name, None)
49+
# `Mock(wraps=...)` support
50+
if isinstance(instance_attr, Mock):
51+
# access the wrapped function
52+
instance_attr = instance_attr._mock_wraps
53+
# `partial` support
54+
elif isinstance(instance_attr, partial):
55+
instance_attr = instance_attr.func
56+
if instance_attr is None:
3257
return False
33-
super_attr = getattr(super_object, method_name)
34-
35-
# when code pointers are different, it was implemented
36-
if hasattr(instance_attr, 'patch_loader_code'):
37-
# cannot pickle __code__ so cannot verify if PatchDataloader
38-
# exists which shows dataloader methods have been overwritten.
39-
# so, we hack it by using the string representation
40-
is_overridden = instance_attr.patch_loader_code != str(super_attr.__code__)
41-
else:
42-
is_overridden = instance_attr.__code__ is not super_attr.__code__
43-
return is_overridden
58+
59+
parent_attr = getattr(parent, method_name, None)
60+
if parent_attr is None:
61+
raise ValueError("The parent should define the method")
62+
63+
# cannot pickle `__code__` so cannot verify if `PatchDataloader`
64+
# exists which shows dataloader methods have been overwritten.
65+
# so, we hack it by using the string representation
66+
instance_code = getattr(instance_attr, 'patch_loader_code', None) or instance_attr.__code__
67+
parent_code = parent_attr.__code__
68+
69+
return instance_code != parent_code

tests/deprecated_api/test_remove_1-6.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from pytorch_lightning import Trainer
1818
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
1919
from pytorch_lightning.plugins.training_type import DDPPlugin, DDPSpawnPlugin
20+
from pytorch_lightning.utilities.model_helpers import is_overridden
2021
from tests.helpers import BoringDataModule, BoringModel
2122

2223

@@ -175,6 +176,14 @@ def prepare_data(self):
175176
assert dm.teardown_calls == ['validate', 'test']
176177

177178

179+
def test_v1_6_0_is_overridden_model():
180+
model = BoringModel()
181+
with pytest.deprecated_call(match="and will be removed in v1.6"):
182+
assert is_overridden("validation_step", model=model)
183+
with pytest.deprecated_call(match="and will be removed in v1.6"):
184+
assert not is_overridden("foo", model=model)
185+
186+
178187
def test_v1_6_0_early_stopping_monitor(tmpdir):
179188
with pytest.deprecated_call(
180189
match=r"The `EarlyStopping\(monitor\)` argument will be required starting in v1.6."

tests/trainer/test_states.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,25 +37,28 @@ class TestModel(BoringModel):
3737

3838
def __init__(self, expected_fn, expected_stage):
3939
super().__init__()
40-
self.expected_state = expected_fn
40+
self.expected_fn = expected_fn
4141
self.expected_stage = expected_stage
4242
self.lr = 0.1
4343

44-
def on_batch_start(self, *_):
45-
assert self.trainer.state == TrainerState(
46-
status=TrainerStatus.RUNNING, fn=self.expected_fn, stage=self.expected_stage
47-
)
48-
4944
def on_train_batch_start(self, *_):
45+
assert self.trainer.state.status == TrainerStatus.RUNNING
46+
assert self.trainer.state.fn == self.expected_fn
5047
assert self.trainer.training
5148

5249
def on_sanity_check_start(self, *_):
50+
assert self.trainer.state.status == TrainerStatus.RUNNING
51+
assert self.trainer.state.fn == self.expected_fn
5352
assert self.trainer.sanity_checking
5453

5554
def on_validation_batch_start(self, *_):
55+
assert self.trainer.state.status == TrainerStatus.RUNNING
56+
assert self.trainer.state.fn == self.expected_fn
5657
assert self.trainer.validating or self.trainer.sanity_checking
5758

5859
def on_test_batch_start(self, *_):
60+
assert self.trainer.state.status == TrainerStatus.RUNNING
61+
assert self.trainer.state.fn == self.expected_fn
5962
assert self.trainer.testing
6063

6164
model = TestModel(TrainerFn.TUNING, RunningStage.TRAINING)

tests/trainer/test_trainer.py

Lines changed: 50 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -232,29 +232,66 @@ def test_trainer_accumulate_grad_batches_zero_grad(tmpdir, accumulate_grad_batch
232232
def test_gradient_accumulation_scheduling_last_batch(tmpdir, accumulate_grad_batches, limit_train_batches):
233233
""" Verify optimizer.step() applied to last batch while grad accumulation """
234234

235-
class CurrentModel(BoringModel):
235+
class TestModel(BoringModel):
236236

237-
def on_batch_start(self, *_):
238-
self.on_train_batch_start_state_dict = self.state_dict()
237+
def state_dict(self, *args, **kwargs):
238+
return deepcopy(super().state_dict(*args, **kwargs))
239239

240-
def on_batch_end(self, outputs, batch, batch_idx, *_):
241-
self.on_train_batch_start_end_dict = self.state_dict()
242-
for key in self.on_train_batch_start_end_dict.keys():
243-
equal = torch.equal(self.on_train_batch_start_state_dict[key], self.on_train_batch_start_end_dict[key])
244-
if (batch_idx + 1) == self.trainer.num_training_batches:
245-
assert equal
246-
else:
247-
assert not equal
240+
def check(self, d1, d2, equal=True):
241+
keys = d1.keys() | d2.keys()
242+
values = [torch.equal(d1[k], d2[k]) for k in keys]
243+
return all(values) if equal else not any(values)
248244

249-
model = CurrentModel()
245+
def backward(self, *args, **kwargs) -> None:
246+
pre_bwd_state_dict = self.state_dict()
247+
assert self.check(self.start_state_dict, pre_bwd_state_dict)
248+
249+
out = super().backward(*args, **kwargs)
250+
251+
# state dict is equal, just the gradients changed
252+
assert self.check(pre_bwd_state_dict, self.state_dict())
253+
254+
return out
255+
256+
# def optimizer_step(self, *args, **kwargs):
257+
# pre_opt_step_state_dict = self.state_dict()
258+
# assert self.check(self.start_state_dict, pre_opt_step_state_dict)
259+
260+
# # this calls `backward` and `on_after_backward` inside the closure
261+
# out = super().optimizer_step(*args, **kwargs)
262+
263+
# # the state dict changed
264+
# assert self.check(pre_opt_step_state_dict, self.state_dict(), equal=False)
250265

266+
# self.opt_step_called = True
267+
# return out
268+
269+
def on_after_backward(self):
270+
# should override `optimizer_step` instead but can't with `accumulate_grad_batches`
271+
# replace with the above after https://github.com/PyTorchLightning/pytorch-lightning/issues/6910
272+
self.opt_step_called = True
273+
274+
def on_train_batch_start(self, *_):
275+
self.start_state_dict = self.state_dict()
276+
self.opt_step_called = False
277+
278+
def on_train_batch_end(self, outputs, batch, batch_idx, *_):
279+
end_state_dict = self.state_dict()
280+
is_last_batch = (batch_idx + 1) == self.trainer.num_training_batches
281+
282+
if is_last_batch or self.opt_step_called:
283+
assert self.check(self.start_state_dict, end_state_dict, equal=False)
284+
else:
285+
assert self.check(self.start_state_dict, end_state_dict)
286+
287+
model = TestModel()
251288
trainer = Trainer(
252289
accumulate_grad_batches=accumulate_grad_batches,
253290
max_epochs=2,
254291
limit_train_batches=limit_train_batches,
255292
limit_val_batches=0,
256-
limit_test_batches=0,
257293
default_root_dir=tmpdir,
294+
progress_bar_refresh_rate=0,
258295
)
259296

260297
trainer.fit(model)
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from functools import partial
15+
from unittest.mock import Mock
16+
17+
import pytest
18+
19+
from pytorch_lightning import LightningDataModule, Trainer
20+
from pytorch_lightning.utilities.model_helpers import is_overridden
21+
from tests.helpers import BoringDataModule, BoringModel
22+
23+
24+
def test_is_overridden():
25+
model = BoringModel()
26+
datamodule = BoringDataModule()
27+
28+
# edge cases
29+
assert not is_overridden("whatever", None)
30+
with pytest.raises(ValueError, match="Expected a parent"):
31+
is_overridden("whatever", object())
32+
assert not is_overridden("whatever", model)
33+
assert not is_overridden("whatever", model, parent=LightningDataModule)
34+
35+
class TestModel(BoringModel):
36+
37+
def foo(self):
38+
pass
39+
40+
with pytest.raises(ValueError, match="The parent should define the method"):
41+
is_overridden("foo", TestModel())
42+
43+
# normal usage
44+
assert is_overridden("training_step", model)
45+
assert is_overridden("train_dataloader", datamodule)
46+
47+
# `Mock` support
48+
mock = Mock(spec=BoringModel, wraps=model)
49+
assert is_overridden("training_step", mock)
50+
mock = Mock(spec=BoringDataModule, wraps=datamodule)
51+
assert is_overridden("train_dataloader", mock)
52+
53+
# `partial` support
54+
model.training_step = partial(model.training_step)
55+
assert is_overridden("training_step", model)
56+
57+
# `_PatchDataLoader.patch_loader_code` support
58+
class TestModel(BoringModel):
59+
60+
def on_fit_start(self):
61+
assert is_overridden("train_dataloader", self)
62+
self.on_fit_start_called = True
63+
64+
model = TestModel()
65+
trainer = Trainer(fast_dev_run=1)
66+
trainer.fit(model, train_dataloader=model.train_dataloader())
67+
assert model.on_fit_start_called

0 commit comments

Comments
 (0)