-
Notifications
You must be signed in to change notification settings - Fork 7.1k
add segmentation reference consistency tests #6591
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
044d120
7a9eb0c
0863741
318b15c
db9df24
8c3bfb3
641d5af
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,7 @@ | ||
import enum | ||
import inspect | ||
import random | ||
from collections import defaultdict | ||
from importlib.machinery import SourceFileLoader | ||
from pathlib import Path | ||
|
||
|
@@ -16,13 +18,15 @@ | |
make_image, | ||
make_images, | ||
make_label, | ||
make_segmentation_mask, | ||
) | ||
from torchvision import transforms as legacy_transforms | ||
from torchvision._utils import sequence_to_str | ||
from torchvision.prototype import features, transforms as prototype_transforms | ||
from torchvision.prototype.transforms import functional as F | ||
from torchvision.prototype.transforms._utils import query_chw | ||
from torchvision.prototype.transforms.functional import to_image_pil | ||
|
||
|
||
DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=[features.ColorSpace.RGB], extra_dims=[(4,)]) | ||
|
||
|
||
|
@@ -852,10 +856,12 @@ def test_aa(self, inpt, interpolation): | |
assert_equal(expected_output, output) | ||
|
||
|
||
# Import reference detection transforms here for consistency checks | ||
# torchvision/references/detection/transforms.py | ||
ref_det_filepath = Path(__file__).parent.parent / "references" / "detection" / "transforms.py" | ||
det_transforms = SourceFileLoader(ref_det_filepath.stem, ref_det_filepath.as_posix()).load_module() | ||
def import_transforms_from_references(reference): | ||
ref_det_filepath = Path(__file__).parent.parent / "references" / reference / "transforms.py" | ||
return SourceFileLoader(ref_det_filepath.stem, ref_det_filepath.as_posix()).load_module() | ||
|
||
|
||
det_transforms = import_transforms_from_references("detection") | ||
|
||
|
||
class TestRefDetTransforms: | ||
|
@@ -873,7 +879,7 @@ def make_datapoints(self, with_mask=True): | |
|
||
yield (pil_image, target) | ||
|
||
tensor_image = torch.randint(0, 256, size=(3, *size), dtype=torch.uint8) | ||
tensor_image = torch.Tensor(make_image(size=size, color_space=features.ColorSpace.RGB)) | ||
target = { | ||
"boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), | ||
"labels": make_label(extra_dims=(num_objects,), categories=80), | ||
|
@@ -883,7 +889,7 @@ def make_datapoints(self, with_mask=True): | |
|
||
yield (tensor_image, target) | ||
|
||
feature_image = features.Image(torch.randint(0, 256, size=(3, *size), dtype=torch.uint8)) | ||
feature_image = make_image(size=size, color_space=features.ColorSpace.RGB) | ||
target = { | ||
"boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), | ||
"labels": make_label(extra_dims=(num_objects,), categories=80), | ||
|
@@ -927,3 +933,165 @@ def test_transform(self, t_ref, t, data_kwargs): | |
expected_output = t_ref(*dp) | ||
|
||
assert_equal(expected_output, output) | ||
|
||
|
||
seg_transforms = import_transforms_from_references("segmentation") | ||
|
||
|
||
# We need this transform for two reasons: | ||
# 1. transforms.RandomCrop uses a different scheme to pad images and masks of insufficient size than its name | ||
# counterpart in the detection references. Thus, we cannot use it with `pad_if_needed=True` | ||
# 2. transforms.Pad only supports a fixed padding, but the segmentation datasets don't have a fixed image size. | ||
class PadIfSmaller(prototype_transforms.Transform): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This implements the option suggested in #6433 (comment). I've also played around with my proposal from #6433 (comment), but it was a lot more complex. This PR demonstrates that this transform in combination with our |
||
def __init__(self, size, fill=0): | ||
super().__init__() | ||
self.size = size | ||
self.fill = prototype_transforms._geometry._setup_fill_arg(fill) | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def _get_params(self, sample): | ||
_, height, width = query_chw(sample) | ||
padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)] | ||
needs_padding = any(padding) | ||
return dict(padding=padding, needs_padding=needs_padding) | ||
|
||
def _transform(self, inpt, params): | ||
if not params["needs_padding"]: | ||
return inpt | ||
|
||
fill = self.fill[type(inpt)] | ||
fill = F._geometry._convert_fill_arg(fill) | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return F.pad(inpt, padding=params["padding"], fill=fill) | ||
|
||
|
||
class TestRefSegTransforms: | ||
def make_datapoints(self, supports_pil=True, image_dtype=torch.uint8): | ||
size = (256, 640) | ||
num_categories = 21 | ||
|
||
conv_fns = [] | ||
if supports_pil: | ||
conv_fns.append(to_image_pil) | ||
conv_fns.extend([torch.Tensor, lambda x: x]) | ||
|
||
for conv_fn in conv_fns: | ||
feature_image = make_image(size=size, color_space=features.ColorSpace.RGB, dtype=image_dtype) | ||
feature_mask = make_segmentation_mask(size=size, num_categories=num_categories, dtype=torch.uint8) | ||
|
||
dp = (conv_fn(feature_image), feature_mask) | ||
dp_ref = ( | ||
to_image_pil(feature_image) if supports_pil else torch.Tensor(feature_image), | ||
to_image_pil(feature_mask), | ||
) | ||
|
||
yield dp, dp_ref | ||
|
||
def set_seed(self, seed=12): | ||
torch.manual_seed(seed) | ||
random.seed(seed) | ||
|
||
def check(self, t, t_ref, data_kwargs=None): | ||
for dp, dp_ref in self.make_datapoints(**data_kwargs or dict()): | ||
|
||
self.set_seed() | ||
output = t(dp) | ||
|
||
self.set_seed() | ||
expected_output = t_ref(*dp_ref) | ||
|
||
assert_equal(output, expected_output) | ||
|
||
@pytest.mark.parametrize( | ||
("t_ref", "t", "data_kwargs"), | ||
[ | ||
( | ||
seg_transforms.RandomHorizontalFlip(flip_prob=1.0), | ||
prototype_transforms.RandomHorizontalFlip(p=1.0), | ||
dict(), | ||
), | ||
( | ||
seg_transforms.RandomHorizontalFlip(flip_prob=0.0), | ||
prototype_transforms.RandomHorizontalFlip(p=0.0), | ||
dict(), | ||
), | ||
( | ||
seg_transforms.RandomCrop(size=480), | ||
prototype_transforms.Compose( | ||
[ | ||
PadIfSmaller(size=480, fill=defaultdict(lambda: 0, {features.Mask: 255})), | ||
prototype_transforms.RandomCrop(size=480), | ||
] | ||
), | ||
dict(), | ||
), | ||
( | ||
seg_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), | ||
prototype_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), | ||
dict(supports_pil=False, image_dtype=torch.float), | ||
), | ||
], | ||
) | ||
def test_common(self, t_ref, t, data_kwargs): | ||
self.check(t, t_ref, data_kwargs) | ||
|
||
def check_resize(self, mocker, t_ref, t): | ||
mock = mocker.patch("torchvision.prototype.transforms._geometry.F.resize") | ||
mock_ref = mocker.patch("torchvision.transforms.functional.resize") | ||
|
||
for dp, dp_ref in self.make_datapoints(): | ||
mock.reset_mock() | ||
mock_ref.reset_mock() | ||
|
||
self.set_seed() | ||
t(dp) | ||
assert mock.call_count == 2 | ||
assert all( | ||
actual is expected | ||
for actual, expected in zip([call_args[0][0] for call_args in mock.call_args_list], dp) | ||
) | ||
|
||
self.set_seed() | ||
t_ref(*dp_ref) | ||
assert mock_ref.call_count == 2 | ||
assert all( | ||
actual is expected | ||
for actual, expected in zip([call_args[0][0] for call_args in mock_ref.call_args_list], dp_ref) | ||
) | ||
|
||
for args_kwargs, args_kwargs_ref in zip(mock.call_args_list, mock_ref.call_args_list): | ||
assert args_kwargs[0][1] == [args_kwargs_ref[0][1]] | ||
|
||
def test_random_resize_train(self, mocker): | ||
base_size = 520 | ||
min_size = base_size // 2 | ||
max_size = base_size * 2 | ||
|
||
randint = torch.randint | ||
|
||
def patched_randint(a, b, *other_args, **kwargs): | ||
if kwargs or len(other_args) > 1 or other_args[0] != (): | ||
return randint(a, b, *other_args, **kwargs) | ||
|
||
return random.randint(a, b) | ||
|
||
# We are patching torch.randint -> random.randint here, because we can't patch the modules that are not imported | ||
# normally | ||
t = prototype_transforms.RandomResize(min_size=min_size, max_size=max_size, antialias=True) | ||
mocker.patch( | ||
"torchvision.prototype.transforms._geometry.torch.randint", | ||
new=patched_randint, | ||
) | ||
|
||
t_ref = seg_transforms.RandomResize(min_size=min_size, max_size=max_size) | ||
|
||
self.check_resize(mocker, t_ref, t) | ||
|
||
def test_random_resize_eval(self, mocker): | ||
torch.manual_seed(0) | ||
base_size = 520 | ||
|
||
t = prototype_transforms.Resize(size=base_size, antialias=True) | ||
|
||
t_ref = seg_transforms.RandomResize(min_size=base_size, max_size=base_size) | ||
|
||
self.check_resize(mocker, t_ref, t) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I missed this before. Let's use the utilities everywhere.