diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index c6012fddf36..587bc12b108 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -709,7 +709,7 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: Args: img (PIL Image or Tensor): Image to be adjusted. - If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, + If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. brightness_factor (float): How much to adjust the brightness. Can be any non negative number. 0 gives a black image, 1 gives the @@ -729,6 +729,8 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: Args: img (PIL Image or Tensor): Image to be adjusted. + If img is torch Tensor, it is expected to be in [..., 3, H, W] format, + where ... means it can have an arbitrary number of leading dimensions. contrast_factor (float): How much to adjust the contrast. Can be any non negative number. 0 gives a solid gray image, 1 gives the original image while 2 increases the contrast by a factor of 2. @@ -747,6 +749,8 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: Args: img (PIL Image or Tensor): Image to be adjusted. + If img is torch Tensor, it is expected to be in [..., 3, H, W] format, + where ... means it can have an arbitrary number of leading dimensions. saturation_factor (float): How much to adjust the saturation. 0 will give a black and white image, 1 will give the original image while 2 will enhance the saturation by a factor of 2. @@ -776,6 +780,9 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: Args: img (PIL Image or Tensor): Image to be adjusted. + If img is torch Tensor, it is expected to be in [..., 3, H, W] format, + where ... means it can have an arbitrary number of leading dimensions. + If img is PIL Image mode "1", "L", "I", "F" and modes with transparency (alpha channel) are not supported. hue_factor (float): How much to shift the hue channel. Should be in [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in HSV space in positive and negative direction respectively. @@ -806,8 +813,9 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: Args: img (PIL Image or Tensor): PIL Image to be adjusted. - If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, + If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. + If img is PIL Image, modes with transparency (alpha channel) are not supported. gamma (float): Non negative real number, same as :math:`\gamma` in the equation. gamma larger than 1 make the shadows darker, while gamma smaller than 1 make dark regions lighter. @@ -1185,7 +1193,7 @@ def invert(img: Tensor) -> Tensor: Args: img (PIL Image or Tensor): Image to have its colors inverted. - If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, + If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. If img is PIL Image, it is expected to be in mode "L" or "RGB". @@ -1203,7 +1211,7 @@ def posterize(img: Tensor, bits: int) -> Tensor: Args: img (PIL Image or Tensor): Image to have its colors posterized. - If img is a Tensor, it should be of type torch.uint8 and + If img is torch Tensor, it should be of type torch.uint8 and it is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. If img is PIL Image, it is expected to be in mode "L" or "RGB". @@ -1225,7 +1233,7 @@ def solarize(img: Tensor, threshold: float) -> Tensor: Args: img (PIL Image or Tensor): Image to have its colors inverted. - If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, + If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. If img is PIL Image, it is expected to be in mode "L" or "RGB". threshold (float): All pixels equal or above this value are inverted. @@ -1243,7 +1251,7 @@ def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor: Args: img (PIL Image or Tensor): Image to be adjusted. - If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, + If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. sharpness_factor (float): How much to adjust the sharpness. Can be any non negative number. 0 gives a blurred image, 1 gives the @@ -1265,7 +1273,7 @@ def autocontrast(img: Tensor) -> Tensor: Args: img (PIL Image or Tensor): Image on which autocontrast is applied. - If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, + If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. If img is PIL Image, it is expected to be in mode "L" or "RGB". @@ -1285,7 +1293,7 @@ def equalize(img: Tensor) -> Tensor: Args: img (PIL Image or Tensor): Image on which equalize is applied. - If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, + If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. If img is PIL Image, it is expected to be in mode "P", "L" or "RGB". diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index b196ab483c0..c97fcf44fe7 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -180,7 +180,9 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: _assert_image_tensor(img) - _assert_channels(img, [3]) + _assert_channels(img, [1, 3]) + if _get_image_num_channels(img) == 1: # Match PIL behaviour + return img orig_dtype = img.dtype if img.dtype == torch.uint8: diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 1274c19795a..1a308f4a7ef 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -1043,7 +1043,8 @@ def __repr__(self): class ColorJitter(torch.nn.Module): """Randomly change the brightness, contrast, saturation and hue of an image. If the image is torch Tensor, it is expected - to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. + to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + If img is PIL Image, mode "1", "L", "I", "F" and modes with transparency (alpha channel) are not supported. Args: brightness (float or tuple of float (min, max)): How much to jitter brightness.