Skip to content

modified code of io.read_video to interpret start_pts and end_pts in seconds #1313

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 43 additions & 11 deletions torchvision/io/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import gc
import torch
import numpy as np
import math
import warnings

try:
import av
Expand Down Expand Up @@ -145,7 +147,7 @@ def _align_audio_frames(aframes, audio_frames, ref_start, ref_end):
return aframes[:, s_idx:e_idx]


def read_video(filename, start_pts=0, end_pts=None):
def read_video(filename, start_pts=0, end_pts=None, pts_unit='sec'):
"""
Reads a video from a file, returning both the video frames as well as
the audio frames
Expand All @@ -158,6 +160,8 @@ def read_video(filename, start_pts=0, end_pts=None):
the start presentation time of the video
end_pts : int, optional
the end presentation time
pts_unit : str, optional
unit in which start_pts and end_pts values will be interpreted, either 'pts' or 'sec. Defaults to 'sec'.

Returns
-------
Expand All @@ -184,14 +188,34 @@ def read_video(filename, start_pts=0, end_pts=None):

video_frames = []
if container.streams.video:
video_frames = _read_from_stream(container, start_pts, end_pts,
container.streams.video[0], {'video': 0})
info["video_fps"] = float(container.streams.video[0].average_rate)
video_stream = container.streams.video[0]
if pts_unit == 'pts':
warnings.warn("The pts_unit 'pts' produces wrong results and will be removed
in a follow-up version. Please use pts_unit 'sec'.")
video_start_pts = start_pts
video_end_pts = end_pts
else:
video_start_pts = math.floor(start_pts*(1/video_stream.time_base))
if video_end_pts != float("inf"):
video_end_pts = math.ceil(end_pts*(1/video_stream.time_base))
video_frames = _read_from_stream(container, video_start_pts, video_end_pts,
video_stream, {'video': 0})
info["video_fps"] = float(video_stream.average_rate)
audio_frames = []
if container.streams.audio:
audio_frames = _read_from_stream(container, start_pts, end_pts,
container.streams.audio[0], {'audio': 0})
info["audio_fps"] = container.streams.audio[0].rate
audio_stream = container.streams.audio[0]
if pts_unit == 'pts':
warnings.warn("The pts_unit 'pts' produces wrong results and will be removed
in a follow-up version. Please use pts_unit 'sec'.")
audio_start_pts = start_pts
audio_end_pts = end_pts
else:
audio_start_pts = math.floor(start_pts*(1/audio_stream.time_base))
if end_pts != float("inf"):
audio_end_pts = math.ceil(end_pts*(1/audio_stream.time_base))
audio_frames = _read_from_stream(container, audio_start_pts, audio_end_pts,
audio_stream , {'audio': 0})
info["audio_fps"] = audio_stream.rate

container.close()

Expand All @@ -217,7 +241,7 @@ def _can_read_timestamps_from_packets(container):
return False


def read_video_timestamps(filename):
def read_video_timestamps(filename, pts_unit='pts'):
"""
List the video frames timestamps.

Expand All @@ -227,6 +251,8 @@ def read_video_timestamps(filename):
----------
filename : str
path to the video file
pts_unit : str, optional
unit in which timestamp values will be returned either 'pts' or 'sec. Defaults to 'sec'.

Returns
-------
Expand All @@ -242,12 +268,18 @@ def read_video_timestamps(filename):
video_frames = []
video_fps = None
if container.streams.video:
video_stream = container.streams.video[0]
if _can_read_timestamps_from_packets(container):
# fast path
video_frames = [x for x in container.demux(video=0) if x.pts is not None]
else:
video_frames = _read_from_stream(container, 0, float("inf"),
container.streams.video[0], {'video': 0})
video_fps = float(container.streams.video[0].average_rate)
video_stream, {'video': 0})
video_fps = float(video_stream.average_rate)
container.close()
return [x.pts for x in video_frames], video_fps
if pts_unit == 'pts':
warnings.warn("The pts_unit 'pts' produces wrong results and will be removed in a
follow-up version. Please use pts_unit 'sec'.")
return [x.pts for x in video_frames], video_fps
else:
return [x.pts*video_stream.time_base for x in video_frames], video_fps