From 65f0d0cce5e4b0d8eff2a0bf36a49891e1eae6f5 Mon Sep 17 00:00:00 2001 From: voldemortX Date: Wed, 6 Jan 2021 19:38:39 +0800 Subject: [PATCH 1/3] adjust_hue --- torchvision/transforms/functional.py | 2 ++ torchvision/transforms/functional_tensor.py | 4 +++- torchvision/transforms/transforms.py | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index c6012fddf36..d6f55a28a35 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -776,6 +776,8 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: Args: img (PIL Image or Tensor): Image to be adjusted. + If img is a Tensor, it is expected to be a RGB image in [..., 3, H, W] format, + where ... means it can have an arbitrary number of leading dimensions. hue_factor (float): How much to shift the hue channel. Should be in [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in HSV space in positive and negative direction respectively. diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index b196ab483c0..c97fcf44fe7 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -180,7 +180,9 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: _assert_image_tensor(img) - _assert_channels(img, [3]) + _assert_channels(img, [1, 3]) + if _get_image_num_channels(img) == 1: # Match PIL behaviour + return img orig_dtype = img.dtype if img.dtype == torch.uint8: diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 1274c19795a..6825f8c7e15 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -1043,7 +1043,7 @@ def __repr__(self): class ColorJitter(torch.nn.Module): """Randomly change the brightness, contrast, saturation and hue of an image. If the image is torch Tensor, it is expected - to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. + to be RGB and have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions. Args: brightness (float or tuple of float (min, max)): How much to jitter brightness. From 5f8a81ec42fa527b1a52d21a20891970df0ff856 Mon Sep 17 00:00:00 2001 From: voldemortX Date: Wed, 6 Jan 2021 20:30:33 +0800 Subject: [PATCH 2/3] adjust_* --- torchvision/transforms/functional.py | 24 +++++++++++++++--------- torchvision/transforms/transforms.py | 2 +- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index d6f55a28a35..587bc12b108 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -709,7 +709,7 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: Args: img (PIL Image or Tensor): Image to be adjusted. - If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, + If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. brightness_factor (float): How much to adjust the brightness. Can be any non negative number. 0 gives a black image, 1 gives the @@ -729,6 +729,8 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: Args: img (PIL Image or Tensor): Image to be adjusted. + If img is torch Tensor, it is expected to be in [..., 3, H, W] format, + where ... means it can have an arbitrary number of leading dimensions. contrast_factor (float): How much to adjust the contrast. Can be any non negative number. 0 gives a solid gray image, 1 gives the original image while 2 increases the contrast by a factor of 2. @@ -747,6 +749,8 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: Args: img (PIL Image or Tensor): Image to be adjusted. + If img is torch Tensor, it is expected to be in [..., 3, H, W] format, + where ... means it can have an arbitrary number of leading dimensions. saturation_factor (float): How much to adjust the saturation. 0 will give a black and white image, 1 will give the original image while 2 will enhance the saturation by a factor of 2. @@ -776,8 +780,9 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: Args: img (PIL Image or Tensor): Image to be adjusted. - If img is a Tensor, it is expected to be a RGB image in [..., 3, H, W] format, + If img is torch Tensor, it is expected to be in [..., 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. + If img is PIL Image mode "1", "L", "I", "F" and modes with transparency (alpha channel) are not supported. hue_factor (float): How much to shift the hue channel. Should be in [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in HSV space in positive and negative direction respectively. @@ -808,8 +813,9 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: Args: img (PIL Image or Tensor): PIL Image to be adjusted. - If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, + If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. + If img is PIL Image, modes with transparency (alpha channel) are not supported. gamma (float): Non negative real number, same as :math:`\gamma` in the equation. gamma larger than 1 make the shadows darker, while gamma smaller than 1 make dark regions lighter. @@ -1187,7 +1193,7 @@ def invert(img: Tensor) -> Tensor: Args: img (PIL Image or Tensor): Image to have its colors inverted. - If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, + If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. If img is PIL Image, it is expected to be in mode "L" or "RGB". @@ -1205,7 +1211,7 @@ def posterize(img: Tensor, bits: int) -> Tensor: Args: img (PIL Image or Tensor): Image to have its colors posterized. - If img is a Tensor, it should be of type torch.uint8 and + If img is torch Tensor, it should be of type torch.uint8 and it is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. If img is PIL Image, it is expected to be in mode "L" or "RGB". @@ -1227,7 +1233,7 @@ def solarize(img: Tensor, threshold: float) -> Tensor: Args: img (PIL Image or Tensor): Image to have its colors inverted. - If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, + If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. If img is PIL Image, it is expected to be in mode "L" or "RGB". threshold (float): All pixels equal or above this value are inverted. @@ -1245,7 +1251,7 @@ def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor: Args: img (PIL Image or Tensor): Image to be adjusted. - If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, + If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. sharpness_factor (float): How much to adjust the sharpness. Can be any non negative number. 0 gives a blurred image, 1 gives the @@ -1267,7 +1273,7 @@ def autocontrast(img: Tensor) -> Tensor: Args: img (PIL Image or Tensor): Image on which autocontrast is applied. - If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, + If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. If img is PIL Image, it is expected to be in mode "L" or "RGB". @@ -1287,7 +1293,7 @@ def equalize(img: Tensor) -> Tensor: Args: img (PIL Image or Tensor): Image on which equalize is applied. - If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, + If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. If img is PIL Image, it is expected to be in mode "P", "L" or "RGB". diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 6825f8c7e15..878446f7d04 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -1043,7 +1043,7 @@ def __repr__(self): class ColorJitter(torch.nn.Module): """Randomly change the brightness, contrast, saturation and hue of an image. If the image is torch Tensor, it is expected - to be RGB and have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions. Args: brightness (float or tuple of float (min, max)): How much to jitter brightness. From 018bfcc1d359a2dc9962c3bca348940245d30849 Mon Sep 17 00:00:00 2001 From: voldemortX Date: Wed, 6 Jan 2021 20:39:33 +0800 Subject: [PATCH 3/3] colorjitter --- torchvision/transforms/transforms.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 878446f7d04..1a308f4a7ef 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -1044,6 +1044,7 @@ class ColorJitter(torch.nn.Module): """Randomly change the brightness, contrast, saturation and hue of an image. If the image is torch Tensor, it is expected to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + If img is PIL Image, mode "1", "L", "I", "F" and modes with transparency (alpha channel) are not supported. Args: brightness (float or tuple of float (min, max)): How much to jitter brightness.