Skip to content

[DISCUSSION NEEDED] AV-sync fundamental issues #1248

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 30 additions & 9 deletions test/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import unittest
import sys
import warnings
from fractions import Fraction

from common_utils import get_tmp_dir

Expand Down Expand Up @@ -76,39 +77,40 @@ def test_read_timestamps(self):
# so we use it as a baseline
container = av.open(f_name)
stream = container.streams[0]
pts_step = int(round(float(1 / (stream.average_rate * stream.time_base))))
num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration)))
expected_pts = [i * pts_step for i in range(num_frames)]
# pts is a global fraction of a second, we expect
# to see num_frames frames whose global pts is 1/fps
expected_pts = [Fraction(i, stream.average_rate) for i in range(num_frames)]

self.assertEqual(pts, expected_pts)

def test_read_partial_video(self):
with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
pts, _ = io.read_video_timestamps(f_name)
pts, fps = io.read_video_timestamps(f_name)
for start in range(5):
for l in range(1, 4):
lv, _, _ = io.read_video(f_name, pts[start], pts[start + l - 1])
s_data = data[start:(start + l)]
self.assertEqual(len(lv), l)
self.assertTrue(s_data.equal(lv))

lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7])
lv, _, _ = io.read_video(f_name, pts[4] + 1 / (fps + 1), pts[7])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain the need of the +1 in (fps + 1)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding was that the original test was here to make sure that if we try to decode from a pts that doesn't exist (pts[4]+1) that we return the closest possible frame to that pts - then we check and get 4 frames (from pts[4] to pts[7]).

In the same way, 1/(fps+1) is a non-existing pts that is the closest to the existing (pts[4]).

self.assertEqual(len(lv), 4)
self.assertTrue(data[4:8].equal(lv))

def test_read_partial_video_bframes(self):
# do not use lossless encoding, to test the presence of B-frames
options = {'bframes': '16', 'keyint': '10', 'min-keyint': '4'}
with temp_video(100, 300, 300, 5, options=options) as (f_name, data):
pts, _ = io.read_video_timestamps(f_name)
pts, fps = io.read_video_timestamps(f_name)
for start in range(0, 80, 20):
for l in range(1, 4):
lv, _, _ = io.read_video(f_name, pts[start], pts[start + l - 1])
s_data = data[start:(start + l)]
self.assertEqual(len(lv), l)
self.assertTrue((s_data.float() - lv.float()).abs().max() < self.TOLERANCE)

lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7])
lv, _, _ = io.read_video(f_name, pts[4] + 1 / (fps + 1), pts[7])
self.assertEqual(len(lv), 4)
self.assertTrue((data[4:8].float() - lv.float()).abs().max() < self.TOLERANCE)

Expand Down Expand Up @@ -138,13 +140,32 @@ def test_read_timestamps_from_packet(self):
stream = container.streams[0]
# make sure we went through the optimized codepath
self.assertIn(b'Lavc', stream.codec_context.extradata)
pts_step = int(round(float(1 / (stream.average_rate * stream.time_base))))
num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration)))
expected_pts = [i * pts_step for i in range(num_frames)]

num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration)))
# pts is a global fraction of a second, thus we expect
# to see num_frames frames whose global pts is 1/fps
expected_pts = [Fraction(i, stream.average_rate) for i in range(num_frames)]
self.assertEqual(pts, expected_pts)

# TODO add tests for audio
def test_audio_cutting(self):
with get_tmp_dir() as temp_dir:
name = "hmdb51_Turnk_r_Pippi_Michel_cartwheel_f_cm_np2_le_med_6.avi"
f_name = os.path.join(temp_dir, name)
url = "https://download.pytorch.org/vision_tests/io/" + name
try:
utils.download_url(url, temp_dir)
pts, fps = io.read_video_timestamps(f_name)
self.assertEqual(pts, sorted(pts))
self.assertEqual(fps, 30)
except URLError:
msg = "could not download test file '{}'".format(url)
warnings.warn(msg, RuntimeWarning)
raise unittest.SkipTest(msg)

lv, la, info = io.read_video(f_name, pts[3], pts[7])
# FIXME: add Another video - this one doesn't have audio
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds reasonable - I'll take a look.

# self.assertEqual(lv/info['video_fps'], la/info['audio_fps'])


if __name__ == '__main__':
Expand Down
35 changes: 30 additions & 5 deletions torchvision/io/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import gc
import torch
import numpy as np
import warnings


try:
import av
Expand Down Expand Up @@ -83,6 +85,15 @@ def _read_from_stream(container, start_offset, end_offset, stream, stream_name):
frames = {}
should_buffer = False
max_buffer_size = 5
# we expect the input offset to be in the global format (i.e. seconds)
# to get the correct frame we look for the frames corresponding to
# second / time_base
cur_time_base = stream.time_base
seek_offset = int(round(float(start_offset / cur_time_base)))
start_offset = int(round(float(start_offset / cur_time_base)))
if end_offset != float("inf"):
end_offset = int(round(float(end_offset / cur_time_base)))

if stream.type == "video":
# DivX-style packed B-frames can have out-of-order pts (2 frames in a single pkt)
# so need to buffer some extra frames to sort everything
Expand All @@ -99,7 +110,7 @@ def _read_from_stream(container, start_offset, end_offset, stream, stream_name):
o = re.search(br"DivX(\d+)b(\d+)(\w)", d)
if o is not None:
should_buffer = o.group(3) == b"p"
seek_offset = start_offset

# some files don't seek to the right location, so better be safe here
seek_offset = max(seek_offset - 1, 0)
if should_buffer:
Expand Down Expand Up @@ -133,6 +144,11 @@ def _read_from_stream(container, start_offset, end_offset, stream, stream_name):


def _align_audio_frames(aframes, audio_frames, ref_start, ref_end):
# reference frames are global, here we convert them
# to a local audio representation pts values
ref_start = int(round(float(ref_start / audio_frames[0].time_base)))
ref_end = int(round(float(ref_end / audio_frames[0].time_base)))

start, end = audio_frames[0].pts, audio_frames[-1].pts
total_aframes = aframes.shape[1]
step_per_aframe = (end - start + 1) / total_aframes
Expand All @@ -155,9 +171,9 @@ def read_video(filename, start_pts=0, end_pts=None):
filename : str
path to the video file
start_pts : int, optional
the start presentation time of the video
the start presentation time of the video - global in seconds
end_pts : int, optional
the end presentation time
the end presentation time - global in seconds

Returns
-------
Expand Down Expand Up @@ -217,7 +233,7 @@ def _can_read_timestamps_from_packets(container):
return False


def read_video_timestamps(filename):
def read_video_timestamps(filename, output_format="global"):
"""
List the video frames timestamps.

Expand All @@ -232,6 +248,8 @@ def read_video_timestamps(filename):
-------
pts : List[int]
presentation timestamps for each one of the frames in the video.
Note that these are returned as the TRUE timestamps in seconds, i.e.
w.r.t. the global presentation time, not as a function of the stream.
video_fps : int
the frame rate for the video

Expand All @@ -250,4 +268,11 @@ def read_video_timestamps(filename):
container.streams.video[0], {'video': 0})
video_fps = float(container.streams.video[0].average_rate)
container.close()
return [x.pts for x in video_frames], video_fps

if output_format is None:
msg = "Currently, pts are returned as (int) w.r.t. video stream. This behaviour is being\
depreciated"
warinings.warn(msg, RuntimeWarning)
return [x.pts for x in video_frames], video_fps

return [x.pts * x.time_base for x in video_frames], video_fps
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a BC-breaking change, and I wonder if there would be a way of keeping backwards-compatibility for one version before removing the old behavior, with a loud warning.

Maybe we should add an extra option to read_video_timestamps, something like output_format=None, and raise a warning is it is None, and take the current behavior in this case. Similar to what has been done in interpolate, with the align_corners flag.
Thoughts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that makes a lot of sense.
Is there a preferred warning system in torchvision?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's just use warnings.warn for now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

S.G.
what do we want to keep as a default behaviour?