Skip to content

make convert_image_dtype scriptable #2485

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Oct 5, 2020
4 changes: 2 additions & 2 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def test_adjust_gamma(self):
for dt in [torch.float64, torch.float32, None]:

if dt is not None:
tensor = F.convert_image_dtype(tensor, dt)
tensor = F_t.convert_image_dtype(tensor, dt)

gammas = [0.8, 1.0, 1.2]
gains = [0.7, 1.0, 1.3]
Expand All @@ -323,7 +323,7 @@ def test_adjust_gamma(self):

rbg_tensor = adjusted_tensor
if adjusted_tensor.dtype != torch.uint8:
rbg_tensor = F.convert_image_dtype(adjusted_tensor, torch.uint8)
rbg_tensor = F_t.convert_image_dtype(adjusted_tensor, torch.uint8)

self.compareTensorToPIL(rbg_tensor, adjusted_pil)

Expand Down
5 changes: 5 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
import torchvision.transforms as transforms
import torchvision.transforms.functional as F
import torchvision.transforms.functional_tensor as F_t
from torch._utils_internal import get_file_path_2
from numpy.testing import assert_array_almost_equal
import unittest
Expand Down Expand Up @@ -526,6 +527,10 @@ def test_to_tensor(self):
output = trans(img)
self.assertTrue(np.allclose(input_data.numpy(), output.numpy()))

def test_max_value(self):
for dtype in int_dtypes():
self.assertEqual(F_t._max_value(dtype), torch.iinfo(dtype).max)

def test_convert_image_dtype_float_to_float(self):
for input_dtype, output_dtypes in cycle_over(float_dtypes()):
input_image = torch.tensor((0.0, 1.0), dtype=input_dtype)
Expand Down
65 changes: 0 additions & 65 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,71 +124,6 @@ def pil_to_tensor(pic):
return img


def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
"""Convert a tensor image to the given ``dtype`` and scale the values accordingly

Args:
image (torch.Tensor): Image to be converted
dtype (torch.dtype): Desired data type of the output

Returns:
(torch.Tensor): Converted image

.. note::

When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
If converted back and forth, this mismatch has no effect.

Raises:
RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
of the integer ``dtype``.
"""
if image.dtype == dtype:
return image

if image.dtype.is_floating_point:
# float to float
if dtype.is_floating_point:
return image.to(dtype)

# float to int
if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
image.dtype == torch.float64 and dtype == torch.int64
):
msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
raise RuntimeError(msg)

# https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
# For data in the range 0-1, (float * 255).to(uint) is only 255
# when float is exactly 1.0.
# `max + 1 - epsilon` provides more evenly distributed mapping of
# ranges of floats to ints.
eps = 1e-3
result = image.mul(torch.iinfo(dtype).max + 1 - eps)
return result.to(dtype)
else:
# int to float
if dtype.is_floating_point:
max = torch.iinfo(image.dtype).max
image = image.to(dtype)
return image / max

# int to int
input_max = torch.iinfo(image.dtype).max
output_max = torch.iinfo(dtype).max

if input_max > output_max:
factor = (input_max + 1) // (output_max + 1)
image = image // factor
return image.to(dtype)
else:
factor = (output_max + 1) // (input_max + 1)
image = image.to(dtype)
return image * factor


def to_pil_image(pic, mode=None):
"""Convert a tensor or an ndarray to PIL Image.

Expand Down
88 changes: 84 additions & 4 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,88 @@ def _get_image_size(img: Tensor) -> List[int]:
raise TypeError("Unexpected type {}".format(type(img)))


# torch.iinfo isn't scriptable so using this helper function
# https://github.com/pytorch/pytorch/issues/41492
def _max_value(dtype: int) -> int:
a = torch.tensor(2, dtype=dtype)
signed = 1 if torch.tensor(0, dtype=dtype).is_signed() else 0
bits = 1
max_value = torch.tensor(-signed, dtype=torch.long)
while(True):
next_value = a.pow(bits - signed).sub(1)
if next_value > max_value:
max_value = next_value
bits *= 2
else:
return max_value.item()
return max_value.item()


def convert_image_dtype(image: torch.Tensor, dtype: int = torch.float) -> torch.Tensor:
"""Convert a tensor image to the given ``dtype`` and scale the values accordingly

Args:
image (torch.Tensor): Image to be converted
dtype (torch.dtype): Desired data type of the output

Returns:
(torch.Tensor): Converted image

.. note::

When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
If converted back and forth, this mismatch has no effect.

Raises:
RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
of the integer ``dtype``.
"""
if image.dtype == dtype:
return image

if torch.empty(0, dtype=image.dtype).is_floating_point():
# float to float
if torch.tensor(0, dtype=dtype).is_floating_point():
return image.to(dtype)

# float to int
if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
image.dtype == torch.float64 and dtype == torch.int64
):
msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
raise RuntimeError(msg)

# https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
# For data in the range 0-1, (float * 255).to(uint) is only 255
# when float is exactly 1.0.
# `max + 1 - epsilon` provides more evenly distributed mapping of
# ranges of floats to ints.
eps = 1e-3
max_val = _max_value(dtype)
result = image.mul(max_val + 1.0 - eps)
return result.to(dtype)
else:
input_max = _max_value(image.dtype)
output_max = _max_value(dtype)

# int to float
if torch.tensor(0, dtype=dtype).is_floating_point():
image = image.to(dtype)
return image / input_max

# int to int
if input_max > output_max:
factor = (input_max + 1) // (output_max + 1)
image = image // factor
return image.to(dtype)
else:
factor = (output_max + 1) // (input_max + 1)
image = image.to(dtype)
return image * factor


def vflip(img: Tensor) -> Tensor:
"""Vertically flip the given the Image Tensor.

Expand Down Expand Up @@ -228,13 +310,11 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
result = img
dtype = img.dtype
if not torch.is_floating_point(img):
result = result / 255.0
result = convert_image_dtype(result, torch.float32)

result = (gain * result ** gamma).clamp(0, 1)

if result.dtype != dtype:
eps = 1e-3
result = (255 + 1.0 - eps) * result
result = convert_image_dtype(result, dtype)
result = result.to(dtype)
return result

Expand Down
4 changes: 2 additions & 2 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
accimage = None

from . import functional as F

from . import functional_tensor as F_t

__all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale",
"CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop",
Expand Down Expand Up @@ -131,7 +131,7 @@ def __init__(self, dtype: torch.dtype) -> None:
self.dtype = dtype

def __call__(self, image: torch.Tensor) -> torch.Tensor:
return F.convert_image_dtype(image, self.dtype)
return F_t.convert_image_dtype(image, self.dtype)


class ToPILImage(object):
Expand Down