Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions test/prototype_common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import collections.abc
import dataclasses
import functools
from collections import defaultdict
from typing import Callable, Optional, Sequence, Tuple, Union

import PIL.Image
Expand Down Expand Up @@ -47,6 +48,9 @@
"make_masks",
"make_video",
"make_videos",
"TestMark",
"mark_framework_limitation",
"InfoBase",
]


Expand Down Expand Up @@ -588,3 +592,52 @@ def make_video_loaders(


make_videos = from_loaders(make_video_loaders)


class TestMark:
def __init__(
self,
# Tuple of test class name and test function name that identifies the test the mark is applied to. If there is
# no test class, i.e. a standalone test function, use `None`.
test_id,
# `pytest.mark.*` to apply, e.g. `pytest.mark.skip` or `pytest.mark.xfail`
mark,
*,
# Callable, that will be passed an `ArgsKwargs` and should return a boolean to indicate if the mark will be
# applied. If omitted, defaults to always apply.
condition=None,
):
self.test_id = test_id
self.mark = mark
self.condition = condition or (lambda args_kwargs: True)


def mark_framework_limitation(test_id, reason):
# The purpose of this function is to have a single entry point for skip marks that are only there, because the test
# framework cannot handle the kernel in general or a specific parameter combination.
# As development progresses, we can change the `mark.skip` to `mark.xfail` from time to time to see if the skip is
# still justified.
# We don't want to use `mark.xfail` all the time, because that actually runs the test until an error happens. Thus,
# we are wasting CI resources for no reason for most of the time
return TestMark(test_id, pytest.mark.skip(reason=reason))


class InfoBase:
def __init__(self, *, id, test_marks=None, closeness_kwargs=None):
# Identifier if the info that shows up the parametrization.
self.id = id
# Test markers that will be (conditionally) applied to an `ArgsKwargs` parametrization.
# See the `TestMark` class for details
self.test_marks = test_marks or []
# Additional parameters, e.g. `rtol=1e-3`, passed to `assert_close`.
self.closeness_kwargs = closeness_kwargs or dict()

test_marks_map = defaultdict(list)
for test_mark in self.test_marks:
test_marks_map[test_mark.test_id].append(test_mark)
self._test_marks_map = dict(test_marks_map)

def get_marks(self, test_id, args_kwargs):
return [
test_mark.mark for test_mark in self._test_marks_map.get(test_id, []) if test_mark.condition(args_kwargs)
]
94 changes: 52 additions & 42 deletions test/prototype_transforms_dispatcher_infos.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,67 @@
import collections.abc
import dataclasses

from collections import defaultdict

from typing import Callable, Dict, List, Optional, Sequence, Type

import pytest
import torchvision.prototype.transforms.functional as F
from prototype_transforms_kernel_infos import KERNEL_INFOS, TestMark
from prototype_common_utils import InfoBase, TestMark
from prototype_transforms_kernel_infos import KERNEL_INFOS
from torchvision.prototype import features

__all__ = ["DispatcherInfo", "DISPATCHER_INFOS"]

KERNEL_INFO_MAP = {info.kernel: info for info in KERNEL_INFOS}


@dataclasses.dataclass
class PILKernelInfo:
kernel: Callable
kernel_name: str = dataclasses.field(default=None)

def __post_init__(self):
self.kernel_name = self.kernel_name or self.kernel.__name__


@dataclasses.dataclass
class DispatcherInfo:
dispatcher: Callable
kernels: Dict[Type, Callable]
pil_kernel_info: Optional[PILKernelInfo] = None
method_name: str = dataclasses.field(default=None)
test_marks: Sequence[TestMark] = dataclasses.field(default_factory=list)
_test_marks_map: Dict[str, List[TestMark]] = dataclasses.field(default=None, init=False)

def __post_init__(self):
self.kernel_infos = {feature_type: KERNEL_INFO_MAP[kernel] for feature_type, kernel in self.kernels.items()}
self.method_name = self.method_name or self.dispatcher.__name__
test_marks_map = defaultdict(list)
for test_mark in self.test_marks:
test_marks_map[test_mark.test_id].append(test_mark)
self._test_marks_map = dict(test_marks_map)

def get_marks(self, test_id, args_kwargs):
return [
test_mark.mark for test_mark in self._test_marks_map.get(test_id, []) if test_mark.condition(args_kwargs)
]
class PILKernelInfo(InfoBase):
def __init__(
self,
kernel,
*,
# Defaults to `kernel.__name__`. Should be set if the function is exposed under a different name
# TODO: This can probably be removed after roll-out since we shouldn't have any aliasing then
kernel_name=None,
):
super().__init__(id=kernel_name or kernel.__name__)
self.kernel = kernel


class DispatcherInfo(InfoBase):
_KERNEL_INFO_MAP = {info.kernel: info for info in KERNEL_INFOS}

def __init__(
self,
dispatcher,
*,
# Dictionary of types that map to the kernel the dispatcher dispatches to.
kernels,
# If omitted, no PIL dispatch test will be performed.
pil_kernel_info=None,
# See InfoBase
test_marks=None,
# See InfoBase
closeness_kwargs=None,
):
super().__init__(id=dispatcher.__name__, test_marks=test_marks, closeness_kwargs=closeness_kwargs)
self.dispatcher = dispatcher
self.kernels = kernels
self.pil_kernel_info = pil_kernel_info

kernel_infos = {}
for feature_type, kernel in self.kernels.items():
kernel_info = self._KERNEL_INFO_MAP.get(kernel)
if not kernel_info:
raise pytest.UsageError(
f"Can't register {kernel.__name__} for type {feature_type} since there is no `KernelInfo` for it. "
f"Please add a `KernelInfo` for it in `prototype_transforms_kernel_infos.py`."
)
kernel_infos[feature_type] = kernel_info
self.kernel_infos = kernel_infos

def sample_inputs(self, *feature_types, filter_metadata=True):
for feature_type in feature_types or self.kernels.keys():
if feature_type not in self.kernels:
raise pytest.UsageError(f"There is no kernel registered for type {feature_type.__name__}")
for feature_type in feature_types or self.kernel_infos.keys():
kernel_info = self.kernel_infos.get(feature_type)
if not kernel_info:
raise pytest.UsageError(f"There is no kernel registered for type {type.__name__}")

sample_inputs = kernel_info.sample_inputs_fn()

sample_inputs = self.kernel_infos[feature_type].sample_inputs_fn()
if not filter_metadata:
yield from sample_inputs
else:
Expand Down
92 changes: 32 additions & 60 deletions test/prototype_transforms_kernel_infos.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,24 @@
import dataclasses
import functools
import itertools
import math
from collections import defaultdict
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple

import numpy as np
import pytest
import torch.testing
import torchvision.ops
import torchvision.prototype.transforms.functional as F

from _pytest.mark.structures import MarkDecorator
from common_utils import cycle_over
from datasets_utils import combinations_grid
from prototype_common_utils import (
ArgsKwargs,
InfoBase,
make_bounding_box_loaders,
make_image_loader,
make_image_loaders,
make_mask_loaders,
make_video_loaders,
mark_framework_limitation,
TestMark,
VALID_EXTRA_DIMS,
)
from torchvision.prototype import features
Expand All @@ -29,51 +27,35 @@
__all__ = ["KernelInfo", "KERNEL_INFOS"]


TestID = Tuple[Optional[str], str]


@dataclasses.dataclass
class TestMark:
test_id: TestID
mark: MarkDecorator
condition: Callable[[ArgsKwargs], bool] = lambda args_kwargs: True


@dataclasses.dataclass
class KernelInfo:
kernel: Callable
# Most common tests use these inputs to check the kernel. As such it should cover all valid code paths, but should
# not include extensive parameter combinations to keep to overall test count moderate.
sample_inputs_fn: Callable[[], Iterable[ArgsKwargs]]
# Defaults to `kernel.__name__`. Should be set if the function is exposed under a different name
# TODO: This can probably be removed after roll-out since we shouldn't have any aliasing then
kernel_name: str = dataclasses.field(default=None)
# This function should mirror the kernel. It should have the same signature as the `kernel` and as such also take
# tensors as inputs. Any conversion into another object type, e.g. PIL images or numpy arrays, should happen
# inside the function. It should return a tensor or to be more precise an object that can be compared to a
# tensor by `assert_close`. If omitted, no reference test will be performed.
reference_fn: Optional[Callable] = None
# These inputs are only used for the reference tests and thus can be comprehensive with regard to the parameter
# values to be tested. If not specified, `sample_inputs_fn` will be used.
reference_inputs_fn: Optional[Callable[[], Iterable[ArgsKwargs]]] = None
# Additional parameters, e.g. `rtol=1e-3`, passed to `assert_close`.
closeness_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
test_marks: Sequence[TestMark] = dataclasses.field(default_factory=list)
_test_marks_map: Dict[str, List[TestMark]] = dataclasses.field(default=None, init=False)

def __post_init__(self):
self.kernel_name = self.kernel_name or self.kernel.__name__
self.reference_inputs_fn = self.reference_inputs_fn or self.sample_inputs_fn

test_marks_map = defaultdict(list)
for test_mark in self.test_marks:
test_marks_map[test_mark.test_id].append(test_mark)
self._test_marks_map = dict(test_marks_map)

def get_marks(self, test_id, args_kwargs):
return [
test_mark.mark for test_mark in self._test_marks_map.get(test_id, []) if test_mark.condition(args_kwargs)
]
class KernelInfo(InfoBase):
def __init__(
self,
kernel,
*,
# Defaults to `kernel.__name__`. Should be set if the function is exposed under a different name
# TODO: This can probably be removed after roll-out since we shouldn't have any aliasing then
kernel_name=None,
# Most common tests use these inputs to check the kernel. As such it should cover all valid code paths, but
# should not include extensive parameter combinations to keep to overall test count moderate.
sample_inputs_fn,
# This function should mirror the kernel. It should have the same signature as the `kernel` and as such also
# take tensors as inputs. Any conversion into another object type, e.g. PIL images or numpy arrays, should
# happen inside the function. It should return a tensor or to be more precise an object that can be compared to
# a tensor by `assert_close`. If omitted, no reference test will be performed.
reference_fn=None,
# These inputs are only used for the reference tests and thus can be comprehensive with regard to the parameter
# values to be tested. If not specified, `sample_inputs_fn` will be used.
reference_inputs_fn=None,
# See InfoBase
test_marks=None,
# See InfoBase
closeness_kwargs=None,
):
super().__init__(id=kernel_name or kernel.__name__, test_marks=test_marks, closeness_kwargs=closeness_kwargs)
self.kernel = kernel
self.sample_inputs_fn = sample_inputs_fn
self.reference_fn = reference_fn
self.reference_inputs_fn = reference_inputs_fn


DEFAULT_IMAGE_CLOSENESS_KWARGS = dict(
Expand All @@ -97,16 +79,6 @@ def wrapper(image_tensor, *other_args, **kwargs):
return wrapper


def mark_framework_limitation(test_id, reason):
# The purpose of this function is to have a single entry point for skip marks that are only there, because the test
# framework cannot handle the kernel in general or a specific parameter combination.
# As development progresses, we can change the `mark.skip` to `mark.xfail` from time to time to see if the skip is
# still justified.
# We don't want to use `mark.xfail` all the time, because that actually runs the test until an error happens. Thus,
# we are wasting CI resources for no reason for most of the time.
return TestMark(test_id, pytest.mark.skip(reason=reason))


def xfail_jit_python_scalar_arg(name, *, reason=None):
reason = reason or f"Python scalar int or float for `{name}` is not supported when scripting"
return TestMark(
Expand Down
Loading