Skip to content

Commit f795349

Browse files
authored
cleanup features / transforms feature branch (#5406)
* mark candidates for removal * align signature of resize_bounding_box with corresponding image kernel * fix documentation of Feature * remove interpolation mode and antialias option from resize_segmentation_mask
1 parent f3b80ef commit f795349

File tree

7 files changed

+31
-27
lines changed

7 files changed

+31
-27
lines changed

test/test_prototype_transforms_kernels.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,11 +170,11 @@ def resize_image():
170170
def resize_bounding_box():
171171
for bounding_box in make_bounding_boxes():
172172
height, width = bounding_box.image_size
173-
for new_image_size in [
173+
for size in [
174174
(height, width),
175175
(int(height * 0.75), int(width * 1.25)),
176176
]:
177-
yield SampleInput(bounding_box, old_image_size=bounding_box.image_size, new_image_size=new_image_size)
177+
yield SampleInput(bounding_box, size=size, image_size=bounding_box.image_size)
178178

179179

180180
class TestKernelsCommon:

torchvision/prototype/features/_bounding_box.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ def __new__(
3939
return bounding_box
4040

4141
def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox:
42+
# TODO: this is useful for developing and debugging but we should remove or at least revisit this before we
43+
# promote this out of the prototype state
44+
4245
# import at runtime to avoid cyclic imports
4346
from torchvision.prototype.transforms.kernels import convert_bounding_box_format
4447

torchvision/prototype/features/_encoded.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ def image_size(self) -> Tuple[int, int]:
3939
return self._image_size
4040

4141
def decode(self) -> Image:
42+
# TODO: this is useful for developing and debugging but we should remove or at least revisit this before we
43+
# promote this out of the prototype state
44+
4245
# import at runtime to avoid cyclic imports
4346
from torchvision.prototype.transforms.kernels import decode_image_with_pil
4447

torchvision/prototype/features/_feature.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,20 @@ class Feature(torch.Tensor):
1212
_metadata: Dict[str, Any]
1313

1414
def __init_subclass__(cls) -> None:
15-
# In order to help static type checkers, we require subclasses of `Feature` to add the metadata attributes
16-
# as static class annotations:
17-
#
18-
# >>> class Foo(Feature):
19-
# ... bar: str
20-
# ... baz: Optional[str]
21-
#
22-
# Internally, this information is used twofold:
23-
#
24-
# 1. A class annotation is contained in `cls.__annotations__` but not in `cls.__dict__`. We use this difference
25-
# to automatically detect the meta data attributes and expose them as `@property`'s for convenient runtime
26-
# access. This happens in this method.
27-
# 2. The information extracted in 1. is also used at creation (`__new__`) to perform an input parsing for
28-
# unknown arguments.
15+
"""
16+
For convenient copying of metadata, we store it inside a dictionary rather than multiple individual attributes.
17+
By adding the metadata attributes as class annotations on subclasses of :class:`Feature`, this method adds
18+
properties to have the same convenient access as regular attributes.
19+
20+
>>> class Foo(Feature):
21+
... bar: str
22+
... baz: Optional[str]
23+
>>> foo = Foo()
24+
>>> foo.bar
25+
>>> foo.baz
26+
27+
This has the additional benefit that autocomplete engines and static type checkers are aware of the metadata.
28+
"""
2929
meta_attrs = {attr for attr in cls.__annotations__.keys() - cls.__dict__.keys() if not attr.startswith("_")}
3030
for super_cls in cls.__mro__[1:]:
3131
if super_cls is Feature:

torchvision/prototype/features/_image.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,11 @@ def guess_color_space(data: torch.Tensor) -> ColorSpace:
7878
return ColorSpace.OTHER
7979

8080
def show(self) -> None:
81+
# TODO: this is useful for developing and debugging but we should remove or at least revisit this before we
82+
# promote this out of the prototype state
8183
to_pil_image(make_grid(self.view(-1, *self.shape[-3:]))).show()
8284

8385
def draw_bounding_box(self, bounding_box: BoundingBox, **kwargs: Any) -> Image:
86+
# TODO: this is useful for developing and debugging but we should remove or at least revisit this before we
87+
# promote this out of the prototype state
8488
return Image.new_like(self, draw_bounding_boxes(self, bounding_box.to_format("xyxy").view(-1, 4), **kwargs))

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def resize(input: T, *args: Any, **kwargs: Any) -> T:
4141
"""ADDME"""
4242
if isinstance(input, features.BoundingBox):
4343
size = kwargs.pop("size")
44-
output = K.resize_bounding_box(input, old_image_size=list(input.image_size), new_image_size=size)
44+
output = K.resize_bounding_box(input, size=size, image_size=input.image_size)
4545
return cast(T, features.BoundingBox.new_like(input, output, image_size=size))
4646

4747
raise RuntimeError

torchvision/prototype/transforms/kernels/_geometry.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,21 +51,15 @@ def resize_image(
5151
def resize_segmentation_mask(
5252
segmentation_mask: torch.Tensor,
5353
size: List[int],
54-
interpolation: InterpolationMode = InterpolationMode.NEAREST,
5554
max_size: Optional[int] = None,
56-
antialias: Optional[bool] = None,
5755
) -> torch.Tensor:
58-
return resize_image(
59-
segmentation_mask, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias
60-
)
56+
return resize_image(segmentation_mask, size=size, interpolation=InterpolationMode.NEAREST, max_size=max_size)
6157

6258

6359
# TODO: handle max_size
64-
def resize_bounding_box(
65-
bounding_box: torch.Tensor, *, old_image_size: List[int], new_image_size: List[int]
66-
) -> torch.Tensor:
67-
old_height, old_width = old_image_size
68-
new_height, new_width = new_image_size
60+
def resize_bounding_box(bounding_box: torch.Tensor, *, size: List[int], image_size: Tuple[int, int]) -> torch.Tensor:
61+
old_height, old_width = image_size
62+
new_height, new_width = size
6963
ratios = torch.tensor((new_width / old_width, new_height / old_height), device=bounding_box.device)
7064
return bounding_box.view(-1, 2, 2).mul(ratios).view(bounding_box.shape)
7165

0 commit comments

Comments
 (0)