Skip to content

[proto] Added some transformations and fixed type hints #6245

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jul 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 1 addition & 12 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,18 +955,7 @@ def test_adjust_gamma(device, dtype, config, channels):

@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize(
"pad",
[
2,
[
3,
],
[0, 3],
(3, 3),
[4, 2, 4, 3],
],
)
@pytest.mark.parametrize("pad", [2, [3], [0, 3], (3, 3), [4, 2, 4, 3]])
@pytest.mark.parametrize(
"config",
[
Expand Down
9 changes: 8 additions & 1 deletion test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
import pytest
import torch
from common_utils import assert_equal
from test_prototype_transforms_functional import make_images, make_bounding_boxes, make_one_hot_labels
from test_prototype_transforms_functional import (
make_images,
make_bounding_boxes,
make_one_hot_labels,
)
from torchvision.prototype import transforms, features
from torchvision.transforms.functional import to_pil_image, pil_to_tensor

Expand Down Expand Up @@ -72,6 +76,9 @@ class TestSmoke:
transforms.ConvertImageDtype(),
transforms.RandomHorizontalFlip(),
transforms.Pad(5),
transforms.RandomZoomOut(),
transforms.RandomRotation(degrees=(-45, 45)),
transforms.RandomAffine(degrees=(-45, 45)),
)
def test_common(self, transform, input):
transform(input)
Expand Down
2 changes: 1 addition & 1 deletion test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def rotate_image_tensor():
[-87, 15, 90], # angle
[True, False], # expand
[None, [12, 23]], # center
[None, [128]], # fill
[None, [128], [12.0]], # fill
):
if center is not None and expand:
# Skip warning: The provided center argument is ignored if expand is True
Expand Down
24 changes: 11 additions & 13 deletions torchvision/prototype/features/_bounding_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,20 @@ def resized_crop(
return BoundingBox.new_like(self, output, image_size=image_size, dtype=output.dtype)

def pad(
self, padding: List[int], fill: Union[int, float, Sequence[float]] = 0, padding_mode: str = "constant"
self,
padding: Union[int, Sequence[int]],
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
padding_mode: str = "constant",
) -> BoundingBox:
from torchvision.prototype.transforms import functional as _F

if padding_mode not in ["constant"]:
raise ValueError(f"Padding mode '{padding_mode}' is not supported with bounding boxes")

# This cast does Sequence[int] -> List[int] and is required to make mypy happy
if not isinstance(padding, int):
padding = list(padding)

output = _F.pad_bounding_box(self, padding, format=self.format)

# Update output image size:
Expand All @@ -153,7 +160,7 @@ def rotate(
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: Optional[List[float]] = None,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
center: Optional[List[float]] = None,
) -> BoundingBox:
from torchvision.prototype.transforms import functional as _F
Expand All @@ -173,7 +180,7 @@ def affine(
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
center: Optional[List[float]] = None,
) -> BoundingBox:
from torchvision.prototype.transforms import functional as _F
Expand All @@ -194,18 +201,9 @@ def perspective(
self,
perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[List[float]] = None,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> BoundingBox:
from torchvision.prototype.transforms import functional as _F

output = _F.perspective_bounding_box(self, self.format, perspective_coeffs)
return BoundingBox.new_like(self, output, dtype=output.dtype)

def erase(self, i: int, j: int, h: int, w: int, v: torch.Tensor) -> BoundingBox:
raise TypeError("Erase transformation does not support bounding boxes")

def mixup(self, lam: float) -> BoundingBox:
raise TypeError("Mixup transformation does not support bounding boxes")

def cutmix(self, box: Tuple[int, int, int, int], lam_adjusted: float) -> BoundingBox:
raise TypeError("Cutmix transformation does not support bounding boxes")
20 changes: 7 additions & 13 deletions torchvision/prototype/features/_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,10 @@ def resized_crop(
return self

def pad(
self, padding: List[int], fill: Union[int, float, Sequence[float]] = 0, padding_mode: str = "constant"
self,
padding: Union[int, Sequence[int]],
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
padding_mode: str = "constant",
) -> Any:
return self

Expand All @@ -129,7 +132,7 @@ def rotate(
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: Optional[List[float]] = None,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
center: Optional[List[float]] = None,
) -> Any:
return self
Expand All @@ -141,7 +144,7 @@ def affine(
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
center: Optional[List[float]] = None,
) -> Any:
return self
Expand All @@ -150,7 +153,7 @@ def perspective(
self,
perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[List[float]] = None,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> Any:
return self

Expand Down Expand Up @@ -186,12 +189,3 @@ def equalize(self) -> Any:

def invert(self) -> Any:
return self

def erase(self, i: int, j: int, h: int, w: int, v: torch.Tensor) -> Any:
return self

def mixup(self, lam: float) -> Any:
return self

def cutmix(self, box: Tuple[int, int, int, int], lam_adjusted: float) -> Any:
return self
52 changes: 23 additions & 29 deletions torchvision/prototype/features/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,20 @@ def resized_crop(
return Image.new_like(self, output)

def pad(
self, padding: List[int], fill: Union[int, float, Sequence[float]] = 0, padding_mode: str = "constant"
self,
padding: Union[int, Sequence[int]],
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
padding_mode: str = "constant",
) -> Image:
from torchvision.prototype.transforms import functional as _F

# This cast does Sequence[int] -> List[int] and is required to make mypy happy
if not isinstance(padding, int):
padding = list(padding)

if fill is None:
fill = 0

# PyTorch's pad supports only scalars on fill. So we need to overwrite the colour
if isinstance(fill, (int, float)):
output = _F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode)
Expand All @@ -183,10 +193,12 @@ def rotate(
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: Optional[List[float]] = None,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
center: Optional[List[float]] = None,
) -> Image:
from torchvision.prototype.transforms import functional as _F
from torchvision.prototype.transforms.functional import _geometry as _F

fill = _F._convert_fill_arg(fill)

output = _F.rotate_image_tensor(
self, angle, interpolation=interpolation, expand=expand, fill=fill, center=center
Expand All @@ -200,10 +212,12 @@ def affine(
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
center: Optional[List[float]] = None,
) -> Image:
from torchvision.prototype.transforms import functional as _F
from torchvision.prototype.transforms.functional import _geometry as _F

fill = _F._convert_fill_arg(fill)

output = _F.affine_image_tensor(
self,
Expand All @@ -221,9 +235,11 @@ def perspective(
self,
perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[List[float]] = None,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> Image:
from torchvision.prototype.transforms import functional as _F
from torchvision.prototype.transforms.functional import _geometry as _F

fill = _F._convert_fill_arg(fill)

output = _F.perspective_image_tensor(self, perspective_coeffs, interpolation=interpolation, fill=fill)
return Image.new_like(self, output)
Expand Down Expand Up @@ -293,25 +309,3 @@ def invert(self) -> Image:

output = _F.invert_image_tensor(self)
return Image.new_like(self, output)

def erase(self, i: int, j: int, h: int, w: int, v: torch.Tensor) -> Image:
from torchvision.prototype.transforms import functional as _F

output = _F.erase_image_tensor(self, i, j, h, w, v)
return Image.new_like(self, output)

def mixup(self, lam: float) -> Image:
if self.ndim < 4:
raise ValueError("Need a batch of images")
output = self.clone()
output = output.roll(1, -4).mul_(1 - lam).add_(output.mul_(lam))
return Image.new_like(self, output)

def cutmix(self, box: Tuple[int, int, int, int], lam_adjusted: float) -> Image:
if self.ndim < 4:
raise ValueError("Need a batch of images")
x1, y1, x2, y2 = box
image_rolled = self.roll(1, -4)
output = self.clone()
output[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2]
return Image.new_like(self, output)
13 changes: 1 addition & 12 deletions torchvision/prototype/features/_label.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Optional, Sequence, cast, Union, Tuple
from typing import Any, Optional, Sequence, cast, Union

import torch
from torchvision.prototype.utils._internal import apply_recursively
Expand Down Expand Up @@ -77,14 +77,3 @@ def new_like(
return super().new_like(
other, data, categories=categories if categories is not None else other.categories, **kwargs
)

def mixup(self, lam: float) -> OneHotLabel:
if self.ndim < 2:
raise ValueError("Need a batch of one hot labels")
output = self.clone()
output = output.roll(1, -2).mul_(1 - lam).add_(output.mul_(lam))
return OneHotLabel.new_like(self, output)

def cutmix(self, box: Tuple[int, int, int, int], lam_adjusted: float) -> OneHotLabel:
box # unused
return self.mixup(lam_adjusted)
27 changes: 12 additions & 15 deletions torchvision/prototype/features/_segmentation_mask.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from __future__ import annotations

from typing import Tuple, List, Optional, Union, Sequence
from typing import List, Optional, Union, Sequence

import torch
from torchvision.transforms import InterpolationMode

from ._feature import _Feature
Expand Down Expand Up @@ -61,10 +60,17 @@ def resized_crop(
return SegmentationMask.new_like(self, output)

def pad(
self, padding: List[int], fill: Union[int, float, Sequence[float]] = 0, padding_mode: str = "constant"
self,
padding: Union[int, Sequence[int]],
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
padding_mode: str = "constant",
) -> SegmentationMask:
from torchvision.prototype.transforms import functional as _F

# This cast does Sequence[int] -> List[int] and is required to make mypy happy
if not isinstance(padding, int):
padding = list(padding)

output = _F.pad_segmentation_mask(self, padding, padding_mode=padding_mode)
return SegmentationMask.new_like(self, output)

Expand All @@ -73,7 +79,7 @@ def rotate(
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: Optional[List[float]] = None,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
center: Optional[List[float]] = None,
) -> SegmentationMask:
from torchvision.prototype.transforms import functional as _F
Expand All @@ -88,7 +94,7 @@ def affine(
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
center: Optional[List[float]] = None,
) -> SegmentationMask:
from torchvision.prototype.transforms import functional as _F
Expand All @@ -107,18 +113,9 @@ def perspective(
self,
perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> SegmentationMask:
from torchvision.prototype.transforms import functional as _F

output = _F.perspective_segmentation_mask(self, perspective_coeffs)
return SegmentationMask.new_like(self, output)

def erase(self, i: int, j: int, h: int, w: int, v: torch.Tensor) -> SegmentationMask:
raise TypeError("Erase transformation does not support segmentation masks")

def mixup(self, lam: float) -> SegmentationMask:
raise TypeError("Mixup transformation does not support segmentation masks")

def cutmix(self, box: Tuple[int, int, int, int], lam_adjusted: float) -> SegmentationMask:
raise TypeError("Cutmix transformation does not support segmentation masks")
2 changes: 2 additions & 0 deletions torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
RandomVerticalFlip,
Pad,
RandomZoomOut,
RandomRotation,
RandomAffine,
)
from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace
from ._misc import Identity, Normalize, ToDtype, Lambda
Expand Down
Loading