Skip to content

Commit 65769ab

Browse files
authored
fix prototype transforms tests with set agg_method (#6934)
* fix prototype transforms tests with set agg_method * use individual tolerances * refactor PIL reference test * increase tolerance for elastic_mask * fix autocontrast tolerances * increase tolerance for RandomAutocontrast
1 parent d72e906 commit 65769ab

File tree

4 files changed

+381
-176
lines changed

4 files changed

+381
-176
lines changed

test/prototype_common_utils.py

Lines changed: 52 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,9 @@
1212
import torch.testing
1313
from datasets_utils import combinations_grid
1414
from torch.nn.functional import one_hot
15-
from torch.testing._comparison import (
16-
assert_equal as _assert_equal,
17-
BooleanPair,
18-
ErrorMeta,
19-
NonePair,
20-
NumberPair,
21-
TensorLikePair,
22-
UnsupportedInputs,
23-
)
15+
from torch.testing._comparison import assert_equal as _assert_equal, BooleanPair, NonePair, NumberPair, TensorLikePair
2416
from torchvision.prototype import features
25-
from torchvision.prototype.transforms.functional import convert_dtype_image_tensor, to_image_tensor
17+
from torchvision.prototype.transforms.functional import to_image_tensor
2618
from torchvision.transforms.functional_tensor import _max_value as get_max_value
2719

2820
__all__ = [
@@ -54,7 +46,7 @@
5446
]
5547

5648

57-
class PILImagePair(TensorLikePair):
49+
class ImagePair(TensorLikePair):
5850
def __init__(
5951
self,
6052
actual,
@@ -64,44 +56,13 @@ def __init__(
6456
allowed_percentage_diff=None,
6557
**other_parameters,
6658
):
67-
if not any(isinstance(input, PIL.Image.Image) for input in (actual, expected)):
68-
raise UnsupportedInputs()
69-
70-
# This parameter is ignored to enable checking PIL images to tensor images no on the CPU
71-
other_parameters["check_device"] = False
59+
if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]):
60+
actual, expected = [to_image_tensor(input) for input in [actual, expected]]
7261

7362
super().__init__(actual, expected, **other_parameters)
7463
self.agg_method = getattr(torch, agg_method) if isinstance(agg_method, str) else agg_method
7564
self.allowed_percentage_diff = allowed_percentage_diff
7665

77-
def _process_inputs(self, actual, expected, *, id, allow_subclasses):
78-
actual, expected = [
79-
to_image_tensor(input) if not isinstance(input, torch.Tensor) else features.Image(input)
80-
for input in [actual, expected]
81-
]
82-
# This broadcast is needed, because `features.Mask`'s can have a 2D shape, but converting the equivalent PIL
83-
# image to a tensor adds a singleton leading dimension.
84-
# Although it looks like this belongs in `self._equalize_attributes`, it has to happen here.
85-
# `self._equalize_attributes` is called after `super()._compare_attributes` and that has an unconditional
86-
# shape check that will fail if we don't broadcast before.
87-
try:
88-
actual, expected = torch.broadcast_tensors(actual, expected)
89-
except RuntimeError:
90-
raise ErrorMeta(
91-
AssertionError,
92-
f"The image shapes are not broadcastable: {actual.shape} != {expected.shape}.",
93-
id=id,
94-
) from None
95-
return super()._process_inputs(actual, expected, id=id, allow_subclasses=allow_subclasses)
96-
97-
def _equalize_attributes(self, actual, expected):
98-
if actual.dtype != expected.dtype:
99-
dtype = torch.promote_types(actual.dtype, expected.dtype)
100-
actual = convert_dtype_image_tensor(actual, dtype)
101-
expected = convert_dtype_image_tensor(expected, dtype)
102-
103-
return super()._equalize_attributes(actual, expected)
104-
10566
def compare(self) -> None:
10667
actual, expected = self.actual, self.expected
10768

@@ -111,16 +72,24 @@ def compare(self) -> None:
11172
abs_diff = torch.abs(actual - expected)
11273

11374
if self.allowed_percentage_diff is not None:
114-
percentage_diff = (abs_diff != 0).to(torch.float).mean()
75+
percentage_diff = float((abs_diff.ne(0).to(torch.float64).mean()))
11576
if percentage_diff > self.allowed_percentage_diff:
116-
self._make_error_meta(AssertionError, "percentage mismatch")
77+
raise self._make_error_meta(
78+
AssertionError,
79+
f"{percentage_diff:.1%} elements differ, "
80+
f"but only {self.allowed_percentage_diff:.1%} is allowed",
81+
)
11782

11883
if self.agg_method is None:
11984
super()._compare_values(actual, expected)
12085
else:
121-
err = self.agg_method(abs_diff.to(torch.float64))
122-
if err > self.atol:
123-
self._make_error_meta(AssertionError, "aggregated mismatch")
86+
agg_abs_diff = float(self.agg_method(abs_diff.to(torch.float64)))
87+
if agg_abs_diff > self.atol:
88+
raise self._make_error_meta(
89+
AssertionError,
90+
f"The '{self.agg_method.__name__}' of the absolute difference is {agg_abs_diff}, "
91+
f"but only {self.atol} is allowed.",
92+
)
12493

12594

12695
def assert_close(
@@ -148,7 +117,7 @@ def assert_close(
148117
NonePair,
149118
BooleanPair,
150119
NumberPair,
151-
PILImagePair,
120+
ImagePair,
152121
TensorLikePair,
153122
),
154123
allow_subclasses=allow_subclasses,
@@ -167,6 +136,32 @@ def assert_close(
167136
assert_equal = functools.partial(assert_close, rtol=0, atol=0)
168137

169138

139+
def parametrized_error_message(*args, **kwargs):
140+
def to_str(obj):
141+
if isinstance(obj, torch.Tensor) and obj.numel() > 10:
142+
return f"tensor(shape={list(obj.shape)}, dtype={obj.dtype}, device={obj.device})"
143+
else:
144+
return repr(obj)
145+
146+
if args or kwargs:
147+
postfix = "\n".join(
148+
[
149+
"",
150+
"Failure happened for the following parameters:",
151+
"",
152+
*[to_str(arg) for arg in args],
153+
*[f"{name}={to_str(kwarg)}" for name, kwarg in kwargs.items()],
154+
]
155+
)
156+
else:
157+
postfix = ""
158+
159+
def wrapper(msg):
160+
return msg + postfix
161+
162+
return wrapper
163+
164+
170165
class ArgsKwargs:
171166
def __init__(self, *args, **kwargs):
172167
self.args = args
@@ -656,6 +651,13 @@ def get_marks(self, test_id, args_kwargs):
656651
]
657652

658653
def get_closeness_kwargs(self, test_id, *, dtype, device):
654+
if not (isinstance(test_id, tuple) and len(test_id) == 2):
655+
msg = "`test_id` should be a `Tuple[Optional[str], str]` denoting the test class and function name"
656+
if callable(test_id):
657+
msg += ". Did you forget to add the `test_id` fixture to parameters of the test?"
658+
else:
659+
msg += f", but got {test_id} instead."
660+
raise pytest.UsageError(msg)
659661
if isinstance(device, torch.device):
660662
device = device.type
661663
return self.closeness_kwargs.get((test_id, dtype, device), dict())

0 commit comments

Comments
 (0)