Skip to content

modified code of io.read_video and io.read_video_timestamps to intepret pts values in seconds #1331

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions test/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,44 @@ def test_read_timestamps_from_packet(self):

self.assertEqual(pts, expected_pts)

def test_read_video_pts_unit_sec(self):
with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
lv, _, info = io.read_video(f_name, pts_unit='sec')

self.assertTrue(data.equal(lv))
self.assertEqual(info["video_fps"], 5)

def test_read_timestamps_pts_unit_sec(self):
with temp_video(10, 300, 300, 5) as (f_name, data):
pts, _ = io.read_video_timestamps(f_name, pts_unit='sec')

container = av.open(f_name)
stream = container.streams[0]
pts_step = int(round(float(1 / (stream.average_rate * stream.time_base))))
num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration)))
expected_pts = [i * pts_step * stream.time_base for i in range(num_frames)]

self.assertEqual(pts, expected_pts)

def test_read_partial_video_pts_unit_sec(self):
with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
pts, _ = io.read_video_timestamps(f_name, pts_unit='sec')

for start in range(5):
for l in range(1, 4):
lv, _, _ = io.read_video(f_name, pts[start], pts[start + l - 1], pts_unit='sec')
s_data = data[start:(start + l)]
self.assertEqual(len(lv), l)
self.assertTrue(s_data.equal(lv))

container = av.open(f_name)
stream = container.streams[0]
lv, _, _ = io.read_video(f_name,
int(pts[4] * (1.0 / stream.time_base) + 1) * stream.time_base, pts[7],
pts_unit='sec')
self.assertEqual(len(lv), 4)
self.assertTrue(data[4:8].equal(lv))

# TODO add tests for audio


Expand Down
44 changes: 33 additions & 11 deletions torchvision/io/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import gc
import torch
import numpy as np
import math
import warnings

try:
import av
Expand Down Expand Up @@ -74,12 +76,20 @@ def write_video(filename, video_array, fps, video_codec='libx264', options=None)
container.close()


def _read_from_stream(container, start_offset, end_offset, stream, stream_name):
def _read_from_stream(container, start_offset, end_offset, pts_unit, stream, stream_name):
global _CALLED_TIMES, _GC_COLLECTION_INTERVAL
_CALLED_TIMES += 1
if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1:
gc.collect()

if pts_unit == 'sec':
start_offset = int(math.floor(start_offset * (1 / stream.time_base)))
if end_offset != float("inf"):
end_offset = int(math.ceil(end_offset * (1 / stream.time_base)))
else:
warnings.warn("The pts_unit 'pts' gives wrong results and will be removed in a " +
"follow-up version. Please use pts_unit 'sec'.")

frames = {}
should_buffer = False
max_buffer_size = 5
Expand Down Expand Up @@ -145,7 +155,7 @@ def _align_audio_frames(aframes, audio_frames, ref_start, ref_end):
return aframes[:, s_idx:e_idx]


def read_video(filename, start_pts=0, end_pts=None):
def read_video(filename, start_pts=0, end_pts=None, pts_unit='pts'):
"""
Reads a video from a file, returning both the video frames as well as
the audio frames
Expand All @@ -154,10 +164,14 @@ def read_video(filename, start_pts=0, end_pts=None):
----------
filename : str
path to the video file
start_pts : int, optional
start_pts : int if pts_unit = 'pts', optional
float / Fraction if pts_unit = 'sec', optional
the start presentation time of the video
end_pts : int, optional
end_pts : int if pts_unit = 'pts', optional
float / Fraction if pts_unit = 'sec', optional
the end presentation time
pts_unit : str, optional
unit in which start_pts and end_pts values will be interpreted, either 'pts' or 'sec'. Defaults to 'pts'.

Returns
-------
Expand All @@ -184,12 +198,12 @@ def read_video(filename, start_pts=0, end_pts=None):

video_frames = []
if container.streams.video:
video_frames = _read_from_stream(container, start_pts, end_pts,
video_frames = _read_from_stream(container, start_pts, end_pts, pts_unit,
container.streams.video[0], {'video': 0})
info["video_fps"] = float(container.streams.video[0].average_rate)
audio_frames = []
if container.streams.audio:
audio_frames = _read_from_stream(container, start_pts, end_pts,
audio_frames = _read_from_stream(container, start_pts, end_pts, pts_unit,
container.streams.audio[0], {'audio': 0})
info["audio_fps"] = container.streams.audio[0].rate

Expand Down Expand Up @@ -217,7 +231,7 @@ def _can_read_timestamps_from_packets(container):
return False


def read_video_timestamps(filename):
def read_video_timestamps(filename, pts_unit='pts'):
"""
List the video frames timestamps.

Expand All @@ -227,27 +241,35 @@ def read_video_timestamps(filename):
----------
filename : str
path to the video file
pts_unit : str, optional
unit in which timestamp values will be returned either 'pts' or 'sec'. Defaults to 'pts'.

Returns
-------
pts : List[int]
pts : List[int] if pts_unit = 'pts'
List[Fraction] if pts_unit = 'sec'
presentation timestamps for each one of the frames in the video.
video_fps : int
the frame rate for the video

"""
_check_av_available()

container = av.open(filename, metadata_errors='ignore')

video_frames = []
video_fps = None
if container.streams.video:
video_stream = container.streams.video[0]
video_time_base = video_stream.time_base
if _can_read_timestamps_from_packets(container):
# fast path
video_frames = [x for x in container.demux(video=0) if x.pts is not None]
else:
video_frames = _read_from_stream(container, 0, float("inf"),
container.streams.video[0], {'video': 0})
video_fps = float(container.streams.video[0].average_rate)
video_frames = _read_from_stream(container, 0, float("inf"), pts_unit,
video_stream, {'video': 0})
video_fps = float(video_stream.average_rate)
container.close()
if pts_unit == 'sec':
return [x.pts * video_time_base for x in video_frames], video_fps
return [x.pts for x in video_frames], video_fps