12
12
import torch .testing
13
13
from datasets_utils import combinations_grid
14
14
from torch .nn .functional import one_hot
15
- from torch .testing ._comparison import (
16
- assert_equal as _assert_equal ,
17
- BooleanPair ,
18
- ErrorMeta ,
19
- NonePair ,
20
- NumberPair ,
21
- TensorLikePair ,
22
- UnsupportedInputs ,
23
- )
15
+ from torch .testing ._comparison import assert_equal as _assert_equal , BooleanPair , NonePair , NumberPair , TensorLikePair
24
16
from torchvision .prototype import features
25
- from torchvision .prototype .transforms .functional import convert_dtype_image_tensor , to_image_tensor
17
+ from torchvision .prototype .transforms .functional import to_image_tensor
26
18
from torchvision .transforms .functional_tensor import _max_value as get_max_value
27
19
28
20
__all__ = [
54
46
]
55
47
56
48
57
- class PILImagePair (TensorLikePair ):
49
+ class ImagePair (TensorLikePair ):
58
50
def __init__ (
59
51
self ,
60
52
actual ,
@@ -64,44 +56,13 @@ def __init__(
64
56
allowed_percentage_diff = None ,
65
57
** other_parameters ,
66
58
):
67
- if not any (isinstance (input , PIL .Image .Image ) for input in (actual , expected )):
68
- raise UnsupportedInputs ()
69
-
70
- # This parameter is ignored to enable checking PIL images to tensor images no on the CPU
71
- other_parameters ["check_device" ] = False
59
+ if all (isinstance (input , PIL .Image .Image ) for input in [actual , expected ]):
60
+ actual , expected = [to_image_tensor (input ) for input in [actual , expected ]]
72
61
73
62
super ().__init__ (actual , expected , ** other_parameters )
74
63
self .agg_method = getattr (torch , agg_method ) if isinstance (agg_method , str ) else agg_method
75
64
self .allowed_percentage_diff = allowed_percentage_diff
76
65
77
- def _process_inputs (self , actual , expected , * , id , allow_subclasses ):
78
- actual , expected = [
79
- to_image_tensor (input ) if not isinstance (input , torch .Tensor ) else features .Image (input )
80
- for input in [actual , expected ]
81
- ]
82
- # This broadcast is needed, because `features.Mask`'s can have a 2D shape, but converting the equivalent PIL
83
- # image to a tensor adds a singleton leading dimension.
84
- # Although it looks like this belongs in `self._equalize_attributes`, it has to happen here.
85
- # `self._equalize_attributes` is called after `super()._compare_attributes` and that has an unconditional
86
- # shape check that will fail if we don't broadcast before.
87
- try :
88
- actual , expected = torch .broadcast_tensors (actual , expected )
89
- except RuntimeError :
90
- raise ErrorMeta (
91
- AssertionError ,
92
- f"The image shapes are not broadcastable: { actual .shape } != { expected .shape } ." ,
93
- id = id ,
94
- ) from None
95
- return super ()._process_inputs (actual , expected , id = id , allow_subclasses = allow_subclasses )
96
-
97
- def _equalize_attributes (self , actual , expected ):
98
- if actual .dtype != expected .dtype :
99
- dtype = torch .promote_types (actual .dtype , expected .dtype )
100
- actual = convert_dtype_image_tensor (actual , dtype )
101
- expected = convert_dtype_image_tensor (expected , dtype )
102
-
103
- return super ()._equalize_attributes (actual , expected )
104
-
105
66
def compare (self ) -> None :
106
67
actual , expected = self .actual , self .expected
107
68
@@ -111,16 +72,24 @@ def compare(self) -> None:
111
72
abs_diff = torch .abs (actual - expected )
112
73
113
74
if self .allowed_percentage_diff is not None :
114
- percentage_diff = ( abs_diff != 0 ).to (torch .float ).mean ()
75
+ percentage_diff = float (( abs_diff . ne ( 0 ).to (torch .float64 ).mean ()) )
115
76
if percentage_diff > self .allowed_percentage_diff :
116
- self ._make_error_meta (AssertionError , "percentage mismatch" )
77
+ raise self ._make_error_meta (
78
+ AssertionError ,
79
+ f"{ percentage_diff :.1%} elements differ, "
80
+ f"but only { self .allowed_percentage_diff :.1%} is allowed" ,
81
+ )
117
82
118
83
if self .agg_method is None :
119
84
super ()._compare_values (actual , expected )
120
85
else :
121
- err = self .agg_method (abs_diff .to (torch .float64 ))
122
- if err > self .atol :
123
- self ._make_error_meta (AssertionError , "aggregated mismatch" )
86
+ agg_abs_diff = float (self .agg_method (abs_diff .to (torch .float64 )))
87
+ if agg_abs_diff > self .atol :
88
+ raise self ._make_error_meta (
89
+ AssertionError ,
90
+ f"The '{ self .agg_method .__name__ } ' of the absolute difference is { agg_abs_diff } , "
91
+ f"but only { self .atol } is allowed." ,
92
+ )
124
93
125
94
126
95
def assert_close (
@@ -148,7 +117,7 @@ def assert_close(
148
117
NonePair ,
149
118
BooleanPair ,
150
119
NumberPair ,
151
- PILImagePair ,
120
+ ImagePair ,
152
121
TensorLikePair ,
153
122
),
154
123
allow_subclasses = allow_subclasses ,
@@ -167,6 +136,32 @@ def assert_close(
167
136
assert_equal = functools .partial (assert_close , rtol = 0 , atol = 0 )
168
137
169
138
139
+ def parametrized_error_message (* args , ** kwargs ):
140
+ def to_str (obj ):
141
+ if isinstance (obj , torch .Tensor ) and obj .numel () > 10 :
142
+ return f"tensor(shape={ list (obj .shape )} , dtype={ obj .dtype } , device={ obj .device } )"
143
+ else :
144
+ return repr (obj )
145
+
146
+ if args or kwargs :
147
+ postfix = "\n " .join (
148
+ [
149
+ "" ,
150
+ "Failure happened for the following parameters:" ,
151
+ "" ,
152
+ * [to_str (arg ) for arg in args ],
153
+ * [f"{ name } ={ to_str (kwarg )} " for name , kwarg in kwargs .items ()],
154
+ ]
155
+ )
156
+ else :
157
+ postfix = ""
158
+
159
+ def wrapper (msg ):
160
+ return msg + postfix
161
+
162
+ return wrapper
163
+
164
+
170
165
class ArgsKwargs :
171
166
def __init__ (self , * args , ** kwargs ):
172
167
self .args = args
@@ -656,6 +651,13 @@ def get_marks(self, test_id, args_kwargs):
656
651
]
657
652
658
653
def get_closeness_kwargs (self , test_id , * , dtype , device ):
654
+ if not (isinstance (test_id , tuple ) and len (test_id ) == 2 ):
655
+ msg = "`test_id` should be a `Tuple[Optional[str], str]` denoting the test class and function name"
656
+ if callable (test_id ):
657
+ msg += ". Did you forget to add the `test_id` fixture to parameters of the test?"
658
+ else :
659
+ msg += f", but got { test_id } instead."
660
+ raise pytest .UsageError (msg )
659
661
if isinstance (device , torch .device ):
660
662
device = device .type
661
663
return self .closeness_kwargs .get ((test_id , dtype , device ), dict ())
0 commit comments