Skip to content

Commit 999ef25

Browse files
frgfmpmeierprabhat00155
authored
Added missing typing annotations in datasets/video_utils (#4172)
* style: Fixed last missing typing annotation * style: Fixed typing * style: Fixed remaining typing annotations * style: Fixed typing * style: Fixed typing * refactor: Removed unused import * Update torchvision/datasets/video_utils.py Co-authored-by: Philip Meier <[email protected]> Co-authored-by: Prabhat Roy <[email protected]>
1 parent ef2e418 commit 999ef25

File tree

1 file changed

+57
-52
lines changed

1 file changed

+57
-52
lines changed

torchvision/datasets/video_utils.py

Lines changed: 57 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import math
33
import warnings
44
from fractions import Fraction
5-
from typing import List
5+
from typing import Any, Dict, List, Optional, Callable, Union, Tuple, TypeVar, cast
66

77
import torch
88
from torchvision.io import (
@@ -14,8 +14,10 @@
1414

1515
from .utils import tqdm
1616

17+
T = TypeVar("T")
1718

18-
def pts_convert(pts, timebase_from, timebase_to, round_func=math.floor):
19+
20+
def pts_convert(pts: int, timebase_from: Fraction, timebase_to: Fraction, round_func: Callable = math.floor) -> int:
1921
"""convert pts between different time bases
2022
Args:
2123
pts: presentation timestamp, float
@@ -27,7 +29,7 @@ def pts_convert(pts, timebase_from, timebase_to, round_func=math.floor):
2729
return round_func(new_pts)
2830

2931

30-
def unfold(tensor, size, step, dilation=1):
32+
def unfold(tensor: torch.Tensor, size: int, step: int, dilation: int = 1) -> torch.Tensor:
3133
"""
3234
similar to tensor.unfold, but with the dilation
3335
and specialized for 1d tensors
@@ -55,17 +57,17 @@ class _VideoTimestampsDataset:
5557
pickled when forking.
5658
"""
5759

58-
def __init__(self, video_paths: List[str]):
60+
def __init__(self, video_paths: List[str]) -> None:
5961
self.video_paths = video_paths
6062

61-
def __len__(self):
63+
def __len__(self) -> int:
6264
return len(self.video_paths)
6365

64-
def __getitem__(self, idx):
66+
def __getitem__(self, idx: int) -> Tuple[List[int], Optional[float]]:
6567
return read_video_timestamps(self.video_paths[idx])
6668

6769

68-
def _collate_fn(x):
70+
def _collate_fn(x: T) -> T:
6971
"""
7072
Dummy collate function to be used with _VideoTimestampsDataset
7173
"""
@@ -100,19 +102,19 @@ class VideoClips:
100102

101103
def __init__(
102104
self,
103-
video_paths,
104-
clip_length_in_frames=16,
105-
frames_between_clips=1,
106-
frame_rate=None,
107-
_precomputed_metadata=None,
108-
num_workers=0,
109-
_video_width=0,
110-
_video_height=0,
111-
_video_min_dimension=0,
112-
_video_max_dimension=0,
113-
_audio_samples=0,
114-
_audio_channels=0,
115-
):
105+
video_paths: List[str],
106+
clip_length_in_frames: int = 16,
107+
frames_between_clips: int = 1,
108+
frame_rate: Optional[int] = None,
109+
_precomputed_metadata: Optional[Dict[str, Any]] = None,
110+
num_workers: int = 0,
111+
_video_width: int = 0,
112+
_video_height: int = 0,
113+
_video_min_dimension: int = 0,
114+
_video_max_dimension: int = 0,
115+
_audio_samples: int = 0,
116+
_audio_channels: int = 0,
117+
) -> None:
116118

117119
self.video_paths = video_paths
118120
self.num_workers = num_workers
@@ -131,16 +133,16 @@ def __init__(
131133
self._init_from_metadata(_precomputed_metadata)
132134
self.compute_clips(clip_length_in_frames, frames_between_clips, frame_rate)
133135

134-
def _compute_frame_pts(self):
136+
def _compute_frame_pts(self) -> None:
135137
self.video_pts = []
136138
self.video_fps = []
137139

138140
# strategy: use a DataLoader to parallelize read_video_timestamps
139141
# so need to create a dummy dataset first
140142
import torch.utils.data
141143

142-
dl = torch.utils.data.DataLoader(
143-
_VideoTimestampsDataset(self.video_paths),
144+
dl: torch.utils.data.DataLoader = torch.utils.data.DataLoader(
145+
_VideoTimestampsDataset(self.video_paths), # type: ignore[arg-type]
144146
batch_size=16,
145147
num_workers=self.num_workers,
146148
collate_fn=_collate_fn,
@@ -157,23 +159,23 @@ def _compute_frame_pts(self):
157159
self.video_pts.extend(clips)
158160
self.video_fps.extend(fps)
159161

160-
def _init_from_metadata(self, metadata):
162+
def _init_from_metadata(self, metadata: Dict[str, Any]) -> None:
161163
self.video_paths = metadata["video_paths"]
162164
assert len(self.video_paths) == len(metadata["video_pts"])
163165
self.video_pts = metadata["video_pts"]
164166
assert len(self.video_paths) == len(metadata["video_fps"])
165167
self.video_fps = metadata["video_fps"]
166168

167169
@property
168-
def metadata(self):
170+
def metadata(self) -> Dict[str, Any]:
169171
_metadata = {
170172
"video_paths": self.video_paths,
171173
"video_pts": self.video_pts,
172174
"video_fps": self.video_fps,
173175
}
174176
return _metadata
175177

176-
def subset(self, indices):
178+
def subset(self, indices: List[int]) -> "VideoClips":
177179
video_paths = [self.video_paths[i] for i in indices]
178180
video_pts = [self.video_pts[i] for i in indices]
179181
video_fps = [self.video_fps[i] for i in indices]
@@ -198,29 +200,32 @@ def subset(self, indices):
198200
)
199201

200202
@staticmethod
201-
def compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate):
203+
def compute_clips_for_video(
204+
video_pts: torch.Tensor, num_frames: int, step: int, fps: int, frame_rate: Optional[int] = None
205+
) -> Tuple[torch.Tensor, Union[List[slice], torch.Tensor]]:
202206
if fps is None:
203207
# if for some reason the video doesn't have fps (because doesn't have a video stream)
204208
# set the fps to 1. The value doesn't matter, because video_pts is empty anyway
205209
fps = 1
206210
if frame_rate is None:
207211
frame_rate = fps
208212
total_frames = len(video_pts) * (float(frame_rate) / fps)
209-
idxs = VideoClips._resample_video_idx(int(math.floor(total_frames)), fps, frame_rate)
210-
video_pts = video_pts[idxs]
213+
_idxs = VideoClips._resample_video_idx(int(math.floor(total_frames)), fps, frame_rate)
214+
video_pts = video_pts[_idxs]
211215
clips = unfold(video_pts, num_frames, step)
212216
if not clips.numel():
213217
warnings.warn(
214218
"There aren't enough frames in the current video to get a clip for the given clip length and "
215219
"frames between clips. The video (and potentially others) will be skipped."
216220
)
217-
if isinstance(idxs, slice):
218-
idxs = [idxs] * len(clips)
221+
idxs: Union[List[slice], torch.Tensor]
222+
if isinstance(_idxs, slice):
223+
idxs = [_idxs] * len(clips)
219224
else:
220-
idxs = unfold(idxs, num_frames, step)
225+
idxs = unfold(_idxs, num_frames, step)
221226
return clips, idxs
222227

223-
def compute_clips(self, num_frames, step, frame_rate=None):
228+
def compute_clips(self, num_frames: int, step: int, frame_rate: Optional[int] = None) -> None:
224229
"""
225230
Compute all consecutive sequences of clips from video_pts.
226231
Always returns clips of size `num_frames`, meaning that the
@@ -243,19 +248,19 @@ def compute_clips(self, num_frames, step, frame_rate=None):
243248
clip_lengths = torch.as_tensor([len(v) for v in self.clips])
244249
self.cumulative_sizes = clip_lengths.cumsum(0).tolist()
245250

246-
def __len__(self):
251+
def __len__(self) -> int:
247252
return self.num_clips()
248253

249-
def num_videos(self):
254+
def num_videos(self) -> int:
250255
return len(self.video_paths)
251256

252-
def num_clips(self):
257+
def num_clips(self) -> int:
253258
"""
254259
Number of subclips that are available in the video list.
255260
"""
256261
return self.cumulative_sizes[-1]
257262

258-
def get_clip_location(self, idx):
263+
def get_clip_location(self, idx: int) -> Tuple[int, int]:
259264
"""
260265
Converts a flattened representation of the indices into a video_idx, clip_idx
261266
representation.
@@ -268,7 +273,7 @@ def get_clip_location(self, idx):
268273
return video_idx, clip_idx
269274

270275
@staticmethod
271-
def _resample_video_idx(num_frames, original_fps, new_fps):
276+
def _resample_video_idx(num_frames: int, original_fps: int, new_fps: int) -> Union[slice, torch.Tensor]:
272277
step = float(original_fps) / new_fps
273278
if step.is_integer():
274279
# optimization: if step is integer, don't need to perform
@@ -279,7 +284,7 @@ def _resample_video_idx(num_frames, original_fps, new_fps):
279284
idxs = idxs.floor().to(torch.int64)
280285
return idxs
281286

282-
def get_clip(self, idx):
287+
def get_clip(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any], int]:
283288
"""
284289
Gets a subclip from a list of videos.
285290
@@ -320,22 +325,22 @@ def get_clip(self, idx):
320325
end_pts = clip_pts[-1].item()
321326
video, audio, info = read_video(video_path, start_pts, end_pts)
322327
else:
323-
info = _probe_video_from_file(video_path)
324-
video_fps = info.video_fps
328+
_info = _probe_video_from_file(video_path)
329+
video_fps = _info.video_fps
325330
audio_fps = None
326331

327-
video_start_pts = clip_pts[0].item()
328-
video_end_pts = clip_pts[-1].item()
332+
video_start_pts = cast(int, clip_pts[0].item())
333+
video_end_pts = cast(int, clip_pts[-1].item())
329334

330335
audio_start_pts, audio_end_pts = 0, -1
331336
audio_timebase = Fraction(0, 1)
332-
video_timebase = Fraction(info.video_timebase.numerator, info.video_timebase.denominator)
333-
if info.has_audio:
334-
audio_timebase = Fraction(info.audio_timebase.numerator, info.audio_timebase.denominator)
337+
video_timebase = Fraction(_info.video_timebase.numerator, _info.video_timebase.denominator)
338+
if _info.has_audio:
339+
audio_timebase = Fraction(_info.audio_timebase.numerator, _info.audio_timebase.denominator)
335340
audio_start_pts = pts_convert(video_start_pts, video_timebase, audio_timebase, math.floor)
336341
audio_end_pts = pts_convert(video_end_pts, video_timebase, audio_timebase, math.ceil)
337-
audio_fps = info.audio_sample_rate
338-
video, audio, info = _read_video_from_file(
342+
audio_fps = _info.audio_sample_rate
343+
video, audio, _ = _read_video_from_file(
339344
video_path,
340345
video_width=self._video_width,
341346
video_height=self._video_height,
@@ -362,7 +367,7 @@ def get_clip(self, idx):
362367
assert len(video) == self.num_frames, f"{video.shape} x {self.num_frames}"
363368
return video, audio, info, video_idx
364369

365-
def __getstate__(self):
370+
def __getstate__(self) -> Dict[str, Any]:
366371
video_pts_sizes = [len(v) for v in self.video_pts]
367372
# To be back-compatible, we convert data to dtype torch.long as needed
368373
# because for empty list, in legacy implementation, torch.as_tensor will
@@ -371,10 +376,10 @@ def __getstate__(self):
371376
video_pts = [x.to(torch.int64) for x in self.video_pts]
372377
# video_pts can be an empty list if no frames have been decoded
373378
if video_pts:
374-
video_pts = torch.cat(video_pts)
379+
video_pts = torch.cat(video_pts) # type: ignore[assignment]
375380
# avoid bug in https://github.com/pytorch/pytorch/issues/32351
376381
# TODO: Revert it once the bug is fixed.
377-
video_pts = video_pts.numpy()
382+
video_pts = video_pts.numpy() # type: ignore[attr-defined]
378383

379384
# make a copy of the fields of self
380385
d = self.__dict__.copy()
@@ -390,7 +395,7 @@ def __getstate__(self):
390395
d["_version"] = 2
391396
return d
392397

393-
def __setstate__(self, d):
398+
def __setstate__(self, d: Dict[str, Any]) -> None:
394399
# for backwards-compatibility
395400
if "_version" not in d:
396401
self.__dict__ = d

0 commit comments

Comments
 (0)