Skip to content

cleanup features / transforms feature branch #5406

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions test/test_prototype_transforms_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,11 @@ def resize_image():
def resize_bounding_box():
for bounding_box in make_bounding_boxes():
height, width = bounding_box.image_size
for new_image_size in [
for size in [
(height, width),
(int(height * 0.75), int(width * 1.25)),
]:
yield SampleInput(bounding_box, old_image_size=bounding_box.image_size, new_image_size=new_image_size)
yield SampleInput(bounding_box, size=size, image_size=bounding_box.image_size)


class TestKernelsCommon:
Expand Down
3 changes: 3 additions & 0 deletions torchvision/prototype/features/_bounding_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def __new__(
return bounding_box

def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox:
# TODO: this is useful for developing and debugging but we should remove or at least revisit this before we
# promote this out of the prototype state

# import at runtime to avoid cyclic imports
from torchvision.prototype.transforms.kernels import convert_bounding_box_format

Expand Down
3 changes: 3 additions & 0 deletions torchvision/prototype/features/_encoded.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def image_size(self) -> Tuple[int, int]:
return self._image_size

def decode(self) -> Image:
# TODO: this is useful for developing and debugging but we should remove or at least revisit this before we
# promote this out of the prototype state

# import at runtime to avoid cyclic imports
from torchvision.prototype.transforms.kernels import decode_image_with_pil

Expand Down
28 changes: 14 additions & 14 deletions torchvision/prototype/features/_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,20 @@ class Feature(torch.Tensor):
_metadata: Dict[str, Any]

def __init_subclass__(cls) -> None:
# In order to help static type checkers, we require subclasses of `Feature` to add the metadata attributes
# as static class annotations:
#
# >>> class Foo(Feature):
# ... bar: str
# ... baz: Optional[str]
#
# Internally, this information is used twofold:
#
# 1. A class annotation is contained in `cls.__annotations__` but not in `cls.__dict__`. We use this difference
# to automatically detect the meta data attributes and expose them as `@property`'s for convenient runtime
# access. This happens in this method.
# 2. The information extracted in 1. is also used at creation (`__new__`) to perform an input parsing for
# unknown arguments.
"""
For convenient copying of metadata, we store it inside a dictionary rather than multiple individual attributes.
By adding the metadata attributes as class annotations on subclasses of :class:`Feature`, this method adds
properties to have the same convenient access as regular attributes.

>>> class Foo(Feature):
... bar: str
... baz: Optional[str]
>>> foo = Foo()
>>> foo.bar
>>> foo.baz

This has the additional benefit that autocomplete engines and static type checkers are aware of the metadata.
"""
meta_attrs = {attr for attr in cls.__annotations__.keys() - cls.__dict__.keys() if not attr.startswith("_")}
for super_cls in cls.__mro__[1:]:
if super_cls is Feature:
Expand Down
4 changes: 4 additions & 0 deletions torchvision/prototype/features/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,11 @@ def guess_color_space(data: torch.Tensor) -> ColorSpace:
return ColorSpace.OTHER

def show(self) -> None:
# TODO: this is useful for developing and debugging but we should remove or at least revisit this before we
# promote this out of the prototype state
to_pil_image(make_grid(self.view(-1, *self.shape[-3:]))).show()

def draw_bounding_box(self, bounding_box: BoundingBox, **kwargs: Any) -> Image:
# TODO: this is useful for developing and debugging but we should remove or at least revisit this before we
# promote this out of the prototype state
return Image.new_like(self, draw_bounding_boxes(self, bounding_box.to_format("xyxy").view(-1, 4), **kwargs))
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def resize(input: T, *args: Any, **kwargs: Any) -> T:
"""ADDME"""
if isinstance(input, features.BoundingBox):
size = kwargs.pop("size")
output = K.resize_bounding_box(input, old_image_size=list(input.image_size), new_image_size=size)
output = K.resize_bounding_box(input, size=size, image_size=input.image_size)
return cast(T, features.BoundingBox.new_like(input, output, image_size=size))

raise RuntimeError
Expand Down
14 changes: 4 additions & 10 deletions torchvision/prototype/transforms/kernels/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,21 +51,15 @@ def resize_image(
def resize_segmentation_mask(
segmentation_mask: torch.Tensor,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
max_size: Optional[int] = None,
antialias: Optional[bool] = None,
) -> torch.Tensor:
return resize_image(
segmentation_mask, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias
)
return resize_image(segmentation_mask, size=size, interpolation=InterpolationMode.NEAREST, max_size=max_size)


# TODO: handle max_size
def resize_bounding_box(
bounding_box: torch.Tensor, *, old_image_size: List[int], new_image_size: List[int]
) -> torch.Tensor:
old_height, old_width = old_image_size
new_height, new_width = new_image_size
def resize_bounding_box(bounding_box: torch.Tensor, *, size: List[int], image_size: Tuple[int, int]) -> torch.Tensor:
old_height, old_width = image_size
new_height, new_width = size
ratios = torch.tensor((new_width / old_width, new_height / old_height), device=bounding_box.device)
return bounding_box.view(-1, 2, 2).mul(ratios).view(bounding_box.shape)

Expand Down