Skip to content

Add outputs param for on_val/test_epoch_end hooks #6120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Mar 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
c34da59
add outputs param for on_val/test_epoch_end hooks
kaushikb11 Feb 22, 2021
9320916
update changelog
kaushikb11 Feb 22, 2021
fd84be8
fix warning message
kaushikb11 Feb 22, 2021
871884a
add custom call hook
kaushikb11 Feb 22, 2021
66bc0c7
cache logged metrics
kaushikb11 Feb 22, 2021
a3d8966
add args to docstrings
kaushikb11 Feb 22, 2021
ebd8507
use warning cache
kaushikb11 Feb 22, 2021
00f829e
add utility method for param in sig check
kaushikb11 Feb 22, 2021
4a45eb9
Update CHANGELOG.md
kaushikb11 Feb 22, 2021
b68a74e
update docstring
kaushikb11 Feb 22, 2021
9f8992e
add test for eval epoch end hook
kaushikb11 Feb 22, 2021
ff63404
add types and replace model ref
kaushikb11 Feb 23, 2021
cd01f88
add deprecation test
kaushikb11 Feb 23, 2021
388aad1
fix test fx name
kaushikb11 Feb 23, 2021
1a8e5b6
add model hooks warning
kaushikb11 Feb 23, 2021
d8d01a5
add old signature model to tests
kaushikb11 Feb 23, 2021
9481ab6
add clear warning cache
kaushikb11 Feb 23, 2021
6786a43
sopport args param
kaushikb11 Feb 23, 2021
1dbd851
update tests
kaushikb11 Feb 23, 2021
7aeb3e4
add tests for model hooks
kaushikb11 Feb 24, 2021
053edb5
code suggestions
kaushikb11 Feb 24, 2021
f978567
add signature utils
kaushikb11 Feb 24, 2021
06b1771
fix pep8 issues
kaushikb11 Feb 25, 2021
174f767
fix pep8 issues
kaushikb11 Feb 28, 2021
bce1a2c
fix outputs issue
kaushikb11 Mar 11, 2021
edc48b8
fix tests
kaushikb11 Mar 11, 2021
8c143a1
code fixes
kaushikb11 Mar 12, 2021
ff4aade
fix validate test
kaushikb11 Mar 12, 2021
4681130
test
kaushikb11 Mar 12, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added no return warning to predict ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139))


- Added `outputs` parameter to callback's `on_validation_epoch_end` & `on_test_epoch_end` hooks ([#6120](https://github.com/PyTorchLightning/pytorch-lightning/pull/6120))



### Changed

- Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259))
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""

import abc
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional

from pytorch_lightning.core.lightning import LightningModule

Expand Down Expand Up @@ -81,23 +81,23 @@ def on_train_epoch_start(self, trainer, pl_module: LightningModule) -> None:
"""Called when the train epoch begins."""
pass

def on_train_epoch_end(self, trainer, pl_module: LightningModule, outputs: Any) -> None:
def on_train_epoch_end(self, trainer, pl_module: LightningModule, outputs: List[Any]) -> None:
"""Called when the train epoch ends."""
pass

def on_validation_epoch_start(self, trainer, pl_module: LightningModule) -> None:
"""Called when the val epoch begins."""
pass

def on_validation_epoch_end(self, trainer, pl_module: LightningModule) -> None:
def on_validation_epoch_end(self, trainer, pl_module: LightningModule, outputs: List[Any]) -> None:
"""Called when the val epoch ends."""
pass

def on_test_epoch_start(self, trainer, pl_module: LightningModule) -> None:
"""Called when the test epoch begins."""
pass

def on_test_epoch_end(self, trainer, pl_module: LightningModule) -> None:
def on_test_epoch_end(self, trainer, pl_module: LightningModule, outputs: List[Any]) -> None:
"""Called when the test epoch ends."""
pass

Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def on_train_epoch_start(self) -> None:
"""
# do something when the epoch starts

def on_train_epoch_end(self, outputs) -> None:
def on_train_epoch_end(self, outputs: List[Any]) -> None:
"""
Called in the training loop at the very end of the epoch.
"""
Expand All @@ -252,7 +252,7 @@ def on_validation_epoch_start(self) -> None:
"""
# do something when the epoch starts

def on_validation_epoch_end(self) -> None:
def on_validation_epoch_end(self, outputs: List[Any]) -> None:
"""
Called in the validation loop at the very end of the epoch.
"""
Expand All @@ -264,7 +264,7 @@ def on_test_epoch_start(self) -> None:
"""
# do something when the epoch starts

def on_test_epoch_end(self) -> None:
def on_test_epoch_end(self, outputs: List[Any]) -> None:
"""
Called in the test loop at the very end of the epoch.
"""
Expand Down
50 changes: 41 additions & 9 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@
from abc import ABC
from copy import deepcopy
from inspect import signature
from typing import Any, Callable, Dict, List, Type, Optional
from typing import Any, Callable, Dict, List, Optional, Type

from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.warnings import WarningCache

warning_cache = WarningCache()


class TrainerCallbackHookMixin(ABC):
Expand Down Expand Up @@ -79,8 +83,12 @@ def on_train_epoch_start(self):
for callback in self.callbacks:
callback.on_train_epoch_start(self, self.lightning_module)

def on_train_epoch_end(self, outputs):
"""Called when the epoch ends."""
def on_train_epoch_end(self, outputs: List[Any]):
"""Called when the epoch ends.

Args:
outputs: List of outputs on each ``train`` epoch
"""
for callback in self.callbacks:
callback.on_train_epoch_end(self, self.lightning_module, outputs)

Expand All @@ -89,20 +97,44 @@ def on_validation_epoch_start(self):
for callback in self.callbacks:
callback.on_validation_epoch_start(self, self.lightning_module)

def on_validation_epoch_end(self):
"""Called when the epoch ends."""
def on_validation_epoch_end(self, outputs: List[Any]):
"""Called when the epoch ends.

Args:
outputs: List of outputs on each ``validation`` epoch
"""
for callback in self.callbacks:
callback.on_validation_epoch_end(self, self.lightning_module)
if is_param_in_hook_signature(callback.on_validation_epoch_end, "outputs"):
callback.on_validation_epoch_end(self, self.lightning_module, outputs)
else:
warning_cache.warn(
"`Callback.on_validation_epoch_end` signature has changed in v1.3."
" `outputs` parameter has been added."
" Support for the old signature will be removed in v1.5", DeprecationWarning
)
callback.on_validation_epoch_end(self, self.lightning_module)

def on_test_epoch_start(self):
"""Called when the epoch begins."""
for callback in self.callbacks:
callback.on_test_epoch_start(self, self.lightning_module)

def on_test_epoch_end(self):
"""Called when the epoch ends."""
def on_test_epoch_end(self, outputs: List[Any]):
"""Called when the epoch ends.

Args:
outputs: List of outputs on each ``test`` epoch
"""
for callback in self.callbacks:
callback.on_test_epoch_end(self, self.lightning_module)
if is_param_in_hook_signature(callback.on_test_epoch_end, "outputs"):
callback.on_test_epoch_end(self, self.lightning_module, outputs)
else:
warning_cache.warn(
"`Callback.on_test_epoch_end` signature has changed in v1.3."
" `outputs` parameter has been added."
" Support for the old signature will be removed in v1.5", DeprecationWarning
)
callback.on_test_epoch_end(self, self.lightning_module)

def on_epoch_start(self):
"""Called when the epoch begins."""
Expand Down
41 changes: 34 additions & 7 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from pytorch_lightning.core.step_result import Result
from pytorch_lightning.trainer.supporters import PredictionCollection
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.warnings import WarningCache


Expand Down Expand Up @@ -202,9 +204,6 @@ def __run_eval_epoch_end(self, num_dataloaders):
# with a single dataloader don't pass an array
outputs = self.outputs

# free memory
self.outputs = []

eval_results = outputs
if num_dataloaders == 1:
eval_results = outputs[0]
Expand Down Expand Up @@ -313,13 +312,41 @@ def store_predictions(self, output, batch_idx, dataloader_idx):

def on_evaluation_epoch_end(self, *args, **kwargs):
# call the callback hook
if self.trainer.testing:
self.trainer.call_hook('on_test_epoch_end', *args, **kwargs)
else:
self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs)
self.call_on_evaluation_epoch_end_hook()

self.trainer.call_hook('on_epoch_end')

def call_on_evaluation_epoch_end_hook(self):
outputs = self.outputs

# free memory
self.outputs = []

model_ref = self.trainer.lightning_module
hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end"

self.trainer._reset_result_and_set_hook_fx_name(hook_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not using self.trainer.call_hook with outputs directly if the outputs is present ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.trainer.call_hook addresses callback_hooks and model_hook together, and pass the same arguments. We need to inspect both of them separately and pass outputs based on the signature.


with self.trainer.profiler.profile(hook_name):

if hasattr(self.trainer, hook_name):
on_evaluation_epoch_end_hook = getattr(self.trainer, hook_name)
on_evaluation_epoch_end_hook(outputs)

if is_overridden(hook_name, model_ref):
model_hook_fx = getattr(model_ref, hook_name)
if is_param_in_hook_signature(model_hook_fx, "outputs"):
model_hook_fx(outputs)
else:
self.warning_cache.warn(
f"`ModelHooks.{hook_name}` signature has changed in v1.3."
" `outputs` parameter has been added."
" Support for the old signature will be removed in v1.5", DeprecationWarning
)
model_hook_fx()

self.trainer._cache_logged_metrics()

def log_evaluation_step_metrics(self, output, batch_idx):
if self.trainer.sanity_checking:
return
Expand Down
22 changes: 22 additions & 0 deletions pytorch_lightning/utilities/signature_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable


def is_param_in_hook_signature(hook_fx: Callable, param: str) -> bool:
hook_params = list(inspect.signature(hook_fx).parameters)
if "args" in hook_params or param in hook_params:
return True
return False
63 changes: 63 additions & 0 deletions tests/callbacks/test_callback_hook_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,66 @@ def on_train_epoch_end(self, outputs) -> None:

results = trainer.fit(model)
assert results


def test_on_val_epoch_end_outputs(tmpdir):

class CB(Callback):

def on_validation_epoch_end(self, trainer, pl_module, outputs):
if trainer.running_sanity_check:
assert len(outputs[0]) == trainer.num_sanity_val_batches[0]
else:
assert len(outputs[0]) == trainer.num_val_batches[0]

model = BoringModel()

trainer = Trainer(
callbacks=CB(),
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=1,
weights_summary=None,
)

trainer.fit(model)


def test_on_test_epoch_end_outputs(tmpdir):

class CB(Callback):

def on_test_epoch_end(self, trainer, pl_module, outputs):
assert len(outputs[0]) == trainer.num_test_batches[0]

model = BoringModel()

trainer = Trainer(
callbacks=CB(),
default_root_dir=tmpdir,
weights_summary=None,
)

trainer.test(model)


def test_free_memory_on_eval_outputs(tmpdir):

class CB(Callback):

def on_epoch_end(self, trainer, pl_module):
assert len(trainer.evaluation_loop.outputs) == 0

model = BoringModel()

trainer = Trainer(
callbacks=CB(),
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=1,
weights_summary=None,
)

trainer.fit(model)
8 changes: 4 additions & 4 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir):
call.on_validation_epoch_start(trainer, model),
call.on_validation_batch_start(trainer, model, ANY, 0, 0),
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
call.on_validation_epoch_end(trainer, model),
call.on_validation_epoch_end(trainer, model, ANY),
call.on_epoch_end(trainer, model),
call.on_validation_end(trainer, model),
call.on_sanity_check_end(trainer, model),
Expand Down Expand Up @@ -87,7 +87,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir):
call.on_validation_epoch_start(trainer, model),
call.on_validation_batch_start(trainer, model, ANY, 0, 0),
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
call.on_validation_epoch_end(trainer, model),
call.on_validation_epoch_end(trainer, model, ANY),
call.on_epoch_end(trainer, model),
call.on_validation_end(trainer, model),
call.on_save_checkpoint(trainer, model), # should take ANY but we are inspecting signature for BC
Expand Down Expand Up @@ -123,7 +123,7 @@ def test_trainer_callback_hook_system_test(tmpdir):
call.on_test_batch_end(trainer, model, ANY, ANY, 0, 0),
call.on_test_batch_start(trainer, model, ANY, 1, 0),
call.on_test_batch_end(trainer, model, ANY, ANY, 1, 0),
call.on_test_epoch_end(trainer, model),
call.on_test_epoch_end(trainer, model, ANY),
call.on_epoch_end(trainer, model),
call.on_test_end(trainer, model),
call.teardown(trainer, model, 'test'),
Expand Down Expand Up @@ -156,7 +156,7 @@ def test_trainer_callback_hook_system_validate(tmpdir):
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
call.on_validation_batch_start(trainer, model, ANY, 1, 0),
call.on_validation_batch_end(trainer, model, ANY, ANY, 1, 0),
call.on_validation_epoch_end(trainer, model),
call.on_validation_epoch_end(trainer, model, ANY),
call.on_epoch_end(trainer, model),
call.on_validation_end(trainer, model),
call.teardown(trainer, model, 'validate'),
Expand Down
Loading