Skip to content
4 changes: 4 additions & 0 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,10 @@ def test_perspective(self):
batch_tensors, F.perspective, startpoints=spoints, endpoints=epoints, interpolation=0
)

def test_convert_image_dtype(self):
# TODO: add tests of CPU/CUDA on tensor and batch
pass


@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
class CUDATester(Tester):
Expand Down
33 changes: 33 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
import torchvision.transforms as transforms
import torchvision.transforms.functional as F
import torchvision.transforms.functional_tensor as F_t
from torch._utils_internal import get_file_path_2
from numpy.testing import assert_array_almost_equal
import unittest
Expand Down Expand Up @@ -544,13 +545,26 @@ def test_to_tensor(self):
output = trans(img)
self.assertTrue(np.allclose(input_data.numpy(), output.numpy()))

def test_max_value(self):
for dtype in int_dtypes():
self.assertEqual(F_t._max_value(dtype), torch.iinfo(dtype).max)

for dtype in float_dtypes():
self.assertGreater(F_t._max_value(dtype), torch.finfo(dtype).max)

def test_convert_image_dtype_float_to_float(self):
for input_dtype, output_dtypes in cycle_over(float_dtypes()):
input_image = torch.tensor((0.0, 1.0), dtype=input_dtype)
for output_dtype in output_dtypes:
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype)
transform_script = torch.jit.script(F.convert_image_dtype)

output_image = transform(input_image)
output_image_script = transform_script(input_image, output_dtype)

script_diff = output_image_script - output_image
self.assertLess(script_diff.abs().max(), 1e-6)

actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0.0, 1.0
Expand All @@ -564,6 +578,7 @@ def test_convert_image_dtype_float_to_int(self):
for output_dtype in int_dtypes():
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype)
transform_script = torch.jit.script(F.convert_image_dtype)

if (input_dtype == torch.float32 and output_dtype in (torch.int32, torch.int64)) or (
input_dtype == torch.float64 and output_dtype == torch.int64
Expand All @@ -572,6 +587,10 @@ def test_convert_image_dtype_float_to_int(self):
transform(input_image)
else:
output_image = transform(input_image)
output_image_script = transform_script(input_image, output_dtype)

script_diff = output_image_script - output_image
self.assertLess(script_diff.abs().max(), 1e-6)

actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0, torch.iinfo(output_dtype).max
Expand All @@ -585,7 +604,13 @@ def test_convert_image_dtype_int_to_float(self):
for output_dtype in float_dtypes():
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype)
transform_script = torch.jit.script(F.convert_image_dtype)

output_image = transform(input_image)
output_image_script = transform_script(input_image, output_dtype)

script_diff = output_image_script - output_image
self.assertLess(script_diff.abs().max(), 1e-6)

actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0.0, 1.0
Expand All @@ -604,7 +629,15 @@ def test_convert_image_dtype_int_to_int(self):

with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype)
transform_script = torch.jit.script(F.convert_image_dtype)

output_image = transform(input_image)
output_image_script = transform_script(input_image, output_dtype)

script_diff = output_image_script.float() - output_image.float()
self.assertLess(
script_diff.abs().max(), 1e-6, msg="{} vs {}".format(output_image_script, output_image)
)

actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0, output_max
Expand Down
46 changes: 4 additions & 42 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,48 +152,10 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -
overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
of the integer ``dtype``.
"""
if image.dtype == dtype:
return image

if image.dtype.is_floating_point:
# float to float
if dtype.is_floating_point:
return image.to(dtype)

# float to int
if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
image.dtype == torch.float64 and dtype == torch.int64
):
msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
raise RuntimeError(msg)

# https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
# For data in the range 0-1, (float * 255).to(uint) is only 255
# when float is exactly 1.0.
# `max + 1 - epsilon` provides more evenly distributed mapping of
# ranges of floats to ints.
eps = 1e-3
result = image.mul(torch.iinfo(dtype).max + 1 - eps)
return result.to(dtype)
else:
# int to float
if dtype.is_floating_point:
max = torch.iinfo(image.dtype).max
image = image.to(dtype)
return image / max

# int to int
input_max = torch.iinfo(image.dtype).max
output_max = torch.iinfo(dtype).max

if input_max > output_max:
factor = (input_max + 1) // (output_max + 1)
image = image // factor
return image.to(dtype)
else:
factor = (output_max + 1) // (input_max + 1)
image = image.to(dtype)
return image * factor
if not isinstance(image, torch.Tensor):
raise TypeError('Input img should be Tensor Image')

return F_t.convert_image_dtype(image, dtype)


def to_pil_image(pic, mode=None):
Expand Down
101 changes: 97 additions & 4 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,101 @@ def _get_image_num_channels(img: Tensor) -> int:
raise TypeError("Input ndim should be 2 or more. Got {}".format(img.ndim))


def _max_value(dtype: torch.dtype) -> float:
# TODO: replace this method with torch.iinfo when it gets torchscript support.
# https://github.com/pytorch/pytorch/issues/41492

a = torch.tensor(2, dtype=dtype)
signed = 1 if torch.tensor(0, dtype=dtype).is_signed() else 0
bits = 1
max_value = torch.tensor(-signed, dtype=torch.long)
while True:
next_value = a.pow(bits - signed).sub(1)
if next_value > max_value:
max_value = next_value
bits *= 2
else:
return max_value.item()
return max_value.item()


def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
"""PRIVATE METHOD. Convert a tensor image to the given ``dtype`` and scale the values accordingly

.. warning::

Module ``transforms.functional_tensor`` is private and should not be used in user application.
Please, consider instead using methods from `transforms.functional` module.

Args:
image (torch.Tensor): Image to be converted
dtype (torch.dtype): Desired data type of the output

Returns:
(torch.Tensor): Converted image

.. note::

When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
If converted back and forth, this mismatch has no effect.

Raises:
RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
of the integer ``dtype``.
"""
if image.dtype == dtype:
return image

# TODO: replace with image.dtype.is_floating_point when torchscript supports it
if torch.empty(0, dtype=image.dtype).is_floating_point():

# TODO: replace with dtype.is_floating_point when torchscript supports it
if torch.tensor(0, dtype=dtype).is_floating_point():
return image.to(dtype)

# float to int
if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
image.dtype == torch.float64 and dtype == torch.int64
):
msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
raise RuntimeError(msg)

# https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
# For data in the range 0-1, (float * 255).to(uint) is only 255
# when float is exactly 1.0.
# `max + 1 - epsilon` provides more evenly distributed mapping of
# ranges of floats to ints.
eps = 1e-3
max_val = _max_value(dtype)
result = image.mul(max_val + 1.0 - eps)
return result.to(dtype)
else:
input_max = _max_value(image.dtype)
output_max = _max_value(dtype)

# int to float
# TODO: replace with dtype.is_floating_point when torchscript supports it
if torch.tensor(0, dtype=dtype).is_floating_point():
image = image.to(dtype)
return image / input_max

# int to int
if input_max > output_max:
# factor should be forced to int for torch jit script
# otherwise factor is a float and image // factor can produce different results
factor = int((input_max + 1) // (output_max + 1))
image = image // factor
return image.to(dtype)
else:
# factor should be forced to int for torch jit script
# otherwise factor is a float and image * factor can produce different results
factor = int((output_max + 1) // (input_max + 1))
image = image.to(dtype)
return image * factor


def vflip(img: Tensor) -> Tensor:
"""PRIVATE METHOD. Vertically flip the given the Image Tensor.

Expand Down Expand Up @@ -302,13 +397,11 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
result = img
dtype = img.dtype
if not torch.is_floating_point(img):
result = result / 255.0
result = convert_image_dtype(result, torch.float32)

result = (gain * result ** gamma).clamp(0, 1)

if result.dtype != dtype:
eps = 1e-3
result = (255 + 1.0 - eps) * result
result = convert_image_dtype(result, dtype)
result = result.to(dtype)
return result

Expand Down
6 changes: 3 additions & 3 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from . import functional as F


__all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale",
"CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop",
"RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop",
Expand Down Expand Up @@ -127,7 +126,7 @@ def __repr__(self):
return self.__class__.__name__ + '()'


class ConvertImageDtype:
class ConvertImageDtype(torch.nn.Module):
"""Convert a tensor image to the given ``dtype`` and scale the values accordingly

Args:
Expand All @@ -146,9 +145,10 @@ class ConvertImageDtype:
"""

def __init__(self, dtype: torch.dtype) -> None:
super().__init__()
self.dtype = dtype

def __call__(self, image: torch.Tensor) -> torch.Tensor:
def forward(self, image: torch.Tensor) -> torch.Tensor:
return F.convert_image_dtype(image, self.dtype)


Expand Down