|
22 | 22 | "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop",
|
23 | 23 | "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
|
24 | 24 | "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize",
|
25 |
| - "RandomSolarize", "RandomAutocontrast"] |
| 25 | + "RandomSolarize", "RandomAutocontrast", "RandomEqualize"] |
26 | 26 |
|
27 | 27 |
|
28 | 28 | class Compose:
|
@@ -1876,3 +1876,43 @@ def forward(self, img):
|
1876 | 1876 |
|
1877 | 1877 | def __repr__(self):
|
1878 | 1878 | return self.__class__.__name__ + '(p={})'.format(self.p)
|
| 1879 | + |
| 1880 | + |
| 1881 | +class RandomEqualize(torch.nn.Module): |
| 1882 | + """Equalize the histogram of the given image randomly with a given probability. |
| 1883 | + The image can be a PIL Image or a torch Tensor, in which case it is expected |
| 1884 | + to have [..., H, W] shape, where ... means an arbitrary number of leading |
| 1885 | + dimensions. |
| 1886 | +
|
| 1887 | + Args: |
| 1888 | + p (float): probability of the image being equalized. Default value is 0.5 |
| 1889 | + """ |
| 1890 | + |
| 1891 | + def __init__(self, p=0.5): |
| 1892 | + super().__init__() |
| 1893 | + self.p = p |
| 1894 | + |
| 1895 | + @staticmethod |
| 1896 | + def get_params() -> float: |
| 1897 | + """Choose a value for the random transformation. |
| 1898 | +
|
| 1899 | + Returns: |
| 1900 | + float: Random value which is used to determine whether the random transformation |
| 1901 | + should occur. |
| 1902 | + """ |
| 1903 | + return torch.rand(1).item() |
| 1904 | + |
| 1905 | + def forward(self, img): |
| 1906 | + """ |
| 1907 | + Args: |
| 1908 | + img (PIL Image or Tensor): Image to be equalized. |
| 1909 | +
|
| 1910 | + Returns: |
| 1911 | + PIL Image or Tensor: Randomly equalized image. |
| 1912 | + """ |
| 1913 | + if self.get_params() < self.p: |
| 1914 | + return F.equalize(img) |
| 1915 | + return img |
| 1916 | + |
| 1917 | + def __repr__(self): |
| 1918 | + return self.__class__.__name__ + '(p={})'.format(self.p) |
0 commit comments