Skip to content

Commit ce289d9

Browse files
authored
add automatic feature type dispatch to functional transforms (#5323)
* add auto dispatch * fix missing arguments error message * remove pil kernel for erase * automate feature specific parameter detection * fix typos * cleanup dispatcher call * remove __torch_function__ from transform dispatch * remove auto-generation * revert unrelated changes * remove implements decorator * change register parameter order * change order of transforms for readability * add documentation for __torch_function__ * fix mypy * inline check for support * refactor kernel registering process * refactor dispatch to be a regular decorator * split kernels and dispatchers * remove sentinels * replace pass with ... * appease mypy * make single kernel dispatchers more concise * make dispatcher signatures more generic * make kernel checking more strict * revert doc changes * address Franciscos comments * remove inplace * rename kernel test module * fix inplace * remove special casing for pil and vanilla tensors * address comments * update docs
1 parent c923d77 commit ce289d9

18 files changed

+577
-145
lines changed

test/test_prototype_transforms_functional.py renamed to test/test_prototype_transforms_kernels.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import pytest
55
import torch.testing
6-
import torchvision.prototype.transforms.functional as F
6+
import torchvision.prototype.transforms.kernels as K
77
from torch import jit
88
from torchvision.prototype import features
99

@@ -115,7 +115,7 @@ def __init__(self, *args, **kwargs):
115115
class KernelInfo:
116116
def __init__(self, name, *, sample_inputs_fn):
117117
self.name = name
118-
self.kernel = getattr(F, name)
118+
self.kernel = getattr(K, name)
119119
self._sample_inputs_fn = sample_inputs_fn
120120

121121
def sample_inputs(self):
@@ -146,16 +146,16 @@ def horizontal_flip_image():
146146
@register_kernel_info_from_sample_inputs_fn
147147
def horizontal_flip_bounding_box():
148148
for bounding_box in make_bounding_boxes(formats=[features.BoundingBoxFormat.XYXY]):
149-
yield SampleInput(bounding_box, image_size=bounding_box.image_size)
149+
yield SampleInput(bounding_box, format=bounding_box.format, image_size=bounding_box.image_size)
150150

151151

152152
@register_kernel_info_from_sample_inputs_fn
153153
def resize_image():
154154
for image, interpolation in itertools.product(
155155
make_images(),
156156
[
157-
F.InterpolationMode.BILINEAR,
158-
F.InterpolationMode.NEAREST,
157+
K.InterpolationMode.BILINEAR,
158+
K.InterpolationMode.NEAREST,
159159
],
160160
):
161161
height, width = image.shape[-2:]

torchvision/prototype/features/_bounding_box.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __new__(
4040

4141
def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox:
4242
# import at runtime to avoid cyclic imports
43-
from torchvision.prototype.transforms.functional import convert_bounding_box_format
43+
from torchvision.prototype.transforms.kernels import convert_bounding_box_format
4444

4545
if isinstance(format, str):
4646
format = BoundingBoxFormat[format]

torchvision/prototype/features/_encoded.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def image_size(self) -> Tuple[int, int]:
4040

4141
def decode(self) -> Image:
4242
# import at runtime to avoid cyclic imports
43-
from torchvision.prototype.transforms.functional import decode_image_with_pil
43+
from torchvision.prototype.transforms.kernels import decode_image_with_pil
4444

4545
return Image(decode_image_with_pil(self))
4646

torchvision/prototype/features/_feature.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from typing import Any, cast, Dict, Set, TypeVar, Union, Optional, Type, Callable
1+
from typing import Any, cast, Dict, Set, TypeVar, Union, Optional, Type, Callable, Tuple, Sequence, Mapping
22

33
import torch
4-
from torch._C import _TensorBase
4+
from torch._C import _TensorBase, DisableTorchFunction
55

66

77
F = TypeVar("F", bound="Feature")
@@ -76,5 +76,45 @@ def new_like(
7676
_metadata.update(metadata)
7777
return cls(data, dtype=dtype or other.dtype, device=device or other.device, **_metadata)
7878

79+
@classmethod
80+
def __torch_function__(
81+
cls,
82+
func: Callable[..., torch.Tensor],
83+
types: Tuple[Type[torch.Tensor], ...],
84+
args: Sequence[Any] = (),
85+
kwargs: Optional[Mapping[str, Any]] = None,
86+
) -> torch.Tensor:
87+
"""For general information about how the __torch_function__ protocol works,
88+
see https://pytorch.org/docs/stable/notes/extending.html#extending-torch
89+
90+
TL;DR: Every time a PyTorch operator is called, it goes through the inputs and looks for the
91+
``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the
92+
``args`` and ``kwargs`` of the original call.
93+
94+
The default behavior of :class:`~torch.Tensor`'s is to retain a custom tensor type. For the :class:`Feature`
95+
use case, this has two downsides:
96+
97+
1. Since some :class:`Feature`'s require metadata to be constructed, the default wrapping, i.e.
98+
``return cls(func(*args, **kwargs))``, will fail for them.
99+
2. For most operations, there is no way of knowing if the input type is still valid for the output.
100+
101+
For these reasons, the automatic output wrapping is turned off for most operators.
102+
103+
Exceptions to this are:
104+
105+
- :func:`torch.clone`
106+
- :meth:`torch.Tensor.to`
107+
"""
108+
kwargs = kwargs or dict()
109+
with DisableTorchFunction():
110+
output = func(*args, **kwargs)
111+
112+
if func is torch.Tensor.clone:
113+
return cls.new_like(args[0], output)
114+
elif func is torch.Tensor.to:
115+
return cls.new_like(args[0], output, dtype=output.dtype, device=output.device)
116+
else:
117+
return output
118+
79119
def __repr__(self) -> str:
80120
return cast(str, torch.Tensor.__repr__(self)).replace("tensor", type(self).__name__)
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from . import functional
2-
from .functional import InterpolationMode # usort: skip
1+
from . import kernels # usort: skip
2+
from . import functional # usort: skip
3+
from .kernels import InterpolationMode # usort: skip
34

45
from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval
Lines changed: 12 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,14 @@
1-
from ._augment import erase_image, mixup_image, mixup_one_hot_label, cutmix_image, cutmix_one_hot_label
1+
from ._augment import erase, mixup, cutmix
22
from ._color import (
3-
adjust_brightness_image,
4-
adjust_contrast_image,
5-
adjust_saturation_image,
6-
adjust_sharpness_image,
7-
posterize_image,
8-
solarize_image,
9-
autocontrast_image,
10-
equalize_image,
11-
invert_image,
3+
adjust_brightness,
4+
adjust_contrast,
5+
adjust_saturation,
6+
adjust_sharpness,
7+
posterize,
8+
solarize,
9+
autocontrast,
10+
equalize,
11+
invert,
1212
)
13-
from ._geometry import (
14-
horizontal_flip_bounding_box,
15-
horizontal_flip_image,
16-
resize_bounding_box,
17-
resize_image,
18-
resize_segmentation_mask,
19-
center_crop_image,
20-
resized_crop_image,
21-
InterpolationMode,
22-
affine_image,
23-
rotate_image,
24-
)
25-
from ._meta_conversion import convert_color_space, convert_bounding_box_format
26-
from ._misc import normalize_image
27-
from ._type_conversion import decode_image_with_pil, decode_video_with_av, label_to_one_hot
13+
from ._geometry import horizontal_flip, resize, center_crop, resized_crop, affine, rotate
14+
from ._misc import normalize
Lines changed: 42 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,57 @@
1-
from typing import Tuple
1+
from typing import TypeVar, Any
22

33
import torch
4+
from torchvision.prototype import features
5+
from torchvision.prototype.transforms import kernels as K
46
from torchvision.transforms import functional as _F
57

8+
from ._utils import dispatch
69

7-
erase_image = _F.erase
10+
T = TypeVar("T", bound=features.Feature)
811

912

10-
def _mixup(input: torch.Tensor, batch_dim: int, lam: float, inplace: bool) -> torch.Tensor:
11-
if not inplace:
12-
input = input.clone()
13+
@dispatch(
14+
{
15+
torch.Tensor: _F.erase,
16+
features.Image: K.erase_image,
17+
}
18+
)
19+
def erase(input: T, *args: Any, **kwargs: Any) -> T:
20+
"""ADDME"""
21+
...
1322

14-
input_rolled = input.roll(1, batch_dim)
15-
return input.mul_(lam).add_(input_rolled.mul_(1 - lam))
1623

24+
@dispatch(
25+
{
26+
features.Image: K.mixup_image,
27+
features.OneHotLabel: K.mixup_one_hot_label,
28+
}
29+
)
30+
def mixup(input: T, *args: Any, **kwargs: Any) -> T:
31+
"""ADDME"""
32+
...
1733

18-
def mixup_image(image_batch: torch.Tensor, *, lam: float, inplace: bool = False) -> torch.Tensor:
19-
if image_batch.ndim < 4:
20-
raise ValueError("Need a batch of images")
2134

22-
return _mixup(image_batch, -4, lam, inplace)
35+
@dispatch(
36+
{
37+
features.Image: K.cutmix_image,
38+
features.OneHotLabel: K.cutmix_one_hot_label,
39+
}
40+
)
41+
def cutmix(input: T, *args: Any, **kwargs: Any) -> T:
42+
"""Perform the CutMix operation as introduced in the paper
43+
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" <https://arxiv.org/abs/1905.04899>`_.
2344
45+
Dispatch to the corresponding kernels happens according to this table:
2446
25-
def mixup_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam: float, inplace: bool = False) -> torch.Tensor:
26-
if one_hot_label_batch.ndim < 2:
27-
raise ValueError("Need a batch of one hot labels")
47+
.. table::
48+
:widths: 30 70
2849
29-
return _mixup(one_hot_label_batch, -2, lam, inplace)
50+
==================================================== ================================================================
51+
:class:`~torchvision.prototype.features.Image` :func:`~torch.prototype.transforms.kernels.cutmix_image`
52+
:class:`~torchvision.prototype.features.OneHotLabel` :func:`~torch.prototype.transforms.kernels.cutmix_one_hot_label`
53+
==================================================== ================================================================
3054
31-
32-
def cutmix_image(image_batch: torch.Tensor, *, box: Tuple[int, int, int, int], inplace: bool = False) -> torch.Tensor:
33-
if image_batch.ndim < 4:
34-
raise ValueError("Need a batch of images")
35-
36-
if not inplace:
37-
image_batch = image_batch.clone()
38-
39-
x1, y1, x2, y2 = box
40-
image_rolled = image_batch.roll(1, -4)
41-
42-
image_batch[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2]
43-
return image_batch
44-
45-
46-
def cutmix_one_hot_label(
47-
one_hot_label_batch: torch.Tensor, *, lam_adjusted: float, inplace: bool = False
48-
) -> torch.Tensor:
49-
if one_hot_label_batch.ndim < 2:
50-
raise ValueError("Need a batch of one hot labels")
51-
52-
return _mixup(one_hot_label_batch, -2, lam_adjusted, inplace)
55+
Please refer to the kernel documentations for a detailed explanation of the functionality and parameters.
56+
"""
57+
...
Lines changed: 108 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,119 @@
1+
from typing import TypeVar, Any
2+
3+
import PIL.Image
4+
import torch
5+
from torchvision.prototype import features
6+
from torchvision.prototype.transforms import kernels as K
17
from torchvision.transforms import functional as _F
28

9+
from ._utils import dispatch
10+
11+
T = TypeVar("T", bound=features.Feature)
12+
13+
14+
@dispatch(
15+
{
16+
torch.Tensor: _F.adjust_brightness,
17+
PIL.Image.Image: _F.adjust_brightness,
18+
features.Image: K.adjust_brightness_image,
19+
}
20+
)
21+
def adjust_brightness(input: T, *args: Any, **kwargs: Any) -> T:
22+
"""ADDME"""
23+
...
24+
25+
26+
@dispatch(
27+
{
28+
torch.Tensor: _F.adjust_saturation,
29+
PIL.Image.Image: _F.adjust_saturation,
30+
features.Image: K.adjust_saturation_image,
31+
}
32+
)
33+
def adjust_saturation(input: T, *args: Any, **kwargs: Any) -> T:
34+
"""ADDME"""
35+
...
36+
37+
38+
@dispatch(
39+
{
40+
torch.Tensor: _F.adjust_contrast,
41+
PIL.Image.Image: _F.adjust_contrast,
42+
features.Image: K.adjust_contrast_image,
43+
}
44+
)
45+
def adjust_contrast(input: T, *args: Any, **kwargs: Any) -> T:
46+
"""ADDME"""
47+
...
48+
49+
50+
@dispatch(
51+
{
52+
torch.Tensor: _F.adjust_sharpness,
53+
PIL.Image.Image: _F.adjust_sharpness,
54+
features.Image: K.adjust_sharpness_image,
55+
}
56+
)
57+
def adjust_sharpness(input: T, *args: Any, **kwargs: Any) -> T:
58+
"""ADDME"""
59+
...
60+
361

4-
adjust_brightness_image = _F.adjust_brightness
62+
@dispatch(
63+
{
64+
torch.Tensor: _F.posterize,
65+
PIL.Image.Image: _F.posterize,
66+
features.Image: K.posterize_image,
67+
}
68+
)
69+
def posterize(input: T, *args: Any, **kwargs: Any) -> T:
70+
"""ADDME"""
71+
...
572

6-
adjust_saturation_image = _F.adjust_saturation
773

8-
adjust_contrast_image = _F.adjust_contrast
74+
@dispatch(
75+
{
76+
torch.Tensor: _F.solarize,
77+
PIL.Image.Image: _F.solarize,
78+
features.Image: K.solarize_image,
79+
}
80+
)
81+
def solarize(input: T, *args: Any, **kwargs: Any) -> T:
82+
"""ADDME"""
83+
...
984

10-
adjust_sharpness_image = _F.adjust_sharpness
1185

12-
posterize_image = _F.posterize
86+
@dispatch(
87+
{
88+
torch.Tensor: _F.autocontrast,
89+
PIL.Image.Image: _F.autocontrast,
90+
features.Image: K.autocontrast_image,
91+
}
92+
)
93+
def autocontrast(input: T, *args: Any, **kwargs: Any) -> T:
94+
"""ADDME"""
95+
...
1396

14-
solarize_image = _F.solarize
1597

16-
autocontrast_image = _F.autocontrast
98+
@dispatch(
99+
{
100+
torch.Tensor: _F.equalize,
101+
PIL.Image.Image: _F.equalize,
102+
features.Image: K.equalize_image,
103+
}
104+
)
105+
def equalize(input: T, *args: Any, **kwargs: Any) -> T:
106+
"""ADDME"""
107+
...
17108

18-
equalize_image = _F.equalize
19109

20-
invert_image = _F.invert
110+
@dispatch(
111+
{
112+
torch.Tensor: _F.invert,
113+
PIL.Image.Image: _F.invert,
114+
features.Image: K.invert_image,
115+
}
116+
)
117+
def invert(input: T, *args: Any, **kwargs: Any) -> T:
118+
"""ADDME"""
119+
...

0 commit comments

Comments
 (0)