Skip to content

Commit f88ab12

Browse files
authored
Add PermuteDimensions and TransposeDimensions transforms (#6800)
* Add PermuteDimensions and TransposeDimensions transforms * Strip Subclass info. * Apply changes from code review.
1 parent 3761855 commit f88ab12

File tree

4 files changed

+137
-18
lines changed

4 files changed

+137
-18
lines changed

test/test_prototype_transforms.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818
make_masks,
1919
make_one_hot_labels,
2020
make_segmentation_mask,
21+
make_video,
2122
make_videos,
2223
)
2324
from torchvision.ops.boxes import box_iou
2425
from torchvision.prototype import features, transforms
26+
from torchvision.prototype.transforms._utils import _isinstance
2527
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image
2628

2729
BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims]
@@ -1826,3 +1828,74 @@ def test_to_dtype(dtype, expected_dtypes):
18261828
assert transformed_value.dtype is expected_dtypes[value_type]
18271829
else:
18281830
assert transformed_value is value
1831+
1832+
1833+
@pytest.mark.parametrize(
1834+
("dims", "inverse_dims"),
1835+
[
1836+
(
1837+
{torch.Tensor: (1, 2, 0), features.Image: (2, 1, 0), features.Video: None},
1838+
{torch.Tensor: (2, 0, 1), features.Image: (2, 1, 0), features.Video: None},
1839+
),
1840+
(
1841+
{torch.Tensor: (1, 2, 0), features.Image: (2, 1, 0), features.Video: (1, 2, 3, 0)},
1842+
{torch.Tensor: (2, 0, 1), features.Image: (2, 1, 0), features.Video: (3, 0, 1, 2)},
1843+
),
1844+
],
1845+
)
1846+
def test_permute_dimensions(dims, inverse_dims):
1847+
sample = dict(
1848+
plain_tensor=torch.testing.make_tensor((3, 28, 28), dtype=torch.uint8, device="cpu"),
1849+
image=make_image(),
1850+
bounding_box=make_bounding_box(format=features.BoundingBoxFormat.XYXY),
1851+
video=make_video(),
1852+
str="str",
1853+
int=0,
1854+
)
1855+
1856+
transform = transforms.PermuteDimensions(dims)
1857+
transformed_sample = transform(sample)
1858+
1859+
for key, value in sample.items():
1860+
value_type = type(value)
1861+
transformed_value = transformed_sample[key]
1862+
1863+
if _isinstance(value, (features.Image, features.is_simple_tensor, features.Video)):
1864+
if transform.dims.get(value_type) is not None:
1865+
assert transformed_value.permute(inverse_dims[value_type]).equal(value)
1866+
assert type(transformed_value) == torch.Tensor
1867+
else:
1868+
assert transformed_value is value
1869+
1870+
1871+
@pytest.mark.parametrize(
1872+
"dims",
1873+
[
1874+
(-1, -2),
1875+
{torch.Tensor: (-1, -2), features.Image: (1, 2), features.Video: None},
1876+
],
1877+
)
1878+
def test_transpose_dimensions(dims):
1879+
sample = dict(
1880+
plain_tensor=torch.testing.make_tensor((3, 28, 28), dtype=torch.uint8, device="cpu"),
1881+
image=make_image(),
1882+
bounding_box=make_bounding_box(format=features.BoundingBoxFormat.XYXY),
1883+
video=make_video(),
1884+
str="str",
1885+
int=0,
1886+
)
1887+
1888+
transform = transforms.TransposeDimensions(dims)
1889+
transformed_sample = transform(sample)
1890+
1891+
for key, value in sample.items():
1892+
value_type = type(value)
1893+
transformed_value = transformed_sample[key]
1894+
1895+
transposed_dims = transform.dims.get(value_type)
1896+
if _isinstance(value, (features.Image, features.is_simple_tensor, features.Video)):
1897+
if transposed_dims is not None:
1898+
assert transformed_value.transpose(*transposed_dims).equal(value)
1899+
assert type(transformed_value) == torch.Tensor
1900+
else:
1901+
assert transformed_value is value

torchvision/prototype/transforms/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,17 @@
4040
TenCrop,
4141
)
4242
from ._meta import ClampBoundingBoxes, ConvertBoundingBoxFormat, ConvertColorSpace, ConvertImageDtype
43-
from ._misc import GaussianBlur, Identity, Lambda, LinearTransformation, Normalize, RemoveSmallBoundingBoxes, ToDtype
43+
from ._misc import (
44+
GaussianBlur,
45+
Identity,
46+
Lambda,
47+
LinearTransformation,
48+
Normalize,
49+
PermuteDimensions,
50+
RemoveSmallBoundingBoxes,
51+
ToDtype,
52+
TransposeDimensions,
53+
)
4454
from ._type_conversion import DecodeImage, LabelToOneHot, PILToTensor, ToImagePIL, ToImageTensor, ToPILImage
4555

4656
from ._deprecated import Grayscale, RandomGrayscale, ToTensor # usort: skip

torchvision/prototype/transforms/_misc.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
import functools
2-
from collections import defaultdict
3-
from typing import Any, Callable, Dict, List, Sequence, Type, Union
1+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
42

53
import PIL.Image
64

@@ -9,7 +7,7 @@
97
from torchvision.prototype import features
108
from torchvision.prototype.transforms import functional as F, Transform
119

12-
from ._utils import _setup_float_or_seq, _setup_size, has_any, query_bounding_box
10+
from ._utils import _get_defaultdict, _setup_float_or_seq, _setup_size, has_any, query_bounding_box
1311

1412

1513
class Identity(Transform):
@@ -145,15 +143,10 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
145143
class ToDtype(Transform):
146144
_transformed_types = (torch.Tensor,)
147145

148-
def _default_dtype(self, dtype: torch.dtype) -> torch.dtype:
149-
return dtype
150-
151-
def __init__(self, dtype: Union[torch.dtype, Dict[Type, torch.dtype]]) -> None:
146+
def __init__(self, dtype: Union[torch.dtype, Dict[Type, Optional[torch.dtype]]]) -> None:
152147
super().__init__()
153148
if not isinstance(dtype, dict):
154-
# This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle.
155-
# If it were possible, we could replace this with `defaultdict(lambda: dtype)`
156-
dtype = defaultdict(functools.partial(self._default_dtype, dtype))
149+
dtype = _get_defaultdict(dtype)
157150
self.dtype = dtype
158151

159152
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
@@ -163,6 +156,42 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
163156
return inpt.to(dtype=dtype)
164157

165158

159+
class PermuteDimensions(Transform):
160+
_transformed_types = (features.is_simple_tensor, features.Image, features.Video)
161+
162+
def __init__(self, dims: Union[Sequence[int], Dict[Type, Optional[Sequence[int]]]]) -> None:
163+
super().__init__()
164+
if not isinstance(dims, dict):
165+
dims = _get_defaultdict(dims)
166+
self.dims = dims
167+
168+
def _transform(
169+
self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any]
170+
) -> torch.Tensor:
171+
dims = self.dims[type(inpt)]
172+
if dims is None:
173+
return inpt.as_subclass(torch.Tensor)
174+
return inpt.permute(*dims)
175+
176+
177+
class TransposeDimensions(Transform):
178+
_transformed_types = (features.is_simple_tensor, features.Image, features.Video)
179+
180+
def __init__(self, dims: Union[Tuple[int, int], Dict[Type, Optional[Tuple[int, int]]]]) -> None:
181+
super().__init__()
182+
if not isinstance(dims, dict):
183+
dims = _get_defaultdict(dims)
184+
self.dims = dims
185+
186+
def _transform(
187+
self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any]
188+
) -> torch.Tensor:
189+
dims = self.dims[type(inpt)]
190+
if dims is None:
191+
return inpt.as_subclass(torch.Tensor)
192+
return inpt.transpose(*dims)
193+
194+
166195
class RemoveSmallBoundingBoxes(Transform):
167196
_transformed_types = (features.BoundingBox, features.Mask, features.Label, features.OneHotLabel)
168197

torchvision/prototype/transforms/_utils.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import functools
22
import numbers
33
from collections import defaultdict
4-
from typing import Any, Callable, Dict, List, Sequence, Tuple, Type, Union
4+
from typing import Any, Callable, Dict, List, Sequence, Tuple, Type, TypeVar, Union
55

66
import PIL.Image
77

@@ -42,8 +42,17 @@ def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None:
4242
raise TypeError("Got inappropriate fill arg")
4343

4444

45-
def _default_fill(fill: FillType) -> FillType:
46-
return fill
45+
T = TypeVar("T")
46+
47+
48+
def _default_arg(value: T) -> T:
49+
return value
50+
51+
52+
def _get_defaultdict(default: T) -> Dict[Any, T]:
53+
# This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle.
54+
# If it were possible, we could replace this with `defaultdict(lambda: default)`
55+
return defaultdict(functools.partial(_default_arg, default))
4756

4857

4958
def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillType]:
@@ -52,9 +61,7 @@ def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, F
5261
if isinstance(fill, dict):
5362
return fill
5463

55-
# This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle.
56-
# If it were possible, we could replace this with `defaultdict(lambda: fill)`
57-
return defaultdict(functools.partial(_default_fill, fill))
64+
return _get_defaultdict(fill)
5865

5966

6067
def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:

0 commit comments

Comments
 (0)