Skip to content

Commit 794bbb1

Browse files
committed
make convert_image_dtype scriptable
1 parent d481f2d commit 794bbb1

File tree

3 files changed

+35
-14
lines changed

3 files changed

+35
-14
lines changed

test/test_transforms.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,10 @@ def test_to_tensor(self):
526526
output = trans(img)
527527
self.assertTrue(np.allclose(input_data.numpy(), output.numpy()))
528528

529+
def test_max_value(self):
530+
for dtype in int_dtypes():
531+
self.assertEqual(F._max_value(dtype), torch.iinfo(dtype).max)
532+
529533
def test_convert_image_dtype_float_to_float(self):
530534
for input_dtype, output_dtypes in cycle_over(float_dtypes()):
531535
input_image = torch.tensor((0.0, 1.0), dtype=input_dtype)

torchvision/transforms/functional.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,24 @@ def pil_to_tensor(pic):
124124
return img
125125

126126

127-
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
127+
# torch.iinfo isn't scriptable so using this helper function
128+
# https://github.com/pytorch/pytorch/issues/41492
129+
def _max_value(dtype: int) -> int:
130+
a = torch.tensor(2, dtype=dtype)
131+
signed = 1 if torch.tensor(0, dtype=dtype).is_signed() else 0
132+
bits = 1
133+
max_value = torch.tensor(-signed, dtype=torch.long)
134+
while(True):
135+
next_value = a.pow(bits - signed).sub(1)
136+
if next_value > max_value:
137+
max_value = next_value
138+
bits *= 2
139+
else:
140+
return max_value.item()
141+
return max_value.item()
142+
143+
144+
def convert_image_dtype(image: torch.Tensor, dtype: int = torch.float) -> torch.Tensor:
128145
"""Convert a tensor image to the given ``dtype`` and scale the values accordingly
129146
130147
Args:
@@ -148,9 +165,9 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -
148165
if image.dtype == dtype:
149166
return image
150167

151-
if image.dtype.is_floating_point:
168+
if torch.empty(0, dtype=image.dtype).is_floating_point():
152169
# float to float
153-
if dtype.is_floating_point:
170+
if torch.tensor(0, dtype=dtype).is_floating_point():
154171
return image.to(dtype)
155172

156173
# float to int
@@ -166,19 +183,19 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -
166183
# `max + 1 - epsilon` provides more evenly distributed mapping of
167184
# ranges of floats to ints.
168185
eps = 1e-3
169-
result = image.mul(torch.iinfo(dtype).max + 1 - eps)
186+
max_val = _max_value(dtype)
187+
result = image.mul(max_val + 1.0 - eps)
170188
return result.to(dtype)
171189
else:
190+
input_max = _max_value(image.dtype)
191+
output_max = _max_value(dtype)
192+
172193
# int to float
173-
if dtype.is_floating_point:
174-
max = torch.iinfo(image.dtype).max
194+
if torch.tensor(0, dtype=dtype).is_floating_point():
175195
image = image.to(dtype)
176-
return image / max
196+
return image / input_max
177197

178198
# int to int
179-
input_max = torch.iinfo(image.dtype).max
180-
output_max = torch.iinfo(dtype).max
181-
182199
if input_max > output_max:
183200
factor = (input_max + 1) // (output_max + 1)
184201
image = image // factor

torchvision/transforms/functional_tensor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from torch.nn.functional import affine_grid, grid_sample
77
from torch.jit.annotations import List, BroadcastingList2
88

9+
import torchvision.transforms.functional as F
10+
911

1012
def _is_tensor_a_torch_image(x: Tensor) -> bool:
1113
return x.ndim >= 2
@@ -228,13 +230,11 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
228230
result = img
229231
dtype = img.dtype
230232
if not torch.is_floating_point(img):
231-
result = result / 255.0
233+
result = F.convert_image_dtype(result, torch.get_default_dtype())
232234

233235
result = (gain * result ** gamma).clamp(0, 1)
234236

235-
if result.dtype != dtype:
236-
eps = 1e-3
237-
result = (255 + 1.0 - eps) * result
237+
result = F.convert_image_dtype(result, dtype)
238238
result = result.to(dtype)
239239
return result
240240

0 commit comments

Comments
 (0)