diff --git a/test/test_io.py b/test/test_io.py index 8b75cdea1c1..699d49c119c 100644 --- a/test/test_io.py +++ b/test/test_io.py @@ -7,6 +7,7 @@ import unittest import sys import warnings +from fractions import Fraction from common_utils import get_tmp_dir @@ -76,15 +77,16 @@ def test_read_timestamps(self): # so we use it as a baseline container = av.open(f_name) stream = container.streams[0] - pts_step = int(round(float(1 / (stream.average_rate * stream.time_base)))) num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration))) - expected_pts = [i * pts_step for i in range(num_frames)] + # pts is a global fraction of a second, we expect + # to see num_frames frames whose global pts is 1/fps + expected_pts = [Fraction(i, stream.average_rate) for i in range(num_frames)] self.assertEqual(pts, expected_pts) def test_read_partial_video(self): with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data): - pts, _ = io.read_video_timestamps(f_name) + pts, fps = io.read_video_timestamps(f_name) for start in range(5): for l in range(1, 4): lv, _, _ = io.read_video(f_name, pts[start], pts[start + l - 1]) @@ -92,7 +94,7 @@ def test_read_partial_video(self): self.assertEqual(len(lv), l) self.assertTrue(s_data.equal(lv)) - lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7]) + lv, _, _ = io.read_video(f_name, pts[4] + 1 / (fps + 1), pts[7]) self.assertEqual(len(lv), 4) self.assertTrue(data[4:8].equal(lv)) @@ -100,7 +102,7 @@ def test_read_partial_video_bframes(self): # do not use lossless encoding, to test the presence of B-frames options = {'bframes': '16', 'keyint': '10', 'min-keyint': '4'} with temp_video(100, 300, 300, 5, options=options) as (f_name, data): - pts, _ = io.read_video_timestamps(f_name) + pts, fps = io.read_video_timestamps(f_name) for start in range(0, 80, 20): for l in range(1, 4): lv, _, _ = io.read_video(f_name, pts[start], pts[start + l - 1]) @@ -108,7 +110,7 @@ def test_read_partial_video_bframes(self): self.assertEqual(len(lv), l) self.assertTrue((s_data.float() - lv.float()).abs().max() < self.TOLERANCE) - lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7]) + lv, _, _ = io.read_video(f_name, pts[4] + 1 / (fps + 1), pts[7]) self.assertEqual(len(lv), 4) self.assertTrue((data[4:8].float() - lv.float()).abs().max() < self.TOLERANCE) @@ -138,13 +140,32 @@ def test_read_timestamps_from_packet(self): stream = container.streams[0] # make sure we went through the optimized codepath self.assertIn(b'Lavc', stream.codec_context.extradata) - pts_step = int(round(float(1 / (stream.average_rate * stream.time_base)))) - num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration))) - expected_pts = [i * pts_step for i in range(num_frames)] + num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration))) + # pts is a global fraction of a second, thus we expect + # to see num_frames frames whose global pts is 1/fps + expected_pts = [Fraction(i, stream.average_rate) for i in range(num_frames)] self.assertEqual(pts, expected_pts) # TODO add tests for audio + def test_audio_cutting(self): + with get_tmp_dir() as temp_dir: + name = "hmdb51_Turnk_r_Pippi_Michel_cartwheel_f_cm_np2_le_med_6.avi" + f_name = os.path.join(temp_dir, name) + url = "https://download.pytorch.org/vision_tests/io/" + name + try: + utils.download_url(url, temp_dir) + pts, fps = io.read_video_timestamps(f_name) + self.assertEqual(pts, sorted(pts)) + self.assertEqual(fps, 30) + except URLError: + msg = "could not download test file '{}'".format(url) + warnings.warn(msg, RuntimeWarning) + raise unittest.SkipTest(msg) + + lv, la, info = io.read_video(f_name, pts[3], pts[7]) + # FIXME: add Another video - this one doesn't have audio + # self.assertEqual(lv/info['video_fps'], la/info['audio_fps']) if __name__ == '__main__': diff --git a/torchvision/io/video.py b/torchvision/io/video.py index bd25c224ecb..35057ea5efd 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -2,6 +2,8 @@ import gc import torch import numpy as np +import warnings + try: import av @@ -83,6 +85,15 @@ def _read_from_stream(container, start_offset, end_offset, stream, stream_name): frames = {} should_buffer = False max_buffer_size = 5 + # we expect the input offset to be in the global format (i.e. seconds) + # to get the correct frame we look for the frames corresponding to + # second / time_base + cur_time_base = stream.time_base + seek_offset = int(round(float(start_offset / cur_time_base))) + start_offset = int(round(float(start_offset / cur_time_base))) + if end_offset != float("inf"): + end_offset = int(round(float(end_offset / cur_time_base))) + if stream.type == "video": # DivX-style packed B-frames can have out-of-order pts (2 frames in a single pkt) # so need to buffer some extra frames to sort everything @@ -99,7 +110,7 @@ def _read_from_stream(container, start_offset, end_offset, stream, stream_name): o = re.search(br"DivX(\d+)b(\d+)(\w)", d) if o is not None: should_buffer = o.group(3) == b"p" - seek_offset = start_offset + # some files don't seek to the right location, so better be safe here seek_offset = max(seek_offset - 1, 0) if should_buffer: @@ -133,6 +144,11 @@ def _read_from_stream(container, start_offset, end_offset, stream, stream_name): def _align_audio_frames(aframes, audio_frames, ref_start, ref_end): + # reference frames are global, here we convert them + # to a local audio representation pts values + ref_start = int(round(float(ref_start / audio_frames[0].time_base))) + ref_end = int(round(float(ref_end / audio_frames[0].time_base))) + start, end = audio_frames[0].pts, audio_frames[-1].pts total_aframes = aframes.shape[1] step_per_aframe = (end - start + 1) / total_aframes @@ -155,9 +171,9 @@ def read_video(filename, start_pts=0, end_pts=None): filename : str path to the video file start_pts : int, optional - the start presentation time of the video + the start presentation time of the video - global in seconds end_pts : int, optional - the end presentation time + the end presentation time - global in seconds Returns ------- @@ -217,7 +233,7 @@ def _can_read_timestamps_from_packets(container): return False -def read_video_timestamps(filename): +def read_video_timestamps(filename, output_format="global"): """ List the video frames timestamps. @@ -232,6 +248,8 @@ def read_video_timestamps(filename): ------- pts : List[int] presentation timestamps for each one of the frames in the video. + Note that these are returned as the TRUE timestamps in seconds, i.e. + w.r.t. the global presentation time, not as a function of the stream. video_fps : int the frame rate for the video @@ -250,4 +268,11 @@ def read_video_timestamps(filename): container.streams.video[0], {'video': 0}) video_fps = float(container.streams.video[0].average_rate) container.close() - return [x.pts for x in video_frames], video_fps + + if output_format is None: + msg = "Currently, pts are returned as (int) w.r.t. video stream. This behaviour is being\ + depreciated" + warinings.warn(msg, RuntimeWarning) + return [x.pts for x in video_frames], video_fps + + return [x.pts * x.time_base for x in video_frames], video_fps