Skip to content

Commit 03b1d38

Browse files
authored
PR: Improve handling of truncated/incomplete and corrupt JPEG images (#2471)
* Add corruption cases * Read jpeg headers until exhaustion * Minor error correction * Add test script * Raise exception when image is truncated * Add test * Skip damaged_jpeg folder * Compare against basename * Remove unused test file
1 parent a568c7f commit 03b1d38

File tree

8 files changed

+37
-1
lines changed

8 files changed

+37
-1
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
Copyright 2019 The TensorFlow Authors. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
15.1 KB
Loading

test/assets/damaged_jpeg/corrupt.jpg

1.52 KB
Loading
755 Bytes
Loading
5.38 KB
Loading
4.97 KB
Loading

test/test_image.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import glob
23
import unittest
34
import sys
45

@@ -10,11 +11,15 @@
1011

1112
IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
1213
IMAGE_DIR = os.path.join(IMAGE_ROOT, "fakedata", "imagefolder")
14+
DAMAGED_JPEG = os.path.join(IMAGE_ROOT, 'damaged_jpeg')
1315

1416

1517
def get_images(directory, img_ext):
1618
assert os.path.isdir(directory)
1719
for root, _, files in os.walk(directory):
20+
if os.path.basename(root) == 'damaged_jpeg':
21+
continue
22+
1823
for fl in files:
1924
_, ext = os.path.splitext(fl)
2025
if ext == img_ext:
@@ -44,6 +49,21 @@ def test_decode_jpeg(self):
4449
with self.assertRaises(RuntimeError):
4550
decode_jpeg(torch.empty((100), dtype=torch.uint8))
4651

52+
def test_damaged_images(self):
53+
# Test image with bad Huffman encoding (should not raise)
54+
bad_huff = os.path.join(DAMAGED_JPEG, 'bad_huffman.jpg')
55+
try:
56+
_ = read_jpeg(bad_huff)
57+
except RuntimeError:
58+
self.assertTrue(False)
59+
60+
# Truncated images should raise an exception
61+
truncated_images = glob.glob(
62+
os.path.join(DAMAGED_JPEG, 'corrupt*.jpg'))
63+
for image_path in truncated_images:
64+
with self.assertRaises(RuntimeError):
65+
read_jpeg(image_path)
66+
4767
def test_read_png(self):
4868
# Check across .png
4969
for img_path in get_images(IMAGE_DIR, ".png"):

torchvision/csrc/cpu/image/readjpeg_cpu.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,10 @@ static void torch_jpeg_init_source(j_decompress_ptr cinfo) {}
4848

4949
static boolean torch_jpeg_fill_input_buffer(j_decompress_ptr cinfo) {
5050
torch_jpeg_mgr* src = (torch_jpeg_mgr*)cinfo->src;
51-
// No more data. Probably an incomplete image; just output EOI.
51+
// No more data. Probably an incomplete image; Raise exception.
52+
torch_jpeg_error_ptr myerr = (torch_jpeg_error_ptr)cinfo->err;
53+
strcpy(jpegLastErrorMsg, "Image is incomplete or truncated");
54+
longjmp(myerr->setjmp_buffer, 1);
5255
src->pub.next_input_byte = EOI_BUFFER;
5356
src->pub.bytes_in_buffer = 1;
5457
return TRUE;

0 commit comments

Comments
 (0)