diff --git a/torchvision/datasets/samplers/clip_sampler.py b/torchvision/datasets/samplers/clip_sampler.py index 2432a6d20de..e1a5775b3db 100644 --- a/torchvision/datasets/samplers/clip_sampler.py +++ b/torchvision/datasets/samplers/clip_sampler.py @@ -3,6 +3,7 @@ from torch.utils.data import Sampler import torch.distributed as dist from torchvision.datasets.video_utils import VideoClips +from typing import Optional, List, Iterator, Sized, Union, cast class DistributedSampler(Sampler): @@ -34,7 +35,14 @@ class DistributedSampler(Sampler): """ - def __init__(self, dataset, num_replicas=None, rank=None, shuffle=False, group_size=1): + def __init__( + self, + dataset: Sized, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = False, + group_size: int = 1, + ) -> None: if num_replicas is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") @@ -60,10 +68,11 @@ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=False, group_s self.total_size = self.num_samples * self.num_replicas self.shuffle = shuffle - def __iter__(self): + def __iter__(self) -> Iterator[int]: # deterministically shuffle based on epoch g = torch.Generator() g.manual_seed(self.epoch) + indices: Union[torch.Tensor, List[int]] if self.shuffle: indices = torch.randperm(len(self.dataset), generator=g).tolist() else: @@ -89,10 +98,10 @@ def __iter__(self): return iter(indices) - def __len__(self): + def __len__(self) -> int: return self.num_samples - def set_epoch(self, epoch): + def set_epoch(self, epoch: int) -> None: self.epoch = epoch @@ -106,14 +115,14 @@ class UniformClipSampler(Sampler): video_clips (VideoClips): video clips to sample from num_clips_per_video (int): number of clips to be sampled per video """ - def __init__(self, video_clips, num_clips_per_video): + def __init__(self, video_clips: VideoClips, num_clips_per_video: int) -> None: if not isinstance(video_clips, VideoClips): raise TypeError("Expected video_clips to be an instance of VideoClips, " "got {}".format(type(video_clips))) self.video_clips = video_clips self.num_clips_per_video = num_clips_per_video - def __iter__(self): + def __iter__(self) -> Iterator[int]: idxs = [] s = 0 # select num_clips_per_video for each video, uniformly spaced @@ -130,10 +139,9 @@ def __iter__(self): ) s += length idxs.append(sampled) - idxs = torch.cat(idxs).tolist() - return iter(idxs) + return iter(cast(List[int], torch.cat(idxs).tolist())) - def __len__(self): + def __len__(self) -> int: return sum( self.num_clips_per_video for c in self.video_clips.clips if len(c) > 0 ) @@ -147,14 +155,14 @@ class RandomClipSampler(Sampler): video_clips (VideoClips): video clips to sample from max_clips_per_video (int): maximum number of clips to be sampled per video """ - def __init__(self, video_clips, max_clips_per_video): + def __init__(self, video_clips: VideoClips, max_clips_per_video: int) -> None: if not isinstance(video_clips, VideoClips): raise TypeError("Expected video_clips to be an instance of VideoClips, " "got {}".format(type(video_clips))) self.video_clips = video_clips self.max_clips_per_video = max_clips_per_video - def __iter__(self): + def __iter__(self) -> Iterator[int]: idxs = [] s = 0 # select at most max_clips_per_video for each video, randomly @@ -164,11 +172,10 @@ def __iter__(self): sampled = torch.randperm(length)[:size] + s s += length idxs.append(sampled) - idxs = torch.cat(idxs) + idxs_ = torch.cat(idxs) # shuffle all clips randomly - perm = torch.randperm(len(idxs)) - idxs = idxs[perm].tolist() - return iter(idxs) + perm = torch.randperm(len(idxs_)) + return iter(idxs_[perm].tolist()) - def __len__(self): + def __len__(self) -> int: return sum(min(len(c), self.max_clips_per_video) for c in self.video_clips.clips)