Skip to content

Commit 7199efe

Browse files
committed
Adding FixedSizeCrop transform
1 parent 7bb8186 commit 7199efe

File tree

1 file changed

+75
-0
lines changed

1 file changed

+75
-0
lines changed

references/detection/transforms.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,3 +326,78 @@ def forward(
326326
)
327327

328328
return image, target
329+
330+
331+
class FixedSizeCrop(nn.Module):
332+
def __init__(self, size, fill=0, padding_mode="constant"):
333+
super().__init__()
334+
size = tuple(T._setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."))
335+
self.crop_height = size[0]
336+
self.crop_width = size[1]
337+
self.fill = fill # TODO: Fill is currently respected only on PIL. Apply tensor patch.
338+
self.padding_mode = padding_mode
339+
340+
def _pad(self, img, target, padding):
341+
# Taken from the functional_tensor.py pad
342+
if isinstance(padding, int):
343+
pad_left = pad_right = pad_top = pad_bottom = padding
344+
elif len(padding) == 1:
345+
pad_left = pad_right = pad_top = pad_bottom = padding[0]
346+
elif len(padding) == 2:
347+
pad_left = pad_right = padding[0]
348+
pad_top = pad_bottom = padding[1]
349+
else:
350+
pad_left = padding[0]
351+
pad_top = padding[1]
352+
pad_right = padding[2]
353+
pad_bottom = padding[3]
354+
355+
padding = [pad_left, pad_top, pad_right, pad_bottom]
356+
img = F.pad(img, padding, self.fill, self.padding_mode)
357+
if target is not None:
358+
target["boxes"][:, 0::2] += pad_left
359+
target["boxes"][:, 1::2] += pad_top
360+
if "masks" in target:
361+
target["masks"] = F.pad(target["masks"], padding, 0, "constant")
362+
363+
return img, target
364+
365+
def _crop(self, img, target, top, left, height, width):
366+
img = F.crop(img, top, left, height, width)
367+
if target is not None:
368+
boxes = target["boxes"]
369+
boxes[:, 0::2] -= left
370+
boxes[:, 1::2] -= top
371+
boxes[:, 0::2].clamp_(min=0, max=width)
372+
boxes[:, 1::2].clamp_(min=0, max=height)
373+
374+
is_valid = (boxes[:, 0] < boxes[:, 2]) & (boxes[:, 1] < boxes[:, 3])
375+
376+
target["boxes"] = boxes[is_valid]
377+
target["labels"] = target["labels"][is_valid]
378+
if "masks" in target:
379+
target["masks"] = F.crop(target["masks"][is_valid], top, left, height, width)
380+
381+
return img, target
382+
383+
def forward(self, img, target=None):
384+
_, height, width = F.get_dimensions(img)
385+
new_height = min(height, self.crop_height)
386+
new_width = min(width, self.crop_width)
387+
388+
if new_height != height or new_width != width:
389+
offset_height = max(height - self.crop_height, 0)
390+
offset_width = max(width - self.crop_width, 0)
391+
392+
r = torch.rand(1)
393+
top = int(offset_height * r)
394+
left = int(offset_width * r)
395+
396+
img, target = self._crop(img, target, top, left, new_height, new_width)
397+
398+
pad_bottom = max(self.crop_height - new_height, 0)
399+
pad_right = max(self.crop_width - new_width, 0)
400+
if pad_bottom != 0 or pad_right != 0:
401+
img, target = self._pad(img, target, [0, 0, pad_right, pad_bottom])
402+
403+
return img, target

0 commit comments

Comments
 (0)