Skip to content

Commit 5b0d597

Browse files
author
Federico Pozzi
committed
refactor: reshape tensor in 4d, then pad
1 parent 846ba4b commit 5b0d597

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -377,12 +377,6 @@ def pad_segmentation_mask():
377377
[[1], [1, 1], [1, 1, 2, 2]], # padding
378378
["constant", "symmetric", "edge", "reflect"], # padding mode,
379379
):
380-
if padding_mode == "symmetric" and mask.ndim not in [3, 4]:
381-
continue
382-
383-
if (padding_mode == "edge" or padding_mode == "reflect") and mask.ndim not in [3, 4]:
384-
continue
385-
386380
yield SampleInput(mask, padding=padding, padding_mode=padding_mode)
387381

388382

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,15 @@ def rotate_segmentation_mask(
399399
def pad_segmentation_mask(
400400
segmentation_mask: torch.Tensor, padding: List[int], padding_mode: str = "constant"
401401
) -> torch.Tensor:
402-
return pad_image_tensor(img=segmentation_mask, padding=padding, fill=0, padding_mode=padding_mode)
402+
num_masks, height, width = segmentation_mask.shape[-3:]
403+
extra_dims = segmentation_mask.shape[:-3]
404+
405+
padded_mask = pad_image_tensor(
406+
img=segmentation_mask.view(-1, num_masks, height, width), padding=padding, fill=0, padding_mode=padding_mode
407+
)
408+
409+
new_height, new_width = padded_mask.shape[-2:]
410+
return padded_mask.view(extra_dims + (num_masks, new_height, new_width))
403411

404412

405413
def pad_bounding_box(

0 commit comments

Comments
 (0)