2
2
import math
3
3
import warnings
4
4
from fractions import Fraction
5
- from typing import List
5
+ from typing import Any , Dict , List , Optional , Callable , Union , Tuple , TypeVar , cast
6
6
7
7
import torch
8
8
from torchvision .io import (
14
14
15
15
from .utils import tqdm
16
16
17
+ T = TypeVar ("T" )
17
18
18
- def pts_convert (pts , timebase_from , timebase_to , round_func = math .floor ):
19
+
20
+ def pts_convert (pts : int , timebase_from : Fraction , timebase_to : Fraction , round_func : Callable = math .floor ) -> int :
19
21
"""convert pts between different time bases
20
22
Args:
21
23
pts: presentation timestamp, float
@@ -27,7 +29,7 @@ def pts_convert(pts, timebase_from, timebase_to, round_func=math.floor):
27
29
return round_func (new_pts )
28
30
29
31
30
- def unfold (tensor , size , step , dilation = 1 ) :
32
+ def unfold (tensor : torch . Tensor , size : int , step : int , dilation : int = 1 ) -> torch . Tensor :
31
33
"""
32
34
similar to tensor.unfold, but with the dilation
33
35
and specialized for 1d tensors
@@ -55,17 +57,17 @@ class _VideoTimestampsDataset:
55
57
pickled when forking.
56
58
"""
57
59
58
- def __init__ (self , video_paths : List [str ]):
60
+ def __init__ (self , video_paths : List [str ]) -> None :
59
61
self .video_paths = video_paths
60
62
61
- def __len__ (self ):
63
+ def __len__ (self ) -> int :
62
64
return len (self .video_paths )
63
65
64
- def __getitem__ (self , idx ) :
66
+ def __getitem__ (self , idx : int ) -> Tuple [ List [ int ], Optional [ float ]] :
65
67
return read_video_timestamps (self .video_paths [idx ])
66
68
67
69
68
- def _collate_fn (x ) :
70
+ def _collate_fn (x : T ) -> T :
69
71
"""
70
72
Dummy collate function to be used with _VideoTimestampsDataset
71
73
"""
@@ -100,19 +102,19 @@ class VideoClips:
100
102
101
103
def __init__ (
102
104
self ,
103
- video_paths ,
104
- clip_length_in_frames = 16 ,
105
- frames_between_clips = 1 ,
106
- frame_rate = None ,
107
- _precomputed_metadata = None ,
108
- num_workers = 0 ,
109
- _video_width = 0 ,
110
- _video_height = 0 ,
111
- _video_min_dimension = 0 ,
112
- _video_max_dimension = 0 ,
113
- _audio_samples = 0 ,
114
- _audio_channels = 0 ,
115
- ):
105
+ video_paths : List [ str ] ,
106
+ clip_length_in_frames : int = 16 ,
107
+ frames_between_clips : int = 1 ,
108
+ frame_rate : Optional [ int ] = None ,
109
+ _precomputed_metadata : Optional [ Dict [ str , Any ]] = None ,
110
+ num_workers : int = 0 ,
111
+ _video_width : int = 0 ,
112
+ _video_height : int = 0 ,
113
+ _video_min_dimension : int = 0 ,
114
+ _video_max_dimension : int = 0 ,
115
+ _audio_samples : int = 0 ,
116
+ _audio_channels : int = 0 ,
117
+ ) -> None :
116
118
117
119
self .video_paths = video_paths
118
120
self .num_workers = num_workers
@@ -131,16 +133,16 @@ def __init__(
131
133
self ._init_from_metadata (_precomputed_metadata )
132
134
self .compute_clips (clip_length_in_frames , frames_between_clips , frame_rate )
133
135
134
- def _compute_frame_pts (self ):
136
+ def _compute_frame_pts (self ) -> None :
135
137
self .video_pts = []
136
138
self .video_fps = []
137
139
138
140
# strategy: use a DataLoader to parallelize read_video_timestamps
139
141
# so need to create a dummy dataset first
140
142
import torch .utils .data
141
143
142
- dl = torch .utils .data .DataLoader (
143
- _VideoTimestampsDataset (self .video_paths ),
144
+ dl : torch . utils . data . DataLoader = torch .utils .data .DataLoader (
145
+ _VideoTimestampsDataset (self .video_paths ), # type: ignore[arg-type]
144
146
batch_size = 16 ,
145
147
num_workers = self .num_workers ,
146
148
collate_fn = _collate_fn ,
@@ -157,23 +159,23 @@ def _compute_frame_pts(self):
157
159
self .video_pts .extend (clips )
158
160
self .video_fps .extend (fps )
159
161
160
- def _init_from_metadata (self , metadata ) :
162
+ def _init_from_metadata (self , metadata : Dict [ str , Any ]) -> None :
161
163
self .video_paths = metadata ["video_paths" ]
162
164
assert len (self .video_paths ) == len (metadata ["video_pts" ])
163
165
self .video_pts = metadata ["video_pts" ]
164
166
assert len (self .video_paths ) == len (metadata ["video_fps" ])
165
167
self .video_fps = metadata ["video_fps" ]
166
168
167
169
@property
168
- def metadata (self ):
170
+ def metadata (self ) -> Dict [ str , Any ] :
169
171
_metadata = {
170
172
"video_paths" : self .video_paths ,
171
173
"video_pts" : self .video_pts ,
172
174
"video_fps" : self .video_fps ,
173
175
}
174
176
return _metadata
175
177
176
- def subset (self , indices ) :
178
+ def subset (self , indices : List [ int ]) -> "VideoClips" :
177
179
video_paths = [self .video_paths [i ] for i in indices ]
178
180
video_pts = [self .video_pts [i ] for i in indices ]
179
181
video_fps = [self .video_fps [i ] for i in indices ]
@@ -198,29 +200,32 @@ def subset(self, indices):
198
200
)
199
201
200
202
@staticmethod
201
- def compute_clips_for_video (video_pts , num_frames , step , fps , frame_rate ):
203
+ def compute_clips_for_video (
204
+ video_pts : torch .Tensor , num_frames : int , step : int , fps : int , frame_rate : Optional [int ] = None
205
+ ) -> Tuple [torch .Tensor , Union [List [slice ], torch .Tensor ]]:
202
206
if fps is None :
203
207
# if for some reason the video doesn't have fps (because doesn't have a video stream)
204
208
# set the fps to 1. The value doesn't matter, because video_pts is empty anyway
205
209
fps = 1
206
210
if frame_rate is None :
207
211
frame_rate = fps
208
212
total_frames = len (video_pts ) * (float (frame_rate ) / fps )
209
- idxs = VideoClips ._resample_video_idx (int (math .floor (total_frames )), fps , frame_rate )
210
- video_pts = video_pts [idxs ]
213
+ _idxs = VideoClips ._resample_video_idx (int (math .floor (total_frames )), fps , frame_rate )
214
+ video_pts = video_pts [_idxs ]
211
215
clips = unfold (video_pts , num_frames , step )
212
216
if not clips .numel ():
213
217
warnings .warn (
214
218
"There aren't enough frames in the current video to get a clip for the given clip length and "
215
219
"frames between clips. The video (and potentially others) will be skipped."
216
220
)
217
- if isinstance (idxs , slice ):
218
- idxs = [idxs ] * len (clips )
221
+ idxs : Union [List [slice ], torch .Tensor ]
222
+ if isinstance (_idxs , slice ):
223
+ idxs = [_idxs ] * len (clips )
219
224
else :
220
- idxs = unfold (idxs , num_frames , step )
225
+ idxs = unfold (_idxs , num_frames , step )
221
226
return clips , idxs
222
227
223
- def compute_clips (self , num_frames , step , frame_rate = None ):
228
+ def compute_clips (self , num_frames : int , step : int , frame_rate : Optional [ int ] = None ) -> None :
224
229
"""
225
230
Compute all consecutive sequences of clips from video_pts.
226
231
Always returns clips of size `num_frames`, meaning that the
@@ -243,19 +248,19 @@ def compute_clips(self, num_frames, step, frame_rate=None):
243
248
clip_lengths = torch .as_tensor ([len (v ) for v in self .clips ])
244
249
self .cumulative_sizes = clip_lengths .cumsum (0 ).tolist ()
245
250
246
- def __len__ (self ):
251
+ def __len__ (self ) -> int :
247
252
return self .num_clips ()
248
253
249
- def num_videos (self ):
254
+ def num_videos (self ) -> int :
250
255
return len (self .video_paths )
251
256
252
- def num_clips (self ):
257
+ def num_clips (self ) -> int :
253
258
"""
254
259
Number of subclips that are available in the video list.
255
260
"""
256
261
return self .cumulative_sizes [- 1 ]
257
262
258
- def get_clip_location (self , idx ) :
263
+ def get_clip_location (self , idx : int ) -> Tuple [ int , int ] :
259
264
"""
260
265
Converts a flattened representation of the indices into a video_idx, clip_idx
261
266
representation.
@@ -268,7 +273,7 @@ def get_clip_location(self, idx):
268
273
return video_idx , clip_idx
269
274
270
275
@staticmethod
271
- def _resample_video_idx (num_frames , original_fps , new_fps ) :
276
+ def _resample_video_idx (num_frames : int , original_fps : int , new_fps : int ) -> Union [ slice , torch . Tensor ] :
272
277
step = float (original_fps ) / new_fps
273
278
if step .is_integer ():
274
279
# optimization: if step is integer, don't need to perform
@@ -279,7 +284,7 @@ def _resample_video_idx(num_frames, original_fps, new_fps):
279
284
idxs = idxs .floor ().to (torch .int64 )
280
285
return idxs
281
286
282
- def get_clip (self , idx ) :
287
+ def get_clip (self , idx : int ) -> Tuple [ torch . Tensor , torch . Tensor , Dict [ str , Any ], int ] :
283
288
"""
284
289
Gets a subclip from a list of videos.
285
290
@@ -320,22 +325,22 @@ def get_clip(self, idx):
320
325
end_pts = clip_pts [- 1 ].item ()
321
326
video , audio , info = read_video (video_path , start_pts , end_pts )
322
327
else :
323
- info = _probe_video_from_file (video_path )
324
- video_fps = info .video_fps
328
+ _info = _probe_video_from_file (video_path )
329
+ video_fps = _info .video_fps
325
330
audio_fps = None
326
331
327
- video_start_pts = clip_pts [0 ].item ()
328
- video_end_pts = clip_pts [- 1 ].item ()
332
+ video_start_pts = cast ( int , clip_pts [0 ].item () )
333
+ video_end_pts = cast ( int , clip_pts [- 1 ].item () )
329
334
330
335
audio_start_pts , audio_end_pts = 0 , - 1
331
336
audio_timebase = Fraction (0 , 1 )
332
- video_timebase = Fraction (info .video_timebase .numerator , info .video_timebase .denominator )
333
- if info .has_audio :
334
- audio_timebase = Fraction (info .audio_timebase .numerator , info .audio_timebase .denominator )
337
+ video_timebase = Fraction (_info .video_timebase .numerator , _info .video_timebase .denominator )
338
+ if _info .has_audio :
339
+ audio_timebase = Fraction (_info .audio_timebase .numerator , _info .audio_timebase .denominator )
335
340
audio_start_pts = pts_convert (video_start_pts , video_timebase , audio_timebase , math .floor )
336
341
audio_end_pts = pts_convert (video_end_pts , video_timebase , audio_timebase , math .ceil )
337
- audio_fps = info .audio_sample_rate
338
- video , audio , info = _read_video_from_file (
342
+ audio_fps = _info .audio_sample_rate
343
+ video , audio , _ = _read_video_from_file (
339
344
video_path ,
340
345
video_width = self ._video_width ,
341
346
video_height = self ._video_height ,
@@ -362,7 +367,7 @@ def get_clip(self, idx):
362
367
assert len (video ) == self .num_frames , f"{ video .shape } x { self .num_frames } "
363
368
return video , audio , info , video_idx
364
369
365
- def __getstate__ (self ):
370
+ def __getstate__ (self ) -> Dict [ str , Any ] :
366
371
video_pts_sizes = [len (v ) for v in self .video_pts ]
367
372
# To be back-compatible, we convert data to dtype torch.long as needed
368
373
# because for empty list, in legacy implementation, torch.as_tensor will
@@ -371,10 +376,10 @@ def __getstate__(self):
371
376
video_pts = [x .to (torch .int64 ) for x in self .video_pts ]
372
377
# video_pts can be an empty list if no frames have been decoded
373
378
if video_pts :
374
- video_pts = torch .cat (video_pts )
379
+ video_pts = torch .cat (video_pts ) # type: ignore[assignment]
375
380
# avoid bug in https://github.com/pytorch/pytorch/issues/32351
376
381
# TODO: Revert it once the bug is fixed.
377
- video_pts = video_pts .numpy ()
382
+ video_pts = video_pts .numpy () # type: ignore[attr-defined]
378
383
379
384
# make a copy of the fields of self
380
385
d = self .__dict__ .copy ()
@@ -390,7 +395,7 @@ def __getstate__(self):
390
395
d ["_version" ] = 2
391
396
return d
392
397
393
- def __setstate__ (self , d ) :
398
+ def __setstate__ (self , d : Dict [ str , Any ]) -> None :
394
399
# for backwards-compatibility
395
400
if "_version" not in d :
396
401
self .__dict__ = d
0 commit comments