Skip to content

Commit 1d12cd2

Browse files
committed
add int to int and cleanup
1 parent a5d9573 commit 1d12cd2

File tree

3 files changed

+143
-75
lines changed

3 files changed

+143
-75
lines changed

test/test_transforms.py

Lines changed: 98 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,20 @@
2424
os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')
2525

2626

27+
def cycle_over(objs):
28+
objs = list(objs)
29+
for idx, obj in enumerate(objs):
30+
yield obj, objs[:idx] + objs[idx + 1:]
31+
32+
def int_dtypes():
33+
yield from iter(
34+
(torch.uint8, torch.int8, torch.int16, torch.short, torch.int32, torch.int, torch.int64, torch.long,)
35+
)
36+
37+
def float_dtypes():
38+
yield from iter((torch.float32, torch.float, torch.float64, torch.double))
39+
40+
2741
class Tester(unittest.TestCase):
2842

2943
def test_crop(self):
@@ -502,54 +516,99 @@ def test_to_tensor(self):
502516
output = trans(img)
503517
self.assertTrue(np.allclose(input_data.numpy(), output.numpy()))
504518

505-
def test_convert_image_dtype(self):
506-
def cycle_over(objs):
507-
objs = list(objs)
508-
for idx, obj in enumerate(objs):
509-
yield obj, objs[:idx] + objs[idx + 1:]
510-
511-
# dtype_max_value = {
512-
# dtype: 1.0
513-
# for dtype in (torch.float32, torch.float, torch.float64, torch.double)#, torch.bool,)
514-
# # torch.float16 and torch.half are disabled for now since they do not support torch.max
515-
# # See https://github.com/pytorch/pytorch/issues/28623#issuecomment-611379051
516-
# # (torch.float32, torch.float, torch.float64, torch.double, torch.float16, torch.half, torch.bool, )
517-
# }
518-
dtype_max_value = {}
519-
dtype_max_value.update(
520-
{
521-
dtype: torch.iinfo(dtype).max
522-
for dtype in (
523-
torch.uint8,
524-
torch.int8,
525-
torch.int16,
526-
torch.short,
527-
torch.int32,
528-
torch.int,
529-
torch.int64,
530-
torch.long,
531-
)
532-
}
533-
)
519+
def test_convert_image_dtype_float_to_float(self):
520+
for input_dtype, output_dtypes in cycle_over(float_dtypes()):
521+
input_image = torch.tensor((0.0, 1.0), dtype=input_dtype)
522+
for output_dtype in output_dtypes:
523+
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
524+
transform = transforms.ConvertImageDtype(output_dtype)
525+
output_image = transform(input_image)
526+
527+
actual_min, actual_max = output_image.tolist()
528+
desired_min, desired_max = 0.0, 1.0
529+
530+
self.assertAlmostEqual(actual_min, desired_min)
531+
self.assertAlmostEqual(actual_max, desired_max)
532+
533+
def test_convert_image_dtype_float_to_int(self):
534+
for input_dtype in float_dtypes():
535+
input_image = torch.tensor((0.0, 1.0), dtype=input_dtype)
536+
for output_dtype in int_dtypes():
537+
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
538+
transform = transforms.ConvertImageDtype(output_dtype)
539+
540+
if (input_dtype == torch.float32 and output_dtype in (torch.int32, torch.int64)) or (
541+
input_dtype == torch.float64 and output_dtype == torch.int64
542+
):
543+
with self.assertRaises(RuntimeError):
544+
transform(input_image)
545+
else:
546+
output_image = transform(input_image)
534547

535-
for input_dtype, output_dtypes in cycle_over(dtype_max_value.keys()):
536-
input_image = torch.ones(1, dtype=input_dtype) * dtype_max_value[input_dtype]
548+
actual_min, actual_max = output_image.tolist()
549+
desired_min, desired_max = 0, torch.iinfo(output_dtype).max
537550

551+
self.assertEqual(actual_min, desired_min)
552+
self.assertEqual(actual_max, desired_max)
553+
554+
def test_convert_image_dtype_int_to_float(self):
555+
for input_dtype in int_dtypes():
556+
input_image = torch.tensor((0, torch.iinfo(input_dtype).max), dtype=input_dtype)
557+
for output_dtype in float_dtypes():
558+
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
559+
transform = transforms.ConvertImageDtype(output_dtype)
560+
output_image = transform(input_image)
561+
562+
actual_min, actual_max = output_image.tolist()
563+
desired_min, desired_max = 0.0, 1.0
564+
565+
self.assertAlmostEqual(actual_min, desired_min)
566+
self.assertGreaterEqual(actual_min, desired_min)
567+
self.assertAlmostEqual(actual_max, desired_max)
568+
self.assertLessEqual(actual_max, desired_max)
569+
570+
def test_convert_image_dtype_int_to_int(self):
571+
for input_dtype, output_dtypes in cycle_over(int_dtypes()):
572+
input_max = torch.iinfo(input_dtype).max
573+
input_image = torch.tensor((0, input_max), dtype=input_dtype)
538574
for output_dtype in output_dtypes:
575+
output_max = torch.iinfo(output_dtype).max
576+
539577
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
540578
transform = transforms.ConvertImageDtype(output_dtype)
541579
output_image = transform(input_image)
542580

543-
actual = output_image.dtype
544-
desired = output_dtype
545-
self.assertEqual(actual, desired)
581+
actual_min, actual_max = output_image.tolist()
582+
desired_min, desired_max = 0, output_max
546583

547-
actual = torch.max(output_image).item()
548-
desired = dtype_max_value[output_dtype]
549-
if output_dtype.is_floating_point:
550-
self.assertAlmostEqual(actual, desired)
584+
# see https://github.com/pytorch/vision/pull/2078#issuecomment-641036236 for details
585+
if input_max >= output_max:
586+
error_term = 0
551587
else:
552-
self.assertEqual(actual, desired)
588+
error_term = 1 - (torch.iinfo(output_dtype).max + 1) // (torch.iinfo(input_dtype).max + 1)
589+
590+
self.assertEqual(actual_min, desired_min)
591+
self.assertEqual(actual_max, desired_max + error_term)
592+
593+
def test_convert_image_dtype_int_to_int_consistency(self):
594+
for input_dtype, output_dtypes in cycle_over(int_dtypes()):
595+
input_max = torch.iinfo(input_dtype).max
596+
input_image = torch.tensor((0, input_max), dtype=input_dtype)
597+
for output_dtype in output_dtypes:
598+
output_max = torch.iinfo(output_dtype).max
599+
if output_max <= input_max:
600+
continue
601+
602+
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
603+
transform = transforms.ConvertImageDtype(output_dtype)
604+
inverse_transfrom = transforms.ConvertImageDtype(input_dtype)
605+
output_image = inverse_transfrom(transform(input_image))
606+
607+
actual_min, actual_max = output_image.tolist()
608+
desired_min, desired_max = 0, input_max
609+
610+
self.assertEqual(actual_min, desired_min)
611+
self.assertEqual(actual_max, desired_max)
553612

554613
@unittest.skipIf(accimage is None, 'accimage not available')
555614
def test_accimage_to_tensor(self):

torchvision/transforms/functional.py

Lines changed: 34 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,7 @@ def to_tensor(pic):
8282
return img
8383

8484

85-
def convert_image_dtype(
86-
image: torch.Tensor, dtype: torch.dtype = torch.float
87-
) -> torch.Tensor:
85+
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
8886
"""Convert a tensor image to the given ``dtype`` and scale the values accordingly
8987
9088
Args:
@@ -94,28 +92,42 @@ def convert_image_dtype(
9492
Returns:
9593
(torch.Tensor): Converted image
9694
95+
.. note::
96+
97+
When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
98+
If converted back and forth, this mismatch has no effect.
99+
97100
Raises:
98-
TypeError: When trying to cast :class:`torch.float32` to :class:`torch.int32`
99-
or :class:`torch.int64` as well as for trying to cast
100-
:class:`torch.float64` to :class:`torch.int64`. These conversions are
101-
unsafe since the floating point ``dtype`` cannot store consecutive XXX. which might lead to overflow errors
101+
RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
102+
well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
103+
overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
104+
of the integer ``dtype``.
102105
"""
103-
def float_to_float(image: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
104-
return image.to(dtype)
105-
106-
def float_to_int(image: torch.Tensor, dtype: torch.dtype, eps=1e-3) -> torch.Tensor:
107-
if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (image.dtype == torch.float64 and dtype == torch.int64):
108-
msg = (f"The cast from {image.dtype} to {dtype} cannot be performed safely, "
109-
f"since {image.dtype} cannot ")
110-
raise TypeError(msg)
111-
return image.mul(torch.iinfo(dtype).max + 1 - eps).to(dtype)
106+
if image.dtype == dtype:
107+
return image
112108

113-
def int_to_float(image: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
114-
max = torch.iinfo(image.dtype).max
115-
image = image.to(dtype)
116-
return image / max
109+
if image.dtype.is_floating_point:
110+
# float to float
111+
if dtype.is_floating_point:
112+
return image.to(dtype)
113+
114+
# float to int
115+
if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
116+
image.dtype == torch.float64 and dtype == torch.int64
117+
):
118+
msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
119+
raise RuntimeError(msg)
117120

118-
def int_to_int(image: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
121+
eps = 1e-3
122+
return image.mul(torch.iinfo(dtype).max + 1 - eps).to(dtype)
123+
else:
124+
# int to float
125+
if dtype.is_floating_point:
126+
max = torch.iinfo(image.dtype).max
127+
image = image.to(dtype)
128+
return image / max
129+
130+
# int to int
119131
input_max = torch.iinfo(image.dtype).max
120132
output_max = torch.iinfo(dtype).max
121133

@@ -126,21 +138,7 @@ def int_to_int(image: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
126138
else:
127139
factor = (output_max + 1) // (input_max + 1)
128140
image = image.to(dtype)
129-
return (image + 1) * factor - 1
130-
131-
if image.dtype == dtype:
132-
return image
133-
134-
if image.dtype.is_floating_point:
135-
if dtype.is_floating_point:
136-
return float_to_float(image, dtype)
137-
else:
138-
return float_to_int(image, dtype)
139-
else:
140-
if dtype.is_floating_point:
141-
return int_to_float(image, dtype)
142-
else:
143-
return int_to_int(image, dtype)
141+
return image * factor
144142

145143

146144
def to_pil_image(pic, mode=None):

torchvision/transforms/transforms.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,18 @@ class ConvertImageDtype(object):
101101
Args:
102102
dtype (torch.dtype): Desired data type of the output
103103
104+
.. note::
105+
106+
When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
107+
If converted back and forth, this mismatch has no effect.
108+
109+
Raises:
110+
RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
111+
well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
112+
overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
113+
of the integer ``dtype``.
104114
"""
115+
105116
def __init__(self, dtype: torch.dtype) -> None:
106117
self.dtype = dtype
107118

0 commit comments

Comments
 (0)