Skip to content

Commit 6662b30

Browse files
authored
add typehints for .datasets.samplers (#2667)
1 parent f8bf06d commit 6662b30

File tree

1 file changed

+23
-16
lines changed

1 file changed

+23
-16
lines changed

torchvision/datasets/samplers/clip_sampler.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from torch.utils.data import Sampler
44
import torch.distributed as dist
55
from torchvision.datasets.video_utils import VideoClips
6+
from typing import Optional, List, Iterator, Sized, Union, cast
67

78

89
class DistributedSampler(Sampler):
@@ -34,7 +35,14 @@ class DistributedSampler(Sampler):
3435
3536
"""
3637

37-
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=False, group_size=1):
38+
def __init__(
39+
self,
40+
dataset: Sized,
41+
num_replicas: Optional[int] = None,
42+
rank: Optional[int] = None,
43+
shuffle: bool = False,
44+
group_size: int = 1,
45+
) -> None:
3846
if num_replicas is None:
3947
if not dist.is_available():
4048
raise RuntimeError("Requires distributed package to be available")
@@ -60,10 +68,11 @@ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=False, group_s
6068
self.total_size = self.num_samples * self.num_replicas
6169
self.shuffle = shuffle
6270

63-
def __iter__(self):
71+
def __iter__(self) -> Iterator[int]:
6472
# deterministically shuffle based on epoch
6573
g = torch.Generator()
6674
g.manual_seed(self.epoch)
75+
indices: Union[torch.Tensor, List[int]]
6776
if self.shuffle:
6877
indices = torch.randperm(len(self.dataset), generator=g).tolist()
6978
else:
@@ -89,10 +98,10 @@ def __iter__(self):
8998

9099
return iter(indices)
91100

92-
def __len__(self):
101+
def __len__(self) -> int:
93102
return self.num_samples
94103

95-
def set_epoch(self, epoch):
104+
def set_epoch(self, epoch: int) -> None:
96105
self.epoch = epoch
97106

98107

@@ -106,14 +115,14 @@ class UniformClipSampler(Sampler):
106115
video_clips (VideoClips): video clips to sample from
107116
num_clips_per_video (int): number of clips to be sampled per video
108117
"""
109-
def __init__(self, video_clips, num_clips_per_video):
118+
def __init__(self, video_clips: VideoClips, num_clips_per_video: int) -> None:
110119
if not isinstance(video_clips, VideoClips):
111120
raise TypeError("Expected video_clips to be an instance of VideoClips, "
112121
"got {}".format(type(video_clips)))
113122
self.video_clips = video_clips
114123
self.num_clips_per_video = num_clips_per_video
115124

116-
def __iter__(self):
125+
def __iter__(self) -> Iterator[int]:
117126
idxs = []
118127
s = 0
119128
# select num_clips_per_video for each video, uniformly spaced
@@ -130,10 +139,9 @@ def __iter__(self):
130139
)
131140
s += length
132141
idxs.append(sampled)
133-
idxs = torch.cat(idxs).tolist()
134-
return iter(idxs)
142+
return iter(cast(List[int], torch.cat(idxs).tolist()))
135143

136-
def __len__(self):
144+
def __len__(self) -> int:
137145
return sum(
138146
self.num_clips_per_video for c in self.video_clips.clips if len(c) > 0
139147
)
@@ -147,14 +155,14 @@ class RandomClipSampler(Sampler):
147155
video_clips (VideoClips): video clips to sample from
148156
max_clips_per_video (int): maximum number of clips to be sampled per video
149157
"""
150-
def __init__(self, video_clips, max_clips_per_video):
158+
def __init__(self, video_clips: VideoClips, max_clips_per_video: int) -> None:
151159
if not isinstance(video_clips, VideoClips):
152160
raise TypeError("Expected video_clips to be an instance of VideoClips, "
153161
"got {}".format(type(video_clips)))
154162
self.video_clips = video_clips
155163
self.max_clips_per_video = max_clips_per_video
156164

157-
def __iter__(self):
165+
def __iter__(self) -> Iterator[int]:
158166
idxs = []
159167
s = 0
160168
# select at most max_clips_per_video for each video, randomly
@@ -164,11 +172,10 @@ def __iter__(self):
164172
sampled = torch.randperm(length)[:size] + s
165173
s += length
166174
idxs.append(sampled)
167-
idxs = torch.cat(idxs)
175+
idxs_ = torch.cat(idxs)
168176
# shuffle all clips randomly
169-
perm = torch.randperm(len(idxs))
170-
idxs = idxs[perm].tolist()
171-
return iter(idxs)
177+
perm = torch.randperm(len(idxs_))
178+
return iter(idxs_[perm].tolist())
172179

173-
def __len__(self):
180+
def __len__(self) -> int:
174181
return sum(min(len(c), self.max_clips_per_video) for c in self.video_clips.clips)

0 commit comments

Comments
 (0)