Skip to content

Commit dfbc999

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] [proto] Fixed issue with F.pad from RandomZoomOut (#6386)
Summary: * [proto] Fixed issue with `F.pad` from RandomZoomOut * Fixed failing tests * Fixed wrong type hint * Fixed fill=None in pad_image_pil * Try to support fill=None in functional * Code formatting Reviewed By: datumbox Differential Revision: D38824241 fbshipit-source-id: 919e7bb9e9c575899ac4d8845462521fa655f79e
1 parent 5ca2bf7 commit dfbc999

File tree

7 files changed

+15
-21
lines changed

7 files changed

+15
-21
lines changed

test/test_prototype_transforms.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -377,12 +377,11 @@ def test__get_params(self, fill, side_range, mocker):
377377
transform = transforms.RandomZoomOut(fill=fill, side_range=side_range)
378378

379379
image = mocker.MagicMock(spec=features.Image)
380-
c = image.num_channels = 3
381380
h, w = image.image_size = (24, 32)
382381

383382
params = transform._get_params(image)
384383

385-
assert params["fill"] == (fill if not isinstance(fill, int) else [fill] * c)
384+
assert params["fill"] == fill
386385
assert len(params["padding"]) == 4
387386
assert 0 <= params["padding"][0] <= (side_range[1] - 1) * w
388387
assert 0 <= params["padding"][1] <= (side_range[1] - 1) * h

test/test_prototype_transforms_functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,7 @@ def pad_image_tensor():
464464
for image, padding, fill, padding_mode in itertools.product(
465465
make_images(),
466466
[[1], [1, 1], [1, 1, 2, 2]], # padding
467-
[12, 12.0], # fill
467+
[None, 12, 12.0], # fill
468468
["constant", "symmetric", "edge", "reflect"], # padding mode,
469469
):
470470
yield SampleInput(image, padding=padding, fill=fill, padding_mode=padding_mode)

torchvision/prototype/features/_image.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,8 @@ def pad(
174174
if not isinstance(padding, int):
175175
padding = list(padding)
176176

177-
if fill is None:
178-
fill = 0
179-
180177
# PyTorch's pad supports only scalars on fill. So we need to overwrite the colour
181-
if isinstance(fill, (int, float)):
178+
if isinstance(fill, (int, float)) or fill is None:
182179
output = _F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode)
183180
else:
184181
from torchvision.prototype.transforms.functional._geometry import _pad_with_vector_fill

torchvision/prototype/transforms/_geometry.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -294,12 +294,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
294294
bottom = canvas_height - (top + orig_h)
295295
padding = [left, top, right, bottom]
296296

297-
# vfdev-5: Can we put that into pad_image_tensor ?
298-
fill = self.fill
299-
if not isinstance(fill, collections.abc.Sequence):
300-
fill = [fill] * orig_c
301-
302-
return dict(padding=padding, fill=fill)
297+
return dict(padding=padding, fill=self.fill)
303298

304299
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
305300
return F.pad(inpt, **params)

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,10 @@ def rotate(
531531

532532

533533
def pad_image_tensor(
534-
img: torch.Tensor, padding: Union[int, List[int]], fill: Union[int, float] = 0, padding_mode: str = "constant"
534+
img: torch.Tensor,
535+
padding: Union[int, List[int]],
536+
fill: Optional[Union[int, float]] = 0,
537+
padding_mode: str = "constant",
535538
) -> torch.Tensor:
536539
num_channels, height, width = img.shape[-3:]
537540
extra_dims = img.shape[:-3]
@@ -555,7 +558,7 @@ def _pad_with_vector_fill(
555558
raise ValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar")
556559

557560
output = pad_image_tensor(img, padding, fill=0, padding_mode="constant")
558-
left, top, right, bottom = _FT._parse_pad_padding(padding)
561+
left, right, top, bottom = _FT._parse_pad_padding(padding)
559562
fill = torch.tensor(fill, dtype=img.dtype, device=img.device).view(-1, 1, 1)
560563

561564
if top > 0:
@@ -614,11 +617,8 @@ def pad(
614617
if not isinstance(padding, int):
615618
padding = list(padding)
616619

617-
if fill is None:
618-
fill = 0
619-
620620
# TODO: PyTorch's pad supports only scalars on fill. So we need to overwrite the colour
621-
if isinstance(fill, (int, float)):
621+
if isinstance(fill, (int, float)) or fill is None:
622622
return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode)
623623
return _pad_with_vector_fill(inpt, padding, fill=fill, padding_mode=padding_mode)
624624

torchvision/transforms/functional_pil.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def pad(
155155

156156
if not isinstance(padding, (numbers.Number, tuple, list)):
157157
raise TypeError("Got inappropriate padding arg")
158-
if not isinstance(fill, (numbers.Number, tuple, list)):
158+
if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)):
159159
raise TypeError("Got inappropriate fill arg")
160160
if not isinstance(padding_mode, str):
161161
raise TypeError("Got inappropriate padding_mode arg")

torchvision/transforms/functional_tensor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,10 +371,13 @@ def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
371371

372372

373373
def pad(
374-
img: Tensor, padding: Union[int, List[int]], fill: Union[int, float] = 0, padding_mode: str = "constant"
374+
img: Tensor, padding: Union[int, List[int]], fill: Optional[Union[int, float]] = 0, padding_mode: str = "constant"
375375
) -> Tensor:
376376
_assert_image_tensor(img)
377377

378+
if fill is None:
379+
fill = 0
380+
378381
if not isinstance(padding, (int, tuple, list)):
379382
raise TypeError("Got inappropriate padding arg")
380383
if not isinstance(fill, (int, float)):

0 commit comments

Comments
 (0)