Skip to content

Commit b941ba3

Browse files
prabhat00155pmeier
authored andcommitted
[fbsync] Added typing annotations to io/_video_opts (#4173)
Summary: * style: Added typing annotations * style: Fixed lint * style: Fixed typing * chore: Updated mypy.ini * style: Fixed typing * chore: Updated mypy.ini * style: Fixed typing compatibility with jit * style: Fixed typing * style: Fixed typing * style: Fixed missing import * style: Fixed typing of __iter__ * style: Fixed typing * style: Fixed lint * style: Finished typing * style: ufmt the file * style: Removed unnecessary typing * style: Fixed typing of iterator Reviewed By: NicolasHug Differential Revision: D32694320 fbshipit-source-id: afdd8c0a70cfdba91a5c349a5961051b993185a1 Co-authored-by: Philip Meier <[email protected]> Co-authored-by: Prabhat Roy <[email protected]>
1 parent 375c206 commit b941ba3

File tree

3 files changed

+73
-56
lines changed

3 files changed

+73
-56
lines changed

mypy.ini

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ warn_unreachable = True
2222
; miscellaneous strictness flags
2323
allow_redefinition = True
2424

25-
[mypy-torchvision.io._video_opt.*]
25+
[mypy-torchvision.io.image.*]
2626

2727
ignore_errors = True
2828

29-
[mypy-torchvision.io.*]
29+
[mypy-torchvision.io.video.*]
3030

3131
ignore_errors = True
3232

torchvision/io/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def __next__(self) -> Dict[str, Any]:
132132
raise StopIteration
133133
return {"data": frame, "pts": pts}
134134

135-
def __iter__(self) -> Iterator["VideoReader"]:
135+
def __iter__(self) -> Iterator[Dict[str, Any]]:
136136
return self
137137

138138
def seek(self, time_s: float, keyframes_only: bool = False) -> "VideoReader":

torchvision/io/_video_opt.py

Lines changed: 70 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import math
22
import warnings
33
from fractions import Fraction
4-
from typing import List, Tuple
4+
from typing import List, Tuple, Dict, Optional, Union
55

66
import torch
77

@@ -26,10 +26,9 @@ class Timebase:
2626

2727
def __init__(
2828
self,
29-
numerator, # type: int
30-
denominator, # type: int
31-
):
32-
# type: (...) -> None
29+
numerator: int,
30+
denominator: int,
31+
) -> None:
3332
self.numerator = numerator
3433
self.denominator = denominator
3534

@@ -56,7 +55,7 @@ class VideoMetaData:
5655
"audio_sample_rate",
5756
]
5857

59-
def __init__(self):
58+
def __init__(self) -> None:
6059
self.has_video = False
6160
self.video_timebase = Timebase(0, 1)
6261
self.video_duration = 0.0
@@ -67,8 +66,7 @@ def __init__(self):
6766
self.audio_sample_rate = 0.0
6867

6968

70-
def _validate_pts(pts_range):
71-
# type: (List[int]) -> None
69+
def _validate_pts(pts_range: Tuple[int, int]) -> None:
7270

7371
if pts_range[1] > 0:
7472
assert (
@@ -80,8 +78,14 @@ def _validate_pts(pts_range):
8078
)
8179

8280

83-
def _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration):
84-
# type: (torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor) -> VideoMetaData
81+
def _fill_info(
82+
vtimebase: torch.Tensor,
83+
vfps: torch.Tensor,
84+
vduration: torch.Tensor,
85+
atimebase: torch.Tensor,
86+
asample_rate: torch.Tensor,
87+
aduration: torch.Tensor,
88+
) -> VideoMetaData:
8589
"""
8690
Build update VideoMetaData struct with info about the video
8791
"""
@@ -106,8 +110,9 @@ def _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration):
106110
return meta
107111

108112

109-
def _align_audio_frames(aframes, aframe_pts, audio_pts_range):
110-
# type: (torch.Tensor, torch.Tensor, List[int]) -> torch.Tensor
113+
def _align_audio_frames(
114+
aframes: torch.Tensor, aframe_pts: torch.Tensor, audio_pts_range: Tuple[int, int]
115+
) -> torch.Tensor:
111116
start, end = aframe_pts[0], aframe_pts[-1]
112117
num_samples = aframes.size(0)
113118
step_per_aframe = float(end - start + 1) / float(num_samples)
@@ -121,21 +126,21 @@ def _align_audio_frames(aframes, aframe_pts, audio_pts_range):
121126

122127

123128
def _read_video_from_file(
124-
filename,
125-
seek_frame_margin=0.25,
126-
read_video_stream=True,
127-
video_width=0,
128-
video_height=0,
129-
video_min_dimension=0,
130-
video_max_dimension=0,
131-
video_pts_range=(0, -1),
132-
video_timebase=default_timebase,
133-
read_audio_stream=True,
134-
audio_samples=0,
135-
audio_channels=0,
136-
audio_pts_range=(0, -1),
137-
audio_timebase=default_timebase,
138-
):
129+
filename: str,
130+
seek_frame_margin: float = 0.25,
131+
read_video_stream: bool = True,
132+
video_width: int = 0,
133+
video_height: int = 0,
134+
video_min_dimension: int = 0,
135+
video_max_dimension: int = 0,
136+
video_pts_range: Tuple[int, int] = (0, -1),
137+
video_timebase: Fraction = default_timebase,
138+
read_audio_stream: bool = True,
139+
audio_samples: int = 0,
140+
audio_channels: int = 0,
141+
audio_pts_range: Tuple[int, int] = (0, -1),
142+
audio_timebase: Fraction = default_timebase,
143+
) -> Tuple[torch.Tensor, torch.Tensor, VideoMetaData]:
139144
"""
140145
Reads a video from a file, returning both the video frames as well as
141146
the audio frames
@@ -217,7 +222,7 @@ def _read_video_from_file(
217222
return vframes, aframes, info
218223

219224

220-
def _read_video_timestamps_from_file(filename):
225+
def _read_video_timestamps_from_file(filename: str) -> Tuple[List[int], List[int], VideoMetaData]:
221226
"""
222227
Decode all video- and audio frames in the video. Only pts
223228
(presentation timestamp) is returned. The actual frame pixel data is not
@@ -252,7 +257,7 @@ def _read_video_timestamps_from_file(filename):
252257
return vframe_pts, aframe_pts, info
253258

254259

255-
def _probe_video_from_file(filename):
260+
def _probe_video_from_file(filename: str) -> VideoMetaData:
256261
"""
257262
Probe a video file and return VideoMetaData with info about the video
258263
"""
@@ -263,24 +268,23 @@ def _probe_video_from_file(filename):
263268

264269

265270
def _read_video_from_memory(
266-
video_data, # type: torch.Tensor
267-
seek_frame_margin=0.25, # type: float
268-
read_video_stream=1, # type: int
269-
video_width=0, # type: int
270-
video_height=0, # type: int
271-
video_min_dimension=0, # type: int
272-
video_max_dimension=0, # type: int
273-
video_pts_range=(0, -1), # type: List[int]
274-
video_timebase_numerator=0, # type: int
275-
video_timebase_denominator=1, # type: int
276-
read_audio_stream=1, # type: int
277-
audio_samples=0, # type: int
278-
audio_channels=0, # type: int
279-
audio_pts_range=(0, -1), # type: List[int]
280-
audio_timebase_numerator=0, # type: int
281-
audio_timebase_denominator=1, # type: int
282-
):
283-
# type: (...) -> Tuple[torch.Tensor, torch.Tensor]
271+
video_data: torch.Tensor,
272+
seek_frame_margin: float = 0.25,
273+
read_video_stream: int = 1,
274+
video_width: int = 0,
275+
video_height: int = 0,
276+
video_min_dimension: int = 0,
277+
video_max_dimension: int = 0,
278+
video_pts_range: Tuple[int, int] = (0, -1),
279+
video_timebase_numerator: int = 0,
280+
video_timebase_denominator: int = 1,
281+
read_audio_stream: int = 1,
282+
audio_samples: int = 0,
283+
audio_channels: int = 0,
284+
audio_pts_range: Tuple[int, int] = (0, -1),
285+
audio_timebase_numerator: int = 0,
286+
audio_timebase_denominator: int = 1,
287+
) -> Tuple[torch.Tensor, torch.Tensor]:
284288
"""
285289
Reads a video from memory, returning both the video frames as well as
286290
the audio frames
@@ -370,7 +374,9 @@ def _read_video_from_memory(
370374
return vframes, aframes
371375

372376

373-
def _read_video_timestamps_from_memory(video_data):
377+
def _read_video_timestamps_from_memory(
378+
video_data: torch.Tensor,
379+
) -> Tuple[List[int], List[int], VideoMetaData]:
374380
"""
375381
Decode all frames in the video. Only pts (presentation timestamp) is returned.
376382
The actual frame pixel data is not copied. Thus, read_video_timestamps(...)
@@ -407,8 +413,9 @@ def _read_video_timestamps_from_memory(video_data):
407413
return vframe_pts, aframe_pts, info
408414

409415

410-
def _probe_video_from_memory(video_data):
411-
# type: (torch.Tensor) -> VideoMetaData
416+
def _probe_video_from_memory(
417+
video_data: torch.Tensor,
418+
) -> VideoMetaData:
412419
"""
413420
Probe a video in memory and return VideoMetaData with info about the video
414421
This function is torchscriptable
@@ -421,15 +428,22 @@ def _probe_video_from_memory(video_data):
421428
return info
422429

423430

424-
def _convert_to_sec(start_pts, end_pts, pts_unit, time_base):
431+
def _convert_to_sec(
432+
start_pts: Union[float, Fraction], end_pts: Union[float, Fraction], pts_unit: str, time_base: Fraction
433+
) -> Tuple[Union[float, Fraction], Union[float, Fraction], str]:
425434
if pts_unit == "pts":
426435
start_pts = float(start_pts * time_base)
427436
end_pts = float(end_pts * time_base)
428437
pts_unit = "sec"
429438
return start_pts, end_pts, pts_unit
430439

431440

432-
def _read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"):
441+
def _read_video(
442+
filename: str,
443+
start_pts: Union[float, Fraction] = 0,
444+
end_pts: Optional[Union[float, Fraction]] = None,
445+
pts_unit: str = "pts",
446+
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, float]]:
433447
if end_pts is None:
434448
end_pts = float("inf")
435449

@@ -495,13 +509,16 @@ def get_pts(time_base):
495509
return vframes, aframes, _info
496510

497511

498-
def _read_video_timestamps(filename, pts_unit="pts"):
512+
def _read_video_timestamps(
513+
filename: str, pts_unit: str = "pts"
514+
) -> Tuple[Union[List[int], List[Fraction]], Optional[float]]:
499515
if pts_unit == "pts":
500516
warnings.warn(
501517
"The pts_unit 'pts' gives wrong results and will be removed in a "
502518
+ "follow-up version. Please use pts_unit 'sec'."
503519
)
504520

521+
pts: Union[List[int], List[Fraction]]
505522
pts, _, info = _read_video_timestamps_from_file(filename)
506523

507524
if pts_unit == "sec":

0 commit comments

Comments
 (0)