Skip to content

Adjust adjust_* transforms #3222

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,7 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:

Args:
img (PIL Image or Tensor): Image to be adjusted.
If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
brightness_factor (float): How much to adjust the brightness. Can be
any non negative number. 0 gives a black image, 1 gives the
Expand All @@ -729,6 +729,8 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:

Args:
img (PIL Image or Tensor): Image to be adjusted.
If img is torch Tensor, it is expected to be in [..., 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
contrast_factor (float): How much to adjust the contrast. Can be any
non negative number. 0 gives a solid gray image, 1 gives the
original image while 2 increases the contrast by a factor of 2.
Expand All @@ -747,6 +749,8 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:

Args:
img (PIL Image or Tensor): Image to be adjusted.
If img is torch Tensor, it is expected to be in [..., 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
saturation_factor (float): How much to adjust the saturation. 0 will
give a black and white image, 1 will give the original image while
2 will enhance the saturation by a factor of 2.
Expand Down Expand Up @@ -776,6 +780,9 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:

Args:
img (PIL Image or Tensor): Image to be adjusted.
If img is torch Tensor, it is expected to be in [..., 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
If img is PIL Image mode "1", "L", "I", "F" and modes with transparency (alpha channel) are not supported.
hue_factor (float): How much to shift the hue channel. Should be in
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
HSV space in positive and negative direction respectively.
Expand Down Expand Up @@ -806,8 +813,9 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:

Args:
img (PIL Image or Tensor): PIL Image to be adjusted.
If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
If img is PIL Image, modes with transparency (alpha channel) are not supported.
gamma (float): Non negative real number, same as :math:`\gamma` in the equation.
gamma larger than 1 make the shadows darker,
while gamma smaller than 1 make dark regions lighter.
Expand Down Expand Up @@ -1185,7 +1193,7 @@ def invert(img: Tensor) -> Tensor:

Args:
img (PIL Image or Tensor): Image to have its colors inverted.
If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".

Expand All @@ -1203,7 +1211,7 @@ def posterize(img: Tensor, bits: int) -> Tensor:

Args:
img (PIL Image or Tensor): Image to have its colors posterized.
If img is a Tensor, it should be of type torch.uint8 and
If img is torch Tensor, it should be of type torch.uint8 and
it is expected to be in [..., 1 or 3, H, W] format, where ... means
it can have an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Expand All @@ -1225,7 +1233,7 @@ def solarize(img: Tensor, threshold: float) -> Tensor:

Args:
img (PIL Image or Tensor): Image to have its colors inverted.
If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
threshold (float): All pixels equal or above this value are inverted.
Expand All @@ -1243,7 +1251,7 @@ def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:

Args:
img (PIL Image or Tensor): Image to be adjusted.
If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
sharpness_factor (float): How much to adjust the sharpness. Can be
any non negative number. 0 gives a blurred image, 1 gives the
Expand All @@ -1265,7 +1273,7 @@ def autocontrast(img: Tensor) -> Tensor:

Args:
img (PIL Image or Tensor): Image on which autocontrast is applied.
If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".

Expand All @@ -1285,7 +1293,7 @@ def equalize(img: Tensor) -> Tensor:

Args:
img (PIL Image or Tensor): Image on which equalize is applied.
If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "P", "L" or "RGB".

Expand Down
4 changes: 3 additions & 1 deletion torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,9 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:

_assert_image_tensor(img)

_assert_channels(img, [3])
_assert_channels(img, [1, 3])
if _get_image_num_channels(img) == 1: # Match PIL behaviour
return img

orig_dtype = img.dtype
if img.dtype == torch.uint8:
Expand Down
3 changes: 2 additions & 1 deletion torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,7 +1043,8 @@ def __repr__(self):
class ColorJitter(torch.nn.Module):
"""Randomly change the brightness, contrast, saturation and hue of an image.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, mode "1", "L", "I", "F" and modes with transparency (alpha channel) are not supported.

Args:
brightness (float or tuple of float (min, max)): How much to jitter brightness.
Expand Down