|
7 | 7 |
|
8 | 8 | from . import functional as F, InterpolationMode
|
9 | 9 |
|
10 |
| -__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment"] |
| 10 | +__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide"] |
11 | 11 |
|
12 | 12 |
|
13 | 13 | def _apply_op(img: Tensor, op_name: str, magnitude: float,
|
@@ -44,6 +44,8 @@ def _apply_op(img: Tensor, op_name: str, magnitude: float,
|
44 | 44 | img = F.equalize(img)
|
45 | 45 | elif op_name == "Invert":
|
46 | 46 | img = F.invert(img)
|
| 47 | + elif op_name == "Identity": |
| 48 | + pass |
47 | 49 | else:
|
48 | 50 | raise ValueError("The provided operator {} is not recognized.".format(op_name))
|
49 | 51 | return img
|
@@ -325,3 +327,79 @@ def __repr__(self) -> str:
|
325 | 327 | s += ', fill={fill}'
|
326 | 328 | s += ')'
|
327 | 329 | return s.format(**self.__dict__)
|
| 330 | + |
| 331 | + |
| 332 | +class TrivialAugmentWide(torch.nn.Module): |
| 333 | + r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in |
| 334 | + `"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`. |
| 335 | + If the image is torch Tensor, it should be of type torch.uint8, and it is expected |
| 336 | + to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. |
| 337 | + If img is PIL Image, it is expected to be in mode "L" or "RGB". |
| 338 | +
|
| 339 | + Args: |
| 340 | + num_magnitude_bins (int): The number of different magnitude values. |
| 341 | + interpolation (InterpolationMode): Desired interpolation enum defined by |
| 342 | + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. |
| 343 | + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. |
| 344 | + fill (sequence or number, optional): Pixel fill value for the area outside the transformed |
| 345 | + image. If given a number, the value is used for all bands respectively. |
| 346 | + """ |
| 347 | + |
| 348 | + def __init__(self, num_magnitude_bins: int = 30, interpolation: InterpolationMode = InterpolationMode.NEAREST, |
| 349 | + fill: Optional[List[float]] = None) -> None: |
| 350 | + super().__init__() |
| 351 | + self.num_magnitude_bins = num_magnitude_bins |
| 352 | + self.interpolation = interpolation |
| 353 | + self.fill = fill |
| 354 | + |
| 355 | + def _augmentation_space(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]: |
| 356 | + return { |
| 357 | + # op_name: (magnitudes, signed) |
| 358 | + "Identity": (torch.tensor(0.0), False), |
| 359 | + "ShearX": (torch.linspace(0.0, 0.99, num_bins), True), |
| 360 | + "ShearY": (torch.linspace(0.0, 0.99, num_bins), True), |
| 361 | + "TranslateX": (torch.linspace(0.0, 32.0, num_bins), True), |
| 362 | + "TranslateY": (torch.linspace(0.0, 32.0, num_bins), True), |
| 363 | + "Rotate": (torch.linspace(0.0, 135.0, num_bins), True), |
| 364 | + "Brightness": (torch.linspace(0.0, 0.99, num_bins), True), |
| 365 | + "Color": (torch.linspace(0.0, 0.99, num_bins), True), |
| 366 | + "Contrast": (torch.linspace(0.0, 0.99, num_bins), True), |
| 367 | + "Sharpness": (torch.linspace(0.0, 0.99, num_bins), True), |
| 368 | + "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)).round().int(), False), |
| 369 | + "Solarize": (torch.linspace(256.0, 0.0, num_bins), False), |
| 370 | + "AutoContrast": (torch.tensor(0.0), False), |
| 371 | + "Equalize": (torch.tensor(0.0), False), |
| 372 | + } |
| 373 | + |
| 374 | + def forward(self, img: Tensor) -> Tensor: |
| 375 | + """ |
| 376 | + img (PIL Image or Tensor): Image to be transformed. |
| 377 | +
|
| 378 | + Returns: |
| 379 | + PIL Image or Tensor: Transformed image. |
| 380 | + """ |
| 381 | + fill = self.fill |
| 382 | + if isinstance(img, Tensor): |
| 383 | + if isinstance(fill, (int, float)): |
| 384 | + fill = [float(fill)] * F.get_image_num_channels(img) |
| 385 | + elif fill is not None: |
| 386 | + fill = [float(f) for f in fill] |
| 387 | + |
| 388 | + op_meta = self._augmentation_space(self.num_magnitude_bins) |
| 389 | + op_index = int(torch.randint(len(op_meta), (1,)).item()) |
| 390 | + op_name = list(op_meta.keys())[op_index] |
| 391 | + magnitudes, signed = op_meta[op_name] |
| 392 | + magnitude = float(magnitudes[torch.randint(len(magnitudes), (1,), dtype=torch.long)].item()) \ |
| 393 | + if magnitudes.ndim > 0 else 0.0 |
| 394 | + if signed and torch.randint(2, (1,)): |
| 395 | + magnitude *= -1.0 |
| 396 | + |
| 397 | + return _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill) |
| 398 | + |
| 399 | + def __repr__(self) -> str: |
| 400 | + s = self.__class__.__name__ + '(' |
| 401 | + s += 'num_magnitude_bins={num_magnitude_bins}' |
| 402 | + s += ', interpolation={interpolation}' |
| 403 | + s += ', fill={fill}' |
| 404 | + s += ')' |
| 405 | + return s.format(**self.__dict__) |
0 commit comments