-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Added missing typing annotations in datasets/video_utils #4172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3d948dc
ef8ef3b
c40bac5
1765ba4
c176e66
d2e85ea
2b6e9b0
5645da6
a29a1c8
01620d3
bad5f61
b8887ec
db9e01a
39e5159
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -2,7 +2,7 @@ | |||
import math | ||||
import warnings | ||||
from fractions import Fraction | ||||
from typing import List | ||||
from typing import Any, Dict, List, Optional, Callable, Union, Tuple, TypeVar, cast | ||||
|
||||
import torch | ||||
from torchvision.io import ( | ||||
|
@@ -14,8 +14,10 @@ | |||
|
||||
from .utils import tqdm | ||||
|
||||
T = TypeVar("T") | ||||
|
||||
def pts_convert(pts, timebase_from, timebase_to, round_func=math.floor): | ||||
|
||||
def pts_convert(pts: int, timebase_from: Fraction, timebase_to: Fraction, round_func: Callable = math.floor) -> int: | ||||
"""convert pts between different time bases | ||||
Args: | ||||
pts: presentation timestamp, float | ||||
|
@@ -27,7 +29,7 @@ def pts_convert(pts, timebase_from, timebase_to, round_func=math.floor): | |||
return round_func(new_pts) | ||||
|
||||
|
||||
def unfold(tensor, size, step, dilation=1): | ||||
def unfold(tensor: torch.Tensor, size: int, step: int, dilation: int = 1) -> torch.Tensor: | ||||
""" | ||||
similar to tensor.unfold, but with the dilation | ||||
and specialized for 1d tensors | ||||
|
@@ -55,17 +57,17 @@ class _VideoTimestampsDataset: | |||
pickled when forking. | ||||
""" | ||||
|
||||
def __init__(self, video_paths: List[str]): | ||||
def __init__(self, video_paths: List[str]) -> None: | ||||
self.video_paths = video_paths | ||||
|
||||
def __len__(self): | ||||
def __len__(self) -> int: | ||||
return len(self.video_paths) | ||||
|
||||
def __getitem__(self, idx): | ||||
def __getitem__(self, idx: int) -> Tuple[List[int], Optional[float]]: | ||||
return read_video_timestamps(self.video_paths[idx]) | ||||
|
||||
|
||||
def _collate_fn(x): | ||||
def _collate_fn(x: T) -> T: | ||||
""" | ||||
Dummy collate function to be used with _VideoTimestampsDataset | ||||
""" | ||||
|
@@ -100,19 +102,19 @@ class VideoClips: | |||
|
||||
def __init__( | ||||
self, | ||||
video_paths, | ||||
clip_length_in_frames=16, | ||||
frames_between_clips=1, | ||||
frame_rate=None, | ||||
_precomputed_metadata=None, | ||||
num_workers=0, | ||||
_video_width=0, | ||||
_video_height=0, | ||||
_video_min_dimension=0, | ||||
_video_max_dimension=0, | ||||
_audio_samples=0, | ||||
_audio_channels=0, | ||||
): | ||||
video_paths: List[str], | ||||
clip_length_in_frames: int = 16, | ||||
frames_between_clips: int = 1, | ||||
frame_rate: Optional[int] = None, | ||||
_precomputed_metadata: Optional[Dict[str, Any]] = None, | ||||
num_workers: int = 0, | ||||
_video_width: int = 0, | ||||
_video_height: int = 0, | ||||
_video_min_dimension: int = 0, | ||||
_video_max_dimension: int = 0, | ||||
_audio_samples: int = 0, | ||||
_audio_channels: int = 0, | ||||
) -> None: | ||||
|
||||
self.video_paths = video_paths | ||||
self.num_workers = num_workers | ||||
|
@@ -131,16 +133,16 @@ def __init__( | |||
self._init_from_metadata(_precomputed_metadata) | ||||
self.compute_clips(clip_length_in_frames, frames_between_clips, frame_rate) | ||||
|
||||
def _compute_frame_pts(self): | ||||
def _compute_frame_pts(self) -> None: | ||||
self.video_pts = [] | ||||
self.video_fps = [] | ||||
|
||||
# strategy: use a DataLoader to parallelize read_video_timestamps | ||||
# so need to create a dummy dataset first | ||||
import torch.utils.data | ||||
|
||||
dl = torch.utils.data.DataLoader( | ||||
_VideoTimestampsDataset(self.video_paths), | ||||
dl: torch.utils.data.DataLoader = torch.utils.data.DataLoader( | ||||
_VideoTimestampsDataset(self.video_paths), # type: ignore[arg-type] | ||||
prabhat00155 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
batch_size=16, | ||||
num_workers=self.num_workers, | ||||
collate_fn=_collate_fn, | ||||
|
@@ -157,23 +159,23 @@ def _compute_frame_pts(self): | |||
self.video_pts.extend(clips) | ||||
self.video_fps.extend(fps) | ||||
|
||||
def _init_from_metadata(self, metadata): | ||||
def _init_from_metadata(self, metadata: Dict[str, Any]) -> None: | ||||
self.video_paths = metadata["video_paths"] | ||||
assert len(self.video_paths) == len(metadata["video_pts"]) | ||||
self.video_pts = metadata["video_pts"] | ||||
assert len(self.video_paths) == len(metadata["video_fps"]) | ||||
self.video_fps = metadata["video_fps"] | ||||
|
||||
@property | ||||
def metadata(self): | ||||
def metadata(self) -> Dict[str, Any]: | ||||
_metadata = { | ||||
"video_paths": self.video_paths, | ||||
"video_pts": self.video_pts, | ||||
"video_fps": self.video_fps, | ||||
} | ||||
return _metadata | ||||
|
||||
def subset(self, indices): | ||||
def subset(self, indices: List[int]) -> "VideoClips": | ||||
video_paths = [self.video_paths[i] for i in indices] | ||||
video_pts = [self.video_pts[i] for i in indices] | ||||
video_fps = [self.video_fps[i] for i in indices] | ||||
|
@@ -198,29 +200,32 @@ def subset(self, indices): | |||
) | ||||
|
||||
@staticmethod | ||||
def compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate): | ||||
def compute_clips_for_video( | ||||
video_pts: torch.Tensor, num_frames: int, step: int, fps: int, frame_rate: Optional[int] = None | ||||
) -> Tuple[torch.Tensor, Union[List[slice], torch.Tensor]]: | ||||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
if fps is None: | ||||
# if for some reason the video doesn't have fps (because doesn't have a video stream) | ||||
# set the fps to 1. The value doesn't matter, because video_pts is empty anyway | ||||
fps = 1 | ||||
if frame_rate is None: | ||||
frame_rate = fps | ||||
total_frames = len(video_pts) * (float(frame_rate) / fps) | ||||
idxs = VideoClips._resample_video_idx(int(math.floor(total_frames)), fps, frame_rate) | ||||
prabhat00155 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
video_pts = video_pts[idxs] | ||||
_idxs = VideoClips._resample_video_idx(int(math.floor(total_frames)), fps, frame_rate) | ||||
video_pts = video_pts[_idxs] | ||||
clips = unfold(video_pts, num_frames, step) | ||||
if not clips.numel(): | ||||
warnings.warn( | ||||
"There aren't enough frames in the current video to get a clip for the given clip length and " | ||||
"frames between clips. The video (and potentially others) will be skipped." | ||||
) | ||||
if isinstance(idxs, slice): | ||||
idxs = [idxs] * len(clips) | ||||
idxs: Union[List[slice], torch.Tensor] | ||||
if isinstance(_idxs, slice): | ||||
idxs = [_idxs] * len(clips) | ||||
else: | ||||
idxs = unfold(idxs, num_frames, step) | ||||
idxs = unfold(_idxs, num_frames, step) | ||||
return clips, idxs | ||||
|
||||
def compute_clips(self, num_frames, step, frame_rate=None): | ||||
def compute_clips(self, num_frames: int, step: int, frame_rate: Optional[int] = None) -> None: | ||||
""" | ||||
Compute all consecutive sequences of clips from video_pts. | ||||
Always returns clips of size `num_frames`, meaning that the | ||||
|
@@ -243,19 +248,19 @@ def compute_clips(self, num_frames, step, frame_rate=None): | |||
clip_lengths = torch.as_tensor([len(v) for v in self.clips]) | ||||
self.cumulative_sizes = clip_lengths.cumsum(0).tolist() | ||||
|
||||
def __len__(self): | ||||
def __len__(self) -> int: | ||||
return self.num_clips() | ||||
|
||||
def num_videos(self): | ||||
def num_videos(self) -> int: | ||||
return len(self.video_paths) | ||||
|
||||
def num_clips(self): | ||||
def num_clips(self) -> int: | ||||
""" | ||||
Number of subclips that are available in the video list. | ||||
""" | ||||
return self.cumulative_sizes[-1] | ||||
|
||||
def get_clip_location(self, idx): | ||||
def get_clip_location(self, idx: int) -> Tuple[int, int]: | ||||
""" | ||||
Converts a flattened representation of the indices into a video_idx, clip_idx | ||||
representation. | ||||
|
@@ -268,7 +273,7 @@ def get_clip_location(self, idx): | |||
return video_idx, clip_idx | ||||
|
||||
@staticmethod | ||||
def _resample_video_idx(num_frames, original_fps, new_fps): | ||||
def _resample_video_idx(num_frames: int, original_fps: int, new_fps: int) -> Union[slice, torch.Tensor]: | ||||
step = float(original_fps) / new_fps | ||||
if step.is_integer(): | ||||
# optimization: if step is integer, don't need to perform | ||||
|
@@ -279,7 +284,7 @@ def _resample_video_idx(num_frames, original_fps, new_fps): | |||
idxs = idxs.floor().to(torch.int64) | ||||
return idxs | ||||
|
||||
def get_clip(self, idx): | ||||
def get_clip(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any], int]: | ||||
""" | ||||
Gets a subclip from a list of videos. | ||||
|
||||
|
@@ -320,22 +325,22 @@ def get_clip(self, idx): | |||
end_pts = clip_pts[-1].item() | ||||
video, audio, info = read_video(video_path, start_pts, end_pts) | ||||
else: | ||||
info = _probe_video_from_file(video_path) | ||||
prabhat00155 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
video_fps = info.video_fps | ||||
_info = _probe_video_from_file(video_path) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we now have Line 6 in 59baae9
this renaming is probably not needed anymore. Could you check? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did and unfortunately it's still an issue There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems even with
Since that is not the case here, we need to keep your fix. |
||||
video_fps = _info.video_fps | ||||
audio_fps = None | ||||
|
||||
video_start_pts = clip_pts[0].item() | ||||
video_end_pts = clip_pts[-1].item() | ||||
video_start_pts = cast(int, clip_pts[0].item()) | ||||
video_end_pts = cast(int, clip_pts[-1].item()) | ||||
|
||||
audio_start_pts, audio_end_pts = 0, -1 | ||||
audio_timebase = Fraction(0, 1) | ||||
video_timebase = Fraction(info.video_timebase.numerator, info.video_timebase.denominator) | ||||
if info.has_audio: | ||||
audio_timebase = Fraction(info.audio_timebase.numerator, info.audio_timebase.denominator) | ||||
video_timebase = Fraction(_info.video_timebase.numerator, _info.video_timebase.denominator) | ||||
if _info.has_audio: | ||||
audio_timebase = Fraction(_info.audio_timebase.numerator, _info.audio_timebase.denominator) | ||||
audio_start_pts = pts_convert(video_start_pts, video_timebase, audio_timebase, math.floor) | ||||
audio_end_pts = pts_convert(video_end_pts, video_timebase, audio_timebase, math.ceil) | ||||
audio_fps = info.audio_sample_rate | ||||
video, audio, info = _read_video_from_file( | ||||
audio_fps = _info.audio_sample_rate | ||||
video, audio, _ = _read_video_from_file( | ||||
video_path, | ||||
video_width=self._video_width, | ||||
video_height=self._video_height, | ||||
|
@@ -362,7 +367,7 @@ def get_clip(self, idx): | |||
assert len(video) == self.num_frames, f"{video.shape} x {self.num_frames}" | ||||
return video, audio, info, video_idx | ||||
|
||||
def __getstate__(self): | ||||
def __getstate__(self) -> Dict[str, Any]: | ||||
video_pts_sizes = [len(v) for v in self.video_pts] | ||||
# To be back-compatible, we convert data to dtype torch.long as needed | ||||
# because for empty list, in legacy implementation, torch.as_tensor will | ||||
|
@@ -371,10 +376,10 @@ def __getstate__(self): | |||
video_pts = [x.to(torch.int64) for x in self.video_pts] | ||||
# video_pts can be an empty list if no frames have been decoded | ||||
if video_pts: | ||||
video_pts = torch.cat(video_pts) | ||||
video_pts = torch.cat(video_pts) # type: ignore[assignment] | ||||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
prabhat00155 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
# avoid bug in https://github.com/pytorch/pytorch/issues/32351 | ||||
# TODO: Revert it once the bug is fixed. | ||||
video_pts = video_pts.numpy() | ||||
video_pts = video_pts.numpy() # type: ignore[attr-defined] | ||||
|
||||
# make a copy of the fields of self | ||||
d = self.__dict__.copy() | ||||
|
@@ -390,7 +395,7 @@ def __getstate__(self): | |||
d["_version"] = 2 | ||||
return d | ||||
|
||||
def __setstate__(self, d): | ||||
def __setstate__(self, d: Dict[str, Any]) -> None: | ||||
# for backwards-compatibility | ||||
if "_version" not in d: | ||||
self.__dict__ = d | ||||
|
Uh oh!
There was an error while loading. Please reload this page.