Skip to content

Commit 77e9a57

Browse files
committed
Adding Scale Jitter in references.
1 parent 26fe8fa commit 77e9a57

File tree

1 file changed

+45
-1
lines changed

1 file changed

+45
-1
lines changed

references/detection/transforms.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torchvision
55
from torch import nn, Tensor
66
from torchvision.transforms import functional as F
7-
from torchvision.transforms import transforms as T
7+
from torchvision.transforms import transforms as T, InterpolationMode
88

99

1010
def _flip_coco_person_keypoints(kps, width):
@@ -282,3 +282,47 @@ def forward(
282282
image = F.to_pil_image(image)
283283

284284
return image, target
285+
286+
287+
class ScaleJitter(nn.Module):
288+
"""Randomly resizes the image and its bounding boxes within a specified ratio range.
289+
The class implements the Scale Jitter augmentation as described in the paper
290+
`"Simple Copy-Paste is a Strong Data Augmentation Method for Instance Segmentation" <https://arxiv.org/abs/2012.07177>`_.
291+
292+
Args:
293+
scale_range (tuple of ints): scaling factor interval, e.g (a, b), then scale is randomly sampled from the
294+
range a <= scale <= b.
295+
interpolation (InterpolationMode): Desired interpolation enum defined by
296+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
297+
"""
298+
299+
def __init__(
300+
self,
301+
scale_range: Tuple[float, float] = (0.1, 2.0),
302+
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
303+
):
304+
super().__init__()
305+
self.scale_range = scale_range
306+
self.interpolation = interpolation
307+
308+
def forward(
309+
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
310+
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
311+
if isinstance(image, torch.Tensor):
312+
if image.ndimension() not in {2, 3}:
313+
raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
314+
elif image.ndimension() == 2:
315+
image = image.unsqueeze(0)
316+
317+
old_width, old_height = F.get_image_size(image)
318+
319+
r = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0])
320+
new_width = int(old_width * r)
321+
new_height = int(old_height * r)
322+
323+
image = F.resize(image, [new_height, new_width], interpolation=self.interpolation)
324+
325+
if target is not None:
326+
target["boxes"] *= r
327+
328+
return image, target

0 commit comments

Comments
 (0)