From 6edd0652065f56f983f25a3e5e1bbc5ce1fbbffc Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 24 Oct 2022 11:46:25 +0100 Subject: [PATCH 1/4] Speed improvement for adjust gamma op --- .../prototype/transforms/functional/_color.py | 25 +++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 2c268fa4085..23d58970191 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -2,7 +2,7 @@ from torchvision.prototype import features from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT -from ._meta import _rgb_to_gray, get_dimensions_image_tensor, get_num_channels_image_tensor +from ._meta import _rgb_to_gray, get_dimensions_image_tensor, get_num_channels_image_tensor, convert_dtype_image_tensor def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor: @@ -257,7 +257,28 @@ def adjust_hue(inpt: features.InputTypeJIT, hue_factor: float) -> features.Input return adjust_hue_image_pil(inpt, hue_factor=hue_factor) -adjust_gamma_image_tensor = _FT.adjust_gamma +def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1.0) -> torch.Tensor: + if not (isinstance(image, torch.Tensor)): + raise TypeError("Input img should be Tensor image") + + c = get_num_channels_image_tensor(image) + + if c not in [1, 3]: + raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") + + if gamma < 0: + raise ValueError("Gamma should be a non-negative real number") + + if not torch.is_floating_point(image): + output = convert_dtype_image_tensor(image, torch.float32).pow_(gamma) + else: + output = image.pow(gamma) + + output = output.mul_(gain).clamp_(0.0, 1.0) + + return convert_dtype_image_tensor(output, image.dtype) + + adjust_gamma_image_pil = _FP.adjust_gamma From d9dbb6aa4832723464c46bca2cc9a5822e396c3a Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 24 Oct 2022 12:10:27 +0100 Subject: [PATCH 2/4] Adding comments and optimizations. --- torchvision/prototype/transforms/functional/_color.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 23d58970191..67de525255f 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -2,7 +2,7 @@ from torchvision.prototype import features from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT -from ._meta import _rgb_to_gray, get_dimensions_image_tensor, get_num_channels_image_tensor, convert_dtype_image_tensor +from ._meta import _rgb_to_gray, convert_dtype_image_tensor, get_dimensions_image_tensor, get_num_channels_image_tensor def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor: @@ -269,12 +269,17 @@ def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1 if gamma < 0: raise ValueError("Gamma should be a non-negative real number") + # The input image is either assumed to be at [0, 1] scale (if float) or is converted to that scale (if integer). + # Since the gamma is non-negative, the output remains at [0, 1] scale. if not torch.is_floating_point(image): output = convert_dtype_image_tensor(image, torch.float32).pow_(gamma) else: output = image.pow(gamma) - output = output.mul_(gain).clamp_(0.0, 1.0) + if gain != 1.0: + # The clamp operation is needed only if mutiplication is performed. It's only when gain != 1, that the scale + # of the output can go beyond [0, 1]. + output = output.mul_(gain).clamp_(0.0, 1.0) return convert_dtype_image_tensor(output, image.dtype) From 5c78f1e1a37baf130df56a4d4c55367104e11625 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 24 Oct 2022 12:15:36 +0100 Subject: [PATCH 3/4] fixing typo --- torchvision/prototype/transforms/functional/_color.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 67de525255f..83b9772690d 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -277,7 +277,7 @@ def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1 output = image.pow(gamma) if gain != 1.0: - # The clamp operation is needed only if mutiplication is performed. It's only when gain != 1, that the scale + # The clamp operation is needed only if multiplication is performed. It's only when gain != 1, that the scale # of the output can go beyond [0, 1]. output = output.mul_(gain).clamp_(0.0, 1.0) From c34fa4093ea3a489c935e8715116fc111110ed46 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 24 Oct 2022 12:41:34 +0100 Subject: [PATCH 4/4] Remove unnecessary channel check. --- torchvision/prototype/transforms/functional/_color.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 83b9772690d..0d5bdc31d3c 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -261,11 +261,6 @@ def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1 if not (isinstance(image, torch.Tensor)): raise TypeError("Input img should be Tensor image") - c = get_num_channels_image_tensor(image) - - if c not in [1, 3]: - raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") - if gamma < 0: raise ValueError("Gamma should be a non-negative real number")