3
3
from torch .utils .data import Sampler
4
4
import torch .distributed as dist
5
5
from torchvision .datasets .video_utils import VideoClips
6
+ from typing import Optional , List , Iterator , Sized , Union , cast
6
7
7
8
8
9
class DistributedSampler (Sampler ):
@@ -34,7 +35,14 @@ class DistributedSampler(Sampler):
34
35
35
36
"""
36
37
37
- def __init__ (self , dataset , num_replicas = None , rank = None , shuffle = False , group_size = 1 ):
38
+ def __init__ (
39
+ self ,
40
+ dataset : Sized ,
41
+ num_replicas : Optional [int ] = None ,
42
+ rank : Optional [int ] = None ,
43
+ shuffle : bool = False ,
44
+ group_size : int = 1 ,
45
+ ) -> None :
38
46
if num_replicas is None :
39
47
if not dist .is_available ():
40
48
raise RuntimeError ("Requires distributed package to be available" )
@@ -60,10 +68,11 @@ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=False, group_s
60
68
self .total_size = self .num_samples * self .num_replicas
61
69
self .shuffle = shuffle
62
70
63
- def __iter__ (self ):
71
+ def __iter__ (self ) -> Iterator [ int ] :
64
72
# deterministically shuffle based on epoch
65
73
g = torch .Generator ()
66
74
g .manual_seed (self .epoch )
75
+ indices : Union [torch .Tensor , List [int ]]
67
76
if self .shuffle :
68
77
indices = torch .randperm (len (self .dataset ), generator = g ).tolist ()
69
78
else :
@@ -89,10 +98,10 @@ def __iter__(self):
89
98
90
99
return iter (indices )
91
100
92
- def __len__ (self ):
101
+ def __len__ (self ) -> int :
93
102
return self .num_samples
94
103
95
- def set_epoch (self , epoch ) :
104
+ def set_epoch (self , epoch : int ) -> None :
96
105
self .epoch = epoch
97
106
98
107
@@ -106,14 +115,14 @@ class UniformClipSampler(Sampler):
106
115
video_clips (VideoClips): video clips to sample from
107
116
num_clips_per_video (int): number of clips to be sampled per video
108
117
"""
109
- def __init__ (self , video_clips , num_clips_per_video ) :
118
+ def __init__ (self , video_clips : VideoClips , num_clips_per_video : int ) -> None :
110
119
if not isinstance (video_clips , VideoClips ):
111
120
raise TypeError ("Expected video_clips to be an instance of VideoClips, "
112
121
"got {}" .format (type (video_clips )))
113
122
self .video_clips = video_clips
114
123
self .num_clips_per_video = num_clips_per_video
115
124
116
- def __iter__ (self ):
125
+ def __iter__ (self ) -> Iterator [ int ] :
117
126
idxs = []
118
127
s = 0
119
128
# select num_clips_per_video for each video, uniformly spaced
@@ -130,10 +139,9 @@ def __iter__(self):
130
139
)
131
140
s += length
132
141
idxs .append (sampled )
133
- idxs = torch .cat (idxs ).tolist ()
134
- return iter (idxs )
142
+ return iter (cast (List [int ], torch .cat (idxs ).tolist ()))
135
143
136
- def __len__ (self ):
144
+ def __len__ (self ) -> int :
137
145
return sum (
138
146
self .num_clips_per_video for c in self .video_clips .clips if len (c ) > 0
139
147
)
@@ -147,14 +155,14 @@ class RandomClipSampler(Sampler):
147
155
video_clips (VideoClips): video clips to sample from
148
156
max_clips_per_video (int): maximum number of clips to be sampled per video
149
157
"""
150
- def __init__ (self , video_clips , max_clips_per_video ) :
158
+ def __init__ (self , video_clips : VideoClips , max_clips_per_video : int ) -> None :
151
159
if not isinstance (video_clips , VideoClips ):
152
160
raise TypeError ("Expected video_clips to be an instance of VideoClips, "
153
161
"got {}" .format (type (video_clips )))
154
162
self .video_clips = video_clips
155
163
self .max_clips_per_video = max_clips_per_video
156
164
157
- def __iter__ (self ):
165
+ def __iter__ (self ) -> Iterator [ int ] :
158
166
idxs = []
159
167
s = 0
160
168
# select at most max_clips_per_video for each video, randomly
@@ -164,11 +172,10 @@ def __iter__(self):
164
172
sampled = torch .randperm (length )[:size ] + s
165
173
s += length
166
174
idxs .append (sampled )
167
- idxs = torch .cat (idxs )
175
+ idxs_ = torch .cat (idxs )
168
176
# shuffle all clips randomly
169
- perm = torch .randperm (len (idxs ))
170
- idxs = idxs [perm ].tolist ()
171
- return iter (idxs )
177
+ perm = torch .randperm (len (idxs_ ))
178
+ return iter (idxs_ [perm ].tolist ())
172
179
173
- def __len__ (self ):
180
+ def __len__ (self ) -> int :
174
181
return sum (min (len (c ), self .max_clips_per_video ) for c in self .video_clips .clips )
0 commit comments