Skip to content

Commit c828632

Browse files
cskananifmassa
authored andcommitted
modified code of io.read_video and io.read_video_timestamps to intepret pts values in seconds (#1331)
* modified code of io.read_video and io.read_video_timestamps to interpret pts values in seconds * changed default value for pts_unit to pts, corrected formatting * hanndliing both fractions and floats for start_pts and end_pts, added test cases for pts_unit sec * moved unit conversion logic to _read_from_stream method
1 parent 735b748 commit c828632

File tree

2 files changed

+71
-11
lines changed

2 files changed

+71
-11
lines changed

test/test_io.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,44 @@ def test_read_timestamps_from_packet(self):
181181

182182
self.assertEqual(pts, expected_pts)
183183

184+
def test_read_video_pts_unit_sec(self):
185+
with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
186+
lv, _, info = io.read_video(f_name, pts_unit='sec')
187+
188+
self.assertTrue(data.equal(lv))
189+
self.assertEqual(info["video_fps"], 5)
190+
191+
def test_read_timestamps_pts_unit_sec(self):
192+
with temp_video(10, 300, 300, 5) as (f_name, data):
193+
pts, _ = io.read_video_timestamps(f_name, pts_unit='sec')
194+
195+
container = av.open(f_name)
196+
stream = container.streams[0]
197+
pts_step = int(round(float(1 / (stream.average_rate * stream.time_base))))
198+
num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration)))
199+
expected_pts = [i * pts_step * stream.time_base for i in range(num_frames)]
200+
201+
self.assertEqual(pts, expected_pts)
202+
203+
def test_read_partial_video_pts_unit_sec(self):
204+
with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
205+
pts, _ = io.read_video_timestamps(f_name, pts_unit='sec')
206+
207+
for start in range(5):
208+
for l in range(1, 4):
209+
lv, _, _ = io.read_video(f_name, pts[start], pts[start + l - 1], pts_unit='sec')
210+
s_data = data[start:(start + l)]
211+
self.assertEqual(len(lv), l)
212+
self.assertTrue(s_data.equal(lv))
213+
214+
container = av.open(f_name)
215+
stream = container.streams[0]
216+
lv, _, _ = io.read_video(f_name,
217+
int(pts[4] * (1.0 / stream.time_base) + 1) * stream.time_base, pts[7],
218+
pts_unit='sec')
219+
self.assertEqual(len(lv), 4)
220+
self.assertTrue(data[4:8].equal(lv))
221+
184222
# TODO add tests for audio
185223

186224

torchvision/io/video.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import gc
33
import torch
44
import numpy as np
5+
import math
6+
import warnings
57

68
try:
79
import av
@@ -74,12 +76,20 @@ def write_video(filename, video_array, fps, video_codec='libx264', options=None)
7476
container.close()
7577

7678

77-
def _read_from_stream(container, start_offset, end_offset, stream, stream_name):
79+
def _read_from_stream(container, start_offset, end_offset, pts_unit, stream, stream_name):
7880
global _CALLED_TIMES, _GC_COLLECTION_INTERVAL
7981
_CALLED_TIMES += 1
8082
if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1:
8183
gc.collect()
8284

85+
if pts_unit == 'sec':
86+
start_offset = int(math.floor(start_offset * (1 / stream.time_base)))
87+
if end_offset != float("inf"):
88+
end_offset = int(math.ceil(end_offset * (1 / stream.time_base)))
89+
else:
90+
warnings.warn("The pts_unit 'pts' gives wrong results and will be removed in a " +
91+
"follow-up version. Please use pts_unit 'sec'.")
92+
8393
frames = {}
8494
should_buffer = False
8595
max_buffer_size = 5
@@ -145,7 +155,7 @@ def _align_audio_frames(aframes, audio_frames, ref_start, ref_end):
145155
return aframes[:, s_idx:e_idx]
146156

147157

148-
def read_video(filename, start_pts=0, end_pts=None):
158+
def read_video(filename, start_pts=0, end_pts=None, pts_unit='pts'):
149159
"""
150160
Reads a video from a file, returning both the video frames as well as
151161
the audio frames
@@ -154,10 +164,14 @@ def read_video(filename, start_pts=0, end_pts=None):
154164
----------
155165
filename : str
156166
path to the video file
157-
start_pts : int, optional
167+
start_pts : int if pts_unit = 'pts', optional
168+
float / Fraction if pts_unit = 'sec', optional
158169
the start presentation time of the video
159-
end_pts : int, optional
170+
end_pts : int if pts_unit = 'pts', optional
171+
float / Fraction if pts_unit = 'sec', optional
160172
the end presentation time
173+
pts_unit : str, optional
174+
unit in which start_pts and end_pts values will be interpreted, either 'pts' or 'sec'. Defaults to 'pts'.
161175
162176
Returns
163177
-------
@@ -184,12 +198,12 @@ def read_video(filename, start_pts=0, end_pts=None):
184198

185199
video_frames = []
186200
if container.streams.video:
187-
video_frames = _read_from_stream(container, start_pts, end_pts,
201+
video_frames = _read_from_stream(container, start_pts, end_pts, pts_unit,
188202
container.streams.video[0], {'video': 0})
189203
info["video_fps"] = float(container.streams.video[0].average_rate)
190204
audio_frames = []
191205
if container.streams.audio:
192-
audio_frames = _read_from_stream(container, start_pts, end_pts,
206+
audio_frames = _read_from_stream(container, start_pts, end_pts, pts_unit,
193207
container.streams.audio[0], {'audio': 0})
194208
info["audio_fps"] = container.streams.audio[0].rate
195209

@@ -217,7 +231,7 @@ def _can_read_timestamps_from_packets(container):
217231
return False
218232

219233

220-
def read_video_timestamps(filename):
234+
def read_video_timestamps(filename, pts_unit='pts'):
221235
"""
222236
List the video frames timestamps.
223237
@@ -227,27 +241,35 @@ def read_video_timestamps(filename):
227241
----------
228242
filename : str
229243
path to the video file
244+
pts_unit : str, optional
245+
unit in which timestamp values will be returned either 'pts' or 'sec'. Defaults to 'pts'.
230246
231247
Returns
232248
-------
233-
pts : List[int]
249+
pts : List[int] if pts_unit = 'pts'
250+
List[Fraction] if pts_unit = 'sec'
234251
presentation timestamps for each one of the frames in the video.
235252
video_fps : int
236253
the frame rate for the video
237254
238255
"""
239256
_check_av_available()
257+
240258
container = av.open(filename, metadata_errors='ignore')
241259

242260
video_frames = []
243261
video_fps = None
244262
if container.streams.video:
263+
video_stream = container.streams.video[0]
264+
video_time_base = video_stream.time_base
245265
if _can_read_timestamps_from_packets(container):
246266
# fast path
247267
video_frames = [x for x in container.demux(video=0) if x.pts is not None]
248268
else:
249-
video_frames = _read_from_stream(container, 0, float("inf"),
250-
container.streams.video[0], {'video': 0})
251-
video_fps = float(container.streams.video[0].average_rate)
269+
video_frames = _read_from_stream(container, 0, float("inf"), pts_unit,
270+
video_stream, {'video': 0})
271+
video_fps = float(video_stream.average_rate)
252272
container.close()
273+
if pts_unit == 'sec':
274+
return [x.pts * video_time_base for x in video_frames], video_fps
253275
return [x.pts for x in video_frames], video_fps

0 commit comments

Comments
 (0)