1
- import dataclasses
2
1
import functools
3
2
import itertools
4
3
import math
5
- from collections import defaultdict
6
- from typing import Any , Callable , Dict , Iterable , List , Optional , Sequence , Tuple
7
4
8
5
import numpy as np
9
6
import pytest
10
7
import torch .testing
11
8
import torchvision .ops
12
9
import torchvision .prototype .transforms .functional as F
13
-
14
- from _pytest .mark .structures import MarkDecorator
15
10
from common_utils import cycle_over
16
11
from datasets_utils import combinations_grid
17
12
from prototype_common_utils import (
18
13
ArgsKwargs ,
14
+ InfoBase ,
19
15
make_bounding_box_loaders ,
20
16
make_image_loader ,
21
17
make_image_loaders ,
22
18
make_mask_loaders ,
23
19
make_video_loaders ,
20
+ mark_framework_limitation ,
21
+ TestMark ,
24
22
VALID_EXTRA_DIMS ,
25
23
)
26
24
from torchvision .prototype import features
29
27
__all__ = ["KernelInfo" , "KERNEL_INFOS" ]
30
28
31
29
32
- TestID = Tuple [Optional [str ], str ]
33
-
34
-
35
- @dataclasses .dataclass
36
- class TestMark :
37
- test_id : TestID
38
- mark : MarkDecorator
39
- condition : Callable [[ArgsKwargs ], bool ] = lambda args_kwargs : True
40
-
41
-
42
- @dataclasses .dataclass
43
- class KernelInfo :
44
- kernel : Callable
45
- # Most common tests use these inputs to check the kernel. As such it should cover all valid code paths, but should
46
- # not include extensive parameter combinations to keep to overall test count moderate.
47
- sample_inputs_fn : Callable [[], Iterable [ArgsKwargs ]]
48
- # Defaults to `kernel.__name__`. Should be set if the function is exposed under a different name
49
- # TODO: This can probably be removed after roll-out since we shouldn't have any aliasing then
50
- kernel_name : str = dataclasses .field (default = None )
51
- # This function should mirror the kernel. It should have the same signature as the `kernel` and as such also take
52
- # tensors as inputs. Any conversion into another object type, e.g. PIL images or numpy arrays, should happen
53
- # inside the function. It should return a tensor or to be more precise an object that can be compared to a
54
- # tensor by `assert_close`. If omitted, no reference test will be performed.
55
- reference_fn : Optional [Callable ] = None
56
- # These inputs are only used for the reference tests and thus can be comprehensive with regard to the parameter
57
- # values to be tested. If not specified, `sample_inputs_fn` will be used.
58
- reference_inputs_fn : Optional [Callable [[], Iterable [ArgsKwargs ]]] = None
59
- # Additional parameters, e.g. `rtol=1e-3`, passed to `assert_close`.
60
- closeness_kwargs : Dict [str , Any ] = dataclasses .field (default_factory = dict )
61
- test_marks : Sequence [TestMark ] = dataclasses .field (default_factory = list )
62
- _test_marks_map : Dict [str , List [TestMark ]] = dataclasses .field (default = None , init = False )
63
-
64
- def __post_init__ (self ):
65
- self .kernel_name = self .kernel_name or self .kernel .__name__
66
- self .reference_inputs_fn = self .reference_inputs_fn or self .sample_inputs_fn
67
-
68
- test_marks_map = defaultdict (list )
69
- for test_mark in self .test_marks :
70
- test_marks_map [test_mark .test_id ].append (test_mark )
71
- self ._test_marks_map = dict (test_marks_map )
72
-
73
- def get_marks (self , test_id , args_kwargs ):
74
- return [
75
- test_mark .mark for test_mark in self ._test_marks_map .get (test_id , []) if test_mark .condition (args_kwargs )
76
- ]
30
+ class KernelInfo (InfoBase ):
31
+ def __init__ (
32
+ self ,
33
+ kernel ,
34
+ * ,
35
+ # Defaults to `kernel.__name__`. Should be set if the function is exposed under a different name
36
+ # TODO: This can probably be removed after roll-out since we shouldn't have any aliasing then
37
+ kernel_name = None ,
38
+ # Most common tests use these inputs to check the kernel. As such it should cover all valid code paths, but
39
+ # should not include extensive parameter combinations to keep to overall test count moderate.
40
+ sample_inputs_fn ,
41
+ # This function should mirror the kernel. It should have the same signature as the `kernel` and as such also
42
+ # take tensors as inputs. Any conversion into another object type, e.g. PIL images or numpy arrays, should
43
+ # happen inside the function. It should return a tensor or to be more precise an object that can be compared to
44
+ # a tensor by `assert_close`. If omitted, no reference test will be performed.
45
+ reference_fn = None ,
46
+ # These inputs are only used for the reference tests and thus can be comprehensive with regard to the parameter
47
+ # values to be tested. If not specified, `sample_inputs_fn` will be used.
48
+ reference_inputs_fn = None ,
49
+ # See InfoBase
50
+ test_marks = None ,
51
+ # See InfoBase
52
+ closeness_kwargs = None ,
53
+ ):
54
+ super ().__init__ (id = kernel_name or kernel .__name__ , test_marks = test_marks , closeness_kwargs = closeness_kwargs )
55
+ self .kernel = kernel
56
+ self .sample_inputs_fn = sample_inputs_fn
57
+ self .reference_fn = reference_fn
58
+ self .reference_inputs_fn = reference_inputs_fn
77
59
78
60
79
61
DEFAULT_IMAGE_CLOSENESS_KWARGS = dict (
@@ -97,16 +79,6 @@ def wrapper(image_tensor, *other_args, **kwargs):
97
79
return wrapper
98
80
99
81
100
- def mark_framework_limitation (test_id , reason ):
101
- # The purpose of this function is to have a single entry point for skip marks that are only there, because the test
102
- # framework cannot handle the kernel in general or a specific parameter combination.
103
- # As development progresses, we can change the `mark.skip` to `mark.xfail` from time to time to see if the skip is
104
- # still justified.
105
- # We don't want to use `mark.xfail` all the time, because that actually runs the test until an error happens. Thus,
106
- # we are wasting CI resources for no reason for most of the time.
107
- return TestMark (test_id , pytest .mark .skip (reason = reason ))
108
-
109
-
110
82
def xfail_jit_python_scalar_arg (name , * , reason = None ):
111
83
reason = reason or f"Python scalar int or float for `{ name } ` is not supported when scripting"
112
84
return TestMark (
0 commit comments