11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
- from distutils .version import LooseVersion
15
14
from typing import Any , Callable , Optional
16
15
17
- import torch
18
- from torchmetrics import Metric
16
+ from torchmetrics import AUROC as _AUROC
19
17
20
- from pytorch_lightning .metrics .functional .auroc import _auroc_compute , _auroc_update
21
- from pytorch_lightning .utilities import rank_zero_warn
18
+ from pytorch_lightning .utilities .deprecation import deprecated
22
19
23
20
24
- class AUROC (Metric ):
25
- r"""Compute `Area Under the Receiver Operating Characteristic Curve (ROC AUC)
26
- <https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Further_interpretations>`_.
27
- Works for both binary, multilabel and multiclass problems. In the case of
28
- multiclass, the values will be calculated based on a one-vs-the-rest approach.
29
-
30
- Forward accepts
31
-
32
- - ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor
33
- with probabilities, where C is the number of classes.
34
-
35
- - ``target`` (long tensor): ``(N, ...)`` or ``(N, C, ...)`` with integer labels
36
-
37
- For non-binary input, if the ``preds`` and ``target`` tensor have the same
38
- size the input will be interpretated as multilabel and if ``preds`` have one
39
- dimension more than the ``target`` tensor the input will be interpretated as
40
- multiclass.
41
-
42
- Args:
43
- num_classes: integer with number of classes. Not nessesary to provide
44
- for binary problems.
45
- pos_label: integer determining the positive class. Default is ``None``
46
- which for binary problem is translate to 1. For multiclass problems
47
- this argument should not be set as we iteratively change it in the
48
- range [0,num_classes-1]
49
- average:
50
- - ``'macro'`` computes metric for each class and uniformly averages them
51
- - ``'weighted'`` computes metric for each class and does a weighted-average,
52
- where each class is weighted by their support (accounts for class imbalance)
53
- - ``None`` computes and returns the metric per class
54
- max_fpr:
55
- If not ``None``, calculates standardized partial AUC over the
56
- range [0, max_fpr]. Should be a float between 0 and 1.
57
- compute_on_step:
58
- Forward only calls ``update()`` and return None if this is set to False. default: True
59
- dist_sync_on_step:
60
- Synchronize metric state across processes at each ``forward()``
61
- before returning the value at the step.
62
- process_group:
63
- Specify the process group on which synchronization is called. default: None (which selects the entire world)
64
- dist_sync_fn:
65
- Callback that performs the allgather operation on the metric state. When ``None``, DDP
66
- will be used to perform the allgather
67
-
68
- Raises:
69
- ValueError:
70
- If ``average`` is none of ``None``, ``"macro"`` or ``"weighted"``.
71
- ValueError:
72
- If ``max_fpr`` is not a ``float`` in the range ``(0, 1]``.
73
- RuntimeError:
74
- If ``PyTorch version`` is ``below 1.6`` since max_fpr requires ``torch.bucketize``
75
- which is not available below 1.6.
76
- ValueError:
77
- If the mode of data (binary, multi-label, multi-class) changes between batches.
78
-
79
- Example (binary case):
80
-
81
- >>> from pytorch_lightning.metrics import AUROC
82
- >>> preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34])
83
- >>> target = torch.tensor([0, 0, 1, 1, 1])
84
- >>> auroc = AUROC(pos_label=1)
85
- >>> auroc(preds, target)
86
- tensor(0.5000)
87
-
88
- Example (multiclass case):
89
-
90
- >>> from pytorch_lightning.metrics import AUROC
91
- >>> preds = torch.tensor([[0.90, 0.05, 0.05],
92
- ... [0.05, 0.90, 0.05],
93
- ... [0.05, 0.05, 0.90],
94
- ... [0.85, 0.05, 0.10],
95
- ... [0.10, 0.10, 0.80]])
96
- >>> target = torch.tensor([0, 1, 1, 2, 2])
97
- >>> auroc = AUROC(num_classes=3)
98
- >>> auroc(preds, target)
99
- tensor(0.7778)
100
-
101
- """
21
+ class AUROC (_AUROC ):
102
22
23
+ @deprecated (target = _AUROC , ver_deprecate = "1.3.0" , ver_remove = "1.5.0" )
103
24
def __init__ (
104
25
self ,
105
26
num_classes : Optional [int ] = None ,
@@ -111,74 +32,9 @@ def __init__(
111
32
process_group : Optional [Any ] = None ,
112
33
dist_sync_fn : Callable = None ,
113
34
):
114
- super ().__init__ (
115
- compute_on_step = compute_on_step ,
116
- dist_sync_on_step = dist_sync_on_step ,
117
- process_group = process_group ,
118
- dist_sync_fn = dist_sync_fn ,
119
- )
120
-
121
- self .num_classes = num_classes
122
- self .pos_label = pos_label
123
- self .average = average
124
- self .max_fpr = max_fpr
125
-
126
- allowed_average = (None , 'macro' , 'weighted' )
127
- if self .average not in allowed_average :
128
- raise ValueError (
129
- f'Argument `average` expected to be one of the following: { allowed_average } but got { average } '
130
- )
131
-
132
- if self .max_fpr is not None :
133
- if (not isinstance (max_fpr , float ) and 0 < max_fpr <= 1 ):
134
- raise ValueError (f"`max_fpr` should be a float in range (0, 1], got: { max_fpr } " )
135
-
136
- if LooseVersion (torch .__version__ ) < LooseVersion ('1.6.0' ):
137
- raise RuntimeError (
138
- '`max_fpr` argument requires `torch.bucketize` which is not available below PyTorch version 1.6'
139
- )
140
-
141
- self .mode = None
142
- self .add_state ("preds" , default = [], dist_reduce_fx = None )
143
- self .add_state ("target" , default = [], dist_reduce_fx = None )
144
-
145
- rank_zero_warn (
146
- 'Metric `AUROC` will save all targets and predictions in buffer.'
147
- ' For large datasets this may lead to large memory footprint.'
148
- )
149
-
150
- def update (self , preds : torch .Tensor , target : torch .Tensor ):
151
35
"""
152
- Update state with predictions and targets .
36
+ This implementation refers to :class:`~torchmetrics.AUROC` .
153
37
154
- Args:
155
- preds: Predictions from model (probabilities, or labels)
156
- target: Ground truth labels
157
- """
158
- preds , target , mode = _auroc_update (preds , target )
159
-
160
- self .preds .append (preds )
161
- self .target .append (target )
162
-
163
- if self .mode is not None and self .mode != mode :
164
- raise ValueError (
165
- 'The mode of data (binary, multi-label, multi-class) should be constant, but changed'
166
- f' between batches from { self .mode } to { mode } '
167
- )
168
- self .mode = mode
169
-
170
- def compute (self ) -> torch .Tensor :
171
- """
172
- Computes AUROC based on inputs passed in to ``update`` previously.
38
+ .. deprecated::
39
+ Use :class:`~torchmetrics.AUROC`. Will be removed in v1.5.0.
173
40
"""
174
- preds = torch .cat (self .preds , dim = 0 )
175
- target = torch .cat (self .target , dim = 0 )
176
- return _auroc_compute (
177
- preds ,
178
- target ,
179
- self .mode ,
180
- self .num_classes ,
181
- self .pos_label ,
182
- self .average ,
183
- self .max_fpr ,
184
- )
0 commit comments