Skip to content

Commit 9e35f97

Browse files
authored
Prune metrics: AUC & AUROC (#6572)
* class: AUC AUROC * func: auc auroc * format * tests
1 parent 2f6ce1a commit 9e35f97

File tree

14 files changed

+101
-1057
lines changed

14 files changed

+101
-1057
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
7777

7878
[#6515](https://github.com/PyTorchLightning/pytorch-lightning/pull/6515),
7979

80+
[#6572](https://github.com/PyTorchLightning/pytorch-lightning/pull/6572),
81+
8082
)
8183

8284

pytorch_lightning/metrics/classification/auc.py

Lines changed: 7 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -13,36 +13,14 @@
1313
# limitations under the License.
1414
from typing import Any, Callable, Optional
1515

16-
import torch
17-
from torchmetrics import Metric
16+
from torchmetrics import AUC as _AUC
1817

19-
from pytorch_lightning.metrics.functional.auc import _auc_compute, _auc_update
20-
from pytorch_lightning.utilities import rank_zero_warn
18+
from pytorch_lightning.utilities.deprecation import deprecated
2119

2220

23-
class AUC(Metric):
24-
r"""
25-
Computes Area Under the Curve (AUC) using the trapezoidal rule
26-
27-
Forward accepts two input tensors that should be 1D and have the same number
28-
of elements
29-
30-
Args:
31-
reorder: AUC expects its first input to be sorted. If this is not the case,
32-
setting this argument to ``True`` will use a stable sorting algorithm to
33-
sort the input in decending order
34-
compute_on_step:
35-
Forward only calls ``update()`` and return None if this is set to False.
36-
dist_sync_on_step:
37-
Synchronize metric state across processes at each ``forward()``
38-
before returning the value at the step.
39-
process_group:
40-
Specify the process group on which synchronization is called. default: None (which selects the entire world)
41-
dist_sync_fn:
42-
Callback that performs the allgather operation on the metric state. When ``None``, DDP
43-
will be used to perform the allgather
44-
"""
21+
class AUC(_AUC):
4522

23+
@deprecated(target=_AUC, ver_deprecate="1.3.0", ver_remove="1.5.0")
4624
def __init__(
4725
self,
4826
reorder: bool = False,
@@ -51,40 +29,9 @@ def __init__(
5129
process_group: Optional[Any] = None,
5230
dist_sync_fn: Callable = None,
5331
):
54-
super().__init__(
55-
compute_on_step=compute_on_step,
56-
dist_sync_on_step=dist_sync_on_step,
57-
process_group=process_group,
58-
dist_sync_fn=dist_sync_fn,
59-
)
60-
61-
self.reorder = reorder
62-
63-
self.add_state("x", default=[], dist_reduce_fx=None)
64-
self.add_state("y", default=[], dist_reduce_fx=None)
65-
66-
rank_zero_warn(
67-
'Metric `AUC` will save all targets and predictions in buffer.'
68-
' For large datasets this may lead to large memory footprint.'
69-
)
70-
71-
def update(self, x: torch.Tensor, y: torch.Tensor):
72-
"""
73-
Update state with predictions and targets.
74-
75-
Args:
76-
x: Predictions from model (probabilities, or labels)
77-
y: Ground truth labels
7832
"""
79-
x, y = _auc_update(x, y)
33+
This implementation refers to :class:`~torchmetrics.AUC`.
8034
81-
self.x.append(x)
82-
self.y.append(y)
83-
84-
def compute(self) -> torch.Tensor:
85-
"""
86-
Computes AUC based on inputs passed in to ``update`` previously.
35+
.. deprecated::
36+
Use :class:`~torchmetrics.AUC`. Will be removed in v1.5.0.
8737
"""
88-
x = torch.cat(self.x, dim=0)
89-
y = torch.cat(self.y, dim=0)
90-
return _auc_compute(x, y, reorder=self.reorder)

pytorch_lightning/metrics/classification/auroc.py

Lines changed: 7 additions & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -11,95 +11,16 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from distutils.version import LooseVersion
1514
from typing import Any, Callable, Optional
1615

17-
import torch
18-
from torchmetrics import Metric
16+
from torchmetrics import AUROC as _AUROC
1917

20-
from pytorch_lightning.metrics.functional.auroc import _auroc_compute, _auroc_update
21-
from pytorch_lightning.utilities import rank_zero_warn
18+
from pytorch_lightning.utilities.deprecation import deprecated
2219

2320

24-
class AUROC(Metric):
25-
r"""Compute `Area Under the Receiver Operating Characteristic Curve (ROC AUC)
26-
<https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Further_interpretations>`_.
27-
Works for both binary, multilabel and multiclass problems. In the case of
28-
multiclass, the values will be calculated based on a one-vs-the-rest approach.
29-
30-
Forward accepts
31-
32-
- ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor
33-
with probabilities, where C is the number of classes.
34-
35-
- ``target`` (long tensor): ``(N, ...)`` or ``(N, C, ...)`` with integer labels
36-
37-
For non-binary input, if the ``preds`` and ``target`` tensor have the same
38-
size the input will be interpretated as multilabel and if ``preds`` have one
39-
dimension more than the ``target`` tensor the input will be interpretated as
40-
multiclass.
41-
42-
Args:
43-
num_classes: integer with number of classes. Not nessesary to provide
44-
for binary problems.
45-
pos_label: integer determining the positive class. Default is ``None``
46-
which for binary problem is translate to 1. For multiclass problems
47-
this argument should not be set as we iteratively change it in the
48-
range [0,num_classes-1]
49-
average:
50-
- ``'macro'`` computes metric for each class and uniformly averages them
51-
- ``'weighted'`` computes metric for each class and does a weighted-average,
52-
where each class is weighted by their support (accounts for class imbalance)
53-
- ``None`` computes and returns the metric per class
54-
max_fpr:
55-
If not ``None``, calculates standardized partial AUC over the
56-
range [0, max_fpr]. Should be a float between 0 and 1.
57-
compute_on_step:
58-
Forward only calls ``update()`` and return None if this is set to False. default: True
59-
dist_sync_on_step:
60-
Synchronize metric state across processes at each ``forward()``
61-
before returning the value at the step.
62-
process_group:
63-
Specify the process group on which synchronization is called. default: None (which selects the entire world)
64-
dist_sync_fn:
65-
Callback that performs the allgather operation on the metric state. When ``None``, DDP
66-
will be used to perform the allgather
67-
68-
Raises:
69-
ValueError:
70-
If ``average`` is none of ``None``, ``"macro"`` or ``"weighted"``.
71-
ValueError:
72-
If ``max_fpr`` is not a ``float`` in the range ``(0, 1]``.
73-
RuntimeError:
74-
If ``PyTorch version`` is ``below 1.6`` since max_fpr requires ``torch.bucketize``
75-
which is not available below 1.6.
76-
ValueError:
77-
If the mode of data (binary, multi-label, multi-class) changes between batches.
78-
79-
Example (binary case):
80-
81-
>>> from pytorch_lightning.metrics import AUROC
82-
>>> preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34])
83-
>>> target = torch.tensor([0, 0, 1, 1, 1])
84-
>>> auroc = AUROC(pos_label=1)
85-
>>> auroc(preds, target)
86-
tensor(0.5000)
87-
88-
Example (multiclass case):
89-
90-
>>> from pytorch_lightning.metrics import AUROC
91-
>>> preds = torch.tensor([[0.90, 0.05, 0.05],
92-
... [0.05, 0.90, 0.05],
93-
... [0.05, 0.05, 0.90],
94-
... [0.85, 0.05, 0.10],
95-
... [0.10, 0.10, 0.80]])
96-
>>> target = torch.tensor([0, 1, 1, 2, 2])
97-
>>> auroc = AUROC(num_classes=3)
98-
>>> auroc(preds, target)
99-
tensor(0.7778)
100-
101-
"""
21+
class AUROC(_AUROC):
10222

23+
@deprecated(target=_AUROC, ver_deprecate="1.3.0", ver_remove="1.5.0")
10324
def __init__(
10425
self,
10526
num_classes: Optional[int] = None,
@@ -111,74 +32,9 @@ def __init__(
11132
process_group: Optional[Any] = None,
11233
dist_sync_fn: Callable = None,
11334
):
114-
super().__init__(
115-
compute_on_step=compute_on_step,
116-
dist_sync_on_step=dist_sync_on_step,
117-
process_group=process_group,
118-
dist_sync_fn=dist_sync_fn,
119-
)
120-
121-
self.num_classes = num_classes
122-
self.pos_label = pos_label
123-
self.average = average
124-
self.max_fpr = max_fpr
125-
126-
allowed_average = (None, 'macro', 'weighted')
127-
if self.average not in allowed_average:
128-
raise ValueError(
129-
f'Argument `average` expected to be one of the following: {allowed_average} but got {average}'
130-
)
131-
132-
if self.max_fpr is not None:
133-
if (not isinstance(max_fpr, float) and 0 < max_fpr <= 1):
134-
raise ValueError(f"`max_fpr` should be a float in range (0, 1], got: {max_fpr}")
135-
136-
if LooseVersion(torch.__version__) < LooseVersion('1.6.0'):
137-
raise RuntimeError(
138-
'`max_fpr` argument requires `torch.bucketize` which is not available below PyTorch version 1.6'
139-
)
140-
141-
self.mode = None
142-
self.add_state("preds", default=[], dist_reduce_fx=None)
143-
self.add_state("target", default=[], dist_reduce_fx=None)
144-
145-
rank_zero_warn(
146-
'Metric `AUROC` will save all targets and predictions in buffer.'
147-
' For large datasets this may lead to large memory footprint.'
148-
)
149-
150-
def update(self, preds: torch.Tensor, target: torch.Tensor):
15135
"""
152-
Update state with predictions and targets.
36+
This implementation refers to :class:`~torchmetrics.AUROC`.
15337
154-
Args:
155-
preds: Predictions from model (probabilities, or labels)
156-
target: Ground truth labels
157-
"""
158-
preds, target, mode = _auroc_update(preds, target)
159-
160-
self.preds.append(preds)
161-
self.target.append(target)
162-
163-
if self.mode is not None and self.mode != mode:
164-
raise ValueError(
165-
'The mode of data (binary, multi-label, multi-class) should be constant, but changed'
166-
f' between batches from {self.mode} to {mode}'
167-
)
168-
self.mode = mode
169-
170-
def compute(self) -> torch.Tensor:
171-
"""
172-
Computes AUROC based on inputs passed in to ``update`` previously.
38+
.. deprecated::
39+
Use :class:`~torchmetrics.AUROC`. Will be removed in v1.5.0.
17340
"""
174-
preds = torch.cat(self.preds, dim=0)
175-
target = torch.cat(self.target, dim=0)
176-
return _auroc_compute(
177-
preds,
178-
target,
179-
self.mode,
180-
self.num_classes,
181-
self.pos_label,
182-
self.average,
183-
self.max_fpr,
184-
)

0 commit comments

Comments
 (0)