Skip to content

introduce heuristic for simple tensor handling of transforms v2 #7170

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Feb 8, 2023
Merged
297 changes: 200 additions & 97 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import itertools
import re

import numpy as np

import PIL.Image

import pytest
import torch

import torchvision.prototype.transforms.utils
from common_utils import assert_equal, cpu_and_gpu
from common_utils import cpu_and_gpu
from prototype_common_utils import (
assert_equal,
DEFAULT_EXTRA_DIMS,
make_bounding_box,
make_bounding_boxes,
Expand All @@ -25,7 +26,7 @@
)
from torchvision.ops.boxes import box_iou
from torchvision.prototype import datapoints, transforms
from torchvision.prototype.transforms.utils import check_type
from torchvision.prototype.transforms.utils import check_type, is_simple_tensor
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image

BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims]
Expand Down Expand Up @@ -222,6 +223,67 @@ def test_random_resized_crop(self, transform, input):
transform(input)


@pytest.mark.parametrize(
"flat_inputs",
itertools.permutations(
[
next(make_vanilla_tensor_images()),
next(make_vanilla_tensor_images()),
next(make_pil_images()),
make_image(),
next(make_videos()),
],
3,
),
)
def test_simple_tensor_heuristic(flat_inputs):
def split_on_simple_tensor(to_split):
# This takes a sequence that is structurally aligned with `flat_inputs` and splits its items into three parts:
# 1. The first simple tensor. If none is present, this will be `None`
# 2. A list of the remaining simple tensors
# 3. A list of all other items
simple_tensors = []
others = []
# Splitting always happens on the original `flat_inputs` to avoid any erroneous type changes by the transform to
# affect the splitting.
for item, inpt in zip(to_split, flat_inputs):
(simple_tensors if is_simple_tensor(inpt) else others).append(item)
return simple_tensors[0] if simple_tensors else None, simple_tensors[1:], others

class CopyCloneTransform(transforms.Transform):
def _transform(self, inpt, params):
return inpt.clone() if isinstance(inpt, torch.Tensor) else inpt.copy()

@staticmethod
def was_applied(output, inpt):
identity = output is inpt
if identity:
return False

# Make sure nothing fishy is going on
assert_equal(output, inpt)
return True

first_simple_tensor_input, other_simple_tensor_inputs, other_inputs = split_on_simple_tensor(flat_inputs)

transform = CopyCloneTransform()
transformed_sample = transform(flat_inputs)

first_simple_tensor_output, other_simple_tensor_outputs, other_outputs = split_on_simple_tensor(transformed_sample)

if first_simple_tensor_input is not None:
if other_inputs:
assert not transform.was_applied(first_simple_tensor_output, first_simple_tensor_input)
else:
assert transform.was_applied(first_simple_tensor_output, first_simple_tensor_input)

for output, inpt in zip(other_simple_tensor_outputs, other_simple_tensor_inputs):
assert not transform.was_applied(output, inpt)

for input, output in zip(other_inputs, other_outputs):
assert transform.was_applied(output, input)


@pytest.mark.parametrize("p", [0.0, 1.0])
class TestRandomHorizontalFlip:
def input_expected_image_tensor(self, p, dtype=torch.float32):
Expand Down Expand Up @@ -1755,117 +1817,158 @@ def test__transform(self, mocker):
)


@pytest.mark.parametrize(
("dtype", "expected_dtypes"),
[
(
torch.float64,
{torch.Tensor: torch.float64, datapoints.Image: torch.float64, datapoints.BoundingBox: torch.float64},
),
(
{torch.Tensor: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64},
{torch.Tensor: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64},
),
],
)
def test_to_dtype(dtype, expected_dtypes):
sample = dict(
plain_tensor=torch.testing.make_tensor(5, dtype=torch.int64, device="cpu"),
image=make_image(dtype=torch.uint8),
bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY, dtype=torch.float32),
str="str",
int=0,
class TestToDtype:
@pytest.mark.parametrize(
("dtype", "expected_dtypes"),
[
(
torch.float64,
{
datapoints.Video: torch.float64,
datapoints.Image: torch.float64,
datapoints.BoundingBox: torch.float64,
},
),
(
{datapoints.Video: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64},
{datapoints.Video: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64},
),
],
)
def test_call(self, dtype, expected_dtypes):
sample = dict(
video=make_video(dtype=torch.int64),
image=make_image(dtype=torch.uint8),
bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY, dtype=torch.float32),
str="str",
int=0,
)

transform = transforms.ToDtype(dtype)
transformed_sample = transform(sample)
transform = transforms.ToDtype(dtype)
transformed_sample = transform(sample)

for key, value in sample.items():
value_type = type(value)
transformed_value = transformed_sample[key]
for key, value in sample.items():
value_type = type(value)
transformed_value = transformed_sample[key]

# make sure the transformation retains the type
assert isinstance(transformed_value, value_type)
# make sure the transformation retains the type
assert isinstance(transformed_value, value_type)

if isinstance(value, torch.Tensor):
assert transformed_value.dtype is expected_dtypes[value_type]
else:
assert transformed_value is value
if isinstance(value, torch.Tensor):
assert transformed_value.dtype is expected_dtypes[value_type]
else:
assert transformed_value is value

@pytest.mark.filterwarnings("error")
def test_plain_tensor_call(self):
tensor = torch.empty((), dtype=torch.float32)
transform = transforms.ToDtype({torch.Tensor: torch.float64})

@pytest.mark.parametrize(
("dims", "inverse_dims"),
[
(
{torch.Tensor: (1, 2, 0), datapoints.Image: (2, 1, 0), datapoints.Video: None},
{torch.Tensor: (2, 0, 1), datapoints.Image: (2, 1, 0), datapoints.Video: None},
),
(
{torch.Tensor: (1, 2, 0), datapoints.Image: (2, 1, 0), datapoints.Video: (1, 2, 3, 0)},
{torch.Tensor: (2, 0, 1), datapoints.Image: (2, 1, 0), datapoints.Video: (3, 0, 1, 2)},
),
],
)
def test_permute_dimensions(dims, inverse_dims):
sample = dict(
plain_tensor=torch.testing.make_tensor((3, 28, 28), dtype=torch.uint8, device="cpu"),
image=make_image(),
bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY),
video=make_video(),
str="str",
int=0,
assert transform(tensor).dtype is torch.float64

@pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video])
def test_plain_tensor_warning(self, other_type):
with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")):
transforms.ToDtype(dtype={torch.Tensor: torch.float32, other_type: torch.float64})


class TestPermuteDimensions:
@pytest.mark.parametrize(
("dims", "inverse_dims"),
[
(
{datapoints.Image: (2, 1, 0), datapoints.Video: None},
{datapoints.Image: (2, 1, 0), datapoints.Video: None},
),
(
{datapoints.Image: (2, 1, 0), datapoints.Video: (1, 2, 3, 0)},
{datapoints.Image: (2, 1, 0), datapoints.Video: (3, 0, 1, 2)},
),
],
)
def test_call(self, dims, inverse_dims):
sample = dict(
image=make_image(),
bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY),
video=make_video(),
str="str",
int=0,
)

transform = transforms.PermuteDimensions(dims)
transformed_sample = transform(sample)
transform = transforms.PermuteDimensions(dims)
transformed_sample = transform(sample)

for key, value in sample.items():
value_type = type(value)
transformed_value = transformed_sample[key]
for key, value in sample.items():
value_type = type(value)
transformed_value = transformed_sample[key]

if check_type(
value, (datapoints.Image, torchvision.prototype.transforms.utils.is_simple_tensor, datapoints.Video)
):
if transform.dims.get(value_type) is not None:
assert transformed_value.permute(inverse_dims[value_type]).equal(value)
assert type(transformed_value) == torch.Tensor
else:
assert transformed_value is value
if check_type(
value, (datapoints.Image, torchvision.prototype.transforms.utils.is_simple_tensor, datapoints.Video)
):
if transform.dims.get(value_type) is not None:
assert transformed_value.permute(inverse_dims[value_type]).equal(value)
assert type(transformed_value) == torch.Tensor
else:
assert transformed_value is value

@pytest.mark.filterwarnings("error")
def test_plain_tensor_call(self):
tensor = torch.empty((2, 3, 4))
transform = transforms.PermuteDimensions(dims=(1, 2, 0))

@pytest.mark.parametrize(
"dims",
[
(-1, -2),
{torch.Tensor: (-1, -2), datapoints.Image: (1, 2), datapoints.Video: None},
],
)
def test_transpose_dimensions(dims):
sample = dict(
plain_tensor=torch.testing.make_tensor((3, 28, 28), dtype=torch.uint8, device="cpu"),
image=make_image(),
bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY),
video=make_video(),
str="str",
int=0,
assert transform(tensor).shape == (3, 4, 2)

@pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video])
def test_plain_tensor_warning(self, other_type):
with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")):
transforms.PermuteDimensions(dims={torch.Tensor: (0, 1), other_type: (1, 0)})


class TestTransposeDimensions:
@pytest.mark.parametrize(
"dims",
[
(-1, -2),
{datapoints.Image: (1, 2), datapoints.Video: None},
],
)
def test_call(self, dims):
sample = dict(
image=make_image(),
bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY),
video=make_video(),
str="str",
int=0,
)

transform = transforms.TransposeDimensions(dims)
transformed_sample = transform(sample)
transform = transforms.TransposeDimensions(dims)
transformed_sample = transform(sample)

for key, value in sample.items():
value_type = type(value)
transformed_value = transformed_sample[key]
for key, value in sample.items():
value_type = type(value)
transformed_value = transformed_sample[key]

transposed_dims = transform.dims.get(value_type)
if check_type(
value, (datapoints.Image, torchvision.prototype.transforms.utils.is_simple_tensor, datapoints.Video)
):
if transposed_dims is not None:
assert transformed_value.transpose(*transposed_dims).equal(value)
assert type(transformed_value) == torch.Tensor
else:
assert transformed_value is value
transposed_dims = transform.dims.get(value_type)
if check_type(
value, (datapoints.Image, torchvision.prototype.transforms.utils.is_simple_tensor, datapoints.Video)
):
if transposed_dims is not None:
assert transformed_value.transpose(*transposed_dims).equal(value)
assert type(transformed_value) == torch.Tensor
else:
assert transformed_value is value

@pytest.mark.filterwarnings("error")
def test_plain_tensor_call(self):
tensor = torch.empty((2, 3, 4))
transform = transforms.TransposeDimensions(dims=(0, 2))

assert transform(tensor).shape == (4, 3, 2)

@pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video])
def test_plain_tensor_warning(self, other_type):
with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")):
transforms.TransposeDimensions(dims={torch.Tensor: (0, 1), other_type: (1, 0)})


class TestUniformTemporalSubsample:
Expand Down
19 changes: 19 additions & 0 deletions torchvision/prototype/transforms/_misc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union

import PIL.Image
Expand Down Expand Up @@ -155,6 +156,12 @@ def __init__(self, dtype: Union[torch.dtype, Dict[Type, Optional[torch.dtype]]])
super().__init__()
if not isinstance(dtype, dict):
dtype = _get_defaultdict(dtype)
if torch.Tensor in dtype and any(cls in dtype for cls in [datapoints.Image, datapoints.Video]):
warnings.warn(
"Got `dtype` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. "
"Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) "
"in case a `datapoints.Image` or `datapoints.Video` is present in the input."
)
self.dtype = dtype

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
Expand All @@ -171,6 +178,12 @@ def __init__(self, dims: Union[Sequence[int], Dict[Type, Optional[Sequence[int]]
super().__init__()
if not isinstance(dims, dict):
dims = _get_defaultdict(dims)
if torch.Tensor in dims and any(cls in dims for cls in [datapoints.Image, datapoints.Video]):
warnings.warn(
"Got `dims` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. "
"Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) "
"in case a `datapoints.Image` or `datapoints.Video` is present in the input."
)
self.dims = dims

def _transform(
Expand All @@ -189,6 +202,12 @@ def __init__(self, dims: Union[Tuple[int, int], Dict[Type, Optional[Tuple[int, i
super().__init__()
if not isinstance(dims, dict):
dims = _get_defaultdict(dims)
if torch.Tensor in dims and any(cls in dims for cls in [datapoints.Image, datapoints.Video]):
warnings.warn(
"Got `dims` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. "
"Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) "
"in case a `datapoints.Image` or `datapoints.Video` is present in the input."
)
self.dims = dims

def _transform(
Expand Down
Loading