@@ -9,9 +9,32 @@ class DistributedSampler(Sampler):
9
9
"""
10
10
Extension of DistributedSampler, as discussed in
11
11
https://github.com/pytorch/pytorch/issues/23430
12
+
13
+ Example:
14
+ dataset: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
15
+ num_replicas: 4
16
+ shuffle: False
17
+
18
+ when group_size = 1
19
+ RANK | shard_dataset
20
+ =========================
21
+ rank_0 | [0, 4, 8, 12]
22
+ rank_1 | [1, 5, 9, 13]
23
+ rank_2 | [2, 6, 10, 0]
24
+ rank_3 | [3, 7, 11, 1]
25
+
26
+ when group_size = 2
27
+
28
+ RANK | shard_dataset
29
+ =========================
30
+ rank_0 | [0, 1, 8, 9]
31
+ rank_1 | [2, 3, 10, 11]
32
+ rank_2 | [4, 5, 12, 13]
33
+ rank_3 | [6, 7, 0, 1]
34
+
12
35
"""
13
36
14
- def __init__ (self , dataset , num_replicas = None , rank = None , shuffle = False ):
37
+ def __init__ (self , dataset , num_replicas = None , rank = None , shuffle = False , group_size = 1 ):
15
38
if num_replicas is None :
16
39
if not dist .is_available ():
17
40
raise RuntimeError ("Requires distributed package to be available" )
@@ -20,11 +43,20 @@ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=False):
20
43
if not dist .is_available ():
21
44
raise RuntimeError ("Requires distributed package to be available" )
22
45
rank = dist .get_rank ()
46
+ assert len (dataset ) % group_size == 0 , (
47
+ "dataset length must be a multiplier of group size"
48
+ "dataset length: %d, group size: %d" % (len (dataset ), group_size )
49
+ )
23
50
self .dataset = dataset
51
+ self .group_size = group_size
24
52
self .num_replicas = num_replicas
25
53
self .rank = rank
26
54
self .epoch = 0
27
- self .num_samples = int (math .ceil (len (self .dataset ) * 1.0 / self .num_replicas ))
55
+ dataset_group_length = len (dataset ) // group_size
56
+ self .num_group_samples = int (
57
+ math .ceil (dataset_group_length * 1.0 / self .num_replicas )
58
+ )
59
+ self .num_samples = self .num_group_samples * group_size
28
60
self .total_size = self .num_samples * self .num_replicas
29
61
self .shuffle = shuffle
30
62
@@ -41,8 +73,14 @@ def __iter__(self):
41
73
indices += indices [:(self .total_size - len (indices ))]
42
74
assert len (indices ) == self .total_size
43
75
76
+ total_group_size = self .total_size // self .group_size
77
+ indices = torch .reshape (
78
+ torch .LongTensor (indices ), (total_group_size , self .group_size )
79
+ )
80
+
44
81
# subsample
45
- indices = indices [self .rank :self .total_size :self .num_replicas ]
82
+ indices = indices [self .rank :total_group_size :self .num_replicas , :]
83
+ indices = torch .reshape (indices , (- 1 ,)).tolist ()
46
84
assert len (indices ) == self .num_samples
47
85
48
86
if isinstance (self .dataset , Sampler ):
0 commit comments