Skip to content

Commit 8bf31ba

Browse files
committed
Updated pad op on prototype side
1 parent e24d71a commit 8bf31ba

File tree

3 files changed

+36
-21
lines changed

3 files changed

+36
-21
lines changed

torchvision/prototype/features/_image.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -177,12 +177,11 @@ def pad(
177177
if not isinstance(padding, int):
178178
padding = list(padding)
179179

180-
# PyTorch's pad supports only scalars on fill. So we need to overwrite the colour
181-
if isinstance(fill, (int, float)) or fill is None:
182-
output = self._F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode)
183-
else:
184-
output = self._F._geometry._pad_with_vector_fill(self, padding, fill=fill, padding_mode=padding_mode)
180+
# This cast does Sequence -> List and is required to make mypy happy
181+
if not (fill is None or isinstance(fill, (int, float))):
182+
fill = list(fill)
185183

184+
output = self._F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode)
186185
return Image.new_like(self, output)
187186

188187
def rotate(

torchvision/prototype/features/_mask.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,7 @@ def pad(
5858
if not isinstance(padding, int):
5959
padding = list(padding)
6060

61-
if isinstance(fill, (int, float)) or fill is None:
62-
if fill is None:
63-
fill = 0
64-
output = self._F.pad_mask(self, padding, padding_mode=padding_mode, fill=fill)
65-
else:
66-
# Let's raise an error for vector fill on masks
67-
raise ValueError("Non-scalar fill value is not supported")
68-
61+
output = self._F.pad_mask(self, padding, padding_mode=padding_mode, fill=fill)
6962
return Mask.new_like(self, output)
7063

7164
def rotate(

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,23 @@ def rotate(
590590
def pad_image_tensor(
591591
img: torch.Tensor,
592592
padding: Union[int, List[int]],
593-
fill: Optional[Union[int, float]] = 0,
593+
fill: Optional[Union[int, float, List[float]]] = None,
594+
padding_mode: str = "constant",
595+
) -> torch.Tensor:
596+
if fill is None:
597+
# This JIT workaround
598+
return _pad_with_scalar_fill(img, padding, fill=None, padding_mode=padding_mode)
599+
elif isinstance(fill, (int, float)) or len(fill) == 1:
600+
fill_number = fill[0] if isinstance(fill, list) else fill
601+
return _pad_with_scalar_fill(img, padding, fill=fill_number, padding_mode=padding_mode)
602+
else:
603+
return _pad_with_vector_fill(img, padding, fill=fill, padding_mode=padding_mode)
604+
605+
606+
def _pad_with_scalar_fill(
607+
img: torch.Tensor,
608+
padding: Union[int, List[int]],
609+
fill: Optional[Union[int, float]] = None,
594610
padding_mode: str = "constant",
595611
) -> torch.Tensor:
596612
num_channels, height, width = img.shape[-3:]
@@ -613,13 +629,13 @@ def pad_image_tensor(
613629
def _pad_with_vector_fill(
614630
img: torch.Tensor,
615631
padding: Union[int, List[int]],
616-
fill: Sequence[float] = [0.0],
632+
fill: List[float],
617633
padding_mode: str = "constant",
618634
) -> torch.Tensor:
619635
if padding_mode != "constant":
620636
raise ValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar")
621637

622-
output = pad_image_tensor(img, padding, fill=0, padding_mode="constant")
638+
output = _pad_with_scalar_fill(img, padding, fill=0, padding_mode="constant")
623639
left, right, top, bottom = _parse_pad_padding(padding)
624640
fill = torch.tensor(fill, dtype=img.dtype, device=img.device).view(-1, 1, 1)
625641

@@ -638,8 +654,14 @@ def pad_mask(
638654
mask: torch.Tensor,
639655
padding: Union[int, List[int]],
640656
padding_mode: str = "constant",
641-
fill: Optional[Union[int, float]] = 0,
657+
fill: Optional[Union[int, float, List[float]]] = None,
642658
) -> torch.Tensor:
659+
if fill is None:
660+
fill = 0
661+
662+
if isinstance(fill, list):
663+
raise ValueError("Non-scalar fill value is not supported")
664+
643665
if mask.ndim < 3:
644666
mask = mask.unsqueeze(0)
645667
needs_squeeze = True
@@ -692,10 +714,11 @@ def pad(
692714
if not isinstance(padding, int):
693715
padding = list(padding)
694716

695-
# TODO: PyTorch's pad supports only scalars on fill. So we need to overwrite the colour
696-
if isinstance(fill, (int, float)) or fill is None:
697-
return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode)
698-
return _pad_with_vector_fill(inpt, padding, fill=fill, padding_mode=padding_mode)
717+
# This cast does Sequence -> List and is required to make mypy happy
718+
if not (fill is None or isinstance(fill, (int, float))):
719+
fill = list(fill)
720+
721+
return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode)
699722

700723

701724
crop_image_tensor = _FT.crop

0 commit comments

Comments
 (0)