-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Added typing annotations to io/_video_opts #4173
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
29a1c92
add76b3
a3f9748
7678445
eeadcab
8c49207
0797f2f
496984c
f6489b9
a2483c4
a6fe091
c23cf7d
da0949f
60f5bfe
752015e
aff69b1
ce47bb1
6cfabdc
56bd64c
642fa6d
a943e73
a4a49f9
cd6fb5a
f6e7e03
0aab926
4c5246c
8a9dab3
c65cb53
b6aca1f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -3,7 +3,7 @@ | |||
import os | ||||
import warnings | ||||
from fractions import Fraction | ||||
from typing import List, Tuple | ||||
from typing import List, Tuple, Dict, Any, Optional, Type, cast | ||||
|
||||
import numpy as np | ||||
import torch | ||||
|
@@ -24,21 +24,20 @@ | |||
# simple class for torch scripting | ||||
# the complex Fraction class from fractions module is not scriptable | ||||
class Timebase(object): | ||||
__annotations__ = {"numerator": int, "denominator": int} | ||||
__slots__ = ["numerator", "denominator"] | ||||
__annotations__: Dict[str, Type[int]] = {"numerator": int, "denominator": int} | ||||
__slots__: List[str] = ["numerator", "denominator"] | ||||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
|
||||
def __init__( | ||||
self, | ||||
numerator, # type: int | ||||
denominator, # type: int | ||||
): | ||||
# type: (...) -> None | ||||
numerator: int, | ||||
denominator: int, | ||||
) -> None: | ||||
self.numerator = numerator | ||||
self.denominator = denominator | ||||
|
||||
|
||||
class VideoMetaData(object): | ||||
__annotations__ = { | ||||
__annotations__: Dict[str, Any] = { | ||||
frgfm marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
"has_video": bool, | ||||
"video_timebase": Timebase, | ||||
"video_duration": float, | ||||
|
@@ -48,7 +47,7 @@ class VideoMetaData(object): | |||
"audio_duration": float, | ||||
"audio_sample_rate": float, | ||||
} | ||||
__slots__ = [ | ||||
__slots__: List[str] = [ | ||||
frgfm marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
"has_video", | ||||
"video_timebase", | ||||
"video_duration", | ||||
|
@@ -59,7 +58,7 @@ class VideoMetaData(object): | |||
"audio_sample_rate", | ||||
] | ||||
|
||||
def __init__(self): | ||||
def __init__(self) -> None: | ||||
self.has_video = False | ||||
self.video_timebase = Timebase(0, 1) | ||||
self.video_duration = 0.0 | ||||
|
@@ -70,8 +69,7 @@ def __init__(self): | |||
self.audio_sample_rate = 0.0 | ||||
|
||||
|
||||
def _validate_pts(pts_range): | ||||
# type: (List[int]) -> None | ||||
def _validate_pts(pts_range: Tuple[int, int]) -> None: | ||||
|
||||
if pts_range[1] > 0: | ||||
assert ( | ||||
|
@@ -83,8 +81,14 @@ def _validate_pts(pts_range): | |||
) | ||||
|
||||
|
||||
def _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration): | ||||
# type: (torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor) -> VideoMetaData | ||||
def _fill_info( | ||||
vtimebase: torch.Tensor, | ||||
vfps: torch.Tensor, | ||||
vduration: torch.Tensor, | ||||
atimebase: torch.Tensor, | ||||
asample_rate: torch.Tensor, | ||||
aduration: torch.Tensor, | ||||
) -> VideoMetaData: | ||||
""" | ||||
Build update VideoMetaData struct with info about the video | ||||
""" | ||||
|
@@ -113,8 +117,11 @@ def _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration): | |||
return meta | ||||
|
||||
|
||||
def _align_audio_frames(aframes, aframe_pts, audio_pts_range): | ||||
# type: (torch.Tensor, torch.Tensor, List[int]) -> torch.Tensor | ||||
def _align_audio_frames( | ||||
aframes: torch.Tensor, | ||||
aframe_pts: torch.Tensor, | ||||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
audio_pts_range: Tuple[int, int] | ||||
) -> torch.Tensor: | ||||
start, end = aframe_pts[0], aframe_pts[-1] | ||||
num_samples = aframes.size(0) | ||||
step_per_aframe = float(end - start + 1) / float(num_samples) | ||||
|
@@ -128,21 +135,21 @@ def _align_audio_frames(aframes, aframe_pts, audio_pts_range): | |||
|
||||
|
||||
def _read_video_from_file( | ||||
filename, | ||||
seek_frame_margin=0.25, | ||||
read_video_stream=True, | ||||
video_width=0, | ||||
video_height=0, | ||||
video_min_dimension=0, | ||||
video_max_dimension=0, | ||||
video_pts_range=(0, -1), | ||||
video_timebase=default_timebase, | ||||
read_audio_stream=True, | ||||
audio_samples=0, | ||||
audio_channels=0, | ||||
audio_pts_range=(0, -1), | ||||
audio_timebase=default_timebase, | ||||
): | ||||
filename: str, | ||||
seek_frame_margin: float = 0.25, | ||||
read_video_stream: bool = True, | ||||
video_width: int = 0, | ||||
video_height: int = 0, | ||||
video_min_dimension: int = 0, | ||||
video_max_dimension: int = 0, | ||||
video_pts_range: Tuple[int, int] = (0, -1), | ||||
video_timebase: Fraction = default_timebase, | ||||
read_audio_stream: bool = True, | ||||
audio_samples: int = 0, | ||||
audio_channels: int = 0, | ||||
audio_pts_range: Tuple[int, int] = (0, -1), | ||||
audio_timebase: Fraction = default_timebase, | ||||
) -> Tuple[torch.Tensor, torch.Tensor, VideoMetaData]: | ||||
""" | ||||
Reads a video from a file, returning both the video frames as well as | ||||
the audio frames | ||||
|
@@ -227,7 +234,7 @@ def _read_video_from_file( | |||
return vframes, aframes, info | ||||
|
||||
|
||||
def _read_video_timestamps_from_file(filename): | ||||
def _read_video_timestamps_from_file(filename: str) -> Tuple[List[int], List[int], VideoMetaData]: | ||||
""" | ||||
Decode all video- and audio frames in the video. Only pts | ||||
(presentation timestamp) is returned. The actual frame pixel data is not | ||||
|
@@ -263,7 +270,7 @@ def _read_video_timestamps_from_file(filename): | |||
return vframe_pts, aframe_pts, info | ||||
|
||||
|
||||
def _probe_video_from_file(filename): | ||||
def _probe_video_from_file(filename: str) -> VideoMetaData: | ||||
""" | ||||
Probe a video file and return VideoMetaData with info about the video | ||||
""" | ||||
|
@@ -274,24 +281,23 @@ def _probe_video_from_file(filename): | |||
|
||||
|
||||
def _read_video_from_memory( | ||||
video_data, # type: torch.Tensor | ||||
seek_frame_margin=0.25, # type: float | ||||
read_video_stream=1, # type: int | ||||
video_width=0, # type: int | ||||
video_height=0, # type: int | ||||
video_min_dimension=0, # type: int | ||||
video_max_dimension=0, # type: int | ||||
video_pts_range=(0, -1), # type: List[int] | ||||
video_timebase_numerator=0, # type: int | ||||
video_timebase_denominator=1, # type: int | ||||
read_audio_stream=1, # type: int | ||||
audio_samples=0, # type: int | ||||
audio_channels=0, # type: int | ||||
audio_pts_range=(0, -1), # type: List[int] | ||||
audio_timebase_numerator=0, # type: int | ||||
audio_timebase_denominator=1, # type: int | ||||
): | ||||
# type: (...) -> Tuple[torch.Tensor, torch.Tensor] | ||||
video_data: torch.Tensor, | ||||
seek_frame_margin: float = 0.25, | ||||
read_video_stream: int = 1, | ||||
video_width: int = 0, | ||||
video_height: int = 0, | ||||
video_min_dimension: int = 0, | ||||
video_max_dimension: int = 0, | ||||
video_pts_range: Tuple[int, int] = (0, -1), | ||||
video_timebase_numerator: int = 0, | ||||
video_timebase_denominator: int = 1, | ||||
read_audio_stream: int = 1, | ||||
audio_samples: int = 0, | ||||
audio_channels: int = 0, | ||||
audio_pts_range: Tuple[int, int] = (0, -1), | ||||
audio_timebase_numerator: int = 0, | ||||
audio_timebase_denominator: int = 1, | ||||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||||
""" | ||||
Reads a video from memory, returning both the video frames as well as | ||||
the audio frames | ||||
|
@@ -384,7 +390,9 @@ def _read_video_from_memory( | |||
return vframes, aframes | ||||
|
||||
|
||||
def _read_video_timestamps_from_memory(video_data): | ||||
def _read_video_timestamps_from_memory( | ||||
video_data: torch.Tensor, | ||||
) -> Tuple[List[int], List[int], VideoMetaData]: | ||||
""" | ||||
Decode all frames in the video. Only pts (presentation timestamp) is returned. | ||||
The actual frame pixel data is not copied. Thus, read_video_timestamps(...) | ||||
|
@@ -424,8 +432,9 @@ def _read_video_timestamps_from_memory(video_data): | |||
return vframe_pts, aframe_pts, info | ||||
|
||||
|
||||
def _probe_video_from_memory(video_data): | ||||
# type: (torch.Tensor) -> VideoMetaData | ||||
def _probe_video_from_memory( | ||||
video_data: torch.Tensor, | ||||
) -> VideoMetaData: | ||||
""" | ||||
Probe a video in memory and return VideoMetaData with info about the video | ||||
This function is torchscriptable | ||||
|
@@ -438,15 +447,20 @@ def _probe_video_from_memory(video_data): | |||
return info | ||||
|
||||
|
||||
def _convert_to_sec(start_pts, end_pts, pts_unit, time_base): | ||||
def _convert_to_sec(start_pts: float, end_pts: float, pts_unit: str, time_base: Fraction) -> Tuple[float, float, str]: | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As mentioned in the other comment, there is no clear dtype in
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry I missed the other comment. I understand it's not straightforward to identify the types. That's why I think finding out what they are and annotating the code-base is useful. Nevertheless if we fail to annotate them correctly (say they are floats and we mark them as ints) it's going to be really confusing. Can we confirm the types of the various pts vars by running the unit-tests and putting a debugger to observe their materialized types? |
||||
if pts_unit == 'pts': | ||||
start_pts = float(start_pts * time_base) | ||||
end_pts = float(end_pts * time_base) | ||||
pts_unit = 'sec' | ||||
return start_pts, end_pts, pts_unit | ||||
|
||||
|
||||
def _read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"): | ||||
def _read_video( | ||||
filename: str, | ||||
start_pts: int = 0, | ||||
prabhat00155 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
end_pts: Optional[float] = None, | ||||
frgfm marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
pts_unit: str = "pts" | ||||
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, float]]: | ||||
if end_pts is None: | ||||
end_pts = float("inf") | ||||
|
||||
|
@@ -517,7 +531,7 @@ def get_pts(time_base): | |||
return vframes, aframes, _info | ||||
|
||||
|
||||
def _read_video_timestamps(filename, pts_unit="pts"): | ||||
def _read_video_timestamps(filename: str, pts_unit: str = "pts") -> Tuple[List[Fraction], Optional[float]]: | ||||
frgfm marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
if pts_unit == "pts": | ||||
warnings.warn( | ||||
"The pts_unit 'pts' gives wrong results and will be removed in a " | ||||
|
@@ -530,8 +544,8 @@ def _read_video_timestamps(filename, pts_unit="pts"): | |||
video_time_base = Fraction( | ||||
info.video_timebase.numerator, info.video_timebase.denominator | ||||
) | ||||
pts = [x * video_time_base for x in pts] | ||||
pts = [x * video_time_base for x in pts] # type: ignore[misc] | ||||
frgfm marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
|
||||
video_fps = info.video_fps if info.has_video else None | ||||
|
||||
return pts, video_fps | ||||
return pts, video_fps # type: ignore[return-value] | ||||
frgfm marked this conversation as resolved.
Show resolved
Hide resolved
|
Uh oh!
There was an error while loading. Please reload this page.