Skip to content

Commit 793c3db

Browse files
prabhat00155facebook-github-bot
authored andcommitted
[fbsync] Adding Scale Jitter transform for detection (#5435)
Summary: * Adding Scale Jitter in references. * Update documentation. * Address review comments. Reviewed By: jdsgomes Differential Revision: D34475308 fbshipit-source-id: dcdb00685d4de39b7315ff7d4b9cfb2411218e5c
1 parent d677be7 commit 793c3db

File tree

1 file changed

+50
-1
lines changed

1 file changed

+50
-1
lines changed

references/detection/transforms.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torchvision
55
from torch import nn, Tensor
66
from torchvision.transforms import functional as F
7-
from torchvision.transforms import transforms as T
7+
from torchvision.transforms import transforms as T, InterpolationMode
88

99

1010
def _flip_coco_person_keypoints(kps, width):
@@ -282,3 +282,52 @@ def forward(
282282
image = F.to_pil_image(image)
283283

284284
return image, target
285+
286+
287+
class ScaleJitter(nn.Module):
288+
"""Randomly resizes the image and its bounding boxes within the specified scale range.
289+
The class implements the Scale Jitter augmentation as described in the paper
290+
`"Simple Copy-Paste is a Strong Data Augmentation Method for Instance Segmentation" <https://arxiv.org/abs/2012.07177>`_.
291+
292+
Args:
293+
target_size (tuple of ints): The target size for the transform provided in (height, weight) format.
294+
scale_range (tuple of ints): scaling factor interval, e.g (a, b), then scale is randomly sampled from the
295+
range a <= scale <= b.
296+
interpolation (InterpolationMode): Desired interpolation enum defined by
297+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
298+
"""
299+
300+
def __init__(
301+
self,
302+
target_size: Tuple[int, int],
303+
scale_range: Tuple[float, float] = (0.1, 2.0),
304+
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
305+
):
306+
super().__init__()
307+
self.target_size = target_size
308+
self.scale_range = scale_range
309+
self.interpolation = interpolation
310+
311+
def forward(
312+
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
313+
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
314+
if isinstance(image, torch.Tensor):
315+
if image.ndimension() not in {2, 3}:
316+
raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
317+
elif image.ndimension() == 2:
318+
image = image.unsqueeze(0)
319+
320+
r = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0])
321+
new_width = int(self.target_size[1] * r)
322+
new_height = int(self.target_size[0] * r)
323+
324+
image = F.resize(image, [new_height, new_width], interpolation=self.interpolation)
325+
326+
if target is not None:
327+
target["boxes"] *= r
328+
if "masks" in target:
329+
target["masks"] = F.resize(
330+
target["masks"], [new_height, new_width], interpolation=InterpolationMode.NEAREST
331+
)
332+
333+
return image, target

0 commit comments

Comments
 (0)