From 5cf797a1c7be1be716e83712c14cafb31d79dafa Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 3 Feb 2023 15:59:25 +0100 Subject: [PATCH 1/8] introduce heuristic for simple tensor handling of transforms v2 --- test/test_prototype_transforms.py | 34 +++++++++++++++++-- .../prototype/transforms/_transform.py | 30 +++++++++++++--- 2 files changed, 58 insertions(+), 6 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 335fbfd4fe3..727db6fd10d 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -8,8 +8,9 @@ import torch import torchvision.prototype.transforms.utils -from common_utils import assert_equal, cpu_and_gpu +from common_utils import cpu_and_gpu from prototype_common_utils import ( + assert_equal, DEFAULT_EXTRA_DIMS, make_bounding_box, make_bounding_boxes, @@ -25,7 +26,7 @@ ) from torchvision.ops.boxes import box_iou from torchvision.prototype import datapoints, transforms -from torchvision.prototype.transforms.utils import check_type +from torchvision.prototype.transforms.utils import check_type, is_simple_tensor from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims] @@ -222,6 +223,35 @@ def test_random_resized_crop(self, transform, input): transform(input) +@pytest.mark.parametrize( + ("first_input", "second_input"), + itertools.product( + [ + next(make_vanilla_tensor_images()), + make_image(), + next(make_pil_images()), + ], + repeat=2, + ), +) +def test_simple_tensor_heurisitc(first_input, second_input): + class CopyCloneTransform(transforms.Transform): + def _transform(self, inpt, params): + return inpt.clone() if isinstance(inpt, torch.Tensor) else inpt.copy() + + transform = CopyCloneTransform() + first_output, second_output = transform([first_input, second_input]) + + assert first_output is not first_input + assert_equal(first_output, first_input) + + if is_simple_tensor(second_input): + assert second_output is second_input + else: + assert second_output is not second_input + assert_equal(second_output, second_input) + + @pytest.mark.parametrize("p", [0.0, 1.0]) class TestRandomHorizontalFlip: def input_expected_image_tensor(self, p, dtype=torch.float32): diff --git a/torchvision/prototype/transforms/_transform.py b/torchvision/prototype/transforms/_transform.py index 18678a5265a..00a1f9db181 100644 --- a/torchvision/prototype/transforms/_transform.py +++ b/torchvision/prototype/transforms/_transform.py @@ -7,7 +7,8 @@ import torch from torch import nn from torch.utils._pytree import tree_flatten, tree_unflatten -from torchvision.prototype.transforms.utils import check_type +from torchvision.prototype import datapoints +from torchvision.prototype.transforms.utils import check_type, is_simple_tensor from torchvision.utils import _log_api_usage_once @@ -37,9 +38,30 @@ def forward(self, *inputs: Any) -> Any: params = self._get_params(flat_inputs) - flat_outputs = [ - self._transform(inpt, params) if check_type(inpt, self._transformed_types) else inpt for inpt in flat_inputs - ] + # This is a heuristic on how to deal with simple tensor inputs: + # 1. If we haven't seen an image yet, we transform a simple tensor + # 2. If we have seen an image before, so either a simple tensor, a `datapoints.Image` or a `PIL.Image.Image`, + # we return simple tensors without modification. + # The order is defined by the returned list of `tree_flatten`, which recurses depth-first through the input. + # Since in most cases the image is the first input or at least comes before any other numerical data, this + # heuristic allows users to keep any supplemental numerical data in the sample as simple tensors. We have a few + # datasets, like `Caltech101`, `CelebA`, and `Widerface`, that would need special handling without this + # heuristic, since they return the target partially or completely as tensors. + # TODO: try to get user feedback if this heuristic is confusing or it is fine to keep it + flat_outputs = [] + image_found = False + for inpt in flat_inputs: + needs_transform = False + if is_simple_tensor(inpt): + if not image_found: + image_found = True + needs_transform = True + elif check_type(inpt, self._transformed_types): + if isinstance(inpt, (datapoints.Image, PIL.Image.Image)): + image_found = True + needs_transform = True + + flat_outputs.append(self._transform(inpt, params) if needs_transform else inpt) return tree_unflatten(flat_outputs, spec) From fc39b3265cf2314bb1bfbd68abcfb66859598bed Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 6 Feb 2023 16:02:10 +0100 Subject: [PATCH 2/8] refactor heuristic --- test/test_prototype_transforms.py | 65 ++++++++++++++----- .../prototype/transforms/_transform.py | 46 ++++++------- 2 files changed, 72 insertions(+), 39 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 727db6fd10d..7bd72f08a9d 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -39,11 +39,19 @@ def make_vanilla_tensor_images(*args, **kwargs): yield image.data +def make_vanilla_tensor_image(*args, **kwargs): + return next(make_vanilla_tensor_images(*args, **kwargs)) + + def make_pil_images(*args, **kwargs): for image in make_vanilla_tensor_images(*args, **kwargs): yield to_pil_image(image) +def make_pil_image(*args, **kwargs): + return next(make_pil_images(*args, **kwargs)) + + def make_vanilla_tensor_bounding_boxes(*args, **kwargs): for bounding_box in make_bounding_boxes(*args, **kwargs): yield bounding_box.data @@ -224,32 +232,55 @@ def test_random_resized_crop(self, transform, input): @pytest.mark.parametrize( - ("first_input", "second_input"), - itertools.product( - [ - next(make_vanilla_tensor_images()), - make_image(), - next(make_pil_images()), - ], - repeat=2, - ), + "sample", + [ + [make_pil_image(), make_vanilla_tensor_image(), make_vanilla_tensor_image()], + [make_vanilla_tensor_image(), make_pil_image(), make_vanilla_tensor_image()], + [make_vanilla_tensor_image(), make_vanilla_tensor_image(), make_pil_image()], + [make_image(), make_vanilla_tensor_image(), make_vanilla_tensor_image()], + [make_vanilla_tensor_image(), make_image(), make_vanilla_tensor_image()], + [make_vanilla_tensor_image(), make_vanilla_tensor_image(), make_image()], + ], ) -def test_simple_tensor_heurisitc(first_input, second_input): +def test_simple_tensor_heuristic(sample): + def split_on_simple_tensor(to_split): + simple_tensors = [] + others = [] + for item, predicate in zip(to_split, sample): + (simple_tensors if is_simple_tensor(predicate) else others).append(item) + return simple_tensors[0], simple_tensors[1:], others + class CopyCloneTransform(transforms.Transform): def _transform(self, inpt, params): return inpt.clone() if isinstance(inpt, torch.Tensor) else inpt.copy() + @staticmethod + def was_applied(output, inpt): + identity = output is inpt + if identity: + return False + + # Make sure nothing fishy is going on + assert_equal(output, inpt) + return True + + first_simple_tensor_input, other_simple_tensor_inputs, other_inputs = split_on_simple_tensor(sample) + transform = CopyCloneTransform() - first_output, second_output = transform([first_input, second_input]) + transformed_sample = transform(sample) - assert first_output is not first_input - assert_equal(first_output, first_input) + first_simple_tensor_output, other_simple_tensor_outputs, other_outputs = split_on_simple_tensor(transformed_sample) - if is_simple_tensor(second_input): - assert second_output is second_input + if other_inputs: + assert not transform.was_applied(first_simple_tensor_output, first_simple_tensor_input) else: - assert second_output is not second_input - assert_equal(second_output, second_input) + assert transform.was_applied(first_simple_tensor_output, first_simple_tensor_input) + + for output, inpt in zip(other_simple_tensor_outputs, other_simple_tensor_inputs): + assert not transform.was_applied(output, inpt) + + for input, output in zip(other_inputs, other_outputs): + assert transform.was_applied(output, input) @pytest.mark.parametrize("p", [0.0, 1.0]) diff --git a/torchvision/prototype/transforms/_transform.py b/torchvision/prototype/transforms/_transform.py index 00a1f9db181..58eccb96523 100644 --- a/torchvision/prototype/transforms/_transform.py +++ b/torchvision/prototype/transforms/_transform.py @@ -8,7 +8,7 @@ from torch import nn from torch.utils._pytree import tree_flatten, tree_unflatten from torchvision.prototype import datapoints -from torchvision.prototype.transforms.utils import check_type, is_simple_tensor +from torchvision.prototype.transforms.utils import check_type, has_any, is_simple_tensor from torchvision.utils import _log_api_usage_once @@ -38,30 +38,32 @@ def forward(self, *inputs: Any) -> Any: params = self._get_params(flat_inputs) - # This is a heuristic on how to deal with simple tensor inputs: - # 1. If we haven't seen an image yet, we transform a simple tensor - # 2. If we have seen an image before, so either a simple tensor, a `datapoints.Image` or a `PIL.Image.Image`, - # we return simple tensors without modification. - # The order is defined by the returned list of `tree_flatten`, which recurses depth-first through the input. - # Since in most cases the image is the first input or at least comes before any other numerical data, this - # heuristic allows users to keep any supplemental numerical data in the sample as simple tensors. We have a few - # datasets, like `Caltech101`, `CelebA`, and `Widerface`, that would need special handling without this - # heuristic, since they return the target partially or completely as tensors. - # TODO: try to get user feedback if this heuristic is confusing or it is fine to keep it + # Below is a heuristic on how to deal with simple tensor inputs: + # 1. Simple tensors, i.e. tensors that are not a datapoint, are passed through if there is an explicit image + # (`datapoints.Image` or `PIL.Image.Image`) or video (`datapoints.Video`) in the sample. + # 2. If there is no explicit image or video in the sample, only the first encountered simple tensor is + # transformed as image, while the rest is passed through. The order is defined by the returned `flat_inputs` + # of `tree_flatten`, which recurses depth-first through the input. + # + # This heuristic stems from two requirements: + # 1. We need to keep BC for single input simple tensors and treat them as images. + # 2. We don't want to treat all simple tensors as images, because some datasets like `CelebA` or `Widerface` + # return supplemental numerical data as tensors that cannot be transformed as images. + # + # The heuristic should work well for most people in practice. The only case where it doesn't is if someone + # tries to transform multiple simple tensors at the same time, expecting them all to be treated as images. + # However, this case wasn't supported by transforms v1 either, so there is no BC concern. flat_outputs = [] - image_found = False + simple_tensor_transformed = has_any(flat_inputs, datapoints.Image, datapoints.Video, PIL.Image.Image) for inpt in flat_inputs: - needs_transform = False if is_simple_tensor(inpt): - if not image_found: - image_found = True - needs_transform = True - elif check_type(inpt, self._transformed_types): - if isinstance(inpt, (datapoints.Image, PIL.Image.Image)): - image_found = True - needs_transform = True - - flat_outputs.append(self._transform(inpt, params) if needs_transform else inpt) + if simple_tensor_transformed: + flat_outputs.append(inpt) + continue + + simple_tensor_transformed = True + + flat_outputs.append(self._transform(inpt, params)) return tree_unflatten(flat_outputs, spec) From ff73560909a1b6960ff4ed2dcd9c222a0a7b1530 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 6 Feb 2023 21:43:48 +0100 Subject: [PATCH 3/8] increase test coverage --- test/test_prototype_transforms.py | 37 +++++++++++++------------------ 1 file changed, 16 insertions(+), 21 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 7bd72f08a9d..2e7c1bedd64 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -39,19 +39,11 @@ def make_vanilla_tensor_images(*args, **kwargs): yield image.data -def make_vanilla_tensor_image(*args, **kwargs): - return next(make_vanilla_tensor_images(*args, **kwargs)) - - def make_pil_images(*args, **kwargs): for image in make_vanilla_tensor_images(*args, **kwargs): yield to_pil_image(image) -def make_pil_image(*args, **kwargs): - return next(make_pil_images(*args, **kwargs)) - - def make_vanilla_tensor_bounding_boxes(*args, **kwargs): for bounding_box in make_bounding_boxes(*args, **kwargs): yield bounding_box.data @@ -233,14 +225,16 @@ def test_random_resized_crop(self, transform, input): @pytest.mark.parametrize( "sample", - [ - [make_pil_image(), make_vanilla_tensor_image(), make_vanilla_tensor_image()], - [make_vanilla_tensor_image(), make_pil_image(), make_vanilla_tensor_image()], - [make_vanilla_tensor_image(), make_vanilla_tensor_image(), make_pil_image()], - [make_image(), make_vanilla_tensor_image(), make_vanilla_tensor_image()], - [make_vanilla_tensor_image(), make_image(), make_vanilla_tensor_image()], - [make_vanilla_tensor_image(), make_vanilla_tensor_image(), make_image()], - ], + itertools.permutations( + [ + next(make_vanilla_tensor_images()), + next(make_vanilla_tensor_images()), + next(make_pil_images()), + make_image(), + next(make_videos()), + ], + 3, + ), ) def test_simple_tensor_heuristic(sample): def split_on_simple_tensor(to_split): @@ -248,7 +242,7 @@ def split_on_simple_tensor(to_split): others = [] for item, predicate in zip(to_split, sample): (simple_tensors if is_simple_tensor(predicate) else others).append(item) - return simple_tensors[0], simple_tensors[1:], others + return simple_tensors[0] if simple_tensors else None, simple_tensors[1:], others class CopyCloneTransform(transforms.Transform): def _transform(self, inpt, params): @@ -271,10 +265,11 @@ def was_applied(output, inpt): first_simple_tensor_output, other_simple_tensor_outputs, other_outputs = split_on_simple_tensor(transformed_sample) - if other_inputs: - assert not transform.was_applied(first_simple_tensor_output, first_simple_tensor_input) - else: - assert transform.was_applied(first_simple_tensor_output, first_simple_tensor_input) + if first_simple_tensor_input is not None: + if other_inputs: + assert not transform.was_applied(first_simple_tensor_output, first_simple_tensor_input) + else: + assert transform.was_applied(first_simple_tensor_output, first_simple_tensor_input) for output, inpt in zip(other_simple_tensor_outputs, other_simple_tensor_inputs): assert not transform.was_applied(output, inpt) From 087a787a81763ec95e2d51001e51218ee17f58ce Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 6 Feb 2023 21:57:55 +0100 Subject: [PATCH 4/8] improve test documentation --- test/test_prototype_transforms.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 2e7c1bedd64..bdc657e9c50 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -224,7 +224,7 @@ def test_random_resized_crop(self, transform, input): @pytest.mark.parametrize( - "sample", + "flat_inputs", itertools.permutations( [ next(make_vanilla_tensor_images()), @@ -236,12 +236,18 @@ def test_random_resized_crop(self, transform, input): 3, ), ) -def test_simple_tensor_heuristic(sample): +def test_simple_tensor_heuristic(flat_inputs): def split_on_simple_tensor(to_split): + # This takes a sequence that is structurally aligned with `flat_inputs` and splits its items into three parts: + # 1. The first simple tensor. If none is present, this will be `None` + # 2. A list of the remaining simple tensors + # 3. A list of all other items simple_tensors = [] others = [] - for item, predicate in zip(to_split, sample): - (simple_tensors if is_simple_tensor(predicate) else others).append(item) + # Splitting always happens on the original `flat_inputs` to avoid any erroneous type changes by the transform to + # affect the splitting. + for item, inpt in zip(to_split, flat_inputs): + (simple_tensors if is_simple_tensor(inpt) else others).append(item) return simple_tensors[0] if simple_tensors else None, simple_tensors[1:], others class CopyCloneTransform(transforms.Transform): @@ -258,10 +264,10 @@ def was_applied(output, inpt): assert_equal(output, inpt) return True - first_simple_tensor_input, other_simple_tensor_inputs, other_inputs = split_on_simple_tensor(sample) + first_simple_tensor_input, other_simple_tensor_inputs, other_inputs = split_on_simple_tensor(flat_inputs) transform = CopyCloneTransform() - transformed_sample = transform(sample) + transformed_sample = transform(flat_inputs) first_simple_tensor_output, other_simple_tensor_outputs, other_outputs = split_on_simple_tensor(transformed_sample) From 7c94e9d87481b1e59d5f414e1d1310abb4900949 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 6 Feb 2023 23:24:19 +0100 Subject: [PATCH 5/8] fix heuristic --- torchvision/prototype/transforms/_transform.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/torchvision/prototype/transforms/_transform.py b/torchvision/prototype/transforms/_transform.py index 58eccb96523..c4019f3ec54 100644 --- a/torchvision/prototype/transforms/_transform.py +++ b/torchvision/prototype/transforms/_transform.py @@ -56,14 +56,17 @@ def forward(self, *inputs: Any) -> Any: flat_outputs = [] simple_tensor_transformed = has_any(flat_inputs, datapoints.Image, datapoints.Video, PIL.Image.Image) for inpt in flat_inputs: - if is_simple_tensor(inpt): - if simple_tensor_transformed: - flat_outputs.append(inpt) - continue + needs_transform = True - simple_tensor_transformed = True + if not check_type(inpt, self._transformed_types): + needs_transform = False + elif is_simple_tensor(inpt): + if simple_tensor_transformed: + needs_transform = False + else: + simple_tensor_transformed = True - flat_outputs.append(self._transform(inpt, params)) + flat_outputs.append(self._transform(inpt, params) if needs_transform else inpt) return tree_unflatten(flat_outputs, spec) From cffd206568464b4944cf7f0bc17babb229198eb8 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 8 Feb 2023 16:37:52 +0100 Subject: [PATCH 6/8] add warnings and fix tests --- test/test_prototype_transforms.py | 40 ++++++++++++++++------- torchvision/prototype/transforms/_misc.py | 19 +++++++++++ 2 files changed, 47 insertions(+), 12 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index bdc657e9c50..71a12e3f6c7 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -1,9 +1,9 @@ import itertools +import re import numpy as np import PIL.Image - import pytest import torch @@ -1822,17 +1822,17 @@ def test__transform(self, mocker): [ ( torch.float64, - {torch.Tensor: torch.float64, datapoints.Image: torch.float64, datapoints.BoundingBox: torch.float64}, + {datapoints.Video: torch.float64, datapoints.Image: torch.float64, datapoints.BoundingBox: torch.float64}, ), ( - {torch.Tensor: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64}, - {torch.Tensor: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64}, + {datapoints.Video: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64}, + {datapoints.Video: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64}, ), ], ) def test_to_dtype(dtype, expected_dtypes): sample = dict( - plain_tensor=torch.testing.make_tensor(5, dtype=torch.int64, device="cpu"), + video=make_video(dtype=torch.int64), image=make_image(dtype=torch.uint8), bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY, dtype=torch.float32), str="str", @@ -1855,22 +1855,27 @@ def test_to_dtype(dtype, expected_dtypes): assert transformed_value is value +@pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video]) +def test_to_dtype_plain_tensor_warning(other_type): + with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")): + transforms.ToDtype(dtype={torch.Tensor: torch.float32, other_type: torch.float64}) + + @pytest.mark.parametrize( ("dims", "inverse_dims"), [ ( - {torch.Tensor: (1, 2, 0), datapoints.Image: (2, 1, 0), datapoints.Video: None}, - {torch.Tensor: (2, 0, 1), datapoints.Image: (2, 1, 0), datapoints.Video: None}, + {datapoints.Image: (2, 1, 0), datapoints.Video: None}, + {datapoints.Image: (2, 1, 0), datapoints.Video: None}, ), ( - {torch.Tensor: (1, 2, 0), datapoints.Image: (2, 1, 0), datapoints.Video: (1, 2, 3, 0)}, - {torch.Tensor: (2, 0, 1), datapoints.Image: (2, 1, 0), datapoints.Video: (3, 0, 1, 2)}, + {datapoints.Image: (2, 1, 0), datapoints.Video: (1, 2, 3, 0)}, + {datapoints.Image: (2, 1, 0), datapoints.Video: (3, 0, 1, 2)}, ), ], ) def test_permute_dimensions(dims, inverse_dims): sample = dict( - plain_tensor=torch.testing.make_tensor((3, 28, 28), dtype=torch.uint8, device="cpu"), image=make_image(), bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY), video=make_video(), @@ -1895,16 +1900,21 @@ def test_permute_dimensions(dims, inverse_dims): assert transformed_value is value +@pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video]) +def test_permute_dimensions_plain_tensor_warning(other_type): + with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")): + transforms.PermuteDimensions(dims={torch.Tensor: (0, 1), other_type: (1, 0)}) + + @pytest.mark.parametrize( "dims", [ (-1, -2), - {torch.Tensor: (-1, -2), datapoints.Image: (1, 2), datapoints.Video: None}, + {datapoints.Image: (1, 2), datapoints.Video: None}, ], ) def test_transpose_dimensions(dims): sample = dict( - plain_tensor=torch.testing.make_tensor((3, 28, 28), dtype=torch.uint8, device="cpu"), image=make_image(), bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY), video=make_video(), @@ -1930,6 +1940,12 @@ def test_transpose_dimensions(dims): assert transformed_value is value +@pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video]) +def test_transpose_dimensions_plain_tensor_warning(other_type): + with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")): + transforms.TransposeDimensions(dims={torch.Tensor: (0, 1), other_type: (1, 0)}) + + class TestUniformTemporalSubsample: @pytest.mark.parametrize( "inpt", diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index 07ab53aff82..e7bb62da18e 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -1,3 +1,4 @@ +import warnings from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union import PIL.Image @@ -155,6 +156,12 @@ def __init__(self, dtype: Union[torch.dtype, Dict[Type, Optional[torch.dtype]]]) super().__init__() if not isinstance(dtype, dict): dtype = _get_defaultdict(dtype) + if torch.Tensor in dtype and any(cls in dtype for cls in [datapoints.Image, datapoints.Video]): + warnings.warn( + "Got `dtype` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. " + "Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) " + "in case a `datapoints.Image` or `datapoints.Video` is present in the input." + ) self.dtype = dtype def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: @@ -171,6 +178,12 @@ def __init__(self, dims: Union[Sequence[int], Dict[Type, Optional[Sequence[int]] super().__init__() if not isinstance(dims, dict): dims = _get_defaultdict(dims) + if torch.Tensor in dims and any(cls in dims for cls in [datapoints.Image, datapoints.Video]): + warnings.warn( + "Got `dims` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. " + "Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) " + "in case a `datapoints.Image` or `datapoints.Video` is present in the input." + ) self.dims = dims def _transform( @@ -189,6 +202,12 @@ def __init__(self, dims: Union[Tuple[int, int], Dict[Type, Optional[Tuple[int, i super().__init__() if not isinstance(dims, dict): dims = _get_defaultdict(dims) + if torch.Tensor in dims and any(cls in dims for cls in [datapoints.Image, datapoints.Video]): + warnings.warn( + "Got `dims` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. " + "Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) " + "in case a `datapoints.Image` or `datapoints.Video` is present in the input." + ) self.dims = dims def _transform( From 9a476bca45190291db6d42fd0d0032caf0d28da8 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 8 Feb 2023 16:41:28 +0100 Subject: [PATCH 7/8] clarify heuristic --- torchvision/prototype/transforms/_transform.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchvision/prototype/transforms/_transform.py b/torchvision/prototype/transforms/_transform.py index ca3c9cc3f1d..675b0787e83 100644 --- a/torchvision/prototype/transforms/_transform.py +++ b/torchvision/prototype/transforms/_transform.py @@ -54,17 +54,17 @@ def forward(self, *inputs: Any) -> Any: # tries to transform multiple simple tensors at the same time, expecting them all to be treated as images. # However, this case wasn't supported by transforms v1 either, so there is no BC concern. flat_outputs = [] - simple_tensor_transformed = has_any(flat_inputs, datapoints.Image, datapoints.Video, PIL.Image.Image) + transform_simple_tensor = not has_any(flat_inputs, datapoints.Image, datapoints.Video, PIL.Image.Image) for inpt in flat_inputs: needs_transform = True if not check_type(inpt, self._transformed_types): needs_transform = False elif is_simple_tensor(inpt): - if simple_tensor_transformed: - needs_transform = False + if transform_simple_tensor: + transform_simple_tensor = False else: - simple_tensor_transformed = True + needs_transform = False flat_outputs.append(self._transform(inpt, params) if needs_transform else inpt) From 4415cb4c1d575fe07e51c9e35e9e9b01152c654d Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 8 Feb 2023 19:47:04 +0100 Subject: [PATCH 8/8] add test for single plain tensor --- test/test_prototype_transforms.py | 233 +++++++++++++++++------------- 1 file changed, 129 insertions(+), 104 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 71a12e3f6c7..29c2bc1358a 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -1817,133 +1817,158 @@ def test__transform(self, mocker): ) -@pytest.mark.parametrize( - ("dtype", "expected_dtypes"), - [ - ( - torch.float64, - {datapoints.Video: torch.float64, datapoints.Image: torch.float64, datapoints.BoundingBox: torch.float64}, - ), - ( - {datapoints.Video: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64}, - {datapoints.Video: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64}, - ), - ], -) -def test_to_dtype(dtype, expected_dtypes): - sample = dict( - video=make_video(dtype=torch.int64), - image=make_image(dtype=torch.uint8), - bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY, dtype=torch.float32), - str="str", - int=0, +class TestToDtype: + @pytest.mark.parametrize( + ("dtype", "expected_dtypes"), + [ + ( + torch.float64, + { + datapoints.Video: torch.float64, + datapoints.Image: torch.float64, + datapoints.BoundingBox: torch.float64, + }, + ), + ( + {datapoints.Video: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64}, + {datapoints.Video: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64}, + ), + ], ) + def test_call(self, dtype, expected_dtypes): + sample = dict( + video=make_video(dtype=torch.int64), + image=make_image(dtype=torch.uint8), + bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY, dtype=torch.float32), + str="str", + int=0, + ) - transform = transforms.ToDtype(dtype) - transformed_sample = transform(sample) + transform = transforms.ToDtype(dtype) + transformed_sample = transform(sample) - for key, value in sample.items(): - value_type = type(value) - transformed_value = transformed_sample[key] + for key, value in sample.items(): + value_type = type(value) + transformed_value = transformed_sample[key] - # make sure the transformation retains the type - assert isinstance(transformed_value, value_type) + # make sure the transformation retains the type + assert isinstance(transformed_value, value_type) - if isinstance(value, torch.Tensor): - assert transformed_value.dtype is expected_dtypes[value_type] - else: - assert transformed_value is value + if isinstance(value, torch.Tensor): + assert transformed_value.dtype is expected_dtypes[value_type] + else: + assert transformed_value is value + @pytest.mark.filterwarnings("error") + def test_plain_tensor_call(self): + tensor = torch.empty((), dtype=torch.float32) + transform = transforms.ToDtype({torch.Tensor: torch.float64}) -@pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video]) -def test_to_dtype_plain_tensor_warning(other_type): - with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")): - transforms.ToDtype(dtype={torch.Tensor: torch.float32, other_type: torch.float64}) + assert transform(tensor).dtype is torch.float64 + @pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video]) + def test_plain_tensor_warning(self, other_type): + with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")): + transforms.ToDtype(dtype={torch.Tensor: torch.float32, other_type: torch.float64}) -@pytest.mark.parametrize( - ("dims", "inverse_dims"), - [ - ( - {datapoints.Image: (2, 1, 0), datapoints.Video: None}, - {datapoints.Image: (2, 1, 0), datapoints.Video: None}, - ), - ( - {datapoints.Image: (2, 1, 0), datapoints.Video: (1, 2, 3, 0)}, - {datapoints.Image: (2, 1, 0), datapoints.Video: (3, 0, 1, 2)}, - ), - ], -) -def test_permute_dimensions(dims, inverse_dims): - sample = dict( - image=make_image(), - bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY), - video=make_video(), - str="str", - int=0, + +class TestPermuteDimensions: + @pytest.mark.parametrize( + ("dims", "inverse_dims"), + [ + ( + {datapoints.Image: (2, 1, 0), datapoints.Video: None}, + {datapoints.Image: (2, 1, 0), datapoints.Video: None}, + ), + ( + {datapoints.Image: (2, 1, 0), datapoints.Video: (1, 2, 3, 0)}, + {datapoints.Image: (2, 1, 0), datapoints.Video: (3, 0, 1, 2)}, + ), + ], ) + def test_call(self, dims, inverse_dims): + sample = dict( + image=make_image(), + bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY), + video=make_video(), + str="str", + int=0, + ) - transform = transforms.PermuteDimensions(dims) - transformed_sample = transform(sample) + transform = transforms.PermuteDimensions(dims) + transformed_sample = transform(sample) - for key, value in sample.items(): - value_type = type(value) - transformed_value = transformed_sample[key] + for key, value in sample.items(): + value_type = type(value) + transformed_value = transformed_sample[key] - if check_type( - value, (datapoints.Image, torchvision.prototype.transforms.utils.is_simple_tensor, datapoints.Video) - ): - if transform.dims.get(value_type) is not None: - assert transformed_value.permute(inverse_dims[value_type]).equal(value) - assert type(transformed_value) == torch.Tensor - else: - assert transformed_value is value + if check_type( + value, (datapoints.Image, torchvision.prototype.transforms.utils.is_simple_tensor, datapoints.Video) + ): + if transform.dims.get(value_type) is not None: + assert transformed_value.permute(inverse_dims[value_type]).equal(value) + assert type(transformed_value) == torch.Tensor + else: + assert transformed_value is value + @pytest.mark.filterwarnings("error") + def test_plain_tensor_call(self): + tensor = torch.empty((2, 3, 4)) + transform = transforms.PermuteDimensions(dims=(1, 2, 0)) -@pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video]) -def test_permute_dimensions_plain_tensor_warning(other_type): - with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")): - transforms.PermuteDimensions(dims={torch.Tensor: (0, 1), other_type: (1, 0)}) + assert transform(tensor).shape == (3, 4, 2) + @pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video]) + def test_plain_tensor_warning(self, other_type): + with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")): + transforms.PermuteDimensions(dims={torch.Tensor: (0, 1), other_type: (1, 0)}) -@pytest.mark.parametrize( - "dims", - [ - (-1, -2), - {datapoints.Image: (1, 2), datapoints.Video: None}, - ], -) -def test_transpose_dimensions(dims): - sample = dict( - image=make_image(), - bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY), - video=make_video(), - str="str", - int=0, + +class TestTransposeDimensions: + @pytest.mark.parametrize( + "dims", + [ + (-1, -2), + {datapoints.Image: (1, 2), datapoints.Video: None}, + ], ) + def test_call(self, dims): + sample = dict( + image=make_image(), + bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY), + video=make_video(), + str="str", + int=0, + ) - transform = transforms.TransposeDimensions(dims) - transformed_sample = transform(sample) + transform = transforms.TransposeDimensions(dims) + transformed_sample = transform(sample) - for key, value in sample.items(): - value_type = type(value) - transformed_value = transformed_sample[key] + for key, value in sample.items(): + value_type = type(value) + transformed_value = transformed_sample[key] - transposed_dims = transform.dims.get(value_type) - if check_type( - value, (datapoints.Image, torchvision.prototype.transforms.utils.is_simple_tensor, datapoints.Video) - ): - if transposed_dims is not None: - assert transformed_value.transpose(*transposed_dims).equal(value) - assert type(transformed_value) == torch.Tensor - else: - assert transformed_value is value + transposed_dims = transform.dims.get(value_type) + if check_type( + value, (datapoints.Image, torchvision.prototype.transforms.utils.is_simple_tensor, datapoints.Video) + ): + if transposed_dims is not None: + assert transformed_value.transpose(*transposed_dims).equal(value) + assert type(transformed_value) == torch.Tensor + else: + assert transformed_value is value + + @pytest.mark.filterwarnings("error") + def test_plain_tensor_call(self): + tensor = torch.empty((2, 3, 4)) + transform = transforms.TransposeDimensions(dims=(0, 2)) + assert transform(tensor).shape == (4, 3, 2) -@pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video]) -def test_transpose_dimensions_plain_tensor_warning(other_type): - with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")): - transforms.TransposeDimensions(dims={torch.Tensor: (0, 1), other_type: (1, 0)}) + @pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video]) + def test_plain_tensor_warning(self, other_type): + with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")): + transforms.TransposeDimensions(dims={torch.Tensor: (0, 1), other_type: (1, 0)}) class TestUniformTemporalSubsample: