diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 2c268fa4085..0d5bdc31d3c 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -2,7 +2,7 @@ from torchvision.prototype import features from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT -from ._meta import _rgb_to_gray, get_dimensions_image_tensor, get_num_channels_image_tensor +from ._meta import _rgb_to_gray, convert_dtype_image_tensor, get_dimensions_image_tensor, get_num_channels_image_tensor def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor: @@ -257,7 +257,28 @@ def adjust_hue(inpt: features.InputTypeJIT, hue_factor: float) -> features.Input return adjust_hue_image_pil(inpt, hue_factor=hue_factor) -adjust_gamma_image_tensor = _FT.adjust_gamma +def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1.0) -> torch.Tensor: + if not (isinstance(image, torch.Tensor)): + raise TypeError("Input img should be Tensor image") + + if gamma < 0: + raise ValueError("Gamma should be a non-negative real number") + + # The input image is either assumed to be at [0, 1] scale (if float) or is converted to that scale (if integer). + # Since the gamma is non-negative, the output remains at [0, 1] scale. + if not torch.is_floating_point(image): + output = convert_dtype_image_tensor(image, torch.float32).pow_(gamma) + else: + output = image.pow(gamma) + + if gain != 1.0: + # The clamp operation is needed only if multiplication is performed. It's only when gain != 1, that the scale + # of the output can go beyond [0, 1]. + output = output.mul_(gain).clamp_(0.0, 1.0) + + return convert_dtype_image_tensor(output, image.dtype) + + adjust_gamma_image_pil = _FP.adjust_gamma