Skip to content

Commit 10c3efa

Browse files
authored
Invert Transform (#3104)
* Adding invert operator. * Make use of the _assert_channels(). * Update upper bound value.
1 parent e2e323c commit 10c3efa

File tree

7 files changed

+156
-1
lines changed

7 files changed

+156
-1
lines changed

test/test_functional_tensor.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -862,6 +862,21 @@ def test_gaussian_blur(self):
862862
msg="{}, {}".format(ksize, sigma)
863863
)
864864

865+
def test_invert(self):
866+
script_invert = torch.jit.script(F.invert)
867+
868+
img_tensor, pil_img = self._create_data(16, 18, device=self.device)
869+
inverted_img = F.invert(img_tensor)
870+
inverted_pil_img = F.invert(pil_img)
871+
self.compareTensorToPIL(inverted_img, inverted_pil_img)
872+
873+
# scriptable function test
874+
inverted_img_script = script_invert(img_tensor)
875+
self.assertTrue(inverted_img.equal(inverted_img_script))
876+
877+
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
878+
self._test_fn_on_batch(batch_tensors, F.invert)
879+
865880

866881
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
867882
class CUDATester(Tester):

test/test_transforms.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1749,6 +1749,38 @@ def test_gaussian_blur_asserts(self):
17491749
with self.assertRaisesRegex(ValueError, r"sigma should be a single number or a list/tuple with length 2"):
17501750
transforms.GaussianBlur(3, "sigma_string")
17511751

1752+
@unittest.skipIf(stats is None, 'scipy.stats not available')
1753+
def test_random_invert(self):
1754+
random_state = random.getstate()
1755+
random.seed(42)
1756+
img = transforms.ToPILImage()(torch.rand(3, 10, 10))
1757+
inv_img = F.invert(img)
1758+
1759+
num_samples = 250
1760+
num_inverts = 0
1761+
for _ in range(num_samples):
1762+
out = transforms.RandomInvert()(img)
1763+
if out == inv_img:
1764+
num_inverts += 1
1765+
1766+
p_value = stats.binom_test(num_inverts, num_samples, p=0.5)
1767+
random.setstate(random_state)
1768+
self.assertGreater(p_value, 0.0001)
1769+
1770+
num_samples = 250
1771+
num_inverts = 0
1772+
for _ in range(num_samples):
1773+
out = transforms.RandomInvert(p=0.7)(img)
1774+
if out == inv_img:
1775+
num_inverts += 1
1776+
1777+
p_value = stats.binom_test(num_inverts, num_samples, p=0.7)
1778+
random.setstate(random_state)
1779+
self.assertGreater(p_value, 0.0001)
1780+
1781+
# Checking if RandomInvert can be printed as string
1782+
transforms.RandomInvert().__repr__()
1783+
17521784

17531785
if __name__ == '__main__':
17541786
unittest.main()

test/test_transforms_tensor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@ def test_random_horizontal_flip(self):
8989
def test_random_vertical_flip(self):
9090
self._test_op('vflip', 'RandomVerticalFlip')
9191

92+
def test_random_invert(self):
93+
self._test_op('invert', 'RandomInvert')
94+
9295
def test_color_jitter(self):
9396

9497
tol = 1.0 + 1e-10

torchvision/transforms/functional.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1178,3 +1178,21 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[floa
11781178
if not isinstance(img, torch.Tensor):
11791179
output = to_pil_image(output)
11801180
return output
1181+
1182+
1183+
def invert(img: Tensor) -> Tensor:
1184+
"""Invert the colors of a PIL Image or torch Tensor.
1185+
1186+
Args:
1187+
img (PIL Image or Tensor): Image to have its colors inverted.
1188+
If img is a Tensor, it is expected to be in [..., H, W] format,
1189+
where ... means it can have an arbitrary number of trailing
1190+
dimensions.
1191+
1192+
Returns:
1193+
PIL Image: Color inverted image.
1194+
"""
1195+
if not isinstance(img, torch.Tensor):
1196+
return F_pil.invert(img)
1197+
1198+
return F_t.invert(img)

torchvision/transforms/functional_pil.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,3 +606,23 @@ def to_grayscale(img, num_output_channels):
606606
raise ValueError('num_output_channels should be either 1 or 3')
607607

608608
return img
609+
610+
611+
@torch.jit.unused
612+
def invert(img):
613+
"""PRIVATE METHOD. Invert the colors of an image.
614+
615+
.. warning::
616+
617+
Module ``transforms.functional_pil`` is private and should not be used in user application.
618+
Please, consider instead using methods from `transforms.functional` module.
619+
620+
Args:
621+
img (PIL Image): Image to have its colors inverted.
622+
623+
Returns:
624+
PIL Image: Color inverted image Tensor.
625+
"""
626+
if not _is_pil_image(img):
627+
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
628+
return ImageOps.invert(img)

torchvision/transforms/functional_tensor.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,3 +1179,30 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Te
11791179

11801180
img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
11811181
return img
1182+
1183+
1184+
def invert(img: Tensor) -> Tensor:
1185+
"""PRIVATE METHOD. Invert the colors of a grayscale or RGB image.
1186+
1187+
.. warning::``
1188+
1189+
Module ``transforms.functional_tensor`` is private and should not be used in user application.
1190+
Please, consider instead using methods from `transforms.functional` module.
1191+
1192+
Args:
1193+
img (Tensor): Image to have its colors inverted in the form [C, H, W].
1194+
1195+
Returns:
1196+
Tensor: Color inverted image Tensor.
1197+
"""
1198+
if not _is_tensor_a_torch_image(img):
1199+
raise TypeError('tensor is not a torch image.')
1200+
1201+
if img.ndim < 3:
1202+
raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim))
1203+
1204+
_assert_channels(img, [1, 3])
1205+
1206+
bound = 1.0 if img.is_floating_point() else 255.0
1207+
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
1208+
return (bound - img.to(dtype)).to(img.dtype)

torchvision/transforms/transforms.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
"CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop",
2222
"RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop",
2323
"LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
24-
"RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode"]
24+
"RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert"]
2525

2626

2727
class Compose:
@@ -1699,3 +1699,43 @@ def _setup_angle(x, name, req_sizes=(2, )):
16991699
_check_sequence_input(x, name, req_sizes)
17001700

17011701
return [float(d) for d in x]
1702+
1703+
1704+
class RandomInvert(torch.nn.Module):
1705+
"""Inverts the colors of the given image randomly with a given probability.
1706+
The image can be a PIL Image or a torch Tensor, in which case it is expected
1707+
to have [..., H, W] shape, where ... means an arbitrary number of leading
1708+
dimensions
1709+
1710+
Args:
1711+
p (float): probability of the image being color inverted. Default value is 0.5
1712+
"""
1713+
1714+
def __init__(self, p=0.5):
1715+
super().__init__()
1716+
self.p = p
1717+
1718+
@staticmethod
1719+
def get_params() -> float:
1720+
"""Choose value for random color inversion.
1721+
1722+
Returns:
1723+
float: Random value which is used to determine whether the random color inversion
1724+
should occur.
1725+
"""
1726+
return torch.rand(1).item()
1727+
1728+
def forward(self, img):
1729+
"""
1730+
Args:
1731+
img (PIL Image or Tensor): Image to be inverted.
1732+
1733+
Returns:
1734+
PIL Image or Tensor: Randomly color inverted image.
1735+
"""
1736+
if self.get_params() < self.p:
1737+
return F.invert(img)
1738+
return img
1739+
1740+
def __repr__(self):
1741+
return self.__class__.__name__ + '(p={})'.format(self.p)

0 commit comments

Comments
 (0)