-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Prune metrics base classes 2/n #6530
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
819bf56
54342c5
3f0a943
ef0b09a
68e5931
968e9e0
ffc51bd
a8a1097
007a686
0c2f39e
dbaab4d
ce9b28d
dcda019
f32455e
e60ef42
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,30 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Callable, Union | ||
|
||
import torch | ||
from torchmetrics import Metric | ||
from torchmetrics.metric import CompositionalMetric as __CompositionalMetric | ||
|
||
from pytorch_lightning.metrics.metric import Metric | ||
from pytorch_lightning.utilities import rank_zero_warn | ||
|
||
|
||
class CompositionalMetric(Metric): | ||
"""Composition of two metrics with a specific operator | ||
which will be executed upon metric's compute | ||
class CompositionalMetric(__CompositionalMetric): | ||
r""" | ||
This implementation refers to :class:`~torchmetrics.metric.CompositionalMetric`. | ||
|
||
.. warning:: This metric is deprecated, use ``torchmetrics.metric.CompositionalMetric``. Will be removed in v1.5.0. | ||
""" | ||
|
||
def __init__( | ||
|
@@ -17,76 +33,8 @@ def __init__( | |
metric_a: Union[Metric, int, float, torch.Tensor], | ||
metric_b: Union[Metric, int, float, torch.Tensor, None], | ||
): | ||
""" | ||
|
||
Args: | ||
operator: the operator taking in one (if metric_b is None) | ||
or two arguments. Will be applied to outputs of metric_a.compute() | ||
and (optionally if metric_b is not None) metric_b.compute() | ||
metric_a: first metric whose compute() result is the first argument of operator | ||
metric_b: second metric whose compute() result is the second argument of operator. | ||
For operators taking in only one input, this should be None | ||
""" | ||
super().__init__() | ||
|
||
self.op = operator | ||
|
||
if isinstance(metric_a, torch.Tensor): | ||
self.register_buffer("metric_a", metric_a) | ||
else: | ||
self.metric_a = metric_a | ||
|
||
if isinstance(metric_b, torch.Tensor): | ||
self.register_buffer("metric_b", metric_b) | ||
else: | ||
self.metric_b = metric_b | ||
|
||
def _sync_dist(self, dist_sync_fn=None): | ||
# No syncing required here. syncing will be done in metric_a and metric_b | ||
pass | ||
|
||
def update(self, *args, **kwargs): | ||
if isinstance(self.metric_a, Metric): | ||
self.metric_a.update(*args, **self.metric_a._filter_kwargs(**kwargs)) | ||
|
||
if isinstance(self.metric_b, Metric): | ||
self.metric_b.update(*args, **self.metric_b._filter_kwargs(**kwargs)) | ||
|
||
def compute(self): | ||
|
||
# also some parsing for kwargs? | ||
if isinstance(self.metric_a, Metric): | ||
val_a = self.metric_a.compute() | ||
else: | ||
val_a = self.metric_a | ||
|
||
if isinstance(self.metric_b, Metric): | ||
val_b = self.metric_b.compute() | ||
else: | ||
val_b = self.metric_b | ||
|
||
if val_b is None: | ||
return self.op(val_a) | ||
|
||
return self.op(val_a, val_b) | ||
|
||
def reset(self): | ||
if isinstance(self.metric_a, Metric): | ||
self.metric_a.reset() | ||
|
||
if isinstance(self.metric_b, Metric): | ||
self.metric_b.reset() | ||
|
||
def persistent(self, mode: bool = False): | ||
if isinstance(self.metric_a, Metric): | ||
self.metric_a.persistent(mode=mode) | ||
if isinstance(self.metric_b, Metric): | ||
self.metric_b.persistent(mode=mode) | ||
|
||
def __repr__(self): | ||
repr_str = ( | ||
self.__class__.__name__ | ||
+ f"(\n {self.op.__name__}(\n {repr(self.metric_a)},\n {repr(self.metric_b)}\n )\n)" | ||
rank_zero_warn( | ||
"This `Metric` was deprecated since v1.3.0 in favor of `torchmetrics.Metric`." | ||
" It will be removed in v1.5.0", DeprecationWarning | ||
Comment on lines
+36
to
+38
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we maybe introduce a temporary decorator/helper function for that? So that we can just forward all init arguments to the base class and have this function raise the warining? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that is a great point! thx :] There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. well, I'll prepare it in another PR as well as we would need to limit calling all warnings only once, especially if they are used in functional... |
||
) | ||
|
||
return repr_str | ||
super().__init__(operator=operator, metric_a=metric_a, metric_b=metric_b) |
Uh oh!
There was an error while loading. Please reload this page.