Skip to content

use pytest markers instead of custom solution for prototype transforms functional tests #6653

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Oct 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 69 additions & 70 deletions test/prototype_transforms_dispatcher_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
import dataclasses

from collections import defaultdict

from typing import Callable, Dict, List, Optional, Sequence, Type

import pytest
import torchvision.prototype.transforms.functional as F
from prototype_common_utils import BoundingBoxLoader
from prototype_transforms_kernel_infos import KERNEL_INFOS, KernelInfo, Skip
from prototype_transforms_kernel_infos import KERNEL_INFOS, TestMark
from torchvision.prototype import features

__all__ = ["DispatcherInfo", "DISPATCHER_INFOS"]
Expand All @@ -24,35 +24,27 @@ def __post_init__(self):
self.kernel_name = self.kernel_name or self.kernel.__name__


def skip_python_scalar_arg_jit(name, *, reason="Python scalar int or float is not supported when scripting"):
return Skip(
"test_scripted_smoke",
condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs[name], (int, float)),
reason=reason,
)


def skip_integer_size_jit(name="size"):
return skip_python_scalar_arg_jit(name, reason="Integer size is not supported when scripting.")


@dataclasses.dataclass
class DispatcherInfo:
dispatcher: Callable
kernels: Dict[Type, Callable]
kernel_infos: Dict[Type, KernelInfo] = dataclasses.field(default=None)
pil_kernel_info: Optional[PILKernelInfo] = None
method_name: str = dataclasses.field(default=None)
skips: Sequence[Skip] = dataclasses.field(default_factory=list)
_skips_map: Dict[str, List[Skip]] = dataclasses.field(default=None, init=False)
test_marks: Sequence[TestMark] = dataclasses.field(default_factory=list)
_test_marks_map: Dict[str, List[TestMark]] = dataclasses.field(default=None, init=False)

def __post_init__(self):
self.kernel_infos = {feature_type: KERNEL_INFO_MAP[kernel] for feature_type, kernel in self.kernels.items()}
self.method_name = self.method_name or self.dispatcher.__name__
skips_map = defaultdict(list)
for skip in self.skips:
skips_map[skip.test_name].append(skip)
self._skips_map = dict(skips_map)
test_marks_map = defaultdict(list)
for test_mark in self.test_marks:
test_marks_map[test_mark.test_id].append(test_mark)
self._test_marks_map = dict(test_marks_map)

def get_marks(self, test_id, args_kwargs):
return [
test_mark.mark for test_mark in self._test_marks_map.get(test_id, []) if test_mark.condition(args_kwargs)
]

def sample_inputs(self, *feature_types, filter_metadata=True):
for feature_type in feature_types or self.kernels.keys():
Expand All @@ -70,17 +62,27 @@ def sample_inputs(self, *feature_types, filter_metadata=True):

yield args_kwargs

def maybe_skip(self, *, test_name, args_kwargs, device):
skips = self._skips_map.get(test_name)
if not skips:
return

for skip in skips:
if skip.condition(args_kwargs, device):
pytest.skip(skip.reason)
def xfail_python_scalar_arg_jit(name, *, reason=None):
reason = reason or f"Python scalar int or float for `{name}` is not supported when scripting"
return TestMark(
("TestDispatchers", "test_scripted_smoke"),
pytest.mark.xfail(reason=reason),
condition=lambda args_kwargs: isinstance(args_kwargs.kwargs[name], (int, float)),
)


def xfail_integer_size_jit(name="size"):
return xfail_python_scalar_arg_jit(name, reason=f"Integer `{name}` is not supported when scripting.")

def fill_sequence_needs_broadcast(args_kwargs, device):

skip_dispatch_feature = TestMark(
("TestDispatchers", "test_dispatch_feature"),
pytest.mark.skip(reason="Dispatcher doesn't support arbitrary feature dispatch."),
)


def fill_sequence_needs_broadcast(args_kwargs):
(image_loader, *_), kwargs = args_kwargs
try:
fill = kwargs["fill"]
Expand All @@ -93,15 +95,12 @@ def fill_sequence_needs_broadcast(args_kwargs, device):
return image_loader.num_channels > 1


skip_dispatch_pil_if_fill_sequence_needs_broadcast = Skip(
"test_dispatch_pil",
xfail_dispatch_pil_if_fill_sequence_needs_broadcast = TestMark(
("TestDispatchers", "test_dispatch_pil"),
pytest.mark.xfail(
reason="PIL kernel doesn't support sequences of length 1 for `fill` if the number of color channels is larger."
),
condition=fill_sequence_needs_broadcast,
reason="PIL kernel doesn't support sequences of length 1 if the number of channels is larger.",
)

skip_dispatch_feature = Skip(
"test_dispatch_feature",
reason="Dispatcher doesn't support arbitrary feature dispatch.",
)


Expand All @@ -123,8 +122,8 @@ def fill_sequence_needs_broadcast(args_kwargs, device):
features.Mask: F.resize_mask,
},
pil_kernel_info=PILKernelInfo(F.resize_image_pil),
skips=[
skip_integer_size_jit(),
test_marks=[
xfail_integer_size_jit(),
],
),
DispatcherInfo(
Expand All @@ -135,9 +134,9 @@ def fill_sequence_needs_broadcast(args_kwargs, device):
features.Mask: F.affine_mask,
},
pil_kernel_info=PILKernelInfo(F.affine_image_pil),
skips=[
skip_dispatch_pil_if_fill_sequence_needs_broadcast,
skip_python_scalar_arg_jit("shear", reason="Scalar shear is not supported by JIT"),
test_marks=[
xfail_dispatch_pil_if_fill_sequence_needs_broadcast,
xfail_python_scalar_arg_jit("shear"),
],
),
DispatcherInfo(
Expand Down Expand Up @@ -166,16 +165,6 @@ def fill_sequence_needs_broadcast(args_kwargs, device):
features.Mask: F.crop_mask,
},
pil_kernel_info=PILKernelInfo(F.crop_image_pil, kernel_name="crop_image_pil"),
skips=[
Skip(
"test_dispatch_feature",
condition=lambda args_kwargs, device: isinstance(args_kwargs.args[0], BoundingBoxLoader),
reason=(
"F.crop expects 4 coordinates as input, but bounding box sample inputs only generate two "
"since that is sufficient for the kernel."
),
)
],
),
DispatcherInfo(
F.resized_crop,
Expand All @@ -193,10 +182,20 @@ def fill_sequence_needs_broadcast(args_kwargs, device):
features.BoundingBox: F.pad_bounding_box,
features.Mask: F.pad_mask,
},
skips=[
skip_dispatch_pil_if_fill_sequence_needs_broadcast,
],
pil_kernel_info=PILKernelInfo(F.pad_image_pil, kernel_name="pad_image_pil"),
test_marks=[
TestMark(
("TestDispatchers", "test_dispatch_pil"),
pytest.mark.xfail(
reason=(
"PIL kernel doesn't support sequences of length 1 for argument `fill` and "
"`padding_mode='constant'`, if the number of color channels is larger."
)
),
condition=lambda args_kwargs: fill_sequence_needs_broadcast(args_kwargs)
and args_kwargs.kwargs.get("padding_mode", "constant") == "constant",
)
],
),
DispatcherInfo(
F.perspective,
Expand All @@ -205,10 +204,10 @@ def fill_sequence_needs_broadcast(args_kwargs, device):
features.BoundingBox: F.perspective_bounding_box,
features.Mask: F.perspective_mask,
},
skips=[
skip_dispatch_pil_if_fill_sequence_needs_broadcast,
],
pil_kernel_info=PILKernelInfo(F.perspective_image_pil),
test_marks=[
xfail_dispatch_pil_if_fill_sequence_needs_broadcast,
],
),
DispatcherInfo(
F.elastic,
Expand All @@ -227,8 +226,8 @@ def fill_sequence_needs_broadcast(args_kwargs, device):
features.Mask: F.center_crop_mask,
},
pil_kernel_info=PILKernelInfo(F.center_crop_image_pil),
skips=[
skip_integer_size_jit("output_size"),
test_marks=[
xfail_integer_size_jit("output_size"),
],
),
DispatcherInfo(
Expand All @@ -237,9 +236,9 @@ def fill_sequence_needs_broadcast(args_kwargs, device):
features.Image: F.gaussian_blur_image_tensor,
},
pil_kernel_info=PILKernelInfo(F.gaussian_blur_image_pil),
skips=[
skip_python_scalar_arg_jit("kernel_size"),
skip_python_scalar_arg_jit("sigma"),
test_marks=[
xfail_python_scalar_arg_jit("kernel_size"),
xfail_python_scalar_arg_jit("sigma"),
],
),
DispatcherInfo(
Expand Down Expand Up @@ -290,7 +289,7 @@ def fill_sequence_needs_broadcast(args_kwargs, device):
features.Image: F.erase_image_tensor,
},
pil_kernel_info=PILKernelInfo(F.erase_image_pil),
skips=[
test_marks=[
skip_dispatch_feature,
],
),
Expand Down Expand Up @@ -335,8 +334,8 @@ def fill_sequence_needs_broadcast(args_kwargs, device):
features.Image: F.five_crop_image_tensor,
},
pil_kernel_info=PILKernelInfo(F.five_crop_image_pil),
skips=[
skip_integer_size_jit(),
test_marks=[
xfail_integer_size_jit(),
skip_dispatch_feature,
],
),
Expand All @@ -345,18 +344,18 @@ def fill_sequence_needs_broadcast(args_kwargs, device):
kernels={
features.Image: F.ten_crop_image_tensor,
},
pil_kernel_info=PILKernelInfo(F.ten_crop_image_pil),
skips=[
skip_integer_size_jit(),
test_marks=[
xfail_integer_size_jit(),
skip_dispatch_feature,
],
pil_kernel_info=PILKernelInfo(F.ten_crop_image_pil),
),
DispatcherInfo(
F.normalize,
kernels={
features.Image: F.normalize_image_tensor,
},
skips=[
test_marks=[
skip_dispatch_feature,
],
),
Expand Down
Loading