Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion test/prototype_common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def __init__(

def _process_inputs(self, actual, expected, *, id, allow_subclasses):
actual, expected = [
to_image_tensor(input) if not isinstance(input, torch.Tensor) else input for input in [actual, expected]
to_image_tensor(input) if not isinstance(input, torch.Tensor) else features.Image(input)
for input in [actual, expected]
]
# This broadcast is needed, because `features.Mask`'s can have a 2D shape, but converting the equivalent PIL
# image to a tensor adds a singleton leading dimension.
Expand Down
182 changes: 175 additions & 7 deletions test/test_prototype_transforms_consistency.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import enum
import inspect
import random
from collections import defaultdict
from importlib.machinery import SourceFileLoader
from pathlib import Path

Expand All @@ -16,13 +18,15 @@
make_image,
make_images,
make_label,
make_segmentation_mask,
)
from torchvision import transforms as legacy_transforms
from torchvision._utils import sequence_to_str
from torchvision.prototype import features, transforms as prototype_transforms
from torchvision.prototype.transforms import functional as F
from torchvision.prototype.transforms._utils import query_chw
from torchvision.prototype.transforms.functional import to_image_pil


DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=[features.ColorSpace.RGB], extra_dims=[(4,)])


Expand Down Expand Up @@ -852,10 +856,12 @@ def test_aa(self, inpt, interpolation):
assert_equal(expected_output, output)


# Import reference detection transforms here for consistency checks
# torchvision/references/detection/transforms.py
ref_det_filepath = Path(__file__).parent.parent / "references" / "detection" / "transforms.py"
det_transforms = SourceFileLoader(ref_det_filepath.stem, ref_det_filepath.as_posix()).load_module()
def import_transforms_from_references(reference):
ref_det_filepath = Path(__file__).parent.parent / "references" / reference / "transforms.py"
return SourceFileLoader(ref_det_filepath.stem, ref_det_filepath.as_posix()).load_module()


det_transforms = import_transforms_from_references("detection")


class TestRefDetTransforms:
Expand All @@ -873,7 +879,7 @@ def make_datapoints(self, with_mask=True):

yield (pil_image, target)

tensor_image = torch.randint(0, 256, size=(3, *size), dtype=torch.uint8)
tensor_image = torch.Tensor(make_image(size=size, color_space=features.ColorSpace.RGB))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed this before. Let's use the utilities everywhere.

target = {
"boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80),
Expand All @@ -883,7 +889,7 @@ def make_datapoints(self, with_mask=True):

yield (tensor_image, target)

feature_image = features.Image(torch.randint(0, 256, size=(3, *size), dtype=torch.uint8))
feature_image = make_image(size=size, color_space=features.ColorSpace.RGB)
target = {
"boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80),
Expand Down Expand Up @@ -927,3 +933,165 @@ def test_transform(self, t_ref, t, data_kwargs):
expected_output = t_ref(*dp)

assert_equal(expected_output, output)


seg_transforms = import_transforms_from_references("segmentation")


# We need this transform for two reasons:
# 1. transforms.RandomCrop uses a different scheme to pad images and masks of insufficient size than its name
# counterpart in the detection references. Thus, we cannot use it with `pad_if_needed=True`
# 2. transforms.Pad only supports a fixed padding, but the segmentation datasets don't have a fixed image size.
class PadIfSmaller(prototype_transforms.Transform):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implements the option suggested in #6433 (comment).

I've also played around with my proposal from #6433 (comment), but it was a lot more complex.

This PR demonstrates that this transform in combination with our transforms.RandomCrop can be used to mimic the behavior of the RandomCrop from the segmentation references. Thus, if we go through with this PR, I will port this transform to #6433.

def __init__(self, size, fill=0):
super().__init__()
self.size = size
self.fill = prototype_transforms._geometry._setup_fill_arg(fill)

def _get_params(self, sample):
_, height, width = query_chw(sample)
padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)]
needs_padding = any(padding)
return dict(padding=padding, needs_padding=needs_padding)

def _transform(self, inpt, params):
if not params["needs_padding"]:
return inpt

fill = self.fill[type(inpt)]
fill = F._geometry._convert_fill_arg(fill)

return F.pad(inpt, padding=params["padding"], fill=fill)


class TestRefSegTransforms:
def make_datapoints(self, supports_pil=True, image_dtype=torch.uint8):
size = (256, 640)
num_categories = 21

conv_fns = []
if supports_pil:
conv_fns.append(to_image_pil)
conv_fns.extend([torch.Tensor, lambda x: x])

for conv_fn in conv_fns:
feature_image = make_image(size=size, color_space=features.ColorSpace.RGB, dtype=image_dtype)
feature_mask = make_segmentation_mask(size=size, num_categories=num_categories, dtype=torch.uint8)

dp = (conv_fn(feature_image), feature_mask)
dp_ref = (
to_image_pil(feature_image) if supports_pil else torch.Tensor(feature_image),
to_image_pil(feature_mask),
)

yield dp, dp_ref

def set_seed(self, seed=12):
torch.manual_seed(seed)
random.seed(seed)

def check(self, t, t_ref, data_kwargs=None):
for dp, dp_ref in self.make_datapoints(**data_kwargs or dict()):

self.set_seed()
output = t(dp)

self.set_seed()
expected_output = t_ref(*dp_ref)

assert_equal(output, expected_output)

@pytest.mark.parametrize(
("t_ref", "t", "data_kwargs"),
[
(
seg_transforms.RandomHorizontalFlip(flip_prob=1.0),
prototype_transforms.RandomHorizontalFlip(p=1.0),
dict(),
),
(
seg_transforms.RandomHorizontalFlip(flip_prob=0.0),
prototype_transforms.RandomHorizontalFlip(p=0.0),
dict(),
),
(
seg_transforms.RandomCrop(size=480),
prototype_transforms.Compose(
[
PadIfSmaller(size=480, fill=defaultdict(lambda: 0, {features.Mask: 255})),
prototype_transforms.RandomCrop(size=480),
]
),
dict(),
),
(
seg_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
prototype_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
dict(supports_pil=False, image_dtype=torch.float),
),
],
)
def test_common(self, t_ref, t, data_kwargs):
self.check(t, t_ref, data_kwargs)

def check_resize(self, mocker, t_ref, t):
mock = mocker.patch("torchvision.prototype.transforms._geometry.F.resize")
mock_ref = mocker.patch("torchvision.transforms.functional.resize")

for dp, dp_ref in self.make_datapoints():
mock.reset_mock()
mock_ref.reset_mock()

self.set_seed()
t(dp)
assert mock.call_count == 2
assert all(
actual is expected
for actual, expected in zip([call_args[0][0] for call_args in mock.call_args_list], dp)
)

self.set_seed()
t_ref(*dp_ref)
assert mock_ref.call_count == 2
assert all(
actual is expected
for actual, expected in zip([call_args[0][0] for call_args in mock_ref.call_args_list], dp_ref)
)

for args_kwargs, args_kwargs_ref in zip(mock.call_args_list, mock_ref.call_args_list):
assert args_kwargs[0][1] == [args_kwargs_ref[0][1]]

def test_random_resize_train(self, mocker):
base_size = 520
min_size = base_size // 2
max_size = base_size * 2

randint = torch.randint

def patched_randint(a, b, *other_args, **kwargs):
if kwargs or len(other_args) > 1 or other_args[0] != ():
return randint(a, b, *other_args, **kwargs)

return random.randint(a, b)

# We are patching torch.randint -> random.randint here, because we can't patch the modules that are not imported
# normally
t = prototype_transforms.RandomResize(min_size=min_size, max_size=max_size, antialias=True)
mocker.patch(
"torchvision.prototype.transforms._geometry.torch.randint",
new=patched_randint,
)

t_ref = seg_transforms.RandomResize(min_size=min_size, max_size=max_size)

self.check_resize(mocker, t_ref, t)

def test_random_resize_eval(self, mocker):
torch.manual_seed(0)
base_size = 520

t = prototype_transforms.Resize(size=base_size, antialias=True)

t_ref = seg_transforms.RandomResize(min_size=base_size, max_size=base_size)

self.check_resize(mocker, t_ref, t)