Skip to content

[prototype] Fix BC-breakages on input params of F #6636

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Sep 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
e41e405
Fix `size` in resize.
datumbox Sep 23, 2022
45fbf87
Merge branch 'main' into prototype/fix-function-signatures
datumbox Sep 23, 2022
dd528d9
Update torchvision/prototype/features/_bounding_box.py
datumbox Sep 23, 2022
4b49089
Address some of the comments.
datumbox Sep 23, 2022
0128a3a
Fix `output_size` in center_crop.
datumbox Sep 23, 2022
e3f93fb
Fix `CenterCrop` transform
datumbox Sep 23, 2022
5d45374
Fix `size` in five_crop.
datumbox Sep 23, 2022
9fb9db7
Fix `size` in ten_crop.
datumbox Sep 23, 2022
37b415e
Fix `kernel_size` and `sigma` in gaussian_blur.
datumbox Sep 23, 2022
0f16b84
Fix `angle` and `shear` in affine.
datumbox Sep 23, 2022
df5a832
Fixing JIT-scriptability issues.
datumbox Sep 23, 2022
5dec03d
Update TODOs.
datumbox Sep 23, 2022
e848b40
Merge branch 'main' into prototype/fix-function-signatures
datumbox Sep 26, 2022
6a0c24b
Merge branch 'main' into prototype/fix-function-signatures
datumbox Sep 28, 2022
59d5166
Restore fake types for `Union[int, List[int]]` and `Union[int, float,…
datumbox Sep 28, 2022
58867f0
Merge branch 'main' into prototype/fix-function-signatures
datumbox Sep 28, 2022
ea0bc55
Fixing tests
datumbox Sep 28, 2022
d62f1eb
Fix linter
datumbox Sep 28, 2022
28119be
revert unnecessary JIT mitigations.
datumbox Sep 28, 2022
f0d9d85
Cherrypick Philip's 6dfc9657ce89fe9e018a11ee25a8e26c7d3d43c6
datumbox Sep 28, 2022
350d47c
Linter fix
datumbox Sep 28, 2022
1d2ae82
Adding center float casting
datumbox Sep 28, 2022
352b72d
Merge branch 'main' into prototype/fix-function-signatures
pmeier Sep 28, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,7 +799,9 @@ def test_assertions(self):
with pytest.raises(ValueError, match="Kernel size value should be an odd and positive number"):
transforms.GaussianBlur(4)

with pytest.raises(TypeError, match="sigma should be a single float or a list/tuple with length 2"):
with pytest.raises(
TypeError, match="sigma should be a single int or float or a list/tuple with length 2 floats."
):
transforms.GaussianBlur(3, sigma=[1, 2, 3])

with pytest.raises(ValueError, match="If sigma is a single number, it must be positive"):
Expand Down Expand Up @@ -833,7 +835,7 @@ def test__transform(self, kernel_size, sigma, mocker):
if isinstance(sigma, (tuple, list)):
assert transform.sigma == sigma
else:
assert transform.sigma == (sigma, sigma)
assert transform.sigma == [sigma, sigma]

fn = mocker.patch("torchvision.prototype.transforms.functional.gaussian_blur")
inpt = mocker.MagicMock(spec=features.Image)
Expand Down
6 changes: 5 additions & 1 deletion torchvision/prototype/features/_bounding_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def resize( # type: ignore[override]
antialias: bool = False,
) -> BoundingBox:
output = self._F.resize_bounding_box(self, size, image_size=self.image_size, max_size=max_size)
if isinstance(size, int):
size = [size]
image_size = (size[0], size[0]) if len(size) == 1 else (size[0], size[1])
return BoundingBox.new_like(self, output, image_size=image_size, dtype=output.dtype)

Expand All @@ -95,6 +97,8 @@ def center_crop(self, output_size: List[int]) -> BoundingBox:
output = self._F.center_crop_bounding_box(
self, format=self.format, output_size=output_size, image_size=self.image_size
)
if isinstance(output_size, int):
output_size = [output_size]
image_size = (output_size[0], output_size[0]) if len(output_size) == 1 else (output_size[0], output_size[1])
return BoundingBox.new_like(self, output, image_size=image_size)

Expand Down Expand Up @@ -160,7 +164,7 @@ def rotate(

def affine(
self,
angle: float,
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/features/_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def rotate(

def affine(
self,
angle: float,
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/features/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def rotate(

def affine(
self,
angle: float,
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/features/_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def rotate(

def affine(
self,
angle: float,
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
Expand Down
24 changes: 4 additions & 20 deletions torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
_check_sequence_input,
_setup_angle,
_setup_fill_arg,
_setup_float_or_seq,
_setup_size,
has_all,
has_any,
Expand Down Expand Up @@ -67,9 +68,9 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:


class CenterCrop(Transform):
def __init__(self, size: List[int]):
def __init__(self, size: Union[int, Sequence[int]]):
super().__init__()
self.size = size
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.center_crop(inpt, output_size=self.size)
Expand Down Expand Up @@ -320,7 +321,7 @@ def __init__(
degrees: Union[numbers.Number, Sequence],
translate: Optional[Sequence[float]] = None,
scale: Optional[Sequence[float]] = None,
shear: Optional[Union[float, Sequence[float]]] = None,
shear: Optional[Union[int, float, Sequence[float]]] = None,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Union[features.FillType, Dict[Type, features.FillType]] = 0,
center: Optional[List[float]] = None,
Expand Down Expand Up @@ -545,23 +546,6 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
)


def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size: int = 2) -> Sequence[float]:
if not isinstance(arg, (float, Sequence)):
raise TypeError(f"{name} should be float or a sequence of floats. Got {type(arg)}")
if isinstance(arg, Sequence) and len(arg) != req_size:
raise ValueError(f"If {name} is a sequence its length should be one of {req_size}. Got {len(arg)}")
if isinstance(arg, Sequence):
for element in arg:
if not isinstance(element, float):
raise ValueError(f"{name} should be a sequence of floats. Got {type(element)}")

if isinstance(arg, float):
arg = [float(arg), float(arg)]
if isinstance(arg, (list, tuple)) and len(arg) == 1:
arg = [arg[0], arg[0]]
return arg


class ElasticTransform(Transform):
def __init__(
self,
Expand Down
12 changes: 6 additions & 6 deletions torchvision/prototype/transforms/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform

from ._utils import _setup_size, has_any, query_bounding_box
from ._utils import _setup_float_or_seq, _setup_size, has_any, query_bounding_box


class Identity(Transform):
Expand Down Expand Up @@ -112,25 +112,25 @@ def forward(self, *inpts: Any) -> Any:

class GaussianBlur(Transform):
def __init__(
self, kernel_size: Union[int, Sequence[int]], sigma: Union[float, Sequence[float]] = (0.1, 2.0)
self, kernel_size: Union[int, Sequence[int]], sigma: Union[int, float, Sequence[float]] = (0.1, 2.0)
) -> None:
super().__init__()
self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers")
for ks in self.kernel_size:
if ks <= 0 or ks % 2 == 0:
raise ValueError("Kernel size value should be an odd and positive number.")

if isinstance(sigma, float):
if isinstance(sigma, (int, float)):
if sigma <= 0:
raise ValueError("If sigma is a single number, it must be positive.")
sigma = (sigma, sigma)
sigma = float(sigma)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the _setup_float_or_seq() will do the job for us, we only need to handle the integer case.

elif isinstance(sigma, Sequence) and len(sigma) == 2:
if not 0.0 < sigma[0] <= sigma[1]:
raise ValueError("sigma values should be positive and of the form (min, max).")
else:
raise TypeError("sigma should be a single float or a list/tuple with length 2 floats.")
raise TypeError("sigma should be a single int or float or a list/tuple with length 2 floats.")

self.sigma = sigma
self.sigma = _setup_float_or_seq(sigma, "sigma", 2)

def _get_params(self, sample: Any) -> Dict[str, Any]:
sigma = torch.empty(1).uniform_(self.sigma[0], self.sigma[1]).item()
Expand Down
17 changes: 17 additions & 0 deletions torchvision/prototype/transforms/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,23 @@
from typing_extensions import Literal


def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size: int = 2) -> Sequence[float]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved as-is.

if not isinstance(arg, (float, Sequence)):
raise TypeError(f"{name} should be float or a sequence of floats. Got {type(arg)}")
if isinstance(arg, Sequence) and len(arg) != req_size:
raise ValueError(f"If {name} is a sequence its length should be one of {req_size}. Got {len(arg)}")
if isinstance(arg, Sequence):
for element in arg:
if not isinstance(element, float):
raise ValueError(f"{name} should be a sequence of floats. Got {type(element)}")

if isinstance(arg, float):
arg = [float(arg), float(arg)]
if isinstance(arg, (list, tuple)) and len(arg) == 1:
arg = [arg[0], arg[0]]
return arg


def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None:
if isinstance(fill, dict):
for key, value in fill.items():
Expand Down
67 changes: 38 additions & 29 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ def resize_image_tensor(
max_size: Optional[int] = None,
antialias: bool = False,
) -> torch.Tensor:
if isinstance(size, int):
size = [size]
num_channels, old_height, old_width = get_dimensions_image_tensor(image)
new_height, new_width = _compute_resized_output_size((old_height, old_width), size=size, max_size=max_size)
extra_dims = image.shape[:-3]
Expand Down Expand Up @@ -145,6 +147,8 @@ def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = N
def resize_bounding_box(
bounding_box: torch.Tensor, size: List[int], image_size: Tuple[int, int], max_size: Optional[int] = None
) -> torch.Tensor:
if isinstance(size, int):
size = [size]
old_height, old_width = image_size
new_height, new_width = _compute_resized_output_size(image_size, size=size, max_size=max_size)
ratios = torch.tensor((new_width / old_width, new_height / old_height), device=bounding_box.device)
Expand All @@ -171,7 +175,7 @@ def resize(


def _affine_parse_args(
angle: float,
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
Expand Down Expand Up @@ -214,15 +218,18 @@ def _affine_parse_args(
if len(shear) != 2:
raise ValueError(f"Shear should be a sequence containing two values. Got {shear}")

if center is not None and not isinstance(center, (list, tuple)):
raise TypeError("Argument center should be a sequence")
if center is not None:
if not isinstance(center, (list, tuple)):
raise TypeError("Argument center should be a sequence")
else:
center = [float(c) for c in center]

return angle, translate, shear, center


def affine_image_tensor(
img: torch.Tensor,
angle: float,
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
Expand Down Expand Up @@ -254,7 +261,7 @@ def affine_image_tensor(
@torch.jit.unused
def affine_image_pil(
img: PIL.Image.Image,
angle: float,
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
Expand All @@ -278,34 +285,26 @@ def affine_image_pil(
def _affine_bounding_box_xyxy(
bounding_box: torch.Tensor,
image_size: Tuple[int, int],
angle: float,
translate: Optional[List[float]] = None,
scale: Optional[float] = None,
shear: Optional[List[float]] = None,
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
center: Optional[List[float]] = None,
expand: bool = False,
) -> torch.Tensor:
dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32
device = bounding_box.device

if translate is None:
translate = [0.0, 0.0]

if scale is None:
scale = 1.0

if shear is None:
shear = [0.0, 0.0]
angle, translate, shear, center = _affine_parse_args(
angle, translate, scale, shear, InterpolationMode.NEAREST, center
)

if center is None:
height, width = image_size
center_f = [width * 0.5, height * 0.5]
else:
center_f = [float(c) for c in center]
center = [width * 0.5, height * 0.5]

dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32
device = bounding_box.device

translate_f = [float(t) for t in translate]
affine_matrix = torch.tensor(
_get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear, inverted=False),
_get_inverse_affine_matrix(center, angle, translate, scale, shear, inverted=False),
dtype=dtype,
device=device,
).view(2, 3)
Expand Down Expand Up @@ -351,7 +350,7 @@ def affine_bounding_box(
bounding_box: torch.Tensor,
format: features.BoundingBoxFormat,
image_size: Tuple[int, int],
angle: float,
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
Expand All @@ -373,7 +372,7 @@ def affine_bounding_box(

def affine_mask(
mask: torch.Tensor,
angle: float,
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
Expand Down Expand Up @@ -419,14 +418,15 @@ def _convert_fill_arg(fill: features.FillType) -> features.FillTypeJIT:

def affine(
inpt: features.InputTypeJIT,
angle: float,
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: features.FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> features.InputTypeJIT:
# TODO: consider deprecating integers from angle and shear on the future
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return affine_image_tensor(
inpt,
Expand Down Expand Up @@ -528,7 +528,16 @@ def rotate_bounding_box(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4)

out_bboxes = _affine_bounding_box_xyxy(bounding_box, image_size, angle=-angle, center=center, expand=expand)
out_bboxes = _affine_bounding_box_xyxy(
bounding_box,
image_size,
angle=-angle,
translate=[0.0, 0.0],
scale=1.0,
shear=[0.0, 0.0],
center=center,
expand=expand,
)

return convert_format_bounding_box(
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/transforms/functional/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def normalize(
def gaussian_blur_image_tensor(
img: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> torch.Tensor:
# TODO: consider deprecating integers from sigma on the future
if isinstance(kernel_size, int):
kernel_size = [kernel_size, kernel_size]
if len(kernel_size) != 2:
Expand Down