Skip to content

Commit 7325e1d

Browse files
authored
Adjust adjust_* transforms (#3222)
* adjust_hue * adjust_* * colorjitter
1 parent 6315358 commit 7325e1d

File tree

3 files changed

+21
-10
lines changed

3 files changed

+21
-10
lines changed

torchvision/transforms/functional.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -709,7 +709,7 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
709709
710710
Args:
711711
img (PIL Image or Tensor): Image to be adjusted.
712-
If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
712+
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
713713
where ... means it can have an arbitrary number of leading dimensions.
714714
brightness_factor (float): How much to adjust the brightness. Can be
715715
any non negative number. 0 gives a black image, 1 gives the
@@ -729,6 +729,8 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
729729
730730
Args:
731731
img (PIL Image or Tensor): Image to be adjusted.
732+
If img is torch Tensor, it is expected to be in [..., 3, H, W] format,
733+
where ... means it can have an arbitrary number of leading dimensions.
732734
contrast_factor (float): How much to adjust the contrast. Can be any
733735
non negative number. 0 gives a solid gray image, 1 gives the
734736
original image while 2 increases the contrast by a factor of 2.
@@ -747,6 +749,8 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
747749
748750
Args:
749751
img (PIL Image or Tensor): Image to be adjusted.
752+
If img is torch Tensor, it is expected to be in [..., 3, H, W] format,
753+
where ... means it can have an arbitrary number of leading dimensions.
750754
saturation_factor (float): How much to adjust the saturation. 0 will
751755
give a black and white image, 1 will give the original image while
752756
2 will enhance the saturation by a factor of 2.
@@ -776,6 +780,9 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
776780
777781
Args:
778782
img (PIL Image or Tensor): Image to be adjusted.
783+
If img is torch Tensor, it is expected to be in [..., 3, H, W] format,
784+
where ... means it can have an arbitrary number of leading dimensions.
785+
If img is PIL Image mode "1", "L", "I", "F" and modes with transparency (alpha channel) are not supported.
779786
hue_factor (float): How much to shift the hue channel. Should be in
780787
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
781788
HSV space in positive and negative direction respectively.
@@ -806,8 +813,9 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
806813
807814
Args:
808815
img (PIL Image or Tensor): PIL Image to be adjusted.
809-
If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
816+
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
810817
where ... means it can have an arbitrary number of leading dimensions.
818+
If img is PIL Image, modes with transparency (alpha channel) are not supported.
811819
gamma (float): Non negative real number, same as :math:`\gamma` in the equation.
812820
gamma larger than 1 make the shadows darker,
813821
while gamma smaller than 1 make dark regions lighter.
@@ -1185,7 +1193,7 @@ def invert(img: Tensor) -> Tensor:
11851193
11861194
Args:
11871195
img (PIL Image or Tensor): Image to have its colors inverted.
1188-
If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
1196+
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
11891197
where ... means it can have an arbitrary number of leading dimensions.
11901198
If img is PIL Image, it is expected to be in mode "L" or "RGB".
11911199
@@ -1203,7 +1211,7 @@ def posterize(img: Tensor, bits: int) -> Tensor:
12031211
12041212
Args:
12051213
img (PIL Image or Tensor): Image to have its colors posterized.
1206-
If img is a Tensor, it should be of type torch.uint8 and
1214+
If img is torch Tensor, it should be of type torch.uint8 and
12071215
it is expected to be in [..., 1 or 3, H, W] format, where ... means
12081216
it can have an arbitrary number of leading dimensions.
12091217
If img is PIL Image, it is expected to be in mode "L" or "RGB".
@@ -1225,7 +1233,7 @@ def solarize(img: Tensor, threshold: float) -> Tensor:
12251233
12261234
Args:
12271235
img (PIL Image or Tensor): Image to have its colors inverted.
1228-
If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
1236+
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
12291237
where ... means it can have an arbitrary number of leading dimensions.
12301238
If img is PIL Image, it is expected to be in mode "L" or "RGB".
12311239
threshold (float): All pixels equal or above this value are inverted.
@@ -1243,7 +1251,7 @@ def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
12431251
12441252
Args:
12451253
img (PIL Image or Tensor): Image to be adjusted.
1246-
If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
1254+
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
12471255
where ... means it can have an arbitrary number of leading dimensions.
12481256
sharpness_factor (float): How much to adjust the sharpness. Can be
12491257
any non negative number. 0 gives a blurred image, 1 gives the
@@ -1265,7 +1273,7 @@ def autocontrast(img: Tensor) -> Tensor:
12651273
12661274
Args:
12671275
img (PIL Image or Tensor): Image on which autocontrast is applied.
1268-
If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
1276+
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
12691277
where ... means it can have an arbitrary number of leading dimensions.
12701278
If img is PIL Image, it is expected to be in mode "L" or "RGB".
12711279
@@ -1285,7 +1293,7 @@ def equalize(img: Tensor) -> Tensor:
12851293
12861294
Args:
12871295
img (PIL Image or Tensor): Image on which equalize is applied.
1288-
If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
1296+
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
12891297
where ... means it can have an arbitrary number of leading dimensions.
12901298
If img is PIL Image, it is expected to be in mode "P", "L" or "RGB".
12911299

torchvision/transforms/functional_tensor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,9 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
180180

181181
_assert_image_tensor(img)
182182

183-
_assert_channels(img, [3])
183+
_assert_channels(img, [1, 3])
184+
if _get_image_num_channels(img) == 1: # Match PIL behaviour
185+
return img
184186

185187
orig_dtype = img.dtype
186188
if img.dtype == torch.uint8:

torchvision/transforms/transforms.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1043,7 +1043,8 @@ def __repr__(self):
10431043
class ColorJitter(torch.nn.Module):
10441044
"""Randomly change the brightness, contrast, saturation and hue of an image.
10451045
If the image is torch Tensor, it is expected
1046-
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
1046+
to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
1047+
If img is PIL Image, mode "1", "L", "I", "F" and modes with transparency (alpha channel) are not supported.
10471048
10481049
Args:
10491050
brightness (float or tuple of float (min, max)): How much to jitter brightness.

0 commit comments

Comments
 (0)