-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Add outputs param for on_val/test_epoch_end
hooks
#6120
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c34da59
9320916
fd84be8
871884a
66bc0c7
a3d8966
ebd8507
00f829e
4a45eb9
b68a74e
9f8992e
ff63404
cd01f88
388aad1
1a8e5b6
d8d01a5
9481ab6
6786a43
1dbd851
7aeb3e4
053edb5
f978567
06b1771
174f767
bce1a2c
edc48b8
8c143a1
ff4aade
4681130
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,12 +11,14 @@ | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import torch | ||
|
||
from pytorch_lightning.core.step_result import Result | ||
from pytorch_lightning.trainer.supporters import PredictionCollection | ||
from pytorch_lightning.utilities.apply_func import apply_to_collection | ||
from pytorch_lightning.utilities.model_helpers import is_overridden | ||
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature | ||
from pytorch_lightning.utilities.warnings import WarningCache | ||
|
||
|
||
|
@@ -202,9 +204,6 @@ def __run_eval_epoch_end(self, num_dataloaders): | |
# with a single dataloader don't pass an array | ||
outputs = self.outputs | ||
|
||
# free memory | ||
self.outputs = [] | ||
|
||
eval_results = outputs | ||
if num_dataloaders == 1: | ||
eval_results = outputs[0] | ||
|
@@ -313,13 +312,41 @@ def store_predictions(self, output, batch_idx, dataloader_idx): | |
|
||
def on_evaluation_epoch_end(self, *args, **kwargs): | ||
# call the callback hook | ||
if self.trainer.testing: | ||
self.trainer.call_hook('on_test_epoch_end', *args, **kwargs) | ||
else: | ||
self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs) | ||
self.call_on_evaluation_epoch_end_hook() | ||
|
||
self.trainer.call_hook('on_epoch_end') | ||
|
||
def call_on_evaluation_epoch_end_hook(self): | ||
kaushikb11 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
outputs = self.outputs | ||
|
||
# free memory | ||
self.outputs = [] | ||
|
||
model_ref = self.trainer.lightning_module | ||
hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end" | ||
|
||
self.trainer._reset_result_and_set_hook_fx_name(hook_name) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not using self.trainer.call_hook with outputs directly if the outputs is present ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
with self.trainer.profiler.profile(hook_name): | ||
|
||
if hasattr(self.trainer, hook_name): | ||
on_evaluation_epoch_end_hook = getattr(self.trainer, hook_name) | ||
on_evaluation_epoch_end_hook(outputs) | ||
|
||
if is_overridden(hook_name, model_ref): | ||
model_hook_fx = getattr(model_ref, hook_name) | ||
if is_param_in_hook_signature(model_hook_fx, "outputs"): | ||
model_hook_fx(outputs) | ||
else: | ||
self.warning_cache.warn( | ||
f"`ModelHooks.{hook_name}` signature has changed in v1.3." | ||
" `outputs` parameter has been added." | ||
" Support for the old signature will be removed in v1.5", DeprecationWarning | ||
) | ||
model_hook_fx() | ||
|
||
self.trainer._cache_logged_metrics() | ||
|
||
def log_evaluation_step_metrics(self, output, batch_idx): | ||
if self.trainer.sanity_checking: | ||
return | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import inspect | ||
from typing import Callable | ||
|
||
|
||
def is_param_in_hook_signature(hook_fx: Callable, param: str) -> bool: | ||
hook_params = list(inspect.signature(hook_fx).parameters) | ||
if "args" in hook_params or param in hook_params: | ||
return True | ||
return False |
Uh oh!
There was an error while loading. Please reload this page.