1
1
from abc import abstractmethod
2
- from typing import Callable , Union
2
+ from typing import Callable , Tuple
3
3
4
4
import torch
5
5
6
6
from ignite .metrics import EpochMetric , Metric
7
7
from ignite .metrics .metric import reinit__is_reduced
8
8
9
9
10
- def _check_output_shapes (output ):
10
+ def _check_output_shapes (output : Tuple [ torch . Tensor , torch . Tensor ] ):
11
11
y_pred , y = output
12
12
if y_pred .shape != y .shape :
13
13
raise ValueError ("Input data shapes should be the same, but given {} and {}" .format (y_pred .shape , y .shape ))
@@ -21,7 +21,7 @@ def _check_output_shapes(output):
21
21
raise ValueError ("Input y should have shape (N,) or (N, 1), but given {}" .format (y .shape ))
22
22
23
23
24
- def _check_output_types (output ):
24
+ def _check_output_types (output : Tuple [ torch . Tensor , torch . Tensor ] ):
25
25
y_pred , y = output
26
26
if y_pred .dtype not in (torch .float16 , torch .float32 , torch .float64 ):
27
27
raise TypeError ("Input y_pred dtype should be float 16, 32 or 64, but given {}" .format (y_pred .dtype ))
@@ -36,7 +36,7 @@ class _BaseRegression(Metric):
36
36
# method `_update`.
37
37
38
38
@reinit__is_reduced
39
- def update (self , output ):
39
+ def update (self , output : Tuple [ torch . Tensor , torch . Tensor ] ):
40
40
_check_output_shapes (output )
41
41
_check_output_types (output )
42
42
y_pred , y = output [0 ].detach (), output [1 ].detach ()
@@ -50,7 +50,7 @@ def update(self, output):
50
50
self ._update ((y_pred , y ))
51
51
52
52
@abstractmethod
53
- def _update (self , output ):
53
+ def _update (self , output : Tuple [ torch . Tensor , torch . Tensor ] ):
54
54
pass
55
55
56
56
@@ -59,14 +59,16 @@ class _BaseRegressionEpoch(EpochMetric):
59
59
# `update` method check the shapes and call internal overloaded method `_update`.
60
60
# Class internally stores complete history of predictions and targets of type float32.
61
61
62
- def __init__ (self , compute_fn , output_transform = lambda x : x , check_compute_fn : bool = True ):
62
+ def __init__ (
63
+ self , compute_fn : Callable , output_transform : Callable = lambda x : x , check_compute_fn : bool = True ,
64
+ ):
63
65
super (_BaseRegressionEpoch , self ).__init__ (
64
66
compute_fn = compute_fn , output_transform = output_transform , check_compute_fn = check_compute_fn
65
67
)
66
68
67
- def _check_type (self , output ):
69
+ def _check_type (self , output : Tuple [ torch . Tensor , torch . Tensor ] ):
68
70
_check_output_types (output )
69
71
super (_BaseRegressionEpoch , self )._check_type (output )
70
72
71
- def _check_shape (self , output ):
73
+ def _check_shape (self , output : Tuple [ torch . Tensor , torch . Tensor ] ):
72
74
_check_output_shapes (output )
0 commit comments