Skip to content

Keep 16bits png decoding private #4732

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
write_file,
ImageReadMode,
read_image,
_read_png_16,
)

IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
Expand Down Expand Up @@ -156,8 +157,21 @@ def test_decode_png(img_path, pil_mode, mode):
img_pil = torch.from_numpy(np.array(img))

img_pil = normalize_dimensions(img_pil)
data = read_file(img_path)
img_lpng = decode_image(data, mode=mode)

if "16" in img_path:
# 16 bits image decoding is supported, but only as a private API
# FIXME: see https://github.com/pytorch/vision/issues/4731 for potential solutions to making it public
with pytest.raises(RuntimeError, match="At most 8-bit PNG images are supported"):
data = read_file(img_path)
img_lpng = decode_image(data, mode=mode)

img_lpng = _read_png_16(img_path, mode=mode)
assert img_lpng.dtype == torch.int32
# PIL converts 16 bits pngs in uint8
img_lpng = torch.round(img_lpng / (2 ** 16 - 1) * 255).to(torch.uint8)
else:
data = read_file(img_path)
img_lpng = decode_image(data, mode=mode)

tol = 0 if pil_mode is None else 1

Expand All @@ -168,11 +182,6 @@ def test_decode_png(img_path, pil_mode, mode):
# TODO: remove once fix is released in PIL. Should be > 8.3.1.
img_lpng, img_pil = img_lpng[0], img_pil[0]

if "16" in img_path:
# PIL converts 16 bits pngs in uint8
assert img_lpng.dtype == torch.int32
img_lpng = torch.round(img_lpng / (2 ** 16 - 1) * 255).to(torch.uint8)

torch.testing.assert_close(img_lpng, img_pil, atol=tol, rtol=0)


Expand Down
17 changes: 13 additions & 4 deletions torchvision/csrc/io/image/cpu/decode_png.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ namespace vision {
namespace image {

#if !PNG_FOUND
torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
torch::Tensor decode_png(
const torch::Tensor& data,
ImageReadMode mode,
bool allow_16_bits) {
TORCH_CHECK(
false, "decode_png: torchvision not compiled with libPNG support");
}
Expand All @@ -16,7 +19,10 @@ bool is_little_endian() {
return *(uint8_t*)&x;
}

torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
torch::Tensor decode_png(
const torch::Tensor& data,
ImageReadMode mode,
bool allow_16_bits) {
// Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
// Check that the input tensor is 1-dimensional
Expand Down Expand Up @@ -77,9 +83,12 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
TORCH_CHECK(retval == 1, "Could read image metadata from content.")
}

if (bit_depth > 16) {
auto max_bit_depth = allow_16_bits ? 16 : 8;
auto err_msg = "At most " + std::to_string(max_bit_depth) +
"-bit PNG images are supported currently.";
if (bit_depth > max_bit_depth) {
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
TORCH_CHECK(false, "At most 16-bit PNG images are supported currently.")
TORCH_CHECK(false, err_msg)
}

int channels = png_get_channels(png_ptr, info_ptr);
Expand Down
3 changes: 2 additions & 1 deletion torchvision/csrc/io/image/cpu/decode_png.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ namespace image {

C10_EXPORT torch::Tensor decode_png(
const torch::Tensor& data,
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED,
bool allow_16_bits = false);

} // namespace image
} // namespace vision
20 changes: 9 additions & 11 deletions torchvision/io/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,7 @@ def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGE
"""
Decodes a PNG image into a 3 dimensional RGB or grayscale Tensor.
Optionally converts the image to the desired format.
The values of the output tensor are uint8 in [0, 255], except for
16-bits pngs which are int32 tensors in [0, 65535].

.. warning::
Should pytorch ever support the uint16 dtype natively, the dtype of the
output for 16-bits pngs will be updated from int32 to uint16.
The values of the output tensor are uint8 in [0, 255].

Args:
input (Tensor[1]): a one dimensional uint8 tensor containing
Expand All @@ -79,7 +74,7 @@ def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGE
Returns:
output (Tensor[image_channels, image_height, image_width])
"""
output = torch.ops.image.decode_png(input, mode.value)
output = torch.ops.image.decode_png(input, mode.value, False)
return output


Expand Down Expand Up @@ -193,8 +188,7 @@ def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHAN
operation to decode the image into a 3 dimensional RGB or grayscale Tensor.

Optionally converts the image to the desired format.
The values of the output tensor are uint8 in [0, 255], except for
16-bits pngs which are int32 tensors in [0, 65535].
The values of the output tensor are uint8 in [0, 255].

Args:
input (Tensor): a one dimensional uint8 tensor containing the raw bytes of the
Expand All @@ -215,8 +209,7 @@ def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torc
"""
Reads a JPEG or PNG image into a 3 dimensional RGB or grayscale Tensor.
Optionally converts the image to the desired format.
The values of the output tensor are uint8 in [0, 255], except for
16-bits pngs which are int32 tensors in [0, 65535].
The values of the output tensor are uint8 in [0, 255].

Args:
path (str): path of the JPEG or PNG image.
Expand All @@ -230,3 +223,8 @@ def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torc
"""
data = read_file(path)
return decode_image(data, mode)


def _read_png_16(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
data = read_file(path)
return torch.ops.image.decode_png(data, mode.value, True)