|
21 | 21 | "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop",
|
22 | 22 | "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop",
|
23 | 23 | "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
|
24 |
| - "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode"] |
| 24 | + "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert"] |
25 | 25 |
|
26 | 26 |
|
27 | 27 | class Compose:
|
@@ -1699,3 +1699,43 @@ def _setup_angle(x, name, req_sizes=(2, )):
|
1699 | 1699 | _check_sequence_input(x, name, req_sizes)
|
1700 | 1700 |
|
1701 | 1701 | return [float(d) for d in x]
|
| 1702 | + |
| 1703 | + |
| 1704 | +class RandomInvert(torch.nn.Module): |
| 1705 | + """Inverts the colors of the given image randomly with a given probability. |
| 1706 | + The image can be a PIL Image or a torch Tensor, in which case it is expected |
| 1707 | + to have [..., H, W] shape, where ... means an arbitrary number of leading |
| 1708 | + dimensions |
| 1709 | +
|
| 1710 | + Args: |
| 1711 | + p (float): probability of the image being color inverted. Default value is 0.5 |
| 1712 | + """ |
| 1713 | + |
| 1714 | + def __init__(self, p=0.5): |
| 1715 | + super().__init__() |
| 1716 | + self.p = p |
| 1717 | + |
| 1718 | + @staticmethod |
| 1719 | + def get_params() -> float: |
| 1720 | + """Choose value for random color inversion. |
| 1721 | +
|
| 1722 | + Returns: |
| 1723 | + float: Random value which is used to determine whether the random color inversion |
| 1724 | + should occur. |
| 1725 | + """ |
| 1726 | + return torch.rand(1).item() |
| 1727 | + |
| 1728 | + def forward(self, img): |
| 1729 | + """ |
| 1730 | + Args: |
| 1731 | + img (PIL Image or Tensor): Image to be inverted. |
| 1732 | +
|
| 1733 | + Returns: |
| 1734 | + PIL Image or Tensor: Randomly color inverted image. |
| 1735 | + """ |
| 1736 | + if self.get_params() < self.p: |
| 1737 | + return F.invert(img) |
| 1738 | + return img |
| 1739 | + |
| 1740 | + def __repr__(self): |
| 1741 | + return self.__class__.__name__ + '(p={})'.format(self.p) |
0 commit comments