From a24653849040ea0fb57d93201e5dae90b7dbff5f Mon Sep 17 00:00:00 2001 From: Chandresh Kanani Date: Thu, 12 Sep 2019 21:07:09 +0530 Subject: [PATCH 1/4] modified code of io.read_video and io.read_video_timestamps to interpret pts values in seconds --- torchvision/io/video.py | 54 ++++++++++++++++++++++++++++++++--------- 1 file changed, 43 insertions(+), 11 deletions(-) diff --git a/torchvision/io/video.py b/torchvision/io/video.py index bd25c224ecb..ffd1755cdb1 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -2,6 +2,8 @@ import gc import torch import numpy as np +import math +import warnings try: import av @@ -145,7 +147,7 @@ def _align_audio_frames(aframes, audio_frames, ref_start, ref_end): return aframes[:, s_idx:e_idx] -def read_video(filename, start_pts=0, end_pts=None): +def read_video(filename, start_pts=0, end_pts=None, pts_unit='sec'): """ Reads a video from a file, returning both the video frames as well as the audio frames @@ -158,6 +160,8 @@ def read_video(filename, start_pts=0, end_pts=None): the start presentation time of the video end_pts : int, optional the end presentation time + pts_unit : str, optional + unit in which start_pts and end_pts values will be interpreted, either 'pts' or 'sec'. Defaults to 'sec'. Returns ------- @@ -179,19 +183,37 @@ def read_video(filename, start_pts=0, end_pts=None): raise ValueError("end_pts should be larger than start_pts, got " "start_pts={} and end_pts={}".format(start_pts, end_pts)) + if pts_unit == 'pts': + warnings.warn("The pts_unit 'pts' gives wrong results and will be removed in a " + + "follow-up version. Please use pts_unit 'sec'.") + container = av.open(filename, metadata_errors='ignore') info = {} video_frames = [] if container.streams.video: - video_frames = _read_from_stream(container, start_pts, end_pts, - container.streams.video[0], {'video': 0}) - info["video_fps"] = float(container.streams.video[0].average_rate) + video_start_pts = start_pts + video_end_pts = end_pts + video_stream = container.streams.video[0] + if pts_unit == 'sec': + video_start_pts = math.floor(start_pts*(1/video_stream.time_base)) + if video_end_pts != float("inf"): + video_end_pts = math.ceil(end_pts*(1/video_stream.time_base)) + video_frames = _read_from_stream(container, video_start_pts, video_end_pts, + video_stream, {'video': 0}) + info["video_fps"] = float(video_stream.average_rate) audio_frames = [] if container.streams.audio: - audio_frames = _read_from_stream(container, start_pts, end_pts, - container.streams.audio[0], {'audio': 0}) - info["audio_fps"] = container.streams.audio[0].rate + audio_start_pts = start_pts + audio_end_pts = end_pts + audio_stream = container.streams.audio[0] + if pts_unit == 'sec': + audio_start_pts = math.floor(start_pts*(1/audio_stream.time_base)) + if audio_end_pts != float("inf"): + audio_end_pts = math.ceil(end_pts*(1/audio_stream.time_base)) + audio_frames = _read_from_stream(container, audio_start_pts, audio_end_pts, + audio_stream , {'audio': 0}) + info["audio_fps"] = audio_stream.rate container.close() @@ -217,7 +239,7 @@ def _can_read_timestamps_from_packets(container): return False -def read_video_timestamps(filename): +def read_video_timestamps(filename, pts_unit='pts'): """ List the video frames timestamps. @@ -227,27 +249,37 @@ def read_video_timestamps(filename): ---------- filename : str path to the video file + pts_unit : str, optional + unit in which timestamp values will be returned either 'pts' or 'sec'. Defaults to 'sec'. Returns ------- - pts : List[int] + pts : List[float] presentation timestamps for each one of the frames in the video. video_fps : int the frame rate for the video """ _check_av_available() + if pts_unit == 'pts': + warnings.warn("The pts_unit 'pts' gives wrong results and will be removed in a " + + "follow-up version. Please use pts_unit 'sec'.") + container = av.open(filename, metadata_errors='ignore') video_frames = [] video_fps = None if container.streams.video: + video_stream = container.streams.video[0] + video_time_base = video_stream.time_base if _can_read_timestamps_from_packets(container): # fast path video_frames = [x for x in container.demux(video=0) if x.pts is not None] else: video_frames = _read_from_stream(container, 0, float("inf"), - container.streams.video[0], {'video': 0}) - video_fps = float(container.streams.video[0].average_rate) + video_stream, {'video': 0}) + video_fps = float(video_stream.average_rate) container.close() + if pts_unit == 'sec': + return [float(x.pts*video_time_base) for x in video_frames], video_fps return [x.pts for x in video_frames], video_fps From 135278401b150699f769bfd4596495100030bf64 Mon Sep 17 00:00:00 2001 From: Chandresh Kanani Date: Thu, 19 Sep 2019 12:54:07 +0530 Subject: [PATCH 2/4] changed default value for pts_unit to pts, corrected formatting --- torchvision/io/video.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/torchvision/io/video.py b/torchvision/io/video.py index ffd1755cdb1..67096dc7742 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -147,7 +147,7 @@ def _align_audio_frames(aframes, audio_frames, ref_start, ref_end): return aframes[:, s_idx:e_idx] -def read_video(filename, start_pts=0, end_pts=None, pts_unit='sec'): +def read_video(filename, start_pts=0, end_pts=None, pts_unit='pts'): """ Reads a video from a file, returning both the video frames as well as the audio frames @@ -161,7 +161,7 @@ def read_video(filename, start_pts=0, end_pts=None, pts_unit='sec'): end_pts : int, optional the end presentation time pts_unit : str, optional - unit in which start_pts and end_pts values will be interpreted, either 'pts' or 'sec'. Defaults to 'sec'. + unit in which start_pts and end_pts values will be interpreted, either 'pts' or 'sec'. Defaults to 'pts'. Returns ------- @@ -184,7 +184,7 @@ def read_video(filename, start_pts=0, end_pts=None, pts_unit='sec'): "start_pts={} and end_pts={}".format(start_pts, end_pts)) if pts_unit == 'pts': - warnings.warn("The pts_unit 'pts' gives wrong results and will be removed in a " + + warnings.warn("The pts_unit 'pts' gives wrong results and will be removed in a " + "follow-up version. Please use pts_unit 'sec'.") container = av.open(filename, metadata_errors='ignore') @@ -196,9 +196,9 @@ def read_video(filename, start_pts=0, end_pts=None, pts_unit='sec'): video_end_pts = end_pts video_stream = container.streams.video[0] if pts_unit == 'sec': - video_start_pts = math.floor(start_pts*(1/video_stream.time_base)) + video_start_pts = math.floor(start_pts * (1 / video_stream.time_base)) if video_end_pts != float("inf"): - video_end_pts = math.ceil(end_pts*(1/video_stream.time_base)) + video_end_pts = math.ceil(end_pts * (1 / video_stream.time_base)) video_frames = _read_from_stream(container, video_start_pts, video_end_pts, video_stream, {'video': 0}) info["video_fps"] = float(video_stream.average_rate) @@ -208,11 +208,11 @@ def read_video(filename, start_pts=0, end_pts=None, pts_unit='sec'): audio_end_pts = end_pts audio_stream = container.streams.audio[0] if pts_unit == 'sec': - audio_start_pts = math.floor(start_pts*(1/audio_stream.time_base)) + audio_start_pts = math.floor(start_pts * (1 / audio_stream.time_base)) if audio_end_pts != float("inf"): - audio_end_pts = math.ceil(end_pts*(1/audio_stream.time_base)) + audio_end_pts = math.ceil(end_pts * (1 / audio_stream.time_base)) audio_frames = _read_from_stream(container, audio_start_pts, audio_end_pts, - audio_stream , {'audio': 0}) + audio_stream, {'audio': 0}) info["audio_fps"] = audio_stream.rate container.close() @@ -250,7 +250,7 @@ def read_video_timestamps(filename, pts_unit='pts'): filename : str path to the video file pts_unit : str, optional - unit in which timestamp values will be returned either 'pts' or 'sec'. Defaults to 'sec'. + unit in which timestamp values will be returned either 'pts' or 'sec'. Defaults to 'pts'. Returns ------- @@ -262,7 +262,7 @@ def read_video_timestamps(filename, pts_unit='pts'): """ _check_av_available() if pts_unit == 'pts': - warnings.warn("The pts_unit 'pts' gives wrong results and will be removed in a " + + warnings.warn("The pts_unit 'pts' gives wrong results and will be removed in a " + "follow-up version. Please use pts_unit 'sec'.") container = av.open(filename, metadata_errors='ignore') @@ -281,5 +281,5 @@ def read_video_timestamps(filename, pts_unit='pts'): video_fps = float(video_stream.average_rate) container.close() if pts_unit == 'sec': - return [float(x.pts*video_time_base) for x in video_frames], video_fps + return [float(x.pts * video_time_base) for x in video_frames], video_fps return [x.pts for x in video_frames], video_fps From 2905d14ad6392336475c5e103c3c5724aa976fc6 Mon Sep 17 00:00:00 2001 From: Chandresh Kanani Date: Wed, 25 Sep 2019 22:30:11 +0530 Subject: [PATCH 3/4] hanndliing both fractions and floats for start_pts and end_pts, added test cases for pts_unit sec --- test/test_io.py | 36 ++++++++++++++++++++++++++++++++++++ torchvision/io/video.py | 19 +++++++++++-------- 2 files changed, 47 insertions(+), 8 deletions(-) diff --git a/test/test_io.py b/test/test_io.py index 96c33a4be68..cda706c8aac 100644 --- a/test/test_io.py +++ b/test/test_io.py @@ -145,6 +145,42 @@ def test_read_timestamps_from_packet(self): self.assertEqual(pts, expected_pts) + def test_read_video_pts_unit_sec(self): + with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data): + lv, _, info = io.read_video(f_name, pts_unit='sec') + + self.assertTrue(data.equal(lv)) + self.assertEqual(info["video_fps"], 5) + + def test_read_timestamps_pts_unit_sec(self): + with temp_video(10, 300, 300, 5) as (f_name, data): + pts, _ = io.read_video_timestamps(f_name, pts_unit='sec') + + container = av.open(f_name) + stream = container.streams[0] + pts_step = int(round(float(1 / (stream.average_rate * stream.time_base)))) + num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration))) + expected_pts = [i * pts_step * stream.time_base for i in range(num_frames)] + + self.assertEqual(pts, expected_pts) + + def test_read_partial_video_pts_unit_sec(self): + with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data): + pts, _ = io.read_video_timestamps(f_name, pts_unit='sec') + + for start in range(5): + for l in range(1, 4): + lv, _, _ = io.read_video(f_name, pts[start], pts[start + l - 1], pts_unit='sec') + s_data = data[start:(start + l)] + self.assertEqual(len(lv), l) + self.assertTrue(s_data.equal(lv)) + + container = av.open(f_name) + stream = container.streams[0] + lv, _, _ = io.read_video(f_name, int(pts[4] * (1.0 / stream.time_base) + 1) * stream.time_base, pts[7], pts_unit='sec') + self.assertEqual(len(lv), 4) + self.assertTrue(data[4:8].equal(lv)) + # TODO add tests for audio diff --git a/torchvision/io/video.py b/torchvision/io/video.py index 67096dc7742..1fdbf8cf727 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -156,9 +156,11 @@ def read_video(filename, start_pts=0, end_pts=None, pts_unit='pts'): ---------- filename : str path to the video file - start_pts : int, optional + start_pts : int if pts_unit = 'pts', optional + float / Fraction if pts_unit = 'sec', optional the start presentation time of the video - end_pts : int, optional + end_pts : int if pts_unit = 'pts', optional + float / Fraction if pts_unit = 'sec', optional the end presentation time pts_unit : str, optional unit in which start_pts and end_pts values will be interpreted, either 'pts' or 'sec'. Defaults to 'pts'. @@ -196,9 +198,9 @@ def read_video(filename, start_pts=0, end_pts=None, pts_unit='pts'): video_end_pts = end_pts video_stream = container.streams.video[0] if pts_unit == 'sec': - video_start_pts = math.floor(start_pts * (1 / video_stream.time_base)) + video_start_pts = int(math.floor(start_pts * (1 / video_stream.time_base))) if video_end_pts != float("inf"): - video_end_pts = math.ceil(end_pts * (1 / video_stream.time_base)) + video_end_pts = int(math.ceil(end_pts * (1 / video_stream.time_base))) video_frames = _read_from_stream(container, video_start_pts, video_end_pts, video_stream, {'video': 0}) info["video_fps"] = float(video_stream.average_rate) @@ -208,9 +210,9 @@ def read_video(filename, start_pts=0, end_pts=None, pts_unit='pts'): audio_end_pts = end_pts audio_stream = container.streams.audio[0] if pts_unit == 'sec': - audio_start_pts = math.floor(start_pts * (1 / audio_stream.time_base)) + audio_start_pts = int(math.floor(start_pts * (1 / audio_stream.time_base))) if audio_end_pts != float("inf"): - audio_end_pts = math.ceil(end_pts * (1 / audio_stream.time_base)) + audio_end_pts = int(math.ceil(end_pts * (1 / audio_stream.time_base))) audio_frames = _read_from_stream(container, audio_start_pts, audio_end_pts, audio_stream, {'audio': 0}) info["audio_fps"] = audio_stream.rate @@ -254,7 +256,8 @@ def read_video_timestamps(filename, pts_unit='pts'): Returns ------- - pts : List[float] + pts : List[int] if pts_unit = 'pts' + List[Fraction] if pts_unit = 'sec' presentation timestamps for each one of the frames in the video. video_fps : int the frame rate for the video @@ -281,5 +284,5 @@ def read_video_timestamps(filename, pts_unit='pts'): video_fps = float(video_stream.average_rate) container.close() if pts_unit == 'sec': - return [float(x.pts * video_time_base) for x in video_frames], video_fps + return [x.pts * video_time_base for x in video_frames], video_fps return [x.pts for x in video_frames], video_fps From 899a6e99dd42375eb3cc4ec2ce60d6ce0841b1b1 Mon Sep 17 00:00:00 2001 From: Chandresh Kanani Date: Fri, 27 Sep 2019 00:16:58 +0530 Subject: [PATCH 4/4] moved unit conversion logic to _read_from_stream method --- test/test_io.py | 4 +++- torchvision/io/video.py | 45 +++++++++++++++-------------------------- 2 files changed, 19 insertions(+), 30 deletions(-) diff --git a/test/test_io.py b/test/test_io.py index cda706c8aac..f962fb270bc 100644 --- a/test/test_io.py +++ b/test/test_io.py @@ -177,7 +177,9 @@ def test_read_partial_video_pts_unit_sec(self): container = av.open(f_name) stream = container.streams[0] - lv, _, _ = io.read_video(f_name, int(pts[4] * (1.0 / stream.time_base) + 1) * stream.time_base, pts[7], pts_unit='sec') + lv, _, _ = io.read_video(f_name, + int(pts[4] * (1.0 / stream.time_base) + 1) * stream.time_base, pts[7], + pts_unit='sec') self.assertEqual(len(lv), 4) self.assertTrue(data[4:8].equal(lv)) diff --git a/torchvision/io/video.py b/torchvision/io/video.py index 1fdbf8cf727..a957e96e9f5 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -76,12 +76,20 @@ def write_video(filename, video_array, fps, video_codec='libx264', options=None) container.close() -def _read_from_stream(container, start_offset, end_offset, stream, stream_name): +def _read_from_stream(container, start_offset, end_offset, pts_unit, stream, stream_name): global _CALLED_TIMES, _GC_COLLECTION_INTERVAL _CALLED_TIMES += 1 if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1: gc.collect() + if pts_unit == 'sec': + start_offset = int(math.floor(start_offset * (1 / stream.time_base))) + if end_offset != float("inf"): + end_offset = int(math.ceil(end_offset * (1 / stream.time_base))) + else: + warnings.warn("The pts_unit 'pts' gives wrong results and will be removed in a " + + "follow-up version. Please use pts_unit 'sec'.") + frames = {} should_buffer = False max_buffer_size = 5 @@ -185,37 +193,19 @@ def read_video(filename, start_pts=0, end_pts=None, pts_unit='pts'): raise ValueError("end_pts should be larger than start_pts, got " "start_pts={} and end_pts={}".format(start_pts, end_pts)) - if pts_unit == 'pts': - warnings.warn("The pts_unit 'pts' gives wrong results and will be removed in a " + - "follow-up version. Please use pts_unit 'sec'.") - container = av.open(filename, metadata_errors='ignore') info = {} video_frames = [] if container.streams.video: - video_start_pts = start_pts - video_end_pts = end_pts - video_stream = container.streams.video[0] - if pts_unit == 'sec': - video_start_pts = int(math.floor(start_pts * (1 / video_stream.time_base))) - if video_end_pts != float("inf"): - video_end_pts = int(math.ceil(end_pts * (1 / video_stream.time_base))) - video_frames = _read_from_stream(container, video_start_pts, video_end_pts, - video_stream, {'video': 0}) - info["video_fps"] = float(video_stream.average_rate) + video_frames = _read_from_stream(container, start_pts, end_pts, pts_unit, + container.streams.video[0], {'video': 0}) + info["video_fps"] = float(container.streams.video[0].average_rate) audio_frames = [] if container.streams.audio: - audio_start_pts = start_pts - audio_end_pts = end_pts - audio_stream = container.streams.audio[0] - if pts_unit == 'sec': - audio_start_pts = int(math.floor(start_pts * (1 / audio_stream.time_base))) - if audio_end_pts != float("inf"): - audio_end_pts = int(math.ceil(end_pts * (1 / audio_stream.time_base))) - audio_frames = _read_from_stream(container, audio_start_pts, audio_end_pts, - audio_stream, {'audio': 0}) - info["audio_fps"] = audio_stream.rate + audio_frames = _read_from_stream(container, start_pts, end_pts, pts_unit, + container.streams.audio[0], {'audio': 0}) + info["audio_fps"] = container.streams.audio[0].rate container.close() @@ -264,9 +254,6 @@ def read_video_timestamps(filename, pts_unit='pts'): """ _check_av_available() - if pts_unit == 'pts': - warnings.warn("The pts_unit 'pts' gives wrong results and will be removed in a " + - "follow-up version. Please use pts_unit 'sec'.") container = av.open(filename, metadata_errors='ignore') @@ -279,7 +266,7 @@ def read_video_timestamps(filename, pts_unit='pts'): # fast path video_frames = [x for x in container.demux(video=0) if x.pts is not None] else: - video_frames = _read_from_stream(container, 0, float("inf"), + video_frames = _read_from_stream(container, 0, float("inf"), pts_unit, video_stream, {'video': 0}) video_fps = float(video_stream.average_rate) container.close()