Skip to content

Commit da89dad

Browse files
authored
Better handle corrupted videos (#1463)
* Handle corrupted video headers in io * Catch exceptions while decoding partly-corrupted files * Add more tests
1 parent 1d6145d commit da89dad

File tree

2 files changed

+96
-34
lines changed

2 files changed

+96
-34
lines changed

test/test_io.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,41 @@ def test_read_partial_video_pts_unit_sec(self):
236236
self.assertEqual(len(lv), 4)
237237
self.assertTrue(data[4:8].equal(lv))
238238

239+
def test_read_video_corrupted_file(self):
240+
with tempfile.NamedTemporaryFile(suffix='.mp4') as f:
241+
f.write(b'This is not an mpg4 file')
242+
video, audio, info = io.read_video(f.name)
243+
self.assertIsInstance(video, torch.Tensor)
244+
self.assertIsInstance(audio, torch.Tensor)
245+
self.assertEqual(video.numel(), 0)
246+
self.assertEqual(audio.numel(), 0)
247+
self.assertEqual(info, {})
248+
249+
def test_read_video_timestamps_corrupted_file(self):
250+
with tempfile.NamedTemporaryFile(suffix='.mp4') as f:
251+
f.write(b'This is not an mpg4 file')
252+
video_pts, video_fps = io.read_video_timestamps(f.name)
253+
self.assertEqual(video_pts, [])
254+
self.assertIs(video_fps, None)
255+
256+
def test_read_video_partially_corrupted_file(self):
257+
with temp_video(5, 4, 4, 5, lossless=True) as (f_name, data):
258+
with open(f_name, 'r+b') as f:
259+
size = os.path.getsize(f_name)
260+
bytes_to_overwrite = size // 10
261+
# seek to the middle of the file
262+
f.seek(5 * bytes_to_overwrite)
263+
# corrupt 10% of the file from the middle
264+
f.write(b'\xff' * bytes_to_overwrite)
265+
# this exercises the container.decode assertion check
266+
video, audio, info = io.read_video(f.name, pts_unit='sec')
267+
# check that size is not equal to 5, but 3
268+
self.assertEqual(len(video), 3)
269+
# but the valid decoded content is still correct
270+
self.assertTrue(video[:3].equal(data[:3]))
271+
# and the last few frames are wrong
272+
self.assertFalse(video.equal(data))
273+
239274
# TODO add tests for audio
240275

241276

torchvision/io/video.py

Lines changed: 61 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -124,13 +124,17 @@ def _read_from_stream(container, start_offset, end_offset, pts_unit, stream, str
124124
# print("Corrupted file?", container.name)
125125
return []
126126
buffer_count = 0
127-
for idx, frame in enumerate(container.decode(**stream_name)):
128-
frames[frame.pts] = frame
129-
if frame.pts >= end_offset:
130-
if should_buffer and buffer_count < max_buffer_size:
131-
buffer_count += 1
132-
continue
133-
break
127+
try:
128+
for idx, frame in enumerate(container.decode(**stream_name)):
129+
frames[frame.pts] = frame
130+
if frame.pts >= end_offset:
131+
if should_buffer and buffer_count < max_buffer_size:
132+
buffer_count += 1
133+
continue
134+
break
135+
except av.AVError:
136+
# TODO add a warning
137+
pass
134138
# ensure that the results are sorted wrt the pts
135139
result = [frames[i] for i in sorted(frames) if start_offset <= frames[i].pts <= end_offset]
136140
if start_offset > 0 and start_offset not in frames:
@@ -193,25 +197,39 @@ def read_video(filename, start_pts=0, end_pts=None, pts_unit='pts'):
193197
raise ValueError("end_pts should be larger than start_pts, got "
194198
"start_pts={} and end_pts={}".format(start_pts, end_pts))
195199

196-
container = av.open(filename, metadata_errors='ignore')
197200
info = {}
198-
199201
video_frames = []
200-
if container.streams.video:
201-
video_frames = _read_from_stream(container, start_pts, end_pts, pts_unit,
202-
container.streams.video[0], {'video': 0})
203-
info["video_fps"] = float(container.streams.video[0].average_rate)
204202
audio_frames = []
205-
if container.streams.audio:
206-
audio_frames = _read_from_stream(container, start_pts, end_pts, pts_unit,
207-
container.streams.audio[0], {'audio': 0})
208-
info["audio_fps"] = container.streams.audio[0].rate
209203

210-
container.close()
204+
try:
205+
container = av.open(filename, metadata_errors='ignore')
206+
except av.AVError:
207+
# TODO raise a warning?
208+
pass
209+
else:
210+
if container.streams.video:
211+
video_frames = _read_from_stream(container, start_pts, end_pts, pts_unit,
212+
container.streams.video[0], {'video': 0})
213+
video_fps = container.streams.video[0].average_rate
214+
# guard against potentially corrupted files
215+
if video_fps is not None:
216+
info["video_fps"] = float(video_fps)
217+
218+
if container.streams.audio:
219+
audio_frames = _read_from_stream(container, start_pts, end_pts, pts_unit,
220+
container.streams.audio[0], {'audio': 0})
221+
info["audio_fps"] = container.streams.audio[0].rate
222+
223+
container.close()
211224

212225
vframes = [frame.to_rgb().to_ndarray() for frame in video_frames]
213226
aframes = [frame.to_ndarray() for frame in audio_frames]
214-
vframes = torch.as_tensor(np.stack(vframes))
227+
228+
if vframes:
229+
vframes = torch.as_tensor(np.stack(vframes))
230+
else:
231+
vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8)
232+
215233
if aframes:
216234
aframes = np.concatenate(aframes, 1)
217235
aframes = torch.as_tensor(aframes)
@@ -255,21 +273,30 @@ def read_video_timestamps(filename, pts_unit='pts'):
255273
"""
256274
_check_av_available()
257275

258-
container = av.open(filename, metadata_errors='ignore')
259-
260276
video_frames = []
261277
video_fps = None
262-
if container.streams.video:
263-
video_stream = container.streams.video[0]
264-
video_time_base = video_stream.time_base
265-
if _can_read_timestamps_from_packets(container):
266-
# fast path
267-
video_frames = [x for x in container.demux(video=0) if x.pts is not None]
268-
else:
269-
video_frames = _read_from_stream(container, 0, float("inf"), pts_unit,
270-
video_stream, {'video': 0})
271-
video_fps = float(video_stream.average_rate)
272-
container.close()
278+
279+
try:
280+
container = av.open(filename, metadata_errors='ignore')
281+
except av.AVError:
282+
# TODO add a warning
283+
pass
284+
else:
285+
if container.streams.video:
286+
video_stream = container.streams.video[0]
287+
video_time_base = video_stream.time_base
288+
if _can_read_timestamps_from_packets(container):
289+
# fast path
290+
video_frames = [x for x in container.demux(video=0) if x.pts is not None]
291+
else:
292+
video_frames = _read_from_stream(container, 0, float("inf"), pts_unit,
293+
video_stream, {'video': 0})
294+
video_fps = float(video_stream.average_rate)
295+
container.close()
296+
297+
pts = [x.pts for x in video_frames]
298+
273299
if pts_unit == 'sec':
274-
return [x.pts * video_time_base for x in video_frames], video_fps
275-
return [x.pts for x in video_frames], video_fps
300+
pts = [x * video_time_base for x in pts]
301+
302+
return pts, video_fps

0 commit comments

Comments
 (0)