Skip to content

add support for BoundinBox and SegmentationMask to RandomResizeCrop #6041

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
import pytest
import torch
from common_utils import assert_equal
from test_prototype_transforms_functional import make_images, make_bounding_boxes, make_one_hot_labels
from test_prototype_transforms_functional import (
make_images,
make_bounding_boxes,
make_one_hot_labels,
make_segmentation_masks,
)
from torchvision.prototype import transforms, features
from torchvision.transforms.functional import to_pil_image, pil_to_tensor

Expand Down Expand Up @@ -153,6 +158,8 @@ def test_normalize(self, transform, input):
transforms.RandomResizedCrop([16, 16]),
itertools.chain(
make_images(extra_dims=[(4,)]),
make_bounding_boxes(),
make_segmentation_masks(),
make_vanilla_tensor_images(),
make_pil_images(),
),
Expand Down
12 changes: 6 additions & 6 deletions torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,19 +183,19 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
input, **params, size=list(self.size), interpolation=self.interpolation
)
return features.Image.new_like(input, output)
elif isinstance(input, features.BoundingBox):
output = F.resized_crop_bounding_box(input, **params, size=list(self.size))
return features.BoundingBox.new_like(input, output, image_size=cast(Tuple[int, int], tuple(self.size)))
elif isinstance(input, features.SegmentationMask):
output = F.resized_crop_segmentation_mask(input, **params, size=list(self.size))
return features.SegmentationMask.new_like(input, output)
elif is_simple_tensor(input):
return F.resized_crop_image_tensor(input, **params, size=list(self.size), interpolation=self.interpolation)
elif isinstance(input, PIL.Image.Image):
return F.resized_crop_image_pil(input, **params, size=list(self.size), interpolation=self.interpolation)
else:
return input

def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if has_any(sample, features.BoundingBox, features.SegmentationMask):
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
return super().forward(sample)


class MultiCropResult(list):
"""Helper class for :class:`~torchvision.prototype.transforms.BatchMultiCrop`.
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def fn(
try:
return next(query_recursively(fn, sample))[1]
except StopIteration:
raise TypeError("No image was found in the sample")
raise TypeError("No image was found in the sample") from None
Copy link
Collaborator

@vfdev-5 vfdev-5 May 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you add "from None" here ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, forgot to add a comment. The StopIteration is an internal detail that does not need to be propagated outside. By adding from None the error will look like a normal exception

try:
    next(iter([]))
except StopIteration:
    raise TypeError("Argh!") from None
TypeError: Argh!

Not doing that gives

try:
    next(iter([]))
except StopIteration:
    raise TypeError("Argh!")
StopIteration

During handling of the above exception, another exception occurred:

[...]
TypeError: Argh!



def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int, int]:
Expand Down