11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
- from typing import Any , List , Optional , Union
14
+ from typing import Any , Optional
15
15
16
- import torch
17
- from torchmetrics import Metric
16
+ from torchmetrics import AveragePrecision as _AveragePrecision
18
17
19
- from pytorch_lightning .metrics .functional .average_precision import _average_precision_compute , _average_precision_update
20
- from pytorch_lightning .utilities import rank_zero_warn
18
+ from pytorch_lightning .utilities .deprecation import deprecated
21
19
22
20
23
- class AveragePrecision (Metric ):
24
- """
25
- Computes the average precision score, which summarises the precision recall
26
- curve into one number. Works for both binary and multiclass problems.
27
- In the case of multiclass, the values will be calculated based on a one-vs-the-rest approach.
28
-
29
- Forward accepts
30
-
31
- - ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor
32
- with probabilities, where C is the number of classes.
33
-
34
- - ``target`` (long tensor): ``(N, ...)`` with integer labels
35
-
36
- Args:
37
- num_classes: integer with number of classes. Not nessesary to provide
38
- for binary problems.
39
- pos_label: integer determining the positive class. Default is ``None``
40
- which for binary problem is translate to 1. For multiclass problems
41
- this argument should not be set as we iteratively change it in the
42
- range [0,num_classes-1]
43
- compute_on_step:
44
- Forward only calls ``update()`` and return None if this is set to False. default: True
45
- dist_sync_on_step:
46
- Synchronize metric state across processes at each ``forward()``
47
- before returning the value at the step. default: False
48
- process_group:
49
- Specify the process group on which synchronization is called. default: None (which selects the entire world)
50
-
51
- Example (binary case):
52
-
53
- >>> pred = torch.tensor([0, 1, 2, 3])
54
- >>> target = torch.tensor([0, 1, 1, 1])
55
- >>> average_precision = AveragePrecision(pos_label=1)
56
- >>> average_precision(pred, target)
57
- tensor(1.)
58
-
59
- Example (multiclass case):
60
-
61
- >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
62
- ... [0.05, 0.75, 0.05, 0.05, 0.05],
63
- ... [0.05, 0.05, 0.75, 0.05, 0.05],
64
- ... [0.05, 0.05, 0.05, 0.75, 0.05]])
65
- >>> target = torch.tensor([0, 1, 3, 2])
66
- >>> average_precision = AveragePrecision(num_classes=5)
67
- >>> average_precision(pred, target)
68
- [tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)]
69
-
70
- """
21
+ class AveragePrecision (_AveragePrecision ):
71
22
23
+ @deprecated (target = _AveragePrecision , ver_deprecate = "1.3.0" , ver_remove = "1.5.0" )
72
24
def __init__ (
73
25
self ,
74
26
num_classes : Optional [int ] = None ,
@@ -77,48 +29,9 @@ def __init__(
77
29
dist_sync_on_step : bool = False ,
78
30
process_group : Optional [Any ] = None ,
79
31
):
80
- super ().__init__ (
81
- compute_on_step = compute_on_step ,
82
- dist_sync_on_step = dist_sync_on_step ,
83
- process_group = process_group ,
84
- )
85
-
86
- self .num_classes = num_classes
87
- self .pos_label = pos_label
88
-
89
- self .add_state ("preds" , default = [], dist_reduce_fx = None )
90
- self .add_state ("target" , default = [], dist_reduce_fx = None )
91
-
92
- rank_zero_warn (
93
- 'Metric `AveragePrecision` will save all targets and predictions in buffer.'
94
- ' For large datasets this may lead to large memory footprint.'
95
- )
96
-
97
- def update (self , preds : torch .Tensor , target : torch .Tensor ):
98
32
"""
99
- Update state with predictions and targets.
100
-
101
- Args:
102
- preds: Predictions from model
103
- target: Ground truth values
104
- """
105
- preds , target , num_classes , pos_label = _average_precision_update (
106
- preds , target , self .num_classes , self .pos_label
107
- )
108
- self .preds .append (preds )
109
- self .target .append (target )
110
- self .num_classes = num_classes
111
- self .pos_label = pos_label
112
-
113
- def compute (self ) -> Union [torch .Tensor , List [torch .Tensor ]]:
114
- """
115
- Compute the average precision score
116
-
117
- Returns:
118
- tensor with average precision. If multiclass will return list
119
- of such tensors, one for each class
33
+ This implementation refers to :class:`~torchmetrics.AveragePrecision`.
120
34
35
+ .. deprecated::
36
+ Use :class:`~torchmetrics.AveragePrecision`. Will be removed in v1.5.0.
121
37
"""
122
- preds = torch .cat (self .preds , dim = 0 )
123
- target = torch .cat (self .target , dim = 0 )
124
- return _average_precision_compute (preds , target , self .num_classes , self .pos_label )
0 commit comments