|
1 |
| -from typing import List, Tuple, Dict, Optional |
| 1 | +from typing import List, Tuple, Dict, Optional, Union |
2 | 2 |
|
3 | 3 | import torch
|
4 | 4 | import torchvision
|
@@ -401,3 +401,39 @@ def forward(self, img, target=None):
|
401 | 401 | img, target = self._pad(img, target, [0, 0, pad_right, pad_bottom])
|
402 | 402 |
|
403 | 403 | return img, target
|
| 404 | + |
| 405 | + |
| 406 | +class RandomShortestSize(nn.Module): |
| 407 | + def __init__( |
| 408 | + self, |
| 409 | + min_size: Union[List[int], Tuple[int], int], |
| 410 | + max_size: int, |
| 411 | + interpolation: InterpolationMode = InterpolationMode.BILINEAR, |
| 412 | + ): |
| 413 | + super().__init__() |
| 414 | + self.min_size = [min_size] if isinstance(min_size, int) else list(min_size) |
| 415 | + self.max_size = max_size |
| 416 | + self.interpolation = interpolation |
| 417 | + |
| 418 | + def forward( |
| 419 | + self, image: Tensor, target: Optional[Dict[str, Tensor]] = None |
| 420 | + ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: |
| 421 | + _, orig_height, orig_width = F.get_dimensions(image) |
| 422 | + |
| 423 | + min_size = self.min_size[torch.randint(len(self.min_size), (1,)).item()] |
| 424 | + r = min(min_size / min(orig_height, orig_width), self.max_size / max(orig_height, orig_width)) |
| 425 | + |
| 426 | + new_width = int(orig_width * r) |
| 427 | + new_height = int(orig_height * r) |
| 428 | + |
| 429 | + image = F.resize(image, [new_height, new_width], interpolation=self.interpolation) |
| 430 | + |
| 431 | + if target is not None: |
| 432 | + target["boxes"][:, 0::2] *= new_width / orig_width |
| 433 | + target["boxes"][:, 1::2] *= new_height / orig_height |
| 434 | + if "masks" in target: |
| 435 | + target["masks"] = F.resize( |
| 436 | + target["masks"], [new_height, new_width], interpolation=InterpolationMode.NEAREST |
| 437 | + ) |
| 438 | + |
| 439 | + return image, target |
0 commit comments