From 62ba4c3a67b156ae22a3d7d7f29f32b689a14c09 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 25 Oct 2021 12:06:08 +0100 Subject: [PATCH] Keep 16bits png decoding private --- test/test_image.py | 23 ++++++++++++++------ torchvision/csrc/io/image/cpu/decode_png.cpp | 17 +++++++++++---- torchvision/csrc/io/image/cpu/decode_png.h | 3 ++- torchvision/io/image.py | 20 ++++++++--------- 4 files changed, 40 insertions(+), 23 deletions(-) diff --git a/test/test_image.py b/test/test_image.py index 7cd74fc915c..489170077dc 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -22,6 +22,7 @@ write_file, ImageReadMode, read_image, + _read_png_16, ) IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") @@ -156,8 +157,21 @@ def test_decode_png(img_path, pil_mode, mode): img_pil = torch.from_numpy(np.array(img)) img_pil = normalize_dimensions(img_pil) - data = read_file(img_path) - img_lpng = decode_image(data, mode=mode) + + if "16" in img_path: + # 16 bits image decoding is supported, but only as a private API + # FIXME: see https://github.com/pytorch/vision/issues/4731 for potential solutions to making it public + with pytest.raises(RuntimeError, match="At most 8-bit PNG images are supported"): + data = read_file(img_path) + img_lpng = decode_image(data, mode=mode) + + img_lpng = _read_png_16(img_path, mode=mode) + assert img_lpng.dtype == torch.int32 + # PIL converts 16 bits pngs in uint8 + img_lpng = torch.round(img_lpng / (2 ** 16 - 1) * 255).to(torch.uint8) + else: + data = read_file(img_path) + img_lpng = decode_image(data, mode=mode) tol = 0 if pil_mode is None else 1 @@ -168,11 +182,6 @@ def test_decode_png(img_path, pil_mode, mode): # TODO: remove once fix is released in PIL. Should be > 8.3.1. img_lpng, img_pil = img_lpng[0], img_pil[0] - if "16" in img_path: - # PIL converts 16 bits pngs in uint8 - assert img_lpng.dtype == torch.int32 - img_lpng = torch.round(img_lpng / (2 ** 16 - 1) * 255).to(torch.uint8) - torch.testing.assert_close(img_lpng, img_pil, atol=tol, rtol=0) diff --git a/torchvision/csrc/io/image/cpu/decode_png.cpp b/torchvision/csrc/io/image/cpu/decode_png.cpp index 2bd25c3d91a..0df55daed68 100644 --- a/torchvision/csrc/io/image/cpu/decode_png.cpp +++ b/torchvision/csrc/io/image/cpu/decode_png.cpp @@ -5,7 +5,10 @@ namespace vision { namespace image { #if !PNG_FOUND -torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) { +torch::Tensor decode_png( + const torch::Tensor& data, + ImageReadMode mode, + bool allow_16_bits) { TORCH_CHECK( false, "decode_png: torchvision not compiled with libPNG support"); } @@ -16,7 +19,10 @@ bool is_little_endian() { return *(uint8_t*)&x; } -torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) { +torch::Tensor decode_png( + const torch::Tensor& data, + ImageReadMode mode, + bool allow_16_bits) { // Check that the input tensor dtype is uint8 TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); // Check that the input tensor is 1-dimensional @@ -77,9 +83,12 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) { TORCH_CHECK(retval == 1, "Could read image metadata from content.") } - if (bit_depth > 16) { + auto max_bit_depth = allow_16_bits ? 16 : 8; + auto err_msg = "At most " + std::to_string(max_bit_depth) + + "-bit PNG images are supported currently."; + if (bit_depth > max_bit_depth) { png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); - TORCH_CHECK(false, "At most 16-bit PNG images are supported currently.") + TORCH_CHECK(false, err_msg) } int channels = png_get_channels(png_ptr, info_ptr); diff --git a/torchvision/csrc/io/image/cpu/decode_png.h b/torchvision/csrc/io/image/cpu/decode_png.h index 471bf77d935..fed89327cdb 100644 --- a/torchvision/csrc/io/image/cpu/decode_png.h +++ b/torchvision/csrc/io/image/cpu/decode_png.h @@ -8,7 +8,8 @@ namespace image { C10_EXPORT torch::Tensor decode_png( const torch::Tensor& data, - ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED); + ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED, + bool allow_16_bits = false); } // namespace image } // namespace vision diff --git a/torchvision/io/image.py b/torchvision/io/image.py index b0969ca3ae3..6dba6a7b168 100644 --- a/torchvision/io/image.py +++ b/torchvision/io/image.py @@ -61,12 +61,7 @@ def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGE """ Decodes a PNG image into a 3 dimensional RGB or grayscale Tensor. Optionally converts the image to the desired format. - The values of the output tensor are uint8 in [0, 255], except for - 16-bits pngs which are int32 tensors in [0, 65535]. - - .. warning:: - Should pytorch ever support the uint16 dtype natively, the dtype of the - output for 16-bits pngs will be updated from int32 to uint16. + The values of the output tensor are uint8 in [0, 255]. Args: input (Tensor[1]): a one dimensional uint8 tensor containing @@ -79,7 +74,7 @@ def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGE Returns: output (Tensor[image_channels, image_height, image_width]) """ - output = torch.ops.image.decode_png(input, mode.value) + output = torch.ops.image.decode_png(input, mode.value, False) return output @@ -193,8 +188,7 @@ def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHAN operation to decode the image into a 3 dimensional RGB or grayscale Tensor. Optionally converts the image to the desired format. - The values of the output tensor are uint8 in [0, 255], except for - 16-bits pngs which are int32 tensors in [0, 65535]. + The values of the output tensor are uint8 in [0, 255]. Args: input (Tensor): a one dimensional uint8 tensor containing the raw bytes of the @@ -215,8 +209,7 @@ def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torc """ Reads a JPEG or PNG image into a 3 dimensional RGB or grayscale Tensor. Optionally converts the image to the desired format. - The values of the output tensor are uint8 in [0, 255], except for - 16-bits pngs which are int32 tensors in [0, 65535]. + The values of the output tensor are uint8 in [0, 255]. Args: path (str): path of the JPEG or PNG image. @@ -230,3 +223,8 @@ def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torc """ data = read_file(path) return decode_image(data, mode) + + +def _read_png_16(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor: + data = read_file(path) + return torch.ops.image.decode_png(data, mode.value, True)