4
4
import torch
5
5
import torchvision .datasets .utils as utils
6
6
import torchvision .io as io
7
+ from torchvision import get_video_backend
7
8
import unittest
8
9
import sys
9
10
import warnings
22
23
except ImportError :
23
24
av = None
24
25
26
+ _video_backend = get_video_backend ()
27
+
28
+
29
+ def _read_video (filename , start_pts = 0 , end_pts = None ):
30
+ if _video_backend == "pyav" :
31
+ return io .read_video (filename , start_pts , end_pts )
32
+ else :
33
+ if end_pts is None :
34
+ end_pts = - 1
35
+ return io ._read_video_from_file (
36
+ filename ,
37
+ video_pts_range = (start_pts , end_pts ),
38
+ )
39
+
25
40
26
41
def _create_video_frames (num_frames , height , width ):
27
42
y , x = torch .meshgrid (torch .linspace (- 2 , 2 , height ), torch .linspace (- 2 , 2 , width ))
@@ -44,7 +59,12 @@ def temp_video(num_frames, height, width, fps, lossless=False, video_codec=None,
44
59
options = {'crf' : '0' }
45
60
46
61
if video_codec is None :
47
- video_codec = 'libx264'
62
+ if _video_backend == "pyav" :
63
+ video_codec = 'libx264'
64
+ else :
65
+ # when video_codec is not set, we assume it is libx264rgb which accepts
66
+ # RGB pixel formats as input instead of YUV
67
+ video_codec = 'libx264rgb'
48
68
if options is None :
49
69
options = {}
50
70
@@ -62,15 +82,16 @@ class Tester(unittest.TestCase):
62
82
63
83
def test_write_read_video (self ):
64
84
with temp_video (10 , 300 , 300 , 5 , lossless = True ) as (f_name , data ):
65
- lv , _ , info = io .read_video (f_name )
66
-
85
+ lv , _ , info = _read_video (f_name )
67
86
self .assertTrue (data .equal (lv ))
68
87
self .assertEqual (info ["video_fps" ], 5 )
69
88
70
89
def test_read_timestamps (self ):
71
90
with temp_video (10 , 300 , 300 , 5 ) as (f_name , data ):
72
- pts , _ = io .read_video_timestamps (f_name )
73
-
91
+ if _video_backend == "pyav" :
92
+ pts , _ = io .read_video_timestamps (f_name )
93
+ else :
94
+ pts , _ , _ = io ._read_video_timestamps_from_file (f_name )
74
95
# note: not all formats/codecs provide accurate information for computing the
75
96
# timestamps. For the format that we use here, this information is available,
76
97
# so we use it as a baseline
@@ -84,26 +105,35 @@ def test_read_timestamps(self):
84
105
85
106
def test_read_partial_video (self ):
86
107
with temp_video (10 , 300 , 300 , 5 , lossless = True ) as (f_name , data ):
87
- pts , _ = io .read_video_timestamps (f_name )
108
+ if _video_backend == "pyav" :
109
+ pts , _ = io .read_video_timestamps (f_name )
110
+ else :
111
+ pts , _ , _ = io ._read_video_timestamps_from_file (f_name )
88
112
for start in range (5 ):
89
113
for l in range (1 , 4 ):
90
- lv , _ , _ = io . read_video (f_name , pts [start ], pts [start + l - 1 ])
114
+ lv , _ , _ = _read_video (f_name , pts [start ], pts [start + l - 1 ])
91
115
s_data = data [start :(start + l )]
92
116
self .assertEqual (len (lv ), l )
93
117
self .assertTrue (s_data .equal (lv ))
94
118
95
- lv , _ , _ = io .read_video (f_name , pts [4 ] + 1 , pts [7 ])
96
- self .assertEqual (len (lv ), 4 )
97
- self .assertTrue (data [4 :8 ].equal (lv ))
119
+ if _video_backend == "pyav" :
120
+ # for "video_reader" backend, we don't decode the closest early frame
121
+ # when the given start pts is not matching any frame pts
122
+ lv , _ , _ = _read_video (f_name , pts [4 ] + 1 , pts [7 ])
123
+ self .assertEqual (len (lv ), 4 )
124
+ self .assertTrue (data [4 :8 ].equal (lv ))
98
125
99
126
def test_read_partial_video_bframes (self ):
100
127
# do not use lossless encoding, to test the presence of B-frames
101
128
options = {'bframes' : '16' , 'keyint' : '10' , 'min-keyint' : '4' }
102
129
with temp_video (100 , 300 , 300 , 5 , options = options ) as (f_name , data ):
103
- pts , _ = io .read_video_timestamps (f_name )
130
+ if _video_backend == "pyav" :
131
+ pts , _ = io .read_video_timestamps (f_name )
132
+ else :
133
+ pts , _ , _ = io ._read_video_timestamps_from_file (f_name )
104
134
for start in range (0 , 80 , 20 ):
105
135
for l in range (1 , 4 ):
106
- lv , _ , _ = io . read_video (f_name , pts [start ], pts [start + l - 1 ])
136
+ lv , _ , _ = _read_video (f_name , pts [start ], pts [start + l - 1 ])
107
137
s_data = data [start :(start + l )]
108
138
self .assertEqual (len (lv ), l )
109
139
self .assertTrue ((s_data .float () - lv .float ()).abs ().max () < self .TOLERANCE )
@@ -119,7 +149,12 @@ def test_read_packed_b_frames_divx_file(self):
119
149
url = "https://download.pytorch.org/vision_tests/io/" + name
120
150
try :
121
151
utils .download_url (url , temp_dir )
122
- pts , fps = io .read_video_timestamps (f_name )
152
+ if _video_backend == "pyav" :
153
+ pts , fps = io .read_video_timestamps (f_name )
154
+ else :
155
+ pts , _ , info = io ._read_video_timestamps_from_file (f_name )
156
+ fps = info ["video_fps" ]
157
+
123
158
self .assertEqual (pts , sorted (pts ))
124
159
self .assertEqual (fps , 30 )
125
160
except URLError :
@@ -129,8 +164,10 @@ def test_read_packed_b_frames_divx_file(self):
129
164
130
165
def test_read_timestamps_from_packet (self ):
131
166
with temp_video (10 , 300 , 300 , 5 , video_codec = 'mpeg4' ) as (f_name , data ):
132
- pts , _ = io .read_video_timestamps (f_name )
133
-
167
+ if _video_backend == "pyav" :
168
+ pts , _ = io .read_video_timestamps (f_name )
169
+ else :
170
+ pts , _ , _ = io ._read_video_timestamps_from_file (f_name )
134
171
# note: not all formats/codecs provide accurate information for computing the
135
172
# timestamps. For the format that we use here, this information is available,
136
173
# so we use it as a baseline
0 commit comments