Skip to content

Commit 89557a8

Browse files
datumboxfacebook-github-bot
authored andcommitted
[fbsync] Adding RandomShortestSize transform (#5610)
Reviewed By: vmoens Differential Revision: D34878987 fbshipit-source-id: 0553419e2fdaa251590fbdeddae1cba88c19d023
1 parent d7490d1 commit 89557a8

File tree

1 file changed

+37
-1
lines changed

1 file changed

+37
-1
lines changed

references/detection/transforms.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Tuple, Dict, Optional
1+
from typing import List, Tuple, Dict, Optional, Union
22

33
import torch
44
import torchvision
@@ -401,3 +401,39 @@ def forward(self, img, target=None):
401401
img, target = self._pad(img, target, [0, 0, pad_right, pad_bottom])
402402

403403
return img, target
404+
405+
406+
class RandomShortestSize(nn.Module):
407+
def __init__(
408+
self,
409+
min_size: Union[List[int], Tuple[int], int],
410+
max_size: int,
411+
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
412+
):
413+
super().__init__()
414+
self.min_size = [min_size] if isinstance(min_size, int) else list(min_size)
415+
self.max_size = max_size
416+
self.interpolation = interpolation
417+
418+
def forward(
419+
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
420+
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
421+
_, orig_height, orig_width = F.get_dimensions(image)
422+
423+
min_size = self.min_size[torch.randint(len(self.min_size), (1,)).item()]
424+
r = min(min_size / min(orig_height, orig_width), self.max_size / max(orig_height, orig_width))
425+
426+
new_width = int(orig_width * r)
427+
new_height = int(orig_height * r)
428+
429+
image = F.resize(image, [new_height, new_width], interpolation=self.interpolation)
430+
431+
if target is not None:
432+
target["boxes"][:, 0::2] *= new_width / orig_width
433+
target["boxes"][:, 1::2] *= new_height / orig_height
434+
if "masks" in target:
435+
target["masks"] = F.resize(
436+
target["masks"], [new_height, new_width], interpolation=InterpolationMode.NEAREST
437+
)
438+
439+
return image, target

0 commit comments

Comments
 (0)