1
1
import math
2
2
import warnings
3
3
from fractions import Fraction
4
- from typing import List , Tuple
4
+ from typing import List , Tuple , Dict , Optional , Union
5
5
6
6
import torch
7
7
@@ -26,10 +26,9 @@ class Timebase:
26
26
27
27
def __init__ (
28
28
self ,
29
- numerator , # type: int
30
- denominator , # type: int
31
- ):
32
- # type: (...) -> None
29
+ numerator : int ,
30
+ denominator : int ,
31
+ ) -> None :
33
32
self .numerator = numerator
34
33
self .denominator = denominator
35
34
@@ -56,7 +55,7 @@ class VideoMetaData:
56
55
"audio_sample_rate" ,
57
56
]
58
57
59
- def __init__ (self ):
58
+ def __init__ (self ) -> None :
60
59
self .has_video = False
61
60
self .video_timebase = Timebase (0 , 1 )
62
61
self .video_duration = 0.0
@@ -67,8 +66,7 @@ def __init__(self):
67
66
self .audio_sample_rate = 0.0
68
67
69
68
70
- def _validate_pts (pts_range ):
71
- # type: (List[int]) -> None
69
+ def _validate_pts (pts_range : Tuple [int , int ]) -> None :
72
70
73
71
if pts_range [1 ] > 0 :
74
72
assert (
@@ -80,8 +78,14 @@ def _validate_pts(pts_range):
80
78
)
81
79
82
80
83
- def _fill_info (vtimebase , vfps , vduration , atimebase , asample_rate , aduration ):
84
- # type: (torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor) -> VideoMetaData
81
+ def _fill_info (
82
+ vtimebase : torch .Tensor ,
83
+ vfps : torch .Tensor ,
84
+ vduration : torch .Tensor ,
85
+ atimebase : torch .Tensor ,
86
+ asample_rate : torch .Tensor ,
87
+ aduration : torch .Tensor ,
88
+ ) -> VideoMetaData :
85
89
"""
86
90
Build update VideoMetaData struct with info about the video
87
91
"""
@@ -106,8 +110,9 @@ def _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration):
106
110
return meta
107
111
108
112
109
- def _align_audio_frames (aframes , aframe_pts , audio_pts_range ):
110
- # type: (torch.Tensor, torch.Tensor, List[int]) -> torch.Tensor
113
+ def _align_audio_frames (
114
+ aframes : torch .Tensor , aframe_pts : torch .Tensor , audio_pts_range : Tuple [int , int ]
115
+ ) -> torch .Tensor :
111
116
start , end = aframe_pts [0 ], aframe_pts [- 1 ]
112
117
num_samples = aframes .size (0 )
113
118
step_per_aframe = float (end - start + 1 ) / float (num_samples )
@@ -121,21 +126,21 @@ def _align_audio_frames(aframes, aframe_pts, audio_pts_range):
121
126
122
127
123
128
def _read_video_from_file (
124
- filename ,
125
- seek_frame_margin = 0.25 ,
126
- read_video_stream = True ,
127
- video_width = 0 ,
128
- video_height = 0 ,
129
- video_min_dimension = 0 ,
130
- video_max_dimension = 0 ,
131
- video_pts_range = (0 , - 1 ),
132
- video_timebase = default_timebase ,
133
- read_audio_stream = True ,
134
- audio_samples = 0 ,
135
- audio_channels = 0 ,
136
- audio_pts_range = (0 , - 1 ),
137
- audio_timebase = default_timebase ,
138
- ):
129
+ filename : str ,
130
+ seek_frame_margin : float = 0.25 ,
131
+ read_video_stream : bool = True ,
132
+ video_width : int = 0 ,
133
+ video_height : int = 0 ,
134
+ video_min_dimension : int = 0 ,
135
+ video_max_dimension : int = 0 ,
136
+ video_pts_range : Tuple [ int , int ] = (0 , - 1 ),
137
+ video_timebase : Fraction = default_timebase ,
138
+ read_audio_stream : bool = True ,
139
+ audio_samples : int = 0 ,
140
+ audio_channels : int = 0 ,
141
+ audio_pts_range : Tuple [ int , int ] = (0 , - 1 ),
142
+ audio_timebase : Fraction = default_timebase ,
143
+ ) -> Tuple [ torch . Tensor , torch . Tensor , VideoMetaData ] :
139
144
"""
140
145
Reads a video from a file, returning both the video frames as well as
141
146
the audio frames
@@ -217,7 +222,7 @@ def _read_video_from_file(
217
222
return vframes , aframes , info
218
223
219
224
220
- def _read_video_timestamps_from_file (filename ) :
225
+ def _read_video_timestamps_from_file (filename : str ) -> Tuple [ List [ int ], List [ int ], VideoMetaData ] :
221
226
"""
222
227
Decode all video- and audio frames in the video. Only pts
223
228
(presentation timestamp) is returned. The actual frame pixel data is not
@@ -252,7 +257,7 @@ def _read_video_timestamps_from_file(filename):
252
257
return vframe_pts , aframe_pts , info
253
258
254
259
255
- def _probe_video_from_file (filename ) :
260
+ def _probe_video_from_file (filename : str ) -> VideoMetaData :
256
261
"""
257
262
Probe a video file and return VideoMetaData with info about the video
258
263
"""
@@ -263,24 +268,23 @@ def _probe_video_from_file(filename):
263
268
264
269
265
270
def _read_video_from_memory (
266
- video_data , # type: torch.Tensor
267
- seek_frame_margin = 0.25 , # type: float
268
- read_video_stream = 1 , # type: int
269
- video_width = 0 , # type: int
270
- video_height = 0 , # type: int
271
- video_min_dimension = 0 , # type: int
272
- video_max_dimension = 0 , # type: int
273
- video_pts_range = (0 , - 1 ), # type: List[int]
274
- video_timebase_numerator = 0 , # type: int
275
- video_timebase_denominator = 1 , # type: int
276
- read_audio_stream = 1 , # type: int
277
- audio_samples = 0 , # type: int
278
- audio_channels = 0 , # type: int
279
- audio_pts_range = (0 , - 1 ), # type: List[int]
280
- audio_timebase_numerator = 0 , # type: int
281
- audio_timebase_denominator = 1 , # type: int
282
- ):
283
- # type: (...) -> Tuple[torch.Tensor, torch.Tensor]
271
+ video_data : torch .Tensor ,
272
+ seek_frame_margin : float = 0.25 ,
273
+ read_video_stream : int = 1 ,
274
+ video_width : int = 0 ,
275
+ video_height : int = 0 ,
276
+ video_min_dimension : int = 0 ,
277
+ video_max_dimension : int = 0 ,
278
+ video_pts_range : Tuple [int , int ] = (0 , - 1 ),
279
+ video_timebase_numerator : int = 0 ,
280
+ video_timebase_denominator : int = 1 ,
281
+ read_audio_stream : int = 1 ,
282
+ audio_samples : int = 0 ,
283
+ audio_channels : int = 0 ,
284
+ audio_pts_range : Tuple [int , int ] = (0 , - 1 ),
285
+ audio_timebase_numerator : int = 0 ,
286
+ audio_timebase_denominator : int = 1 ,
287
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
284
288
"""
285
289
Reads a video from memory, returning both the video frames as well as
286
290
the audio frames
@@ -370,7 +374,9 @@ def _read_video_from_memory(
370
374
return vframes , aframes
371
375
372
376
373
- def _read_video_timestamps_from_memory (video_data ):
377
+ def _read_video_timestamps_from_memory (
378
+ video_data : torch .Tensor ,
379
+ ) -> Tuple [List [int ], List [int ], VideoMetaData ]:
374
380
"""
375
381
Decode all frames in the video. Only pts (presentation timestamp) is returned.
376
382
The actual frame pixel data is not copied. Thus, read_video_timestamps(...)
@@ -407,8 +413,9 @@ def _read_video_timestamps_from_memory(video_data):
407
413
return vframe_pts , aframe_pts , info
408
414
409
415
410
- def _probe_video_from_memory (video_data ):
411
- # type: (torch.Tensor) -> VideoMetaData
416
+ def _probe_video_from_memory (
417
+ video_data : torch .Tensor ,
418
+ ) -> VideoMetaData :
412
419
"""
413
420
Probe a video in memory and return VideoMetaData with info about the video
414
421
This function is torchscriptable
@@ -421,15 +428,22 @@ def _probe_video_from_memory(video_data):
421
428
return info
422
429
423
430
424
- def _convert_to_sec (start_pts , end_pts , pts_unit , time_base ):
431
+ def _convert_to_sec (
432
+ start_pts : Union [float , Fraction ], end_pts : Union [float , Fraction ], pts_unit : str , time_base : Fraction
433
+ ) -> Tuple [Union [float , Fraction ], Union [float , Fraction ], str ]:
425
434
if pts_unit == "pts" :
426
435
start_pts = float (start_pts * time_base )
427
436
end_pts = float (end_pts * time_base )
428
437
pts_unit = "sec"
429
438
return start_pts , end_pts , pts_unit
430
439
431
440
432
- def _read_video (filename , start_pts = 0 , end_pts = None , pts_unit = "pts" ):
441
+ def _read_video (
442
+ filename : str ,
443
+ start_pts : Union [float , Fraction ] = 0 ,
444
+ end_pts : Optional [Union [float , Fraction ]] = None ,
445
+ pts_unit : str = "pts" ,
446
+ ) -> Tuple [torch .Tensor , torch .Tensor , Dict [str , float ]]:
433
447
if end_pts is None :
434
448
end_pts = float ("inf" )
435
449
@@ -495,13 +509,16 @@ def get_pts(time_base):
495
509
return vframes , aframes , _info
496
510
497
511
498
- def _read_video_timestamps (filename , pts_unit = "pts" ):
512
+ def _read_video_timestamps (
513
+ filename : str , pts_unit : str = "pts"
514
+ ) -> Tuple [Union [List [int ], List [Fraction ]], Optional [float ]]:
499
515
if pts_unit == "pts" :
500
516
warnings .warn (
501
517
"The pts_unit 'pts' gives wrong results and will be removed in a "
502
518
+ "follow-up version. Please use pts_unit 'sec'."
503
519
)
504
520
521
+ pts : Union [List [int ], List [Fraction ]]
505
522
pts , _ , info = _read_video_timestamps_from_file (filename )
506
523
507
524
if pts_unit == "sec" :
0 commit comments