From c34da597d5b168e0a95f090019955d11256e4a58 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Mon, 22 Feb 2021 15:46:51 +0530 Subject: [PATCH 01/29] add outputs param for on_val/test_epoch_end hooks --- pytorch_lightning/callbacks/base.py | 4 +- pytorch_lightning/trainer/callback_hook.py | 27 ++++++++++-- pytorch_lightning/trainer/evaluation_loop.py | 5 ++- tests/callbacks/test_callback_hook_outputs.py | 42 +++++++++++++++++++ 4 files changed, 70 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 0ba1fd4ff7785..342824e91f2d2 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -89,7 +89,7 @@ def on_validation_epoch_start(self, trainer, pl_module: LightningModule) -> None """Called when the val epoch begins.""" pass - def on_validation_epoch_end(self, trainer, pl_module: LightningModule) -> None: + def on_validation_epoch_end(self, trainer, pl_module: LightningModule, outputs: Any) -> None: """Called when the val epoch ends.""" pass @@ -97,7 +97,7 @@ def on_test_epoch_start(self, trainer, pl_module: LightningModule) -> None: """Called when the test epoch begins.""" pass - def on_test_epoch_end(self, trainer, pl_module: LightningModule) -> None: + def on_test_epoch_end(self, trainer, pl_module: LightningModule, outputs: Any) -> None: """Called when the test epoch ends.""" pass diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 5aa9f1a44276b..ae9307fc843c0 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect from abc import ABC from copy import deepcopy from inspect import signature @@ -89,20 +90,38 @@ def on_validation_epoch_start(self): for callback in self.callbacks: callback.on_validation_epoch_start(self, self.lightning_module) - def on_validation_epoch_end(self): + def on_validation_epoch_end(self, outputs): """Called when the epoch ends.""" for callback in self.callbacks: - callback.on_validation_epoch_end(self, self.lightning_module) + params = list(inspect.signature(callback.on_validation_epoch_end).parameters) + if "outputs" in params: + callback.on_validation_epoch_end(self, self.get_model(), outputs) + else: + rank_zero_warn( + "`Callback.on_validation_epoch_end` signature has changed in v1.3." + "`outputs` parameter has been added." + " Support for the old signature will be removed in v1.5", DeprecationWarning + ) + callback.on_validation_epoch_end(self, self.get_model()) def on_test_epoch_start(self): """Called when the epoch begins.""" for callback in self.callbacks: callback.on_test_epoch_start(self, self.lightning_module) - def on_test_epoch_end(self): + def on_test_epoch_end(self, outputs): """Called when the epoch ends.""" for callback in self.callbacks: - callback.on_test_epoch_end(self, self.lightning_module) + params = list(inspect.signature(callback.on_test_epoch_end).parameters) + if "outputs" in params: + callback.on_test_epoch_end(self, self.get_model(), outputs) + else: + rank_zero_warn( + "`Callback.on_test_epoch_end` signature has changed in v1.3." + "`outputs` parameter has been added." + " Support for the old signature will be removed in v1.5", DeprecationWarning + ) + callback.on_test_epoch_end(self, self.get_model()) def on_epoch_start(self): """Called when the epoch begins.""" diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 91cfc2ec757d5..f32d43c096c32 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -313,10 +313,11 @@ def store_predictions(self, output, batch_idx, dataloader_idx): def on_evaluation_epoch_end(self, *args, **kwargs): # call the callback hook + outputs = self.outputs if self.trainer.testing: - self.trainer.call_hook('on_test_epoch_end', *args, **kwargs) + self.trainer.call_hook('on_test_epoch_end', outputs, *args, **kwargs) else: - self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs) + self.trainer.call_hook('on_validation_epoch_end', outputs, *args, **kwargs) self.trainer.call_hook('on_epoch_end') diff --git a/tests/callbacks/test_callback_hook_outputs.py b/tests/callbacks/test_callback_hook_outputs.py index 78926cc9a7dd4..b52607edf280c 100644 --- a/tests/callbacks/test_callback_hook_outputs.py +++ b/tests/callbacks/test_callback_hook_outputs.py @@ -71,3 +71,45 @@ def on_train_epoch_end(self, outputs) -> None: results = trainer.fit(model) assert results + + +def test_on_val_epoch_end_outputs(tmpdir): + + class CB(Callback): + + def on_validation_epoch_end(self, trainer, pl_module, outputs): + if trainer.running_sanity_check: + assert len(outputs[0]) == trainer.num_sanity_val_batches[0] + else: + assert len(outputs[0]) == trainer.num_val_batches[0] + + model = BoringModel() + + trainer = Trainer( + callbacks=CB(), + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + weights_summary=None, + ) + + trainer.fit(model) + + +def test_on_test_epoch_end_outputs(tmpdir): + + class CB(Callback): + + def on_test_epoch_end(self, trainer, pl_module, outputs): + assert len(outputs[0]) == trainer.num_test_batches[0] + + model = BoringModel() + + trainer = Trainer( + callbacks=CB(), + default_root_dir=tmpdir, + weights_summary=None, + ) + + trainer.test(model) From 932091672a98d367a06e16a3cda588d8b3de95e8 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Mon, 22 Feb 2021 15:57:32 +0530 Subject: [PATCH 02/29] update changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9dcdea4c1601d..fd8d31c26ae23 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added no return warning to predict ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139)) +- Added `outputs` parameter to callback's `on_validation_epoch_end` & `on_test_epoch_end` hooks ([#6120](https://github.com/PyTorchLightning/pytorch-lightning/pull/6120)) + + ### Changed - Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259)) From fd84be8c2eb09613ba495b52a33a76d83547a5f7 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Mon, 22 Feb 2021 16:54:19 +0530 Subject: [PATCH 03/29] fix warning message --- pytorch_lightning/trainer/callback_hook.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index ae9307fc843c0..1b1a027f0af8d 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -99,7 +99,7 @@ def on_validation_epoch_end(self, outputs): else: rank_zero_warn( "`Callback.on_validation_epoch_end` signature has changed in v1.3." - "`outputs` parameter has been added." + " `outputs` parameter has been added." " Support for the old signature will be removed in v1.5", DeprecationWarning ) callback.on_validation_epoch_end(self, self.get_model()) @@ -118,7 +118,7 @@ def on_test_epoch_end(self, outputs): else: rank_zero_warn( "`Callback.on_test_epoch_end` signature has changed in v1.3." - "`outputs` parameter has been added." + " `outputs` parameter has been added." " Support for the old signature will be removed in v1.5", DeprecationWarning ) callback.on_test_epoch_end(self, self.get_model()) From 871884aaa132030b685f1dae1257facb7d837f10 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Mon, 22 Feb 2021 17:56:40 +0530 Subject: [PATCH 04/29] add custom call hook --- pytorch_lightning/core/hooks.py | 4 ++-- pytorch_lightning/trainer/evaluation_loop.py | 24 ++++++++++++++++---- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 1399d1b3c66ba..186da37aac262 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -252,7 +252,7 @@ def on_validation_epoch_start(self) -> None: """ # do something when the epoch starts - def on_validation_epoch_end(self) -> None: + def on_validation_epoch_end(self, outputs: Any) -> None: """ Called in the validation loop at the very end of the epoch. """ @@ -264,7 +264,7 @@ def on_test_epoch_start(self) -> None: """ # do something when the epoch starts - def on_test_epoch_end(self) -> None: + def on_test_epoch_end(self, outputs: Any) -> None: """ Called in the test loop at the very end of the epoch. """ diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index f32d43c096c32..60d9084331f29 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect + import torch from pytorch_lightning.core.step_result import Result @@ -313,14 +315,26 @@ def store_predictions(self, output, batch_idx, dataloader_idx): def on_evaluation_epoch_end(self, *args, **kwargs): # call the callback hook - outputs = self.outputs - if self.trainer.testing: - self.trainer.call_hook('on_test_epoch_end', outputs, *args, **kwargs) - else: - self.trainer.call_hook('on_validation_epoch_end', outputs, *args, **kwargs) + self.call_on_evaluation_epoch_end_hook() self.trainer.call_hook('on_epoch_end') + def call_on_evaluation_epoch_end_hook(self): + outputs = self.outputs + model_ref = self.trainer.get_model() + hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end" + + with self.trainer.profiler.profile(hook_name): + + if hasattr(self.trainer, hook_name): + on_evaluation_epoch_end_hook = getattr(self.trainer, hook_name) + on_evaluation_epoch_end_hook(outputs) + + if is_overridden(hook_name, model_ref): + model_hook_fx = getattr(model_ref, hook_name) + model_hook_params = list(inspect.signature(model_hook_fx).parameters) + model_hook_fx(outputs) if "outputs" in model_hook_params else model_hook_fx() + def log_evaluation_step_metrics(self, output, batch_idx): if self.trainer.sanity_checking: return From 66bc0c759387a2cb153bfdd3c5b4fc6128758e04 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Mon, 22 Feb 2021 20:22:51 +0530 Subject: [PATCH 05/29] cache logged metrics --- pytorch_lightning/trainer/evaluation_loop.py | 4 ++++ tests/models/test_hooks.py | 8 ++++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 60d9084331f29..0bd784fb8c8bf 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -324,6 +324,8 @@ def call_on_evaluation_epoch_end_hook(self): model_ref = self.trainer.get_model() hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end" + self.trainer._reset_result_and_set_hook_fx_name(hook_name) + with self.trainer.profiler.profile(hook_name): if hasattr(self.trainer, hook_name): @@ -335,6 +337,8 @@ def call_on_evaluation_epoch_end_hook(self): model_hook_params = list(inspect.signature(model_hook_fx).parameters) model_hook_fx(outputs) if "outputs" in model_hook_params else model_hook_fx() + self.trainer._cache_logged_metrics() + def log_evaluation_step_metrics(self, output, batch_idx): if self.trainer.sanity_checking: return diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 0d1c7cf40a2bf..69859547f4a1f 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -360,9 +360,9 @@ def on_validation_epoch_start(self): self.called.append(inspect.currentframe().f_code.co_name) super().on_validation_epoch_start() - def on_validation_epoch_end(self): + def on_validation_epoch_end(self, outputs): self.called.append(inspect.currentframe().f_code.co_name) - super().on_validation_epoch_end() + super().on_validation_epoch_end(outputs) def on_test_start(self): self.called.append(inspect.currentframe().f_code.co_name) @@ -380,9 +380,9 @@ def on_test_epoch_start(self): self.called.append(inspect.currentframe().f_code.co_name) super().on_test_epoch_start() - def on_test_epoch_end(self): + def on_test_epoch_end(self, outputs): self.called.append(inspect.currentframe().f_code.co_name) - super().on_test_epoch_end() + super().on_test_epoch_end(outputs) def on_validation_model_eval(self): self.called.append(inspect.currentframe().f_code.co_name) From a3d896635820d06e3171cb98226d818b34898201 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Mon, 22 Feb 2021 21:08:57 +0530 Subject: [PATCH 06/29] add args to docstrings --- pytorch_lightning/callbacks/base.py | 18 +++++++++++++++--- pytorch_lightning/trainer/callback_hook.py | 18 +++++++++++++++--- 2 files changed, 30 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 342824e91f2d2..31846686358e1 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -82,7 +82,11 @@ def on_train_epoch_start(self, trainer, pl_module: LightningModule) -> None: pass def on_train_epoch_end(self, trainer, pl_module: LightningModule, outputs: Any) -> None: - """Called when the train epoch ends.""" + """Called when the train epoch ends. + + Args: + outputs: List of outputs on each `train` epoch + """ pass def on_validation_epoch_start(self, trainer, pl_module: LightningModule) -> None: @@ -90,7 +94,11 @@ def on_validation_epoch_start(self, trainer, pl_module: LightningModule) -> None pass def on_validation_epoch_end(self, trainer, pl_module: LightningModule, outputs: Any) -> None: - """Called when the val epoch ends.""" + """Called when the val epoch ends. + + Args: + outputs: List of outputs on each `val` epoch + """ pass def on_test_epoch_start(self, trainer, pl_module: LightningModule) -> None: @@ -98,7 +106,11 @@ def on_test_epoch_start(self, trainer, pl_module: LightningModule) -> None: pass def on_test_epoch_end(self, trainer, pl_module: LightningModule, outputs: Any) -> None: - """Called when the test epoch ends.""" + """Called when the test epoch ends. + + Args: + outputs: List of outputs on each `test` epoch + """ pass def on_epoch_start(self, trainer, pl_module: LightningModule) -> None: diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 1b1a027f0af8d..dc62d2d45bf7c 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -81,7 +81,11 @@ def on_train_epoch_start(self): callback.on_train_epoch_start(self, self.lightning_module) def on_train_epoch_end(self, outputs): - """Called when the epoch ends.""" + """Called when the epoch ends. + + Args: + outputs: List of outputs on each `train` epoch + """ for callback in self.callbacks: callback.on_train_epoch_end(self, self.lightning_module, outputs) @@ -91,7 +95,11 @@ def on_validation_epoch_start(self): callback.on_validation_epoch_start(self, self.lightning_module) def on_validation_epoch_end(self, outputs): - """Called when the epoch ends.""" + """Called when the epoch ends. + + Args: + outputs: List of outputs on each `val` epoch + """ for callback in self.callbacks: params = list(inspect.signature(callback.on_validation_epoch_end).parameters) if "outputs" in params: @@ -110,7 +118,11 @@ def on_test_epoch_start(self): callback.on_test_epoch_start(self, self.lightning_module) def on_test_epoch_end(self, outputs): - """Called when the epoch ends.""" + """Called when the epoch ends. + + Args: + outputs: List of outputs on each `test` epoch + """ for callback in self.callbacks: params = list(inspect.signature(callback.on_test_epoch_end).parameters) if "outputs" in params: From ebd85074be37b8276edd116e9e9897959ceb4397 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Mon, 22 Feb 2021 21:26:22 +0530 Subject: [PATCH 07/29] use warning cache --- pytorch_lightning/trainer/callback_hook.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index dc62d2d45bf7c..db30ed1ac61af 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -20,7 +20,9 @@ from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.warnings import WarningCache + +warning_cache = WarningCache() class TrainerCallbackHookMixin(ABC): @@ -105,7 +107,7 @@ def on_validation_epoch_end(self, outputs): if "outputs" in params: callback.on_validation_epoch_end(self, self.get_model(), outputs) else: - rank_zero_warn( + warning_cache.warn( "`Callback.on_validation_epoch_end` signature has changed in v1.3." " `outputs` parameter has been added." " Support for the old signature will be removed in v1.5", DeprecationWarning @@ -128,7 +130,7 @@ def on_test_epoch_end(self, outputs): if "outputs" in params: callback.on_test_epoch_end(self, self.get_model(), outputs) else: - rank_zero_warn( + warning_cache.warn( "`Callback.on_test_epoch_end` signature has changed in v1.3." " `outputs` parameter has been added." " Support for the old signature will be removed in v1.5", DeprecationWarning From 00f829e095db36420e08a5bf4927bcedae45e546 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Mon, 22 Feb 2021 21:43:45 +0530 Subject: [PATCH 08/29] add utility method for param in sig check --- pytorch_lightning/trainer/callback_hook.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index db30ed1ac61af..a748e8c387d53 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -103,8 +103,7 @@ def on_validation_epoch_end(self, outputs): outputs: List of outputs on each `val` epoch """ for callback in self.callbacks: - params = list(inspect.signature(callback.on_validation_epoch_end).parameters) - if "outputs" in params: + if self._is_param_in_hook_signature(callback.on_validation_epoch_end, "outputs"): callback.on_validation_epoch_end(self, self.get_model(), outputs) else: warning_cache.warn( @@ -126,8 +125,7 @@ def on_test_epoch_end(self, outputs): outputs: List of outputs on each `test` epoch """ for callback in self.callbacks: - params = list(inspect.signature(callback.on_test_epoch_end).parameters) - if "outputs" in params: + if self._is_param_in_hook_signature(callback.on_test_epoch_end, "outputs"): callback.on_test_epoch_end(self, self.get_model(), outputs) else: warning_cache.warn( @@ -282,3 +280,10 @@ def on_before_zero_grad(self, optimizer): """ for callback in self.callbacks: callback.on_before_zero_grad(self, self.lightning_module, optimizer) + + @staticmethod + def _is_param_in_hook_signature(hook_fx: Callable, param: str) -> bool: + hook_params = list(inspect.signature(hook_fx).parameters) + if param in hook_params: + return True + return False From 4a45eb9b09182fba18010b51d2744419c40c5bd0 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Mon, 22 Feb 2021 23:22:06 +0530 Subject: [PATCH 09/29] Update CHANGELOG.md Co-authored-by: Jirka Borovec --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index fd8d31c26ae23..f6a7cb94bbd25 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,6 +40,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `outputs` parameter to callback's `on_validation_epoch_end` & `on_test_epoch_end` hooks ([#6120](https://github.com/PyTorchLightning/pytorch-lightning/pull/6120)) + ### Changed - Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259)) From b68a74ec4db2476b329e9a5b041e648c567cb6d8 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Mon, 22 Feb 2021 23:24:17 +0530 Subject: [PATCH 10/29] update docstring --- pytorch_lightning/callbacks/base.py | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 31846686358e1..342824e91f2d2 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -82,11 +82,7 @@ def on_train_epoch_start(self, trainer, pl_module: LightningModule) -> None: pass def on_train_epoch_end(self, trainer, pl_module: LightningModule, outputs: Any) -> None: - """Called when the train epoch ends. - - Args: - outputs: List of outputs on each `train` epoch - """ + """Called when the train epoch ends.""" pass def on_validation_epoch_start(self, trainer, pl_module: LightningModule) -> None: @@ -94,11 +90,7 @@ def on_validation_epoch_start(self, trainer, pl_module: LightningModule) -> None pass def on_validation_epoch_end(self, trainer, pl_module: LightningModule, outputs: Any) -> None: - """Called when the val epoch ends. - - Args: - outputs: List of outputs on each `val` epoch - """ + """Called when the val epoch ends.""" pass def on_test_epoch_start(self, trainer, pl_module: LightningModule) -> None: @@ -106,11 +98,7 @@ def on_test_epoch_start(self, trainer, pl_module: LightningModule) -> None: pass def on_test_epoch_end(self, trainer, pl_module: LightningModule, outputs: Any) -> None: - """Called when the test epoch ends. - - Args: - outputs: List of outputs on each `test` epoch - """ + """Called when the test epoch ends.""" pass def on_epoch_start(self, trainer, pl_module: LightningModule) -> None: From 9f8992e0e569621edb7edaab4b0c0f96564830d6 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Tue, 23 Feb 2021 00:11:06 +0530 Subject: [PATCH 11/29] add test for eval epoch end hook --- tests/trainer/test_evaluation_loop.py | 42 +++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 tests/trainer/test_evaluation_loop.py diff --git a/tests/trainer/test_evaluation_loop.py b/tests/trainer/test_evaluation_loop.py new file mode 100644 index 0000000000000..3fe58afde7341 --- /dev/null +++ b/tests/trainer/test_evaluation_loop.py @@ -0,0 +1,42 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest import mock + +from pytorch_lightning import Trainer +from tests.helpers.boring_model import BoringModel + + +@mock.patch("pytorch_lightning.trainer.evaluation_loop.EvaluationLoop.call_on_evaluation_epoch_end_hook") +def test_call_on_evaluation_epoch_end_hook(eval_epoch_end_mock, tmpdir): + """ + Tests that `call_on_evaluation_epoch_end_hook` is called + for `on_validation_epoch_end` and `on_test_epoch_end` hooks + """ + model = BoringModel() + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2, + weights_summary=None, + ) + + trainer.fit(model) + # sanity + 2 epochs + assert eval_epoch_end_mock.call_count == 3 + + trainer.test() + # sanity + 2 epochs + called once for test + assert eval_epoch_end_mock.call_count == 4 From ff634041d2fdc0f5bada04ac1fda2ceda44e14ae Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Tue, 23 Feb 2021 12:33:49 +0530 Subject: [PATCH 12/29] add types and replace model ref --- pytorch_lightning/callbacks/base.py | 8 ++++---- pytorch_lightning/trainer/callback_hook.py | 14 +++++++------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 342824e91f2d2..db507fa991446 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -17,7 +17,7 @@ """ import abc -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from pytorch_lightning.core.lightning import LightningModule @@ -81,7 +81,7 @@ def on_train_epoch_start(self, trainer, pl_module: LightningModule) -> None: """Called when the train epoch begins.""" pass - def on_train_epoch_end(self, trainer, pl_module: LightningModule, outputs: Any) -> None: + def on_train_epoch_end(self, trainer, pl_module: LightningModule, outputs: List[Any]) -> None: """Called when the train epoch ends.""" pass @@ -89,7 +89,7 @@ def on_validation_epoch_start(self, trainer, pl_module: LightningModule) -> None """Called when the val epoch begins.""" pass - def on_validation_epoch_end(self, trainer, pl_module: LightningModule, outputs: Any) -> None: + def on_validation_epoch_end(self, trainer, pl_module: LightningModule, outputs: List[Any]) -> None: """Called when the val epoch ends.""" pass @@ -97,7 +97,7 @@ def on_test_epoch_start(self, trainer, pl_module: LightningModule) -> None: """Called when the test epoch begins.""" pass - def on_test_epoch_end(self, trainer, pl_module: LightningModule, outputs: Any) -> None: + def on_test_epoch_end(self, trainer, pl_module: LightningModule, outputs: List[Any]) -> None: """Called when the test epoch ends.""" pass diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index a748e8c387d53..20a3ea130688b 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -82,7 +82,7 @@ def on_train_epoch_start(self): for callback in self.callbacks: callback.on_train_epoch_start(self, self.lightning_module) - def on_train_epoch_end(self, outputs): + def on_train_epoch_end(self, outputs: List[Any]): """Called when the epoch ends. Args: @@ -96,7 +96,7 @@ def on_validation_epoch_start(self): for callback in self.callbacks: callback.on_validation_epoch_start(self, self.lightning_module) - def on_validation_epoch_end(self, outputs): + def on_validation_epoch_end(self, outputs: List[Any]): """Called when the epoch ends. Args: @@ -104,21 +104,21 @@ def on_validation_epoch_end(self, outputs): """ for callback in self.callbacks: if self._is_param_in_hook_signature(callback.on_validation_epoch_end, "outputs"): - callback.on_validation_epoch_end(self, self.get_model(), outputs) + callback.on_validation_epoch_end(self, self.lightning_module, outputs) else: warning_cache.warn( "`Callback.on_validation_epoch_end` signature has changed in v1.3." " `outputs` parameter has been added." " Support for the old signature will be removed in v1.5", DeprecationWarning ) - callback.on_validation_epoch_end(self, self.get_model()) + callback.on_validation_epoch_end(self, self.lightning_module) def on_test_epoch_start(self): """Called when the epoch begins.""" for callback in self.callbacks: callback.on_test_epoch_start(self, self.lightning_module) - def on_test_epoch_end(self, outputs): + def on_test_epoch_end(self, outputs: List[Any]): """Called when the epoch ends. Args: @@ -126,14 +126,14 @@ def on_test_epoch_end(self, outputs): """ for callback in self.callbacks: if self._is_param_in_hook_signature(callback.on_test_epoch_end, "outputs"): - callback.on_test_epoch_end(self, self.get_model(), outputs) + callback.on_test_epoch_end(self, self.lightning_module, outputs) else: warning_cache.warn( "`Callback.on_test_epoch_end` signature has changed in v1.3." " `outputs` parameter has been added." " Support for the old signature will be removed in v1.5", DeprecationWarning ) - callback.on_test_epoch_end(self, self.get_model()) + callback.on_test_epoch_end(self, self.lightning_module) def on_epoch_start(self): """Called when the epoch begins.""" From cd01f88a7be4cadf76ee69b3dafc5bdb9e1a5494 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Tue, 23 Feb 2021 13:10:06 +0530 Subject: [PATCH 13/29] add deprecation test --- tests/deprecated_api/test_remove_1-5.py | 28 +++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index e65ebbab254de..b48d9f2a42352 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -111,3 +111,31 @@ def test_v1_5_0_model_checkpoint_period(tmpdir): ModelCheckpoint(dirpath=tmpdir) with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"): ModelCheckpoint(dirpath=tmpdir, period=1) + + +def test_v1_5_0_old_callback_on_validation_epoch_end(tmpdir): + + class OldSignature(Callback): + + def on_validation_epoch_end(self, trainer, pl_module): # noqa + ... + + model = BoringModel() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature()) + + with pytest.deprecated_call(match="old signature will be removed in v1.5"): + trainer.fit(model) + + +def test_v1_5_0_old_callback_on_validation_epoch_end(tmpdir): + + class OldSignature(Callback): + + def on_test_epoch_end(self, trainer, pl_module): # noqa + ... + + model = BoringModel() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature()) + + with pytest.deprecated_call(match="old signature will be removed in v1.5"): + trainer.test(model) From 388aad1accce83249a3153dfed698d408da77cb4 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Tue, 23 Feb 2021 13:18:22 +0530 Subject: [PATCH 14/29] fix test fx name --- tests/deprecated_api/test_remove_1-5.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index b48d9f2a42352..c05ec88f401c7 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -127,7 +127,7 @@ def on_validation_epoch_end(self, trainer, pl_module): # noqa trainer.fit(model) -def test_v1_5_0_old_callback_on_validation_epoch_end(tmpdir): +def test_v1_5_0_old_callback_on_test_epoch_end(tmpdir): class OldSignature(Callback): From 1a8e5b67cbb0ae0feab6df1523338cc2a9a5116f Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Tue, 23 Feb 2021 14:00:58 +0530 Subject: [PATCH 15/29] add model hooks warning --- pytorch_lightning/trainer/evaluation_loop.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 0bd784fb8c8bf..f6dd5bbf0c752 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -335,7 +335,15 @@ def call_on_evaluation_epoch_end_hook(self): if is_overridden(hook_name, model_ref): model_hook_fx = getattr(model_ref, hook_name) model_hook_params = list(inspect.signature(model_hook_fx).parameters) - model_hook_fx(outputs) if "outputs" in model_hook_params else model_hook_fx() + if "outputs" in model_hook_params: + model_hook_fx(outputs) + else: + self.warning_cache.warn( + f"`ModelHooks.{hook_name}` signature has changed in v1.3." + " `outputs` parameter has been added." + " Support for the old signature will be removed in v1.5", DeprecationWarning + ) + model_hook_fx() self.trainer._cache_logged_metrics() From d8d01a5374d8ff5b893421c15dc6d77a7c2572f2 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Tue, 23 Feb 2021 14:07:06 +0530 Subject: [PATCH 16/29] add old signature model to tests --- pytorch_lightning/core/hooks.py | 6 +++--- tests/deprecated_api/test_remove_1-5.py | 24 ++++++++++++++++++++++-- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 186da37aac262..9624f94652713 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -240,7 +240,7 @@ def on_train_epoch_start(self) -> None: """ # do something when the epoch starts - def on_train_epoch_end(self, outputs) -> None: + def on_train_epoch_end(self, outputs: List[Any]) -> None: """ Called in the training loop at the very end of the epoch. """ @@ -252,7 +252,7 @@ def on_validation_epoch_start(self) -> None: """ # do something when the epoch starts - def on_validation_epoch_end(self, outputs: Any) -> None: + def on_validation_epoch_end(self, outputs: List[Any]) -> None: """ Called in the validation loop at the very end of the epoch. """ @@ -264,7 +264,7 @@ def on_test_epoch_start(self) -> None: """ # do something when the epoch starts - def on_test_epoch_end(self, outputs: Any) -> None: + def on_test_epoch_end(self, outputs: List[Any]) -> None: """ Called in the test loop at the very end of the epoch. """ diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index c05ec88f401c7..b0d3a39bb2f6b 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -113,7 +113,7 @@ def test_v1_5_0_model_checkpoint_period(tmpdir): ModelCheckpoint(dirpath=tmpdir, period=1) -def test_v1_5_0_old_callback_on_validation_epoch_end(tmpdir): +def test_v1_5_0_old_on_validation_epoch_end(tmpdir): class OldSignature(Callback): @@ -126,8 +126,18 @@ def on_validation_epoch_end(self, trainer, pl_module): # noqa with pytest.deprecated_call(match="old signature will be removed in v1.5"): trainer.fit(model) + class OldSignatureModel(BoringModel): + + def on_validation_epoch_end(self): # noqa + ... + + model = OldSignatureModel() + + with pytest.deprecated_call(match="old signature will be removed in v1.5"): + trainer.fit(model) + -def test_v1_5_0_old_callback_on_test_epoch_end(tmpdir): +def test_v1_5_0_old_on_test_epoch_end(tmpdir): class OldSignature(Callback): @@ -139,3 +149,13 @@ def on_test_epoch_end(self, trainer, pl_module): # noqa with pytest.deprecated_call(match="old signature will be removed in v1.5"): trainer.test(model) + + class OldSignatureModel(BoringModel): + + def on_test_epoch_end(self): # noqa + ... + + model = OldSignatureModel() + + with pytest.deprecated_call(match="old signature will be removed in v1.5"): + trainer.test(model) From 9481ab6d6390b0a9c97aeb2f59285b3af10f309e Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Tue, 23 Feb 2021 16:05:31 +0530 Subject: [PATCH 17/29] add clear warning cache --- tests/deprecated_api/test_remove_1-5.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index b0d3a39bb2f6b..05042c2aa44c4 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -20,6 +20,7 @@ from pytorch_lightning import Callback, Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import WandbLogger +from pytorch_lightning.trainer.callback_hook import warning_cache as hook_warning_cache from tests.helpers import BoringModel from tests.helpers.utils import no_warning_call @@ -114,6 +115,7 @@ def test_v1_5_0_model_checkpoint_period(tmpdir): def test_v1_5_0_old_on_validation_epoch_end(tmpdir): + hook_warning_cache.clear() class OldSignature(Callback): @@ -138,6 +140,7 @@ def on_validation_epoch_end(self): # noqa def test_v1_5_0_old_on_test_epoch_end(tmpdir): + hook_warning_cache.clear() class OldSignature(Callback): From 6786a4361cd074b0b44b5c796a35c9b34d0f9750 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Tue, 23 Feb 2021 19:03:04 +0530 Subject: [PATCH 18/29] sopport args param --- pytorch_lightning/trainer/callback_hook.py | 2 +- pytorch_lightning/trainer/evaluation_loop.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 20a3ea130688b..86549553241ba 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -284,6 +284,6 @@ def on_before_zero_grad(self, optimizer): @staticmethod def _is_param_in_hook_signature(hook_fx: Callable, param: str) -> bool: hook_params = list(inspect.signature(hook_fx).parameters) - if param in hook_params: + if "args" in hook_params or param in hook_params: return True return False diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index f6dd5bbf0c752..265c86ddfaf31 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -335,7 +335,7 @@ def call_on_evaluation_epoch_end_hook(self): if is_overridden(hook_name, model_ref): model_hook_fx = getattr(model_ref, hook_name) model_hook_params = list(inspect.signature(model_hook_fx).parameters) - if "outputs" in model_hook_params: + if "args" in model_hook_params or "outputs" in model_hook_params: model_hook_fx(outputs) else: self.warning_cache.warn( From 1dbd85189a592d300a7449fe7474743d6375989a Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Tue, 23 Feb 2021 19:56:36 +0530 Subject: [PATCH 19/29] update tests --- tests/callbacks/test_callbacks.py | 6 ++-- tests/deprecated_api/__init__.py | 18 ++++++++++ tests/deprecated_api/test_remove_1-5.py | 47 +++++++++++++++++++++++-- 3 files changed, 65 insertions(+), 6 deletions(-) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 626eb59dffb9c..737e81185046f 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -56,7 +56,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir): call.on_validation_epoch_start(trainer, model), call.on_validation_batch_start(trainer, model, ANY, 0, 0), call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), - call.on_validation_epoch_end(trainer, model), + call.on_validation_epoch_end(trainer, model, ANY), call.on_epoch_end(trainer, model), call.on_validation_end(trainer, model), call.on_sanity_check_end(trainer, model), @@ -87,7 +87,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir): call.on_validation_epoch_start(trainer, model), call.on_validation_batch_start(trainer, model, ANY, 0, 0), call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), - call.on_validation_epoch_end(trainer, model), + call.on_validation_epoch_end(trainer, model, ANY), call.on_epoch_end(trainer, model), call.on_validation_end(trainer, model), call.on_save_checkpoint(trainer, model), # should take ANY but we are inspecting signature for BC @@ -123,7 +123,7 @@ def test_trainer_callback_hook_system_test(tmpdir): call.on_test_batch_end(trainer, model, ANY, ANY, 0, 0), call.on_test_batch_start(trainer, model, ANY, 1, 0), call.on_test_batch_end(trainer, model, ANY, ANY, 1, 0), - call.on_test_epoch_end(trainer, model), + call.on_test_epoch_end(trainer, model, ANY), call.on_epoch_end(trainer, model), call.on_test_end(trainer, model), call.teardown(trainer, model, 'test'), diff --git a/tests/deprecated_api/__init__.py b/tests/deprecated_api/__init__.py index 99e21d1ed6b22..ccfae3ec8dcf2 100644 --- a/tests/deprecated_api/__init__.py +++ b/tests/deprecated_api/__init__.py @@ -13,9 +13,27 @@ # limitations under the License. """Test deprecated functionality which will be removed in vX.Y.Z""" import sys +from contextlib import contextmanager +from typing import Optional + +import pytest def _soft_unimport_module(str_module): # once the module is imported e.g with parsing with pytest it lives in memory if str_module in sys.modules: del sys.modules[str_module] + + +@contextmanager +def no_deprecated_call(match: Optional[str] = None): + with pytest.warns(None) as record: + yield + try: + w = record.pop(DeprecationWarning) + if match is not None and match not in str(w.message): + return + except AssertionError: + # no DeprecationWarning raised + return + raise AssertionError(f"`DeprecationWarning` was raised: {w}") diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index 05042c2aa44c4..f449a37e33c25 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -20,7 +20,8 @@ from pytorch_lightning import Callback, Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import WandbLogger -from pytorch_lightning.trainer.callback_hook import warning_cache as hook_warning_cache +from pytorch_lightning.trainer.callback_hook import warning_cache as callback_warning_cache +from tests.deprecated_api import no_deprecated_call from tests.helpers import BoringModel from tests.helpers.utils import no_warning_call @@ -115,7 +116,7 @@ def test_v1_5_0_model_checkpoint_period(tmpdir): def test_v1_5_0_old_on_validation_epoch_end(tmpdir): - hook_warning_cache.clear() + callback_warning_cache.clear() class OldSignature(Callback): @@ -138,9 +139,29 @@ def on_validation_epoch_end(self): # noqa with pytest.deprecated_call(match="old signature will be removed in v1.5"): trainer.fit(model) + callback_warning_cache.clear() + + class NewSignature(Callback): + + def on_validation_epoch_end(self, trainer, pl_module, outputs): + ... + + trainer.callbacks = [NewSignature()] + with no_deprecated_call(match="`Callback.on_validation_epoch_end` signature has changed in v1.3."): + trainer.fit(model) + + class NewSignatureModel(BoringModel): + + def on_validation_epoch_end(self, outputs): + ... + + model = NewSignatureModel() + with no_deprecated_call(match="`ModelHooks.on_validation_epoch_end` signature has changed in v1.3."): + trainer.fit(model) + def test_v1_5_0_old_on_test_epoch_end(tmpdir): - hook_warning_cache.clear() + callback_warning_cache.clear() class OldSignature(Callback): @@ -162,3 +183,23 @@ def on_test_epoch_end(self): # noqa with pytest.deprecated_call(match="old signature will be removed in v1.5"): trainer.test(model) + + callback_warning_cache.clear() + + class NewSignature(Callback): + + def on_test_epoch_end(self, trainer, pl_module, outputs): + ... + + trainer.callbacks = [NewSignature()] + with no_deprecated_call(match="`Callback.on_test_epoch_end` signature has changed in v1.3."): + trainer.test(model) + + class NewSignatureModel(BoringModel): + + def on_test_epoch_end(self, outputs): + ... + + model = NewSignatureModel() + with no_deprecated_call(match="`ModelHooks.on_test_epoch_end` signature has changed in v1.3."): + trainer.test(model) From 7aeb3e4c69c53a43e7d2693f4b5501291f618030 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Wed, 24 Feb 2021 20:46:18 +0530 Subject: [PATCH 20/29] add tests for model hooks --- tests/core/test_hooks.py | 55 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 tests/core/test_hooks.py diff --git a/tests/core/test_hooks.py b/tests/core/test_hooks.py new file mode 100644 index 0000000000000..4eab4cc488199 --- /dev/null +++ b/tests/core/test_hooks.py @@ -0,0 +1,55 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning import Trainer +from tests.helpers.boring_model import BoringModel + + +def test_on_val_epoch_end_outputs(tmpdir): + + class TestModel(BoringModel): + + def on_validation_epoch_end(self, outputs): + if trainer.running_sanity_check: + assert len(outputs[0]) == trainer.num_sanity_val_batches[0] + else: + assert len(outputs[0]) == trainer.num_val_batches[0] + + model = TestModel() + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + weights_summary=None, + ) + + trainer.fit(model) + + +def test_on_test_epoch_end_outputs(tmpdir): + + class TestModel(BoringModel): + + def on_test_epoch_end(self, outputs): + assert len(outputs[0]) == trainer.num_test_batches[0] + + model = TestModel() + + trainer = Trainer( + default_root_dir=tmpdir, + weights_summary=None, + ) + + trainer.test(model) From 053edb5e09ea3f96187a0f48c1d77b99fcce10ad Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Thu, 25 Feb 2021 01:12:59 +0530 Subject: [PATCH 21/29] code suggestions --- pytorch_lightning/trainer/callback_hook.py | 6 +++--- pytorch_lightning/trainer/evaluation_loop.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 86549553241ba..73cca70a4bbf3 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -86,7 +86,7 @@ def on_train_epoch_end(self, outputs: List[Any]): """Called when the epoch ends. Args: - outputs: List of outputs on each `train` epoch + outputs: List of outputs on each ``train`` epoch """ for callback in self.callbacks: callback.on_train_epoch_end(self, self.lightning_module, outputs) @@ -100,7 +100,7 @@ def on_validation_epoch_end(self, outputs: List[Any]): """Called when the epoch ends. Args: - outputs: List of outputs on each `val` epoch + outputs: List of outputs on each ``validation`` epoch """ for callback in self.callbacks: if self._is_param_in_hook_signature(callback.on_validation_epoch_end, "outputs"): @@ -122,7 +122,7 @@ def on_test_epoch_end(self, outputs: List[Any]): """Called when the epoch ends. Args: - outputs: List of outputs on each `test` epoch + outputs: List of outputs on each ``test`` epoch """ for callback in self.callbacks: if self._is_param_in_hook_signature(callback.on_test_epoch_end, "outputs"): diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 265c86ddfaf31..d461372d4654e 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -321,7 +321,7 @@ def on_evaluation_epoch_end(self, *args, **kwargs): def call_on_evaluation_epoch_end_hook(self): outputs = self.outputs - model_ref = self.trainer.get_model() + model_ref = self.trainer.lightning_module hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end" self.trainer._reset_result_and_set_hook_fx_name(hook_name) From f978567bc335d1ba49c68bea553dcfddd22953ed Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Thu, 25 Feb 2021 01:31:22 +0530 Subject: [PATCH 22/29] add signature utils --- pytorch_lightning/trainer/callback_hook.py | 13 +++-------- pytorch_lightning/trainer/evaluation_loop.py | 5 ++--- .../utilities/signature_utils.py | 22 +++++++++++++++++++ tests/core/test_hooks.py | 1 + 4 files changed, 28 insertions(+), 13 deletions(-) create mode 100644 pytorch_lightning/utilities/signature_utils.py diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 73cca70a4bbf3..848ed3ad96bec 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect from abc import ABC from copy import deepcopy from inspect import signature @@ -20,6 +19,7 @@ from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.warnings import WarningCache warning_cache = WarningCache() @@ -103,7 +103,7 @@ def on_validation_epoch_end(self, outputs: List[Any]): outputs: List of outputs on each ``validation`` epoch """ for callback in self.callbacks: - if self._is_param_in_hook_signature(callback.on_validation_epoch_end, "outputs"): + if is_param_in_hook_signature(callback.on_validation_epoch_end, "outputs"): callback.on_validation_epoch_end(self, self.lightning_module, outputs) else: warning_cache.warn( @@ -125,7 +125,7 @@ def on_test_epoch_end(self, outputs: List[Any]): outputs: List of outputs on each ``test`` epoch """ for callback in self.callbacks: - if self._is_param_in_hook_signature(callback.on_test_epoch_end, "outputs"): + if is_param_in_hook_signature(callback.on_test_epoch_end, "outputs"): callback.on_test_epoch_end(self, self.lightning_module, outputs) else: warning_cache.warn( @@ -280,10 +280,3 @@ def on_before_zero_grad(self, optimizer): """ for callback in self.callbacks: callback.on_before_zero_grad(self, self.lightning_module, optimizer) - - @staticmethod - def _is_param_in_hook_signature(hook_fx: Callable, param: str) -> bool: - hook_params = list(inspect.signature(hook_fx).parameters) - if "args" in hook_params or param in hook_params: - return True - return False diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index d461372d4654e..419736d248837 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import inspect import torch @@ -19,6 +18,7 @@ from pytorch_lightning.trainer.supporters import PredictionCollection from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.warnings import WarningCache @@ -334,8 +334,7 @@ def call_on_evaluation_epoch_end_hook(self): if is_overridden(hook_name, model_ref): model_hook_fx = getattr(model_ref, hook_name) - model_hook_params = list(inspect.signature(model_hook_fx).parameters) - if "args" in model_hook_params or "outputs" in model_hook_params: + if is_param_in_hook_signature(model_hook_fx, "outputs"): model_hook_fx(outputs) else: self.warning_cache.warn( diff --git a/pytorch_lightning/utilities/signature_utils.py b/pytorch_lightning/utilities/signature_utils.py new file mode 100644 index 0000000000000..546d8e845ecb1 --- /dev/null +++ b/pytorch_lightning/utilities/signature_utils.py @@ -0,0 +1,22 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from typing import Callable + + +def is_param_in_hook_signature(hook_fx: Callable, param: str) -> bool: + hook_params = list(inspect.signature(hook_fx).parameters) + if "args" in hook_params or param in hook_params: + return True + return False diff --git a/tests/core/test_hooks.py b/tests/core/test_hooks.py index 4eab4cc488199..191da0a1400c7 100644 --- a/tests/core/test_hooks.py +++ b/tests/core/test_hooks.py @@ -49,6 +49,7 @@ def on_test_epoch_end(self, outputs): trainer = Trainer( default_root_dir=tmpdir, + fast_dev_run=2, weights_summary=None, ) From 06b1771e8503cd46071d271500be2f83e982067e Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Thu, 25 Feb 2021 21:30:49 +0530 Subject: [PATCH 23/29] fix pep8 issues --- tests/deprecated_api/test_remove_1-5.py | 41 +++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index f449a37e33c25..6fdfa29a0a587 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -203,3 +203,44 @@ def on_test_epoch_end(self, outputs): model = NewSignatureModel() with no_deprecated_call(match="`ModelHooks.on_test_epoch_end` signature has changed in v1.3."): trainer.test(model) + + +def test_v1_5_0_old_callback_on_save_checkpoint(tmpdir): + + class OldSignature(Callback): + + def on_save_checkpoint(self, trainer, pl_module): # noqa + ... + + model = BoringModel() + trainer_kwargs = { + "default_root_dir": tmpdir, + "checkpoint_callback": False, + "max_epochs": 1, + } + filepath = tmpdir / "test.ckpt" + + trainer = Trainer(**trainer_kwargs, callbacks=[OldSignature()]) + trainer.fit(model) + + with pytest.deprecated_call(match="old signature will be removed in v1.5"): + trainer.save_checkpoint(filepath) + + class NewSignature(Callback): + + def on_save_checkpoint(self, trainer, pl_module, checkpoint): + ... + + class ValidSignature1(Callback): + + def on_save_checkpoint(self, trainer, *args): + ... + + class ValidSignature2(Callback): + + def on_save_checkpoint(self, *args): + ... + + trainer.callbacks = [NewSignature(), ValidSignature1(), ValidSignature2()] + with no_warning_call(DeprecationWarning): + trainer.save_checkpoint(filepath) From 174f76734bb7dc34cd68f851986661d635265b7a Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Sun, 28 Feb 2021 20:28:04 +0530 Subject: [PATCH 24/29] fix pep8 issues --- tests/deprecated_api/test_remove_1-5.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index 6fdfa29a0a587..cc2eb037197f9 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -205,6 +205,12 @@ def on_test_epoch_end(self, outputs): trainer.test(model) +@mock.patch('pytorch_lightning.loggers.wandb.wandb') +def test_v1_5_0_wandb_unused_sync_step(tmpdir): + with pytest.deprecated_call(match=r"v1.2.1 and will be removed in v1.5"): + WandbLogger(sync_step=True) + + def test_v1_5_0_old_callback_on_save_checkpoint(tmpdir): class OldSignature(Callback): From bce1a2c731af02c5f8a6f60083c1d8e4e444e97f Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Thu, 11 Mar 2021 21:47:24 +0530 Subject: [PATCH 25/29] fix outputs issue --- pytorch_lightning/trainer/evaluation_loop.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 419736d248837..20c842939fe17 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -204,9 +204,6 @@ def __run_eval_epoch_end(self, num_dataloaders): # with a single dataloader don't pass an array outputs = self.outputs - # free memory - self.outputs = [] - eval_results = outputs if num_dataloaders == 1: eval_results = outputs[0] @@ -321,6 +318,10 @@ def on_evaluation_epoch_end(self, *args, **kwargs): def call_on_evaluation_epoch_end_hook(self): outputs = self.outputs + + # free memory + self.outputs = [] + model_ref = self.trainer.lightning_module hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end" From edc48b86cb681201f879c555d56f4cf2f95a283b Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Thu, 11 Mar 2021 22:38:56 +0530 Subject: [PATCH 26/29] fix tests --- tests/trainer/logging_/test_eval_loop_logging_1_0.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/trainer/logging_/test_eval_loop_logging_1_0.py b/tests/trainer/logging_/test_eval_loop_logging_1_0.py index 72084454ba10d..e5cf596a78eca 100644 --- a/tests/trainer/logging_/test_eval_loop_logging_1_0.py +++ b/tests/trainer/logging_/test_eval_loop_logging_1_0.py @@ -126,7 +126,6 @@ def validation_step_end(self, acc): def validation_epoch_end(self, outputs): self.log('g', torch.tensor(2, device=self.device), on_epoch=True) self.validation_epoch_end_called = True - assert len(self.trainer.evaluation_loop.outputs) == 0 def backward(self, loss, optimizer, optimizer_idx): return LightningModule.backward(self, loss, optimizer, optimizer_idx) From 8c143a1ac17c9ddf3e669db9e9df2cc6ca322012 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Fri, 12 Mar 2021 20:00:48 +0530 Subject: [PATCH 27/29] code fixes --- pytorch_lightning/trainer/callback_hook.py | 3 +- tests/deprecated_api/test_remove_1-5.py | 47 ---------------------- 2 files changed, 2 insertions(+), 48 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 848ed3ad96bec..8823d48a7817e 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -15,10 +15,11 @@ from abc import ABC from copy import deepcopy from inspect import signature -from typing import Any, Callable, Dict, List, Type, Optional +from typing import Any, Callable, Dict, List, Optional, Type from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.warnings import WarningCache diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index cc2eb037197f9..f449a37e33c25 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -203,50 +203,3 @@ def on_test_epoch_end(self, outputs): model = NewSignatureModel() with no_deprecated_call(match="`ModelHooks.on_test_epoch_end` signature has changed in v1.3."): trainer.test(model) - - -@mock.patch('pytorch_lightning.loggers.wandb.wandb') -def test_v1_5_0_wandb_unused_sync_step(tmpdir): - with pytest.deprecated_call(match=r"v1.2.1 and will be removed in v1.5"): - WandbLogger(sync_step=True) - - -def test_v1_5_0_old_callback_on_save_checkpoint(tmpdir): - - class OldSignature(Callback): - - def on_save_checkpoint(self, trainer, pl_module): # noqa - ... - - model = BoringModel() - trainer_kwargs = { - "default_root_dir": tmpdir, - "checkpoint_callback": False, - "max_epochs": 1, - } - filepath = tmpdir / "test.ckpt" - - trainer = Trainer(**trainer_kwargs, callbacks=[OldSignature()]) - trainer.fit(model) - - with pytest.deprecated_call(match="old signature will be removed in v1.5"): - trainer.save_checkpoint(filepath) - - class NewSignature(Callback): - - def on_save_checkpoint(self, trainer, pl_module, checkpoint): - ... - - class ValidSignature1(Callback): - - def on_save_checkpoint(self, trainer, *args): - ... - - class ValidSignature2(Callback): - - def on_save_checkpoint(self, *args): - ... - - trainer.callbacks = [NewSignature(), ValidSignature1(), ValidSignature2()] - with no_warning_call(DeprecationWarning): - trainer.save_checkpoint(filepath) From ff4aade409b1e31f466767b2936d28ac84f60d78 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Fri, 12 Mar 2021 20:02:03 +0530 Subject: [PATCH 28/29] fix validate test --- tests/callbacks/test_callbacks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 737e81185046f..608f7bf1051f6 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -156,7 +156,7 @@ def test_trainer_callback_hook_system_validate(tmpdir): call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), call.on_validation_batch_start(trainer, model, ANY, 1, 0), call.on_validation_batch_end(trainer, model, ANY, ANY, 1, 0), - call.on_validation_epoch_end(trainer, model), + call.on_validation_epoch_end(trainer, model, ANY), call.on_epoch_end(trainer, model), call.on_validation_end(trainer, model), call.teardown(trainer, model, 'validate'), From 4681130e6494d0b5860fa25c419b130f15444c5e Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Fri, 12 Mar 2021 20:18:32 +0530 Subject: [PATCH 29/29] test --- tests/callbacks/test_callback_hook_outputs.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/callbacks/test_callback_hook_outputs.py b/tests/callbacks/test_callback_hook_outputs.py index b52607edf280c..df0eab31aac37 100644 --- a/tests/callbacks/test_callback_hook_outputs.py +++ b/tests/callbacks/test_callback_hook_outputs.py @@ -113,3 +113,24 @@ def on_test_epoch_end(self, trainer, pl_module, outputs): ) trainer.test(model) + + +def test_free_memory_on_eval_outputs(tmpdir): + + class CB(Callback): + + def on_epoch_end(self, trainer, pl_module): + assert len(trainer.evaluation_loop.outputs) == 0 + + model = BoringModel() + + trainer = Trainer( + callbacks=CB(), + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + weights_summary=None, + ) + + trainer.fit(model)