Skip to content

Commit b4d3b5c

Browse files
datumboxfacebook-github-bot
authored andcommitted
[fbsync] Keep 16bits png decoding private (#4732)
Reviewed By: NicolasHug Differential Revision: D31916314 fbshipit-source-id: d06ace86997de7482b72edb922725e505cb24cad
1 parent 8e50837 commit b4d3b5c

File tree

4 files changed

+40
-23
lines changed

4 files changed

+40
-23
lines changed

test/test_image.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
write_file,
2323
ImageReadMode,
2424
read_image,
25+
_read_png_16,
2526
)
2627

2728
IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
@@ -156,8 +157,21 @@ def test_decode_png(img_path, pil_mode, mode):
156157
img_pil = torch.from_numpy(np.array(img))
157158

158159
img_pil = normalize_dimensions(img_pil)
159-
data = read_file(img_path)
160-
img_lpng = decode_image(data, mode=mode)
160+
161+
if "16" in img_path:
162+
# 16 bits image decoding is supported, but only as a private API
163+
# FIXME: see https://github.com/pytorch/vision/issues/4731 for potential solutions to making it public
164+
with pytest.raises(RuntimeError, match="At most 8-bit PNG images are supported"):
165+
data = read_file(img_path)
166+
img_lpng = decode_image(data, mode=mode)
167+
168+
img_lpng = _read_png_16(img_path, mode=mode)
169+
assert img_lpng.dtype == torch.int32
170+
# PIL converts 16 bits pngs in uint8
171+
img_lpng = torch.round(img_lpng / (2 ** 16 - 1) * 255).to(torch.uint8)
172+
else:
173+
data = read_file(img_path)
174+
img_lpng = decode_image(data, mode=mode)
161175

162176
tol = 0 if pil_mode is None else 1
163177

@@ -168,11 +182,6 @@ def test_decode_png(img_path, pil_mode, mode):
168182
# TODO: remove once fix is released in PIL. Should be > 8.3.1.
169183
img_lpng, img_pil = img_lpng[0], img_pil[0]
170184

171-
if "16" in img_path:
172-
# PIL converts 16 bits pngs in uint8
173-
assert img_lpng.dtype == torch.int32
174-
img_lpng = torch.round(img_lpng / (2 ** 16 - 1) * 255).to(torch.uint8)
175-
176185
torch.testing.assert_close(img_lpng, img_pil, atol=tol, rtol=0)
177186

178187

torchvision/csrc/io/image/cpu/decode_png.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@ namespace vision {
55
namespace image {
66

77
#if !PNG_FOUND
8-
torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
8+
torch::Tensor decode_png(
9+
const torch::Tensor& data,
10+
ImageReadMode mode,
11+
bool allow_16_bits) {
912
TORCH_CHECK(
1013
false, "decode_png: torchvision not compiled with libPNG support");
1114
}
@@ -16,7 +19,10 @@ bool is_little_endian() {
1619
return *(uint8_t*)&x;
1720
}
1821

19-
torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
22+
torch::Tensor decode_png(
23+
const torch::Tensor& data,
24+
ImageReadMode mode,
25+
bool allow_16_bits) {
2026
// Check that the input tensor dtype is uint8
2127
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
2228
// Check that the input tensor is 1-dimensional
@@ -77,9 +83,12 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
7783
TORCH_CHECK(retval == 1, "Could read image metadata from content.")
7884
}
7985

80-
if (bit_depth > 16) {
86+
auto max_bit_depth = allow_16_bits ? 16 : 8;
87+
auto err_msg = "At most " + std::to_string(max_bit_depth) +
88+
"-bit PNG images are supported currently.";
89+
if (bit_depth > max_bit_depth) {
8190
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
82-
TORCH_CHECK(false, "At most 16-bit PNG images are supported currently.")
91+
TORCH_CHECK(false, err_msg)
8392
}
8493

8594
int channels = png_get_channels(png_ptr, info_ptr);

torchvision/csrc/io/image/cpu/decode_png.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ namespace image {
88

99
C10_EXPORT torch::Tensor decode_png(
1010
const torch::Tensor& data,
11-
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);
11+
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED,
12+
bool allow_16_bits = false);
1213

1314
} // namespace image
1415
} // namespace vision

torchvision/io/image.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,7 @@ def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGE
6161
"""
6262
Decodes a PNG image into a 3 dimensional RGB or grayscale Tensor.
6363
Optionally converts the image to the desired format.
64-
The values of the output tensor are uint8 in [0, 255], except for
65-
16-bits pngs which are int32 tensors in [0, 65535].
66-
67-
.. warning::
68-
Should pytorch ever support the uint16 dtype natively, the dtype of the
69-
output for 16-bits pngs will be updated from int32 to uint16.
64+
The values of the output tensor are uint8 in [0, 255].
7065
7166
Args:
7267
input (Tensor[1]): a one dimensional uint8 tensor containing
@@ -79,7 +74,7 @@ def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGE
7974
Returns:
8075
output (Tensor[image_channels, image_height, image_width])
8176
"""
82-
output = torch.ops.image.decode_png(input, mode.value)
77+
output = torch.ops.image.decode_png(input, mode.value, False)
8378
return output
8479

8580

@@ -193,8 +188,7 @@ def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHAN
193188
operation to decode the image into a 3 dimensional RGB or grayscale Tensor.
194189
195190
Optionally converts the image to the desired format.
196-
The values of the output tensor are uint8 in [0, 255], except for
197-
16-bits pngs which are int32 tensors in [0, 65535].
191+
The values of the output tensor are uint8 in [0, 255].
198192
199193
Args:
200194
input (Tensor): a one dimensional uint8 tensor containing the raw bytes of the
@@ -215,8 +209,7 @@ def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torc
215209
"""
216210
Reads a JPEG or PNG image into a 3 dimensional RGB or grayscale Tensor.
217211
Optionally converts the image to the desired format.
218-
The values of the output tensor are uint8 in [0, 255], except for
219-
16-bits pngs which are int32 tensors in [0, 65535].
212+
The values of the output tensor are uint8 in [0, 255].
220213
221214
Args:
222215
path (str): path of the JPEG or PNG image.
@@ -230,3 +223,8 @@ def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torc
230223
"""
231224
data = read_file(path)
232225
return decode_image(data, mode)
226+
227+
228+
def _read_png_16(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
229+
data = read_file(path)
230+
return torch.ops.image.decode_png(data, mode.value, True)

0 commit comments

Comments
 (0)