|
4 | 4 | import torchvision
|
5 | 5 | from torch import nn, Tensor
|
6 | 6 | from torchvision.transforms import functional as F
|
7 |
| -from torchvision.transforms import transforms as T |
| 7 | +from torchvision.transforms import transforms as T, InterpolationMode |
8 | 8 |
|
9 | 9 |
|
10 | 10 | def _flip_coco_person_keypoints(kps, width):
|
@@ -282,3 +282,52 @@ def forward(
|
282 | 282 | image = F.to_pil_image(image)
|
283 | 283 |
|
284 | 284 | return image, target
|
| 285 | + |
| 286 | + |
| 287 | +class ScaleJitter(nn.Module): |
| 288 | + """Randomly resizes the image and its bounding boxes within the specified scale range. |
| 289 | + The class implements the Scale Jitter augmentation as described in the paper |
| 290 | + `"Simple Copy-Paste is a Strong Data Augmentation Method for Instance Segmentation" <https://arxiv.org/abs/2012.07177>`_. |
| 291 | +
|
| 292 | + Args: |
| 293 | + target_size (tuple of ints): The target size for the transform provided in (height, weight) format. |
| 294 | + scale_range (tuple of ints): scaling factor interval, e.g (a, b), then scale is randomly sampled from the |
| 295 | + range a <= scale <= b. |
| 296 | + interpolation (InterpolationMode): Desired interpolation enum defined by |
| 297 | + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. |
| 298 | + """ |
| 299 | + |
| 300 | + def __init__( |
| 301 | + self, |
| 302 | + target_size: Tuple[int, int], |
| 303 | + scale_range: Tuple[float, float] = (0.1, 2.0), |
| 304 | + interpolation: InterpolationMode = InterpolationMode.BILINEAR, |
| 305 | + ): |
| 306 | + super().__init__() |
| 307 | + self.target_size = target_size |
| 308 | + self.scale_range = scale_range |
| 309 | + self.interpolation = interpolation |
| 310 | + |
| 311 | + def forward( |
| 312 | + self, image: Tensor, target: Optional[Dict[str, Tensor]] = None |
| 313 | + ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: |
| 314 | + if isinstance(image, torch.Tensor): |
| 315 | + if image.ndimension() not in {2, 3}: |
| 316 | + raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.") |
| 317 | + elif image.ndimension() == 2: |
| 318 | + image = image.unsqueeze(0) |
| 319 | + |
| 320 | + r = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0]) |
| 321 | + new_width = int(self.target_size[1] * r) |
| 322 | + new_height = int(self.target_size[0] * r) |
| 323 | + |
| 324 | + image = F.resize(image, [new_height, new_width], interpolation=self.interpolation) |
| 325 | + |
| 326 | + if target is not None: |
| 327 | + target["boxes"] *= r |
| 328 | + if "masks" in target: |
| 329 | + target["masks"] = F.resize( |
| 330 | + target["masks"], [new_height, new_width], interpolation=InterpolationMode.NEAREST |
| 331 | + ) |
| 332 | + |
| 333 | + return image, target |
0 commit comments