Skip to content

Commit 0031675

Browse files
RaviTezusdesrozisvfdev-5
authored
[3] [contrib/metrics] setup typing in contrib part of the library (#1363)
* [3] [contrib/metrics] setup typing in contrib part of the library * review changes * Update gpu_info.py Co-authored-by: Sylvain Desroziers <[email protected]> Co-authored-by: vfdev <[email protected]>
1 parent f9e236e commit 0031675

20 files changed

+86
-46
lines changed

ignite/contrib/metrics/average_precision.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
from typing import Callable
2+
3+
import torch
4+
15
from ignite.metrics import EpochMetric
26

37

4-
def average_precision_compute_fn(y_preds, y_targets):
8+
def average_precision_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor):
59
try:
610
from sklearn.metrics import average_precision_score
711
except ImportError:
@@ -22,7 +26,7 @@ class AveragePrecision(EpochMetric):
2226
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
2327
form expected by the metric. This can be useful if, for example, you have a multi-output model and
2428
you want to compute the metric with respect to one of the outputs.
25-
check_compute_fn (bool): Optional default False. If True, `average_precision_score
29+
check_compute_fn (bool): Default False. If True, `average_precision_score
2630
<http://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score.html
2731
#sklearn.metrics.average_precision_score>`_ is run on the first batch of data to ensure there are
2832
no issues. User will be warned in case there are any issues computing the function.
@@ -41,7 +45,7 @@ def activated_output_transform(output):
4145
4246
"""
4347

44-
def __init__(self, output_transform=lambda x: x, check_compute_fn: bool = False):
48+
def __init__(self, output_transform: Callable = lambda x: x, check_compute_fn: bool = False):
4549
super(AveragePrecision, self).__init__(
4650
average_precision_compute_fn, output_transform=output_transform, check_compute_fn=check_compute_fn
4751
)

ignite/contrib/metrics/gpu_info.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# -*- coding: utf-8 -*-
22
import warnings
3+
from typing import Tuple, Union
34

45
import torch
56

6-
from ignite.engine import Events
7+
from ignite.engine import Engine, EventEnum, Events
78
from ignite.metrics import Metric
89

910

@@ -54,7 +55,7 @@ def __init__(self):
5455
def reset(self):
5556
pass
5657

57-
def update(self, output):
58+
def update(self, output: Tuple[torch.Tensor, torch.Tensor]):
5859
pass
5960

6061
def compute(self):
@@ -64,7 +65,7 @@ def compute(self):
6465
return []
6566
return data["gpu"]
6667

67-
def completed(self, engine, name):
68+
def completed(self, engine: Engine, name: str):
6869
data = self.compute()
6970
if len(data) < 1:
7071
warnings.warn("No GPU information available")
@@ -103,5 +104,5 @@ def completed(self, engine, name):
103104
# Do not set GPU utilization information
104105
pass
105106

106-
def attach(self, engine, name="gpu", event_name=Events.ITERATION_COMPLETED):
107+
def attach(self, engine: Engine, name: str = "gpu", event_name: Union[str, EventEnum] = Events.ITERATION_COMPLETED):
107108
engine.add_event_handler(event_name, self.completed, name)

ignite/contrib/metrics/precision_recall_curve.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
from typing import Callable
2+
3+
import torch
4+
15
from ignite.metrics import EpochMetric
26

37

4-
def precision_recall_curve_compute_fn(y_preds, y_targets):
8+
def precision_recall_curve_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor):
59
try:
610
from sklearn.metrics import precision_recall_curve
711
except ImportError:
@@ -23,7 +27,7 @@ class PrecisionRecallCurve(EpochMetric):
2327
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
2428
form expected by the metric. This can be useful if, for example, you have a multi-output model and
2529
you want to compute the metric with respect to one of the outputs.
26-
check_compute_fn (bool): Optional default False. If True, `precision_recall_curve
30+
check_compute_fn (bool): Default False. If True, `precision_recall_curve
2731
<http://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_curve.html
2832
#sklearn.metrics.precision_recall_curve>`_ is run on the first batch of data to ensure there are
2933
no issues. User will be warned in case there are any issues computing the function.
@@ -42,7 +46,7 @@ def activated_output_transform(output):
4246
4347
"""
4448

45-
def __init__(self, output_transform=lambda x: x, check_compute_fn: bool = False):
49+
def __init__(self, output_transform: Callable = lambda x: x, check_compute_fn: bool = False):
4650
super(PrecisionRecallCurve, self).__init__(
4751
precision_recall_curve_compute_fn, output_transform=output_transform, check_compute_fn=check_compute_fn
4852
)

ignite/contrib/metrics/regression/_base.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from abc import abstractmethod
2-
from typing import Callable, Union
2+
from typing import Callable, Tuple
33

44
import torch
55

66
from ignite.metrics import EpochMetric, Metric
77
from ignite.metrics.metric import reinit__is_reduced
88

99

10-
def _check_output_shapes(output):
10+
def _check_output_shapes(output: Tuple[torch.Tensor, torch.Tensor]):
1111
y_pred, y = output
1212
if y_pred.shape != y.shape:
1313
raise ValueError("Input data shapes should be the same, but given {} and {}".format(y_pred.shape, y.shape))
@@ -21,7 +21,7 @@ def _check_output_shapes(output):
2121
raise ValueError("Input y should have shape (N,) or (N, 1), but given {}".format(y.shape))
2222

2323

24-
def _check_output_types(output):
24+
def _check_output_types(output: Tuple[torch.Tensor, torch.Tensor]):
2525
y_pred, y = output
2626
if y_pred.dtype not in (torch.float16, torch.float32, torch.float64):
2727
raise TypeError("Input y_pred dtype should be float 16, 32 or 64, but given {}".format(y_pred.dtype))
@@ -36,7 +36,7 @@ class _BaseRegression(Metric):
3636
# method `_update`.
3737

3838
@reinit__is_reduced
39-
def update(self, output):
39+
def update(self, output: Tuple[torch.Tensor, torch.Tensor]):
4040
_check_output_shapes(output)
4141
_check_output_types(output)
4242
y_pred, y = output[0].detach(), output[1].detach()
@@ -50,7 +50,7 @@ def update(self, output):
5050
self._update((y_pred, y))
5151

5252
@abstractmethod
53-
def _update(self, output):
53+
def _update(self, output: Tuple[torch.Tensor, torch.Tensor]):
5454
pass
5555

5656

@@ -59,14 +59,16 @@ class _BaseRegressionEpoch(EpochMetric):
5959
# `update` method check the shapes and call internal overloaded method `_update`.
6060
# Class internally stores complete history of predictions and targets of type float32.
6161

62-
def __init__(self, compute_fn, output_transform=lambda x: x, check_compute_fn: bool = True):
62+
def __init__(
63+
self, compute_fn: Callable, output_transform: Callable = lambda x: x, check_compute_fn: bool = True,
64+
):
6365
super(_BaseRegressionEpoch, self).__init__(
6466
compute_fn=compute_fn, output_transform=output_transform, check_compute_fn=check_compute_fn
6567
)
6668

67-
def _check_type(self, output):
69+
def _check_type(self, output: Tuple[torch.Tensor, torch.Tensor]):
6870
_check_output_types(output)
6971
super(_BaseRegressionEpoch, self)._check_type(output)
7072

71-
def _check_shape(self, output):
73+
def _check_shape(self, output: Tuple[torch.Tensor, torch.Tensor]):
7274
_check_output_shapes(output)

ignite/contrib/metrics/regression/canberra_metric.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, Union
1+
from typing import Callable, Tuple, Union
22

33
import torch
44

@@ -34,7 +34,7 @@ def __init__(
3434
def reset(self):
3535
self._sum_of_errors = torch.tensor(0.0, device=self._device)
3636

37-
def _update(self, output):
37+
def _update(self, output: Tuple[torch.Tensor, torch.Tensor]):
3838
y_pred, y = output
3939
errors = torch.abs(y - y_pred) / (torch.abs(y_pred) + torch.abs(y))
4040
self._sum_of_errors += torch.sum(errors).to(self._device)

ignite/contrib/metrics/regression/fractional_absolute_error.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Tuple
2+
13
import torch
24

35
from ignite.contrib.metrics.regression._base import _BaseRegression
@@ -24,7 +26,7 @@ def reset(self):
2426
self._sum_of_errors = 0.0
2527
self._num_examples = 0
2628

27-
def _update(self, output):
29+
def _update(self, output: Tuple[torch.Tensor, torch.Tensor]):
2830
y_pred, y = output
2931
errors = 2 * torch.abs(y.view_as(y_pred) - y_pred) / (torch.abs(y_pred) + torch.abs(y.view_as(y_pred)))
3032
self._sum_of_errors += torch.sum(errors).item()

ignite/contrib/metrics/regression/fractional_bias.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Tuple
2+
13
import torch
24

35
from ignite.contrib.metrics.regression._base import _BaseRegression
@@ -25,7 +27,7 @@ def reset(self):
2527
self._sum_of_errors = 0.0
2628
self._num_examples = 0
2729

28-
def _update(self, output):
30+
def _update(self, output: Tuple[torch.Tensor, torch.Tensor]):
2931
y_pred, y = output
3032
errors = 2 * (y.view_as(y_pred) - y_pred) / (y_pred + y.view_as(y_pred))
3133
self._sum_of_errors += torch.sum(errors).item()

ignite/contrib/metrics/regression/geometric_mean_absolute_error.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Tuple
2+
13
import torch
24

35
from ignite.contrib.metrics.regression._base import _BaseRegression
@@ -24,7 +26,7 @@ def reset(self):
2426
self._sum_of_errors = 0.0
2527
self._num_examples = 0
2628

27-
def _update(self, output):
29+
def _update(self, output: Tuple[torch.Tensor, torch.Tensor]):
2830
y_pred, y = output
2931
errors = torch.log(torch.abs(y.view_as(y_pred) - y_pred))
3032
self._sum_of_errors += torch.sum(errors)

ignite/contrib/metrics/regression/geometric_mean_relative_absolute_error.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
from typing import Tuple
2+
13
import torch
24

35
from ignite.contrib.metrics.regression._base import _BaseRegression
6+
from ignite.exceptions import NotComputableError
47

58

69
class GeometricMeanRelativeAbsoluteError(_BaseRegression):
@@ -26,7 +29,7 @@ def reset(self):
2629
self._num_examples = 0
2730
self._sum_of_errors = 0.0
2831

29-
def _update(self, output):
32+
def _update(self, output: Tuple[torch.Tensor, torch.Tensor]):
3033
y_pred, y = output
3134
self._sum_y += y.sum()
3235
self._num_examples += y.shape[0]

ignite/contrib/metrics/regression/manhattan_distance.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, Union
1+
from typing import Callable, Tuple, Union
22

33
import torch
44

@@ -33,7 +33,7 @@ def __init__(
3333
def reset(self):
3434
self._sum_of_errors = torch.tensor(0.0, device=self._device)
3535

36-
def _update(self, output):
36+
def _update(self, output: Tuple[torch.Tensor, torch.Tensor]):
3737
y_pred, y = output
3838
errors = torch.abs(y - y_pred)
3939
self._sum_of_errors += torch.sum(errors).to(self._device)

0 commit comments

Comments
 (0)