From f70229b607fe54ad53c0660a822f832e4e8aaf9b Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 20 Oct 2022 10:24:32 +0000 Subject: [PATCH 1/8] WIP --- .../prototype/transforms/functional/_color.py | 76 ++++++++++++++++++- 1 file changed, 74 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 49a769e04e0..b465c5d0b90 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -2,7 +2,7 @@ from torchvision.prototype import features from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT -from ._meta import get_dimensions_image_tensor +from ._meta import get_dimensions_image_tensor, get_num_channels_image_tensor adjust_brightness_image_tensor = _FT.adjust_brightness adjust_brightness_image_pil = _FP.adjust_brightness @@ -98,7 +98,79 @@ def adjust_sharpness(inpt: features.InputTypeJIT, sharpness_factor: float) -> fe return adjust_sharpness_image_pil(inpt, sharpness_factor=sharpness_factor) -adjust_hue_image_tensor = _FT.adjust_hue +def _rgb_to_hsv(image: torch.Tensor) -> torch.Tensor: + r, g, b = image.unbind(dim=-3) + + # Implementation is based on https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/ + # src/libImaging/Convert.c#L330 + minc, maxc = torch.aminmax(image, dim=-3) + + # The algorithm erases S and H channel where `maxc = minc`. This avoids NaN + # from happening in the results, because + # + S channel has division by `maxc`, which is zero only if `maxc = minc` + # + H channel has division by `(maxc - minc)`. + # + # Instead of overwriting NaN afterwards, we just prevent it from occuring so + # we don't need to deal with it in case we save the NaN in a buffer in + # backprop, if it is ever supported, but it doesn't hurt to do so. + eqc = maxc == minc + + cr = maxc - minc + # Since `eqc => cr = 0`, replacing denominator with 1 when `eqc` is fine. + ones = torch.ones_like(maxc) + s = cr / torch.where(eqc, ones, maxc) + # Note that `eqc => maxc = minc = r = g = b`. So the following calculation + # of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it + # would not matter what values `rc`, `gc`, and `bc` have here, and thus + # replacing denominator with 1 when `eqc` is fine. + cr_divisor = torch.where(eqc, ones, cr) + rc = (maxc - r).div_(cr_divisor) + gc = (maxc - g).div_(cr_divisor) + bc = (maxc - b).div_(cr_divisor) + + mask_maxc_eq_r = maxc == r + mask_maxc_neq_r = ~mask_maxc_eq_r + mask_maxc_eq_g = maxc == g + mask_maxc_neq_g = ~mask_maxc_eq_g + hr = mask_maxc_eq_r * (bc - gc) + hg = (mask_maxc_eq_g & mask_maxc_neq_r) * (2.0 + rc - bc) + hb = (mask_maxc_neq_g & mask_maxc_neq_r) * (4.0 + gc - rc) + h = hr + hg + hb + h = torch.fmod((h / 6.0 + 1.0), 1.0) + return torch.stack((h, s, maxc), dim=-3) + + +def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Tensor: + if not (-0.5 <= hue_factor <= 0.5): + raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].") + + if not (isinstance(image, torch.Tensor)): + raise TypeError("Input img should be Tensor image") + + c = get_num_channels_image_tensor(image) + + if c not in [1, 3]: + raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") + + if c == 1: # Match PIL behaviour + return image + + orig_dtype = image.dtype + if image.dtype == torch.uint8: + image = image.to(dtype=torch.float32) / 255.0 + + image = _rgb_to_hsv(image) + h, s, v = image.unbind(dim=-3) + h = (h + hue_factor) % 1.0 + image = torch.stack((h, s, v), dim=-3) + image_hue_adj = _FT._hsv2rgb(image) + + if orig_dtype == torch.uint8: + image_hue_adj = (image_hue_adj * 255.0).to(dtype=orig_dtype) + + return image_hue_adj + + adjust_hue_image_pil = _FP.adjust_hue From 6ebc646e3dcf7586f6966ea0473e15d76a5862cc Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 20 Oct 2022 15:24:41 +0000 Subject: [PATCH 2/8] Updated rgb2hsv and a bit of hsv2rgb --- .../prototype/transforms/functional/_color.py | 41 ++++++++++++++----- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index b465c5d0b90..f3f2b8aa50f 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -124,22 +124,41 @@ def _rgb_to_hsv(image: torch.Tensor) -> torch.Tensor: # would not matter what values `rc`, `gc`, and `bc` have here, and thus # replacing denominator with 1 when `eqc` is fine. cr_divisor = torch.where(eqc, ones, cr) - rc = (maxc - r).div_(cr_divisor) - gc = (maxc - g).div_(cr_divisor) - bc = (maxc - b).div_(cr_divisor) + rc, gc, bc = ((maxc - image) / cr_divisor).unbind(dim=-3) - mask_maxc_eq_r = maxc == r - mask_maxc_neq_r = ~mask_maxc_eq_r + mask_maxc_neq_r = maxc != r mask_maxc_eq_g = maxc == g mask_maxc_neq_g = ~mask_maxc_eq_g - hr = mask_maxc_eq_r * (bc - gc) - hg = (mask_maxc_eq_g & mask_maxc_neq_r) * (2.0 + rc - bc) - hb = (mask_maxc_neq_g & mask_maxc_neq_r) * (4.0 + gc - rc) - h = hr + hg + hb - h = torch.fmod((h / 6.0 + 1.0), 1.0) + + hr = (bc - gc).mul_(~mask_maxc_neq_r) + hg = (2.0 + rc).sub_(bc).mul_(mask_maxc_eq_g & mask_maxc_neq_r) + hb = (4.0 + gc).sub_(rc).mul_(mask_maxc_neq_g & mask_maxc_neq_r) + + h = hr.add_(hg).add_(hb) + h = h.div_(6.0).add_(1.0).fmod_(1.0) return torch.stack((h, s, maxc), dim=-3) +def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor: + h, s, v = img.unbind(dim=-3) + i = torch.floor(h * 6.0) + f = (h * 6.0) - i + i = i.to(dtype=torch.int32) + + p = (v * (1.0 - s)).clamp_(0.0, 1.0) + q = (v * (1.0 - s * f)).clamp_(0.0, 1.0) + t = (v * (1.0 - s * (1.0 - f))).clamp_(0.0, 1.0) + i = i % 6 + + mask = i.unsqueeze(dim=-3) == torch.arange(6, device=i.device).view(-1, 1, 1) + + a1 = torch.stack((v, q, p, p, t, v), dim=-3) + a2 = torch.stack((t, v, v, q, p, p), dim=-3) + a3 = torch.stack((p, p, t, v, v, q), dim=-3) + a4 = torch.stack((a1, a2, a3), dim=-4) + + return torch.einsum("...ijk, ...xijk -> ...xjk", mask.to(dtype=img.dtype), a4) + def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Tensor: if not (-0.5 <= hue_factor <= 0.5): raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].") @@ -163,7 +182,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten h, s, v = image.unbind(dim=-3) h = (h + hue_factor) % 1.0 image = torch.stack((h, s, v), dim=-3) - image_hue_adj = _FT._hsv2rgb(image) + image_hue_adj = _hsv_to_rgb(image) if orig_dtype == torch.uint8: image_hue_adj = (image_hue_adj * 255.0).to(dtype=orig_dtype) From b23959288d9dbafeddc56707c67f79fb2e140b59 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 20 Oct 2022 20:58:21 +0000 Subject: [PATCH 3/8] Fix issue with batch of images --- torchvision/prototype/transforms/functional/_color.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index f3f2b8aa50f..5614c653fb2 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -99,7 +99,7 @@ def adjust_sharpness(inpt: features.InputTypeJIT, sharpness_factor: float) -> fe def _rgb_to_hsv(image: torch.Tensor) -> torch.Tensor: - r, g, b = image.unbind(dim=-3) + r, g, _ = image.unbind(dim=-3) # Implementation is based on https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/ # src/libImaging/Convert.c#L330 @@ -123,8 +123,8 @@ def _rgb_to_hsv(image: torch.Tensor) -> torch.Tensor: # of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it # would not matter what values `rc`, `gc`, and `bc` have here, and thus # replacing denominator with 1 when `eqc` is fine. - cr_divisor = torch.where(eqc, ones, cr) - rc, gc, bc = ((maxc - image) / cr_divisor).unbind(dim=-3) + cr_divisor = torch.where(eqc, ones, cr).unsqueeze_(dim=-3) + rc, gc, bc = ((maxc.unsqueeze(dim=-3) - image) / cr_divisor).unbind(dim=-3) mask_maxc_neq_r = maxc != r mask_maxc_eq_g = maxc == g From 8201a1835f3530061fb5b3af21956beca06312c7 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 21 Oct 2022 08:05:15 +0000 Subject: [PATCH 4/8] Few improvements --- torchvision/prototype/transforms/functional/_color.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 5614c653fb2..87485b394bb 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -176,7 +176,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten orig_dtype = image.dtype if image.dtype == torch.uint8: - image = image.to(dtype=torch.float32) / 255.0 + image = image / 255.0 image = _rgb_to_hsv(image) h, s, v = image.unbind(dim=-3) @@ -185,7 +185,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten image_hue_adj = _hsv_to_rgb(image) if orig_dtype == torch.uint8: - image_hue_adj = (image_hue_adj * 255.0).to(dtype=orig_dtype) + image_hue_adj = image_hue_adj.mul_(255.0).to(dtype=orig_dtype) return image_hue_adj From df86cc9b17f9d9abcd356fe5bc71f2836e2579d1 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 21 Oct 2022 11:19:32 +0000 Subject: [PATCH 5/8] hsv2rgb improvements --- .../prototype/transforms/functional/_color.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 87485b394bb..f126ee5a9be 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -141,14 +141,15 @@ def _rgb_to_hsv(image: torch.Tensor) -> torch.Tensor: def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor: h, s, v = img.unbind(dim=-3) - i = torch.floor(h * 6.0) - f = (h * 6.0) - i + h6 = h * 6 + i = torch.floor(h6) + f = (h6) - i i = i.to(dtype=torch.int32) p = (v * (1.0 - s)).clamp_(0.0, 1.0) q = (v * (1.0 - s * f)).clamp_(0.0, 1.0) t = (v * (1.0 - s * (1.0 - f))).clamp_(0.0, 1.0) - i = i % 6 + i.remainder_(6) mask = i.unsqueeze(dim=-3) == torch.arange(6, device=i.device).view(-1, 1, 1) @@ -157,7 +158,8 @@ def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor: a3 = torch.stack((p, p, t, v, v, q), dim=-3) a4 = torch.stack((a1, a2, a3), dim=-4) - return torch.einsum("...ijk, ...xijk -> ...xjk", mask.to(dtype=img.dtype), a4) + return (a4.mul_(mask.to(dtype=img.dtype).unsqueeze(dim=-4))).sum_(dim=-3) + def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Tensor: if not (-0.5 <= hue_factor <= 0.5): @@ -180,7 +182,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten image = _rgb_to_hsv(image) h, s, v = image.unbind(dim=-3) - h = (h + hue_factor) % 1.0 + h = (h + hue_factor).remainder_(1.0) image = torch.stack((h, s, v), dim=-3) image_hue_adj = _hsv_to_rgb(image) From 0fbc66dcc4a5066864b8e3a8c8ac276a61058a2c Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 21 Oct 2022 12:17:44 +0000 Subject: [PATCH 6/8] PR review --- .../prototype/transforms/functional/_color.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index a3ae22720f5..36b22ebda62 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -146,8 +146,8 @@ def adjust_sharpness(inpt: features.InputTypeJIT, sharpness_factor: float) -> fe def _rgb_to_hsv(image: torch.Tensor) -> torch.Tensor: r, g, _ = image.unbind(dim=-3) - # Implementation is based on https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/ - # src/libImaging/Convert.c#L330 + # Implementation is based on + # https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/src/libImaging/Convert.c#L330 minc, maxc = torch.aminmax(image, dim=-3) # The algorithm erases S and H channel where `maxc = minc`. This avoids NaN @@ -160,16 +160,16 @@ def _rgb_to_hsv(image: torch.Tensor) -> torch.Tensor: # backprop, if it is ever supported, but it doesn't hurt to do so. eqc = maxc == minc - cr = maxc - minc - # Since `eqc => cr = 0`, replacing denominator with 1 when `eqc` is fine. + channels_range = maxc - minc + # Since `eqc => channels_range = 0`, replacing denominator with 1 when `eqc` is fine. ones = torch.ones_like(maxc) - s = cr / torch.where(eqc, ones, maxc) + s = channels_range / torch.where(eqc, ones, maxc) # Note that `eqc => maxc = minc = r = g = b`. So the following calculation # of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it # would not matter what values `rc`, `gc`, and `bc` have here, and thus # replacing denominator with 1 when `eqc` is fine. - cr_divisor = torch.where(eqc, ones, cr).unsqueeze_(dim=-3) - rc, gc, bc = ((maxc.unsqueeze(dim=-3) - image) / cr_divisor).unbind(dim=-3) + channels_range_divisor = torch.where(eqc, ones, channels_range).unsqueeze_(dim=-3) + rc, gc, bc = ((maxc.unsqueeze(dim=-3) - image) / channels_range_divisor).unbind(dim=-3) mask_maxc_neq_r = maxc != r mask_maxc_eq_g = maxc == g From 6bde2919bc705699f1be6713da772efaf5d0dabe Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 21 Oct 2022 12:19:05 +0000 Subject: [PATCH 7/8] another update --- torchvision/prototype/transforms/functional/_color.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 36b22ebda62..94f48028654 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -227,7 +227,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten image = _rgb_to_hsv(image) h, s, v = image.unbind(dim=-3) - h = (h + hue_factor).remainder_(1.0) + h.add_(hue_factor).remainder_(1.0) image = torch.stack((h, s, v), dim=-3) image_hue_adj = _hsv_to_rgb(image) From 49e5d7fef001f9d951c13d842b93762d898ff555 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 21 Oct 2022 14:03:16 +0000 Subject: [PATCH 8/8] Fix cuda issue with empty images torch.aminmax is failing --- torchvision/prototype/transforms/functional/_color.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 94f48028654..bb825e60eef 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -221,6 +221,10 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten if c == 1: # Match PIL behaviour return image + if image.numel() == 0: + # exit earlier on empty images + return image + orig_dtype = image.dtype if image.dtype == torch.uint8: image = image / 255.0