From e24d71adb143f4313b6552bee487a5e2a52dd1bc Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 16 Sep 2022 11:58:25 +0200 Subject: [PATCH 1/9] Updated fill arg typehint for affine, perspective and elastic ops --- test/test_prototype_transforms_functional.py | 34 +++++++++++++++++-- .../transforms/functional/_geometry.py | 13 ++++--- torchvision/transforms/functional_tensor.py | 12 +++++-- 3 files changed, 47 insertions(+), 12 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index d76c90340a2..bb83348995b 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -82,6 +82,28 @@ def resize_mask(): yield ArgsKwargs(mask, size=size, max_size=max_size) +@register_kernel_info_from_sample_inputs_fn +def affine_image_tensor(): + for image, angle, translate, scale, shear in itertools.product( + make_images(), + [-87, 15, 90], # angle + [5, -5], # translate + [0.77, 1.27], # scale + [0, 12], # shear + ): + yield ArgsKwargs(image, angle=angle, translate=(translate, translate), scale=scale, shear=(shear, shear)) + + for fill in [None, 128.0, 128, [12.0], [1.0, 2.0, 3.0]]: + yield ArgsKwargs( + image, angle=angle, translate=(translate, translate), scale=scale, shear=(shear, shear), fill=fill + ) + + for center in [None, [12, 23]]: + yield ArgsKwargs( + image, angle=angle, translate=(translate, translate), scale=scale, shear=(shear, shear), center=center + ) + + @register_kernel_info_from_sample_inputs_fn def affine_mask(): for mask, angle, translate, scale, shear in itertools.product( @@ -262,8 +284,12 @@ def perspective_image_tensor(): [1.2405, 0.1772, -6.9113, 0.0463, 1.251, -5.235, 0.00013, 0.0018], [0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063], ], - [None, [128], [12.0]], # fill + [None, 128.0, 128, [12.0], [1.0, 2.0, 3.0]], # fill ): + if isinstance(fill, list) and len(fill) == 3 and image.shape[1] != 3: + # skip the test with non-broadcastable fill value + continue + yield ArgsKwargs(image, perspective_coeffs=perspective_coeffs, fill=fill) @@ -302,8 +328,12 @@ def perspective_mask(): def elastic_image_tensor(): for image, fill in itertools.product( make_images(extra_dims=((), (4,))), - [None, [128], [12.0]], # fill + [None, 128.0, 128, [12.0], [1.0, 2.0, 3.0]], # fill ): + if isinstance(fill, list) and len(fill) == 3 and image.shape[1] != 3: + # skip the test with non-broadcastable fill value + continue + h, w = image.shape[-2:] displacement = torch.rand(1, h, w, 2) yield ArgsKwargs(image, displacement=displacement, fill=fill) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index e7ca7463b79..8a0717310b8 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -232,7 +232,7 @@ def affine_image_tensor( scale: float, shear: List[float], interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[List[float]] = None, + fill: Optional[Union[int, float, List[float]]] = None, center: Optional[List[float]] = None, ) -> torch.Tensor: if img.numel() == 0: @@ -405,7 +405,9 @@ def affine_mask( return output -def _convert_fill_arg(fill: Optional[Union[int, float, Sequence[int], Sequence[float]]]) -> Optional[List[float]]: +def _convert_fill_arg( + fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] +) -> Optional[Union[int, float, List[float]]]: # Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517 # So, we can't reassign fill to 0 # if fill is None: @@ -416,9 +418,6 @@ def _convert_fill_arg(fill: Optional[Union[int, float, Sequence[int], Sequence[f # This cast does Sequence -> List[float] to please mypy and torch.jit.script if not isinstance(fill, (int, float)): fill = [float(v) for v in list(fill)] - else: - # It is OK to cast int to float as later we use inpt.dtype - fill = [float(fill)] return fill @@ -739,7 +738,7 @@ def perspective_image_tensor( img: torch.Tensor, perspective_coeffs: List[float], interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: Optional[List[float]] = None, + fill: Optional[Union[int, float, List[float]]] = None, ) -> torch.Tensor: return _FT.perspective(img, perspective_coeffs, interpolation=interpolation.value, fill=fill) @@ -878,7 +877,7 @@ def elastic_image_tensor( img: torch.Tensor, displacement: torch.Tensor, interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: Optional[List[float]] = None, + fill: Optional[Union[int, float, List[float]]] = None, ) -> torch.Tensor: return _FT.elastic_transform(img, displacement, interpolation=interpolation.value, fill=fill) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 23d1a4b0edf..aabefcc1e69 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -600,7 +600,10 @@ def _gen_affine_grid( def affine( - img: Tensor, matrix: List[float], interpolation: str = "nearest", fill: Optional[List[float]] = None + img: Tensor, + matrix: List[float], + interpolation: str = "nearest", + fill: Optional[Union[int, float, List[float]]] = None, ) -> Tensor: _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"]) @@ -693,7 +696,10 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, def perspective( - img: Tensor, perspective_coeffs: List[float], interpolation: str = "bilinear", fill: Optional[List[float]] = None + img: Tensor, + perspective_coeffs: List[float], + interpolation: str = "bilinear", + fill: Optional[Union[int, float, List[float]]] = None, ) -> Tensor: if not (isinstance(img, torch.Tensor)): raise TypeError("Input img should be Tensor.") @@ -950,7 +956,7 @@ def elastic_transform( img: Tensor, displacement: Tensor, interpolation: str = "bilinear", - fill: Optional[List[float]] = None, + fill: Optional[Union[int, float, List[float]]] = None, ) -> Tensor: if not (isinstance(img, torch.Tensor)): From 8bf31ba2ad39c75020aa0aaffe6a0077e731f230 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 16 Sep 2022 13:16:13 +0200 Subject: [PATCH 2/9] Updated pad op on prototype side --- torchvision/prototype/features/_image.py | 9 ++--- torchvision/prototype/features/_mask.py | 9 +---- .../transforms/functional/_geometry.py | 39 +++++++++++++++---- 3 files changed, 36 insertions(+), 21 deletions(-) diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 3f92d777db7..a33a39c806c 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -177,12 +177,11 @@ def pad( if not isinstance(padding, int): padding = list(padding) - # PyTorch's pad supports only scalars on fill. So we need to overwrite the colour - if isinstance(fill, (int, float)) or fill is None: - output = self._F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode) - else: - output = self._F._geometry._pad_with_vector_fill(self, padding, fill=fill, padding_mode=padding_mode) + # This cast does Sequence -> List and is required to make mypy happy + if not (fill is None or isinstance(fill, (int, float))): + fill = list(fill) + output = self._F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode) return Image.new_like(self, output) def rotate( diff --git a/torchvision/prototype/features/_mask.py b/torchvision/prototype/features/_mask.py index 022915798e1..04d888e75a2 100644 --- a/torchvision/prototype/features/_mask.py +++ b/torchvision/prototype/features/_mask.py @@ -58,14 +58,7 @@ def pad( if not isinstance(padding, int): padding = list(padding) - if isinstance(fill, (int, float)) or fill is None: - if fill is None: - fill = 0 - output = self._F.pad_mask(self, padding, padding_mode=padding_mode, fill=fill) - else: - # Let's raise an error for vector fill on masks - raise ValueError("Non-scalar fill value is not supported") - + output = self._F.pad_mask(self, padding, padding_mode=padding_mode, fill=fill) return Mask.new_like(self, output) def rotate( diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 8a0717310b8..db4dc84a78c 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -590,7 +590,23 @@ def rotate( def pad_image_tensor( img: torch.Tensor, padding: Union[int, List[int]], - fill: Optional[Union[int, float]] = 0, + fill: Optional[Union[int, float, List[float]]] = None, + padding_mode: str = "constant", +) -> torch.Tensor: + if fill is None: + # This JIT workaround + return _pad_with_scalar_fill(img, padding, fill=None, padding_mode=padding_mode) + elif isinstance(fill, (int, float)) or len(fill) == 1: + fill_number = fill[0] if isinstance(fill, list) else fill + return _pad_with_scalar_fill(img, padding, fill=fill_number, padding_mode=padding_mode) + else: + return _pad_with_vector_fill(img, padding, fill=fill, padding_mode=padding_mode) + + +def _pad_with_scalar_fill( + img: torch.Tensor, + padding: Union[int, List[int]], + fill: Optional[Union[int, float]] = None, padding_mode: str = "constant", ) -> torch.Tensor: num_channels, height, width = img.shape[-3:] @@ -613,13 +629,13 @@ def pad_image_tensor( def _pad_with_vector_fill( img: torch.Tensor, padding: Union[int, List[int]], - fill: Sequence[float] = [0.0], + fill: List[float], padding_mode: str = "constant", ) -> torch.Tensor: if padding_mode != "constant": raise ValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar") - output = pad_image_tensor(img, padding, fill=0, padding_mode="constant") + output = _pad_with_scalar_fill(img, padding, fill=0, padding_mode="constant") left, right, top, bottom = _parse_pad_padding(padding) fill = torch.tensor(fill, dtype=img.dtype, device=img.device).view(-1, 1, 1) @@ -638,8 +654,14 @@ def pad_mask( mask: torch.Tensor, padding: Union[int, List[int]], padding_mode: str = "constant", - fill: Optional[Union[int, float]] = 0, + fill: Optional[Union[int, float, List[float]]] = None, ) -> torch.Tensor: + if fill is None: + fill = 0 + + if isinstance(fill, list): + raise ValueError("Non-scalar fill value is not supported") + if mask.ndim < 3: mask = mask.unsqueeze(0) needs_squeeze = True @@ -692,10 +714,11 @@ def pad( if not isinstance(padding, int): padding = list(padding) - # TODO: PyTorch's pad supports only scalars on fill. So we need to overwrite the colour - if isinstance(fill, (int, float)) or fill is None: - return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode) - return _pad_with_vector_fill(inpt, padding, fill=fill, padding_mode=padding_mode) + # This cast does Sequence -> List and is required to make mypy happy + if not (fill is None or isinstance(fill, (int, float))): + fill = list(fill) + + return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode) crop_image_tensor = _FT.crop From a36da7dced00682281e86059ef1691f3867317d0 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 16 Sep 2022 13:22:53 +0200 Subject: [PATCH 3/9] Code updates --- torchvision/prototype/features/_image.py | 4 +--- torchvision/prototype/features/_mask.py | 2 ++ torchvision/prototype/transforms/functional/_geometry.py | 6 ++---- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index a33a39c806c..dc1c80ef0cc 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -177,9 +177,7 @@ def pad( if not isinstance(padding, int): padding = list(padding) - # This cast does Sequence -> List and is required to make mypy happy - if not (fill is None or isinstance(fill, (int, float))): - fill = list(fill) + fill = self._F._geometry._convert_fill_arg(fill) output = self._F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode) return Image.new_like(self, output) diff --git a/torchvision/prototype/features/_mask.py b/torchvision/prototype/features/_mask.py index 04d888e75a2..0d8a360615f 100644 --- a/torchvision/prototype/features/_mask.py +++ b/torchvision/prototype/features/_mask.py @@ -58,6 +58,8 @@ def pad( if not isinstance(padding, int): padding = list(padding) + fill = self._F._geometry._convert_fill_arg(fill) + output = self._F.pad_mask(self, padding, padding_mode=padding_mode, fill=fill) return Mask.new_like(self, output) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index db4dc84a78c..f26b132337e 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -594,7 +594,7 @@ def pad_image_tensor( padding_mode: str = "constant", ) -> torch.Tensor: if fill is None: - # This JIT workaround + # This is a JIT workaround return _pad_with_scalar_fill(img, padding, fill=None, padding_mode=padding_mode) elif isinstance(fill, (int, float)) or len(fill) == 1: fill_number = fill[0] if isinstance(fill, list) else fill @@ -714,9 +714,7 @@ def pad( if not isinstance(padding, int): padding = list(padding) - # This cast does Sequence -> List and is required to make mypy happy - if not (fill is None or isinstance(fill, (int, float))): - fill = list(fill) + fill = _convert_fill_arg(fill) return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode) From 6d6261277bb38da7943d0c8662b18ab70491db46 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 16 Sep 2022 15:32:21 +0200 Subject: [PATCH 4/9] Few other minor updates --- torchvision/prototype/transforms/functional/_geometry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index f26b132337e..72eded62903 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -606,7 +606,7 @@ def pad_image_tensor( def _pad_with_scalar_fill( img: torch.Tensor, padding: Union[int, List[int]], - fill: Optional[Union[int, float]] = None, + fill: Union[int, float, None], padding_mode: str = "constant", ) -> torch.Tensor: num_channels, height, width = img.shape[-3:] From b81d454916aa35b5f7dbb61c38a115775d1ee983 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 15 Sep 2022 17:22:22 +0200 Subject: [PATCH 5/9] WIP --- test/test_prototype_transforms_functional.py | 7 +- torchvision/prototype/features/_image.py | 6 ++ torchvision/prototype/features/_mask.py | 4 ++ torchvision/prototype/transforms/_geometry.py | 65 +++++++++---------- .../transforms/functional/_geometry.py | 2 +- torchvision/transforms/functional_tensor.py | 12 ++-- 6 files changed, 54 insertions(+), 42 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 81bae521b35..cc900d7c835 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -107,7 +107,7 @@ def rotate_image_tensor(): [-87, 15, 90], # angle [True, False], # expand [None, [12, 23]], # center - [None, [128], [12.0]], # fill + [None, 128, [12.0]], # fill ): if center is not None and expand: # Skip warning: The provided center argument is ignored if expand is True @@ -227,9 +227,12 @@ def pad_image_tensor(): for image, padding, fill, padding_mode in itertools.product( make_images(), [[1], [1, 1], [1, 1, 2, 2]], # padding - [None, 12, 12.0], # fill + [None, [12.0], [12.0, 13.0, 14.0]], # fill ["constant", "symmetric", "edge", "reflect"], # padding mode, ): + if padding_mode != ["constant"] and fill is not None and len(fill) > 1: + # ValueError: Padding mode 'reflect' is not supported if fill is not scalar + continue yield ArgsKwargs(image, padding=padding, fill=fill, padding_mode=padding_mode) diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 3f92d777db7..6a795e44a74 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -172,7 +172,9 @@ def pad( padding: Union[int, Sequence[int]], fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, padding_mode: str = "constant", + fill_for_mask: Optional[Union[int, float]] = None, # fill_for_mask has a different type vs others on purpose ) -> Image: + # This cast does Sequence[int] -> List[int] and is required to make mypy happy if not isinstance(padding, int): padding = list(padding) @@ -192,6 +194,7 @@ def rotate( expand: bool = False, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, center: Optional[List[float]] = None, + fill_for_mask: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, ) -> Image: fill = self._F._geometry._convert_fill_arg(fill) @@ -209,6 +212,7 @@ def affine( interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, center: Optional[List[float]] = None, + fill_for_mask: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, ) -> Image: fill = self._F._geometry._convert_fill_arg(fill) @@ -229,6 +233,7 @@ def perspective( perspective_coeffs: List[float], interpolation: InterpolationMode = InterpolationMode.BILINEAR, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, + fill_for_mask: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, ) -> Image: fill = self._F._geometry._convert_fill_arg(fill) @@ -242,6 +247,7 @@ def elastic( displacement: torch.Tensor, interpolation: InterpolationMode = InterpolationMode.BILINEAR, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, + fill_for_mask: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, ) -> Image: fill = self._F._geometry._convert_fill_arg(fill) diff --git a/torchvision/prototype/features/_mask.py b/torchvision/prototype/features/_mask.py index 022915798e1..e6d384e8893 100644 --- a/torchvision/prototype/features/_mask.py +++ b/torchvision/prototype/features/_mask.py @@ -58,6 +58,10 @@ def pad( if not isinstance(padding, int): padding = list(padding) + + + output = self._F.pad_mask(self, padding, padding_mode=padding_mode, fill=fill) + if isinstance(fill, (int, float)) or fill is None: if fill is None: fill = 0 diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index d351fd66f38..32f2f79ca4f 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -209,10 +209,12 @@ def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None: def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillType]: + _check_fill_arg(fill) + if isinstance(fill, dict): return fill - else: - return defaultdict(lambda: fill, {features.Mask: 0}) # type: ignore[arg-type, return-value] + + return defaultdict(lambda: fill) # type: ignore[arg-type, return-value] def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None: @@ -240,7 +242,6 @@ def __init__( super().__init__() _check_padding_arg(padding) - _check_fill_arg(fill) _check_padding_mode_arg(padding_mode) self.padding = padding @@ -261,7 +262,6 @@ def __init__( ) -> None: super().__init__(p=p) - _check_fill_arg(fill) self.fill = _setup_fill_arg(fill) _check_sequence_input(side_range, "side_range", req_sizes=(2,)) @@ -297,7 +297,7 @@ def __init__( degrees: Union[numbers.Number, Sequence], interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - fill: Union[int, float, Sequence[int], Sequence[float]] = 0, + fill: Union[FillType, Dict[Type, FillType]] = 0, center: Optional[List[float]] = None, ) -> None: super().__init__() @@ -305,9 +305,7 @@ def __init__( self.interpolation = interpolation self.expand = expand - _check_fill_arg(fill) - - self.fill = fill + self.fill = _setup_fill_arg(fill) if center is not None: _check_sequence_input(center, "center", req_sizes=(2,)) @@ -319,12 +317,13 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(angle=angle) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + fill = self.fill[type(inpt)] return F.rotate( inpt, **params, interpolation=self.interpolation, expand=self.expand, - fill=self.fill, + fill=fill, center=self.center, ) @@ -337,7 +336,7 @@ def __init__( scale: Optional[Sequence[float]] = None, shear: Optional[Union[float, Sequence[float]]] = None, interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Union[int, float, Sequence[int], Sequence[float]] = 0, + fill: Union[FillType, Dict[Type, FillType]] = 0, center: Optional[List[float]] = None, ) -> None: super().__init__() @@ -361,10 +360,7 @@ def __init__( self.shear = shear self.interpolation = interpolation - - _check_fill_arg(fill) - - self.fill = fill + self.fill = _setup_fill_arg(fill) if center is not None: _check_sequence_input(center, "center", req_sizes=(2,)) @@ -402,11 +398,12 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(angle=angle, translate=translate, scale=scale, shear=shear) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + fill = self.fill[type(inpt)] return F.affine( inpt, **params, interpolation=self.interpolation, - fill=self.fill, + fill=fill, center=self.center, ) @@ -417,7 +414,7 @@ def __init__( size: Union[int, Sequence[int]], padding: Optional[Union[int, Sequence[int]]] = None, pad_if_needed: bool = False, - fill: Union[int, float, Sequence[int], Sequence[float]] = 0, + fill: Union[FillType, Dict[Type, FillType]] = 0, padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", ) -> None: super().__init__() @@ -427,12 +424,11 @@ def __init__( if pad_if_needed or padding is not None: if padding is not None: _check_padding_arg(padding) - _check_fill_arg(fill) _check_padding_mode_arg(padding_mode) self.padding = padding self.pad_if_needed = pad_if_needed - self.fill = fill + self.fill = _setup_fill_arg(fill) self.padding_mode = padding_mode def _get_params(self, sample: Any) -> Dict[str, Any]: @@ -481,17 +477,18 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: # TODO: (PERF) check for speed optimization if we avoid repeated pad calls + fill = self.fill[type(inpt)] if self.padding is not None: - inpt = F.pad(inpt, padding=self.padding, fill=self.fill, padding_mode=self.padding_mode) + inpt = F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) if self.pad_if_needed: input_width, input_height = params["input_width"], params["input_height"] if input_width < self.size[1]: padding = [self.size[1] - input_width, 0] - inpt = F.pad(inpt, padding=padding, fill=self.fill, padding_mode=self.padding_mode) + inpt = F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode) if input_height < self.size[0]: padding = [0, self.size[0] - input_height] - inpt = F.pad(inpt, padding=padding, fill=self.fill, padding_mode=self.padding_mode) + inpt = F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode) return F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"]) @@ -500,19 +497,18 @@ class RandomPerspective(_RandomApplyTransform): def __init__( self, distortion_scale: float = 0.5, - fill: Union[int, float, Sequence[int], Sequence[float]] = 0, + fill: Union[FillType, Dict[Type, FillType]] = 0, interpolation: InterpolationMode = InterpolationMode.BILINEAR, p: float = 0.5, ) -> None: super().__init__(p=p) - _check_fill_arg(fill) if not (0 <= distortion_scale <= 1): raise ValueError("Argument distortion_scale value should be between 0 and 1") self.distortion_scale = distortion_scale self.interpolation = interpolation - self.fill = fill + self.fill = _setup_fill_arg(fill) def _get_params(self, sample: Any) -> Dict[str, Any]: # Get image size @@ -544,10 +540,11 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(startpoints=startpoints, endpoints=endpoints) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + fill = self.fill[type(inpt)] return F.perspective( inpt, **params, - fill=self.fill, + fill=fill, interpolation=self.interpolation, ) @@ -574,17 +571,15 @@ def __init__( self, alpha: Union[float, Sequence[float]] = 50.0, sigma: Union[float, Sequence[float]] = 5.0, - fill: Union[int, float, Sequence[int], Sequence[float]] = 0, + fill: Union[FillType, Dict[Type, FillType]] = 0, interpolation: InterpolationMode = InterpolationMode.BILINEAR, ) -> None: super().__init__() self.alpha = _setup_float_or_seq(alpha, "alpha", 2) self.sigma = _setup_float_or_seq(sigma, "sigma", 2) - _check_fill_arg(fill) - self.interpolation = interpolation - self.fill = fill + self.fill = _setup_fill_arg(fill) def _get_params(self, sample: Any) -> Dict[str, Any]: # Get image size @@ -612,10 +607,11 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(displacement=displacement) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + fill = self.fill[type(inpt)] return F.elastic( inpt, **params, - fill=self.fill, + fill=fill, interpolation=self.interpolation, ) @@ -787,14 +783,16 @@ class FixedSizeCrop(Transform): def __init__( self, size: Union[int, Sequence[int]], - fill: Union[int, float, Sequence[int], Sequence[float]] = 0, + fill: Union[FillType, Dict[Type, FillType]] = 0, padding_mode: str = "constant", ) -> None: super().__init__() size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")) self.crop_height = size[0] self.crop_width = size[1] - self.fill = fill # TODO: Fill is currently respected only on PIL. Apply tensor patch. + + self.fill = _setup_fill_arg(fill) + self.padding_mode = padding_mode def _get_params(self, sample: Any) -> Dict[str, Any]: @@ -867,7 +865,8 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: ) if params["needs_pad"]: - inpt = F.pad(inpt, params["padding"], fill=self.fill, padding_mode=self.padding_mode) + fill = self.fill[type(inpt)] + inpt = F.pad(inpt, params["padding"], fill=fill, padding_mode=self.padding_mode) return inpt diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index a0ed43056ea..e7ca7463b79 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -467,7 +467,7 @@ def rotate_image_tensor( angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - fill: Optional[List[float]] = None, + fill: Optional[Union[int, float, List[float]]] = None, center: Optional[List[float]] = None, ) -> torch.Tensor: num_channels, height, width = img.shape[-3:] diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index bdc02ae6bcc..5ec24545b1d 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -475,7 +475,7 @@ def _assert_grid_transform_inputs( img: Tensor, matrix: Optional[List[float]], interpolation: str, - fill: Optional[List[float]], + fill: Optional[Union[int, float, List[float]]], supported_interpolation_modes: List[str], coeffs: Optional[List[float]] = None, ) -> None: @@ -499,7 +499,7 @@ def _assert_grid_transform_inputs( # Check fill num_channels = get_dimensions(img)[0] - if isinstance(fill, (tuple, list)) and (len(fill) > 1 and len(fill) != num_channels): + if fill is not None and isinstance(fill, (tuple, list)) and len(fill) > 1 and len(fill) != num_channels: msg = ( "The number of elements in 'fill' cannot broadcast to match the number of " "channels of the image ({} != {})" @@ -539,7 +539,7 @@ def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtyp return img -def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str, fill: Optional[List[float]]) -> Tensor: +def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str, fill: Optional[Union[int, float, List[float]]]) -> Tensor: img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [grid.dtype]) @@ -559,8 +559,8 @@ def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str, fill: Optional[L mask = img[:, -1:, :, :] # N * 1 * H * W img = img[:, :-1, :, :] # N * C * H * W mask = mask.expand_as(img) - len_fill = len(fill) if isinstance(fill, (tuple, list)) else 1 - fill_img = torch.tensor(fill, dtype=img.dtype, device=img.device).view(1, len_fill, 1, 1).expand_as(img) + fill_list, len_fill = (fill, len(fill)) if isinstance(fill, (tuple, list)) else ([float(fill)], 1) + fill_img = torch.tensor(fill_list, dtype=img.dtype, device=img.device).view(1, len_fill, 1, 1).expand_as(img) if mode == "nearest": mask = mask < 0.5 img[mask] = fill_img[mask] @@ -648,7 +648,7 @@ def rotate( matrix: List[float], interpolation: str = "nearest", expand: bool = False, - fill: Optional[List[float]] = None, + fill: Optional[Union[int, float, List[float]]] = None, ) -> Tensor: _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"]) w, h = img.shape[-1], img.shape[-2] From efae8069a2e589d1053b3b26b92771843a797253 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 16 Sep 2022 17:33:56 +0200 Subject: [PATCH 6/9] WIP --- test/test_prototype_transforms_functional.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 64b51f1c9a7..37f59f9ca5f 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -195,10 +195,10 @@ def pad_image_tensor(): for image, padding, fill, padding_mode in itertools.product( make_images(), [[1], [1, 1], [1, 1, 2, 2]], # padding - [None, [12.0], [12.0, 13.0, 14.0]], # fill + [None, 128.0, 128, [12.0], [12.0, 13.0, 14.0]], # fill ["constant", "symmetric", "edge", "reflect"], # padding mode, ): - if padding_mode != ["constant"] and fill is not None and len(fill) > 1: + if padding_mode != "constant" and fill is not None: # ValueError: Padding mode 'reflect' is not supported if fill is not scalar continue yield ArgsKwargs(image, padding=padding, fill=fill, padding_mode=padding_mode) From bacc7dd8564f9adec139b8bffffe3191fcd1a413 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 16 Sep 2022 17:41:48 +0200 Subject: [PATCH 7/9] Updates --- test/test_prototype_transforms_functional.py | 4 ++++ torchvision/prototype/features/_image.py | 5 ----- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 9670b54c3de..ab8c58e1f94 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -201,6 +201,10 @@ def pad_image_tensor(): if padding_mode != "constant" and fill is not None: # ValueError: Padding mode 'reflect' is not supported if fill is not scalar continue + + if isinstance(fill, list) and len(fill) != image.shape[-3]: + continue + yield ArgsKwargs(image, padding=padding, fill=fill, padding_mode=padding_mode) diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 945d6c6c839..0c09c6ea76c 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -172,7 +172,6 @@ def pad( padding: Union[int, Sequence[int]], fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, padding_mode: str = "constant", - fill_for_mask: Optional[Union[int, float]] = None, # fill_for_mask has a different type vs others on purpose ) -> Image: # This cast does Sequence[int] -> List[int] and is required to make mypy happy @@ -191,7 +190,6 @@ def rotate( expand: bool = False, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, center: Optional[List[float]] = None, - fill_for_mask: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, ) -> Image: fill = self._F._geometry._convert_fill_arg(fill) @@ -209,7 +207,6 @@ def affine( interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, center: Optional[List[float]] = None, - fill_for_mask: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, ) -> Image: fill = self._F._geometry._convert_fill_arg(fill) @@ -230,7 +227,6 @@ def perspective( perspective_coeffs: List[float], interpolation: InterpolationMode = InterpolationMode.BILINEAR, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, - fill_for_mask: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, ) -> Image: fill = self._F._geometry._convert_fill_arg(fill) @@ -244,7 +240,6 @@ def elastic( displacement: torch.Tensor, interpolation: InterpolationMode = InterpolationMode.BILINEAR, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, - fill_for_mask: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, ) -> Image: fill = self._F._geometry._convert_fill_arg(fill) From 24ce001f8f7b1e525ff5ba9a51011a5e382f49ba Mon Sep 17 00:00:00 2001 From: vfdev Date: Mon, 19 Sep 2022 13:51:31 +0200 Subject: [PATCH 8/9] Update _image.py --- torchvision/prototype/features/_image.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 0c09c6ea76c..dc1c80ef0cc 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -173,7 +173,6 @@ def pad( fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, padding_mode: str = "constant", ) -> Image: - # This cast does Sequence[int] -> List[int] and is required to make mypy happy if not isinstance(padding, int): padding = list(padding) From 268436dcde056a0ab220a400aa052dac2bcd7574 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 19 Sep 2022 14:54:02 +0200 Subject: [PATCH 9/9] Fixed tests --- test/test_prototype_transforms.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 80be193c13d..537dd5dc53e 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -391,7 +391,7 @@ def test__transform_image_mask(self, fill, mocker): if isinstance(fill, int): calls = [ mocker.call(image, padding=1, fill=fill, padding_mode="constant"), - mocker.call(mask, padding=1, fill=0, padding_mode="constant"), + mocker.call(mask, padding=1, fill=fill, padding_mode="constant"), ] else: calls = [ @@ -467,7 +467,7 @@ def test__transform_image_mask(self, fill, mocker): if isinstance(fill, int): calls = [ mocker.call(image, **params, fill=fill), - mocker.call(mask, **params, fill=0), + mocker.call(mask, **params, fill=fill), ] else: calls = [ @@ -1555,7 +1555,7 @@ def test__get_params(self, mocker): @pytest.mark.parametrize("needs", list(itertools.product((False, True), repeat=2))) def test__transform(self, mocker, needs): - fill_sentinel = mocker.MagicMock() + fill_sentinel = 12 padding_mode_sentinel = mocker.MagicMock() transform = transforms.FixedSizeCrop((-1, -1), fill=fill_sentinel, padding_mode=padding_mode_sentinel)