Skip to content

Commit 1d7d92c

Browse files
pmeiervfdev-5
andauthored
refactor prototype.transforms.RandomCrop (#6640)
* refactor RandomCrop * mypy * fix test * use padding directly rather than private attribute * only compute type specific fill if padding is needed * [DRAFT] don't use the diff trick * fix error message Co-authored-by: vfdev <[email protected]> * remove height and width diff * reinstate separate diff checking * introduce needs_crop flag Co-authored-by: vfdev <[email protected]>
1 parent d7d90f5 commit 1d7d92c

File tree

3 files changed

+72
-62
lines changed

3 files changed

+72
-62
lines changed

test/test_prototype_transforms.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -715,30 +715,38 @@ def test__get_params(self, padding, pad_if_needed, size, mocker):
715715

716716
if padding is not None:
717717
if isinstance(padding, int):
718-
h += 2 * padding
719-
w += 2 * padding
718+
pad_top = pad_bottom = pad_left = pad_right = padding
720719
elif isinstance(padding, list) and len(padding) == 2:
721-
w += 2 * padding[0]
722-
h += 2 * padding[1]
720+
pad_left = pad_right = padding[0]
721+
pad_top = pad_bottom = padding[1]
723722
elif isinstance(padding, list) and len(padding) == 4:
724-
w += padding[0] + padding[2]
725-
h += padding[1] + padding[3]
723+
pad_left, pad_top, pad_right, pad_bottom = padding
726724

727-
expected_input_width = w
728-
expected_input_height = h
725+
h += pad_top + pad_bottom
726+
w += pad_left + pad_right
727+
else:
728+
pad_left = pad_right = pad_top = pad_bottom = 0
729729

730730
if pad_if_needed:
731731
if w < size[1]:
732-
w += 2 * (size[1] - w)
732+
diff = size[1] - w
733+
pad_left += diff
734+
pad_right += diff
735+
w += 2 * diff
733736
if h < size[0]:
734-
h += 2 * (size[0] - h)
737+
diff = size[0] - h
738+
pad_top += diff
739+
pad_bottom += diff
740+
h += 2 * diff
741+
742+
padding = [pad_left, pad_top, pad_right, pad_bottom]
735743

736744
assert 0 <= params["top"] <= h - size[0] + 1
737745
assert 0 <= params["left"] <= w - size[1] + 1
738746
assert params["height"] == size[0]
739747
assert params["width"] == size[1]
740-
assert params["input_width"] == expected_input_width
741-
assert params["input_height"] == expected_input_height
748+
assert params["needs_pad"] is any(padding)
749+
assert params["padding"] == padding
742750

743751
@pytest.mark.parametrize("padding", [None, 1, [2, 3], [1, 2, 3, 4]])
744752
@pytest.mark.parametrize("pad_if_needed", [False, True])

test/test_prototype_transforms_consistency.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -966,7 +966,7 @@ def _transform(self, inpt, params):
966966

967967
class TestRefSegTransforms:
968968
def make_datapoints(self, supports_pil=True, image_dtype=torch.uint8):
969-
size = (256, 640)
969+
size = (256, 460)
970970
num_categories = 21
971971

972972
conv_fns = []

torchvision/prototype/transforms/_geometry.py

Lines changed: 51 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -414,78 +414,80 @@ def __init__(
414414
_check_padding_arg(padding)
415415
_check_padding_mode_arg(padding_mode)
416416

417-
self.padding = padding
417+
self.padding = F._geometry._parse_pad_padding(padding) if padding else None # type: ignore[arg-type]
418418
self.pad_if_needed = pad_if_needed
419419
self.fill = _setup_fill_arg(fill)
420420
self.padding_mode = padding_mode
421421

422422
def _get_params(self, sample: Any) -> Dict[str, Any]:
423-
_, height, width = query_chw(sample)
423+
_, padded_height, padded_width = query_chw(sample)
424424

425425
if self.padding is not None:
426-
# update height, width with static padding data
427-
padding = self.padding
428-
if isinstance(padding, Sequence):
429-
padding = list(padding)
430-
pad_left, pad_right, pad_top, pad_bottom = F._geometry._parse_pad_padding(padding)
431-
height += pad_top + pad_bottom
432-
width += pad_left + pad_right
433-
434-
output_height, output_width = self.size
435-
# We have to store maybe padded image size for pad_if_needed branch in _transform
436-
input_height, input_width = height, width
426+
pad_left, pad_right, pad_top, pad_bottom = self.padding
427+
padded_height += pad_top + pad_bottom
428+
padded_width += pad_left + pad_right
429+
else:
430+
pad_left = pad_right = pad_top = pad_bottom = 0
431+
432+
cropped_height, cropped_width = self.size
437433

438434
if self.pad_if_needed:
439-
# pad width if needed
440-
if width < output_width:
441-
width += 2 * (output_width - width)
442-
# pad height if needed
443-
if height < output_height:
444-
height += 2 * (output_height - height)
445-
446-
if height < output_height or width < output_width:
435+
if padded_height < cropped_height:
436+
diff = cropped_height - padded_height
437+
438+
pad_top += diff
439+
pad_bottom += diff
440+
padded_height += 2 * diff
441+
442+
if padded_width < cropped_width:
443+
diff = cropped_width - padded_width
444+
445+
pad_left += diff
446+
pad_right += diff
447+
padded_width += 2 * diff
448+
449+
if padded_height < cropped_height or padded_width < cropped_width:
447450
raise ValueError(
448-
f"Required crop size {(output_height, output_width)} is larger then input image size {(height, width)}"
451+
f"Required crop size {(cropped_height, cropped_width)} is larger than "
452+
f"{'padded ' if self.padding is not None else ''}input image size {(padded_height, padded_width)}."
449453
)
450454

451-
if width == output_width and height == output_height:
452-
return dict(top=0, left=0, height=height, width=width, input_width=input_width, input_height=input_height)
455+
# We need a different order here than we have in self.padding since this padding will be parsed again in `F.pad`
456+
padding = [pad_left, pad_top, pad_right, pad_bottom]
457+
needs_pad = any(padding)
453458

454-
top = torch.randint(0, height - output_height + 1, size=(1,)).item()
455-
left = torch.randint(0, width - output_width + 1, size=(1,)).item()
459+
needs_vert_crop, top = (
460+
(True, int(torch.randint(0, padded_height - cropped_height + 1, size=())))
461+
if padded_height > cropped_height
462+
else (False, 0)
463+
)
464+
needs_horz_crop, left = (
465+
(True, int(torch.randint(0, padded_width - cropped_width + 1, size=())))
466+
if padded_width > cropped_width
467+
else (False, 0)
468+
)
456469

457470
return dict(
471+
needs_crop=needs_vert_crop or needs_horz_crop,
458472
top=top,
459473
left=left,
460-
height=output_height,
461-
width=output_width,
462-
input_width=input_width,
463-
input_height=input_height,
474+
height=cropped_height,
475+
width=cropped_width,
476+
needs_pad=needs_pad,
477+
padding=padding,
464478
)
465479

466480
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
467-
# TODO: (PERF) check for speed optimization if we avoid repeated pad calls
468-
fill = self.fill[type(inpt)]
469-
fill = F._geometry._convert_fill_arg(fill)
481+
if params["needs_pad"]:
482+
fill = self.fill[type(inpt)]
483+
fill = F._geometry._convert_fill_arg(fill)
470484

471-
if self.padding is not None:
472-
# This cast does Sequence[int] -> List[int] and is required to make mypy happy
473-
padding = self.padding
474-
if not isinstance(padding, int):
475-
padding = list(padding)
485+
inpt = F.pad(inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode)
476486

477-
inpt = F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode)
487+
if params["needs_crop"]:
488+
inpt = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"])
478489

479-
if self.pad_if_needed:
480-
input_width, input_height = params["input_width"], params["input_height"]
481-
if input_width < self.size[1]:
482-
padding = [self.size[1] - input_width, 0]
483-
inpt = F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode)
484-
if input_height < self.size[0]:
485-
padding = [0, self.size[0] - input_height]
486-
inpt = F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode)
487-
488-
return F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"])
490+
return inpt
489491

490492

491493
class RandomPerspective(_RandomApplyTransform):

0 commit comments

Comments
 (0)