Skip to content

Commit d481f2d

Browse files
nairbvvfdev-5
andauthored
Add torchscriptable adjust_gamma transform (#2459)
* add torchscriptable adjust_gamma transform #1375 * changes based on code-review * Apply suggested change to add type hint Required by mypy, even thought technically incorrect due to possible Image parameter. torchscript doesn't support a union based type hint. Co-authored-by: vfdev <[email protected]> Co-authored-by: vfdev <[email protected]>
1 parent ab73b44 commit d481f2d

File tree

5 files changed

+122
-18
lines changed

5 files changed

+122
-18
lines changed

test/test_functional_tensor.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ def _create_data(self, height=3, width=3, channels=3):
2424

2525
def compareTensorToPIL(self, tensor, pil_image, msg=None):
2626
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1)))
27+
if msg is None:
28+
msg = "tensor:\n{} \ndid not equal PIL tensor:\n{}".format(tensor, pil_tensor)
2729
self.assertTrue(tensor.equal(pil_tensor), msg)
2830

2931
def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None):
@@ -300,6 +302,33 @@ def test_pad(self):
300302
with self.assertRaises(ValueError, msg="Padding can not be negative for symmetric padding_mode"):
301303
F_t.pad(tensor, (-2, -3), padding_mode="symmetric")
302304

305+
def test_adjust_gamma(self):
306+
script_fn = torch.jit.script(F_t.adjust_gamma)
307+
tensor, pil_img = self._create_data(26, 36)
308+
309+
for dt in [torch.float64, torch.float32, None]:
310+
311+
if dt is not None:
312+
tensor = F.convert_image_dtype(tensor, dt)
313+
314+
gammas = [0.8, 1.0, 1.2]
315+
gains = [0.7, 1.0, 1.3]
316+
for gamma, gain in zip(gammas, gains):
317+
318+
adjusted_tensor = F_t.adjust_gamma(tensor, gamma, gain)
319+
adjusted_pil = F_pil.adjust_gamma(pil_img, gamma, gain)
320+
scripted_result = script_fn(tensor, gamma, gain)
321+
self.assertEqual(adjusted_tensor.dtype, scripted_result.dtype)
322+
self.assertEqual(adjusted_tensor.size()[1:], adjusted_pil.size[::-1])
323+
324+
rbg_tensor = adjusted_tensor
325+
if adjusted_tensor.dtype != torch.uint8:
326+
rbg_tensor = F.convert_image_dtype(adjusted_tensor, torch.uint8)
327+
328+
self.compareTensorToPIL(rbg_tensor, adjusted_pil)
329+
330+
self.assertTrue(adjusted_tensor.equal(scripted_result))
331+
303332
def test_resize(self):
304333
script_fn = torch.jit.script(F_t.resize)
305334
tensor, pil_img = self._create_data(26, 36)

test/test_transforms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,14 +1179,14 @@ def test_adjust_gamma(self):
11791179
# test 1
11801180
y_pil = F.adjust_gamma(x_pil, 0.5)
11811181
y_np = np.array(y_pil)
1182-
y_ans = [0, 35, 57, 117, 185, 240, 97, 45, 244, 151, 255, 15]
1182+
y_ans = [0, 35, 57, 117, 186, 241, 97, 45, 245, 152, 255, 16]
11831183
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
11841184
self.assertTrue(np.allclose(y_np, y_ans))
11851185

11861186
# test 2
11871187
y_pil = F.adjust_gamma(x_pil, 2)
11881188
y_np = np.array(y_pil)
1189-
y_ans = [0, 0, 0, 11, 71, 200, 5, 0, 214, 31, 255, 0]
1189+
y_ans = [0, 0, 0, 11, 71, 201, 5, 0, 215, 31, 255, 0]
11901190
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
11911191
self.assertTrue(np.allclose(y_np, y_ans))
11921192

torchvision/transforms/functional.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,14 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -
160160
msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
161161
raise RuntimeError(msg)
162162

163+
# https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
164+
# For data in the range 0-1, (float * 255).to(uint) is only 255
165+
# when float is exactly 1.0.
166+
# `max + 1 - epsilon` provides more evenly distributed mapping of
167+
# ranges of floats to ints.
163168
eps = 1e-3
164-
return image.mul(torch.iinfo(dtype).max + 1 - eps).to(dtype)
169+
result = image.mul(torch.iinfo(dtype).max + 1 - eps)
170+
return result.to(dtype)
165171
else:
166172
# int to float
167173
if dtype.is_floating_point:
@@ -722,7 +728,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
722728
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
723729

724730

725-
def adjust_gamma(img, gamma, gain=1):
731+
def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
726732
r"""Perform gamma correction on an image.
727733
728734
Also known as Power Law Transform. Intensities in RGB mode are adjusted
@@ -736,26 +742,18 @@ def adjust_gamma(img, gamma, gain=1):
736742
.. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction
737743
738744
Args:
739-
img (PIL Image): PIL Image to be adjusted.
745+
img (PIL Image or Tensor): PIL Image to be adjusted.
740746
gamma (float): Non negative real number, same as :math:`\gamma` in the equation.
741747
gamma larger than 1 make the shadows darker,
742748
while gamma smaller than 1 make dark regions lighter.
743749
gain (float): The constant multiplier.
750+
Returns:
751+
PIL Image or Tensor: Gamma correction adjusted image.
744752
"""
745-
if not F_pil._is_pil_image(img):
746-
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
747-
748-
if gamma < 0:
749-
raise ValueError('Gamma should be a non-negative real number')
750-
751-
input_mode = img.mode
752-
img = img.convert('RGB')
753-
754-
gamma_map = [255 * gain * pow(ele / 255., gamma) for ele in range(256)] * 3
755-
img = img.point(gamma_map) # use PIL's point-function to accelerate this part
753+
if not isinstance(img, torch.Tensor):
754+
return F_pil.adjust_gamma(img, gamma, gain)
756755

757-
img = img.convert(input_mode)
758-
return img
756+
return F_t.adjust_gamma(img, gamma, gain)
759757

760758

761759
def rotate(img, angle, resample=False, expand=False, center=None, fill=None):

torchvision/transforms/functional_pil.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,42 @@ def adjust_hue(img, hue_factor):
165165
return img
166166

167167

168+
@torch.jit.unused
169+
def adjust_gamma(img, gamma, gain=1):
170+
r"""Perform gamma correction on an image.
171+
172+
Also known as Power Law Transform. Intensities in RGB mode are adjusted
173+
based on the following equation:
174+
175+
.. math::
176+
I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma}
177+
178+
See `Gamma Correction`_ for more details.
179+
180+
.. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction
181+
182+
Args:
183+
img (PIL Image): PIL Image to be adjusted.
184+
gamma (float): Non negative real number, same as :math:`\gamma` in the equation.
185+
gamma larger than 1 make the shadows darker,
186+
while gamma smaller than 1 make dark regions lighter.
187+
gain (float): The constant multiplier.
188+
"""
189+
if not _is_pil_image(img):
190+
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
191+
192+
if gamma < 0:
193+
raise ValueError('Gamma should be a non-negative real number')
194+
195+
input_mode = img.mode
196+
img = img.convert('RGB')
197+
gamma_map = [(255 + 1 - 1e-3) * gain * pow(ele / 255., gamma) for ele in range(256)] * 3
198+
img = img.point(gamma_map) # use PIL's point-function to accelerate this part
199+
200+
img = img.convert(input_mode)
201+
return img
202+
203+
168204
@torch.jit.unused
169205
def pad(img, padding, fill=0, padding_mode="constant"):
170206
r"""Pad the given PIL.Image on all sides with the given "pad" value.

torchvision/transforms/functional_tensor.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,47 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
198198
return _blend(img, rgb_to_grayscale(img), saturation_factor)
199199

200200

201+
def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
202+
r"""Adjust gamma of an RGB image.
203+
204+
Also known as Power Law Transform. Intensities in RGB mode are adjusted
205+
based on the following equation:
206+
207+
.. math::
208+
`I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma}`
209+
210+
See `Gamma Correction`_ for more details.
211+
212+
.. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction
213+
214+
Args:
215+
img (Tensor): Tensor of RBG values to be adjusted.
216+
gamma (float): Non negative real number, same as :math:`\gamma` in the equation.
217+
gamma larger than 1 make the shadows darker,
218+
while gamma smaller than 1 make dark regions lighter.
219+
gain (float): The constant multiplier.
220+
"""
221+
222+
if not isinstance(img, torch.Tensor):
223+
raise TypeError('img should be a Tensor. Got {}'.format(type(img)))
224+
225+
if gamma < 0:
226+
raise ValueError('Gamma should be a non-negative real number')
227+
228+
result = img
229+
dtype = img.dtype
230+
if not torch.is_floating_point(img):
231+
result = result / 255.0
232+
233+
result = (gain * result ** gamma).clamp(0, 1)
234+
235+
if result.dtype != dtype:
236+
eps = 1e-3
237+
result = (255 + 1.0 - eps) * result
238+
result = result.to(dtype)
239+
return result
240+
241+
201242
def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
202243
"""Crop the Image Tensor and resize it to desired size.
203244

0 commit comments

Comments
 (0)