Skip to content

Commit f60df2d

Browse files
stephenyan1231fmassa
authored andcommitted
add _backend argument to __init__() of class VideoClips (#1363)
* add _backend argument to __init__() of class VideoClips * minor fix * minor fix * Make backend private in VideoClips * Fix lint * Fix conflict due to cherry-pick for 0.4.2
1 parent 914132c commit f60df2d

File tree

5 files changed

+203
-41
lines changed

5 files changed

+203
-41
lines changed

test/test_datasets_video_utils.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from torchvision import io
77
from torchvision.datasets.video_utils import VideoClips, unfold
8+
from torchvision import get_video_backend
89

910
from common_utils import get_tmp_dir
1011

@@ -59,22 +60,23 @@ def test_unfold(self):
5960

6061
@unittest.skipIf(not io.video._av_available(), "this test requires av")
6162
def test_video_clips(self):
63+
_backend = get_video_backend()
6264
with get_list_of_videos(num_videos=3) as video_list:
63-
video_clips = VideoClips(video_list, 5, 5)
65+
video_clips = VideoClips(video_list, 5, 5, _backend=_backend)
6466
self.assertEqual(video_clips.num_clips(), 1 + 2 + 3)
6567
for i, (v_idx, c_idx) in enumerate([(0, 0), (1, 0), (1, 1), (2, 0), (2, 1), (2, 2)]):
6668
video_idx, clip_idx = video_clips.get_clip_location(i)
6769
self.assertEqual(video_idx, v_idx)
6870
self.assertEqual(clip_idx, c_idx)
6971

70-
video_clips = VideoClips(video_list, 6, 6)
72+
video_clips = VideoClips(video_list, 6, 6, _backend=_backend)
7173
self.assertEqual(video_clips.num_clips(), 0 + 1 + 2)
7274
for i, (v_idx, c_idx) in enumerate([(1, 0), (2, 0), (2, 1)]):
7375
video_idx, clip_idx = video_clips.get_clip_location(i)
7476
self.assertEqual(video_idx, v_idx)
7577
self.assertEqual(clip_idx, c_idx)
7678

77-
video_clips = VideoClips(video_list, 6, 1)
79+
video_clips = VideoClips(video_list, 6, 1, _backend=_backend)
7880
self.assertEqual(video_clips.num_clips(), 0 + (10 - 6 + 1) + (15 - 6 + 1))
7981
for i, v_idx, c_idx in [(0, 1, 0), (4, 1, 4), (5, 2, 0), (6, 2, 1)]:
8082
video_idx, clip_idx = video_clips.get_clip_location(i)
@@ -83,8 +85,9 @@ def test_video_clips(self):
8385

8486
@unittest.skip("Moved to reference scripts for now")
8587
def test_video_sampler(self):
88+
_backend = get_video_backend()
8689
with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list:
87-
video_clips = VideoClips(video_list, 5, 5)
90+
video_clips = VideoClips(video_list, 5, 5, _backend=_backend)
8891
sampler = RandomClipSampler(video_clips, 3) # noqa: F821
8992
self.assertEqual(len(sampler), 3 * 3)
9093
indices = torch.tensor(list(iter(sampler)))
@@ -95,8 +98,9 @@ def test_video_sampler(self):
9598

9699
@unittest.skip("Moved to reference scripts for now")
97100
def test_video_sampler_unequal(self):
101+
_backend = get_video_backend()
98102
with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list:
99-
video_clips = VideoClips(video_list, 5, 5)
103+
video_clips = VideoClips(video_list, 5, 5, _backend=_backend)
100104
sampler = RandomClipSampler(video_clips, 3) # noqa: F821
101105
self.assertEqual(len(sampler), 2 + 3 + 3)
102106
indices = list(iter(sampler))
@@ -113,10 +117,11 @@ def test_video_sampler_unequal(self):
113117

114118
@unittest.skipIf(not io.video._av_available(), "this test requires av")
115119
def test_video_clips_custom_fps(self):
120+
_backend = get_video_backend()
116121
with get_list_of_videos(num_videos=3, sizes=[12, 12, 12], fps=[3, 4, 6]) as video_list:
117122
num_frames = 4
118123
for fps in [1, 3, 4, 10]:
119-
video_clips = VideoClips(video_list, num_frames, num_frames, fps)
124+
video_clips = VideoClips(video_list, num_frames, num_frames, fps, _backend=_backend)
120125
for i in range(video_clips.num_clips()):
121126
video, audio, info, video_idx = video_clips.get_clip(i)
122127
self.assertEqual(video.shape[0], num_frames)

test/test_io.py

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
import torchvision.datasets.utils as utils
66
import torchvision.io as io
7+
from torchvision import get_video_backend
78
import unittest
89
import sys
910
import warnings
@@ -22,6 +23,20 @@
2223
except ImportError:
2324
av = None
2425

26+
_video_backend = get_video_backend()
27+
28+
29+
def _read_video(filename, start_pts=0, end_pts=None):
30+
if _video_backend == "pyav":
31+
return io.read_video(filename, start_pts, end_pts)
32+
else:
33+
if end_pts is None:
34+
end_pts = -1
35+
return io._read_video_from_file(
36+
filename,
37+
video_pts_range=(start_pts, end_pts),
38+
)
39+
2540

2641
def _create_video_frames(num_frames, height, width):
2742
y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width))
@@ -44,7 +59,12 @@ def temp_video(num_frames, height, width, fps, lossless=False, video_codec=None,
4459
options = {'crf': '0'}
4560

4661
if video_codec is None:
47-
video_codec = 'libx264'
62+
if _video_backend == "pyav":
63+
video_codec = 'libx264'
64+
else:
65+
# when video_codec is not set, we assume it is libx264rgb which accepts
66+
# RGB pixel formats as input instead of YUV
67+
video_codec = 'libx264rgb'
4868
if options is None:
4969
options = {}
5070

@@ -62,15 +82,16 @@ class Tester(unittest.TestCase):
6282

6383
def test_write_read_video(self):
6484
with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
65-
lv, _, info = io.read_video(f_name)
66-
85+
lv, _, info = _read_video(f_name)
6786
self.assertTrue(data.equal(lv))
6887
self.assertEqual(info["video_fps"], 5)
6988

7089
def test_read_timestamps(self):
7190
with temp_video(10, 300, 300, 5) as (f_name, data):
72-
pts, _ = io.read_video_timestamps(f_name)
73-
91+
if _video_backend == "pyav":
92+
pts, _ = io.read_video_timestamps(f_name)
93+
else:
94+
pts, _, _ = io._read_video_timestamps_from_file(f_name)
7495
# note: not all formats/codecs provide accurate information for computing the
7596
# timestamps. For the format that we use here, this information is available,
7697
# so we use it as a baseline
@@ -84,26 +105,35 @@ def test_read_timestamps(self):
84105

85106
def test_read_partial_video(self):
86107
with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
87-
pts, _ = io.read_video_timestamps(f_name)
108+
if _video_backend == "pyav":
109+
pts, _ = io.read_video_timestamps(f_name)
110+
else:
111+
pts, _, _ = io._read_video_timestamps_from_file(f_name)
88112
for start in range(5):
89113
for l in range(1, 4):
90-
lv, _, _ = io.read_video(f_name, pts[start], pts[start + l - 1])
114+
lv, _, _ = _read_video(f_name, pts[start], pts[start + l - 1])
91115
s_data = data[start:(start + l)]
92116
self.assertEqual(len(lv), l)
93117
self.assertTrue(s_data.equal(lv))
94118

95-
lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7])
96-
self.assertEqual(len(lv), 4)
97-
self.assertTrue(data[4:8].equal(lv))
119+
if _video_backend == "pyav":
120+
# for "video_reader" backend, we don't decode the closest early frame
121+
# when the given start pts is not matching any frame pts
122+
lv, _, _ = _read_video(f_name, pts[4] + 1, pts[7])
123+
self.assertEqual(len(lv), 4)
124+
self.assertTrue(data[4:8].equal(lv))
98125

99126
def test_read_partial_video_bframes(self):
100127
# do not use lossless encoding, to test the presence of B-frames
101128
options = {'bframes': '16', 'keyint': '10', 'min-keyint': '4'}
102129
with temp_video(100, 300, 300, 5, options=options) as (f_name, data):
103-
pts, _ = io.read_video_timestamps(f_name)
130+
if _video_backend == "pyav":
131+
pts, _ = io.read_video_timestamps(f_name)
132+
else:
133+
pts, _, _ = io._read_video_timestamps_from_file(f_name)
104134
for start in range(0, 80, 20):
105135
for l in range(1, 4):
106-
lv, _, _ = io.read_video(f_name, pts[start], pts[start + l - 1])
136+
lv, _, _ = _read_video(f_name, pts[start], pts[start + l - 1])
107137
s_data = data[start:(start + l)]
108138
self.assertEqual(len(lv), l)
109139
self.assertTrue((s_data.float() - lv.float()).abs().max() < self.TOLERANCE)
@@ -119,7 +149,12 @@ def test_read_packed_b_frames_divx_file(self):
119149
url = "https://download.pytorch.org/vision_tests/io/" + name
120150
try:
121151
utils.download_url(url, temp_dir)
122-
pts, fps = io.read_video_timestamps(f_name)
152+
if _video_backend == "pyav":
153+
pts, fps = io.read_video_timestamps(f_name)
154+
else:
155+
pts, _, info = io._read_video_timestamps_from_file(f_name)
156+
fps = info["video_fps"]
157+
123158
self.assertEqual(pts, sorted(pts))
124159
self.assertEqual(fps, 30)
125160
except URLError:
@@ -129,8 +164,10 @@ def test_read_packed_b_frames_divx_file(self):
129164

130165
def test_read_timestamps_from_packet(self):
131166
with temp_video(10, 300, 300, 5, video_codec='mpeg4') as (f_name, data):
132-
pts, _ = io.read_video_timestamps(f_name)
133-
167+
if _video_backend == "pyav":
168+
pts, _ = io.read_video_timestamps(f_name)
169+
else:
170+
pts, _, _ = io._read_video_timestamps_from_file(f_name)
134171
# note: not all formats/codecs provide accurate information for computing the
135172
# timestamps. For the format that we use here, this information is available,
136173
# so we use it as a baseline

torchvision/__init__.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
_image_backend = 'PIL'
1414

15+
_video_backend = "pyav"
16+
1517

1618
def set_image_backend(backend):
1719
"""
@@ -34,3 +36,27 @@ def get_image_backend():
3436
Gets the name of the package used to load images
3537
"""
3638
return _image_backend
39+
40+
41+
def set_video_backend(backend):
42+
"""
43+
Specifies the package used to decode videos.
44+
45+
Args:
46+
backend (string): Name of the video backend. one of {'pyav', 'video_reader'}.
47+
The :mod:`pyav` package uses the 3rd party PyAv library. It is a Pythonic
48+
binding for the FFmpeg libraries.
49+
The :mod:`video_reader` package includes a native c++ implementation on
50+
top of FFMPEG libraries, and a python API of TorchScript custom operator.
51+
It is generally decoding faster than pyav, but perhaps is less robust.
52+
"""
53+
global _video_backend
54+
if backend not in ["pyav", "video_reader"]:
55+
raise ValueError(
56+
"Invalid video backend '%s'. Options are 'pyav' and 'video_reader'" % backend
57+
)
58+
_video_backend = backend
59+
60+
61+
def get_video_backend():
62+
return _video_backend

0 commit comments

Comments
 (0)