Skip to content

Commit 38a2119

Browse files
authored
Prune metrics: precision & recall 6/n (#6573)
* avg precision * precision * recall * curve * tests * chlog * isort * fix
1 parent 8853a36 commit 38a2119

File tree

13 files changed

+127
-1793
lines changed

13 files changed

+127
-1793
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
7979

8080
[#6572](https://github.com/PyTorchLightning/pytorch-lightning/pull/6572),
8181

82+
[#6573](https://github.com/PyTorchLightning/pytorch-lightning/pull/6573),
83+
8284
)
8385

8486

pytorch_lightning/metrics/classification/average_precision.py

Lines changed: 8 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -11,64 +11,16 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, List, Optional, Union
14+
from typing import Any, Optional
1515

16-
import torch
17-
from torchmetrics import Metric
16+
from torchmetrics import AveragePrecision as _AveragePrecision
1817

19-
from pytorch_lightning.metrics.functional.average_precision import _average_precision_compute, _average_precision_update
20-
from pytorch_lightning.utilities import rank_zero_warn
18+
from pytorch_lightning.utilities.deprecation import deprecated
2119

2220

23-
class AveragePrecision(Metric):
24-
"""
25-
Computes the average precision score, which summarises the precision recall
26-
curve into one number. Works for both binary and multiclass problems.
27-
In the case of multiclass, the values will be calculated based on a one-vs-the-rest approach.
28-
29-
Forward accepts
30-
31-
- ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor
32-
with probabilities, where C is the number of classes.
33-
34-
- ``target`` (long tensor): ``(N, ...)`` with integer labels
35-
36-
Args:
37-
num_classes: integer with number of classes. Not nessesary to provide
38-
for binary problems.
39-
pos_label: integer determining the positive class. Default is ``None``
40-
which for binary problem is translate to 1. For multiclass problems
41-
this argument should not be set as we iteratively change it in the
42-
range [0,num_classes-1]
43-
compute_on_step:
44-
Forward only calls ``update()`` and return None if this is set to False. default: True
45-
dist_sync_on_step:
46-
Synchronize metric state across processes at each ``forward()``
47-
before returning the value at the step. default: False
48-
process_group:
49-
Specify the process group on which synchronization is called. default: None (which selects the entire world)
50-
51-
Example (binary case):
52-
53-
>>> pred = torch.tensor([0, 1, 2, 3])
54-
>>> target = torch.tensor([0, 1, 1, 1])
55-
>>> average_precision = AveragePrecision(pos_label=1)
56-
>>> average_precision(pred, target)
57-
tensor(1.)
58-
59-
Example (multiclass case):
60-
61-
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
62-
... [0.05, 0.75, 0.05, 0.05, 0.05],
63-
... [0.05, 0.05, 0.75, 0.05, 0.05],
64-
... [0.05, 0.05, 0.05, 0.75, 0.05]])
65-
>>> target = torch.tensor([0, 1, 3, 2])
66-
>>> average_precision = AveragePrecision(num_classes=5)
67-
>>> average_precision(pred, target)
68-
[tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)]
69-
70-
"""
21+
class AveragePrecision(_AveragePrecision):
7122

23+
@deprecated(target=_AveragePrecision, ver_deprecate="1.3.0", ver_remove="1.5.0")
7224
def __init__(
7325
self,
7426
num_classes: Optional[int] = None,
@@ -77,48 +29,9 @@ def __init__(
7729
dist_sync_on_step: bool = False,
7830
process_group: Optional[Any] = None,
7931
):
80-
super().__init__(
81-
compute_on_step=compute_on_step,
82-
dist_sync_on_step=dist_sync_on_step,
83-
process_group=process_group,
84-
)
85-
86-
self.num_classes = num_classes
87-
self.pos_label = pos_label
88-
89-
self.add_state("preds", default=[], dist_reduce_fx=None)
90-
self.add_state("target", default=[], dist_reduce_fx=None)
91-
92-
rank_zero_warn(
93-
'Metric `AveragePrecision` will save all targets and predictions in buffer.'
94-
' For large datasets this may lead to large memory footprint.'
95-
)
96-
97-
def update(self, preds: torch.Tensor, target: torch.Tensor):
9832
"""
99-
Update state with predictions and targets.
100-
101-
Args:
102-
preds: Predictions from model
103-
target: Ground truth values
104-
"""
105-
preds, target, num_classes, pos_label = _average_precision_update(
106-
preds, target, self.num_classes, self.pos_label
107-
)
108-
self.preds.append(preds)
109-
self.target.append(target)
110-
self.num_classes = num_classes
111-
self.pos_label = pos_label
112-
113-
def compute(self) -> Union[torch.Tensor, List[torch.Tensor]]:
114-
"""
115-
Compute the average precision score
116-
117-
Returns:
118-
tensor with average precision. If multiclass will return list
119-
of such tensors, one for each class
33+
This implementation refers to :class:`~torchmetrics.AveragePrecision`.
12034
35+
.. deprecated::
36+
Use :class:`~torchmetrics.AveragePrecision`. Will be removed in v1.5.0.
12137
"""
122-
preds = torch.cat(self.preds, dim=0)
123-
target = torch.cat(self.target, dim=0)
124-
return _average_precision_compute(preds, target, self.num_classes, self.pos_label)

0 commit comments

Comments
 (0)