Skip to content

Commit 355e9d2

Browse files
stephenyan1231fmassa
authored andcommitted
extend DistributedSampler to support group_size (#1512)
* extend DistributedSampler to support group_size * Fix lint
1 parent b60cb72 commit 355e9d2

File tree

2 files changed

+71
-4
lines changed

2 files changed

+71
-4
lines changed

test/test_datasets_samplers.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
import unittest
66

77
from torchvision import io
8-
from torchvision.datasets.samplers import RandomClipSampler, UniformClipSampler
8+
from torchvision.datasets.samplers import (
9+
DistributedSampler,
10+
RandomClipSampler,
11+
UniformClipSampler,
12+
)
913
from torchvision.datasets.video_utils import VideoClips, unfold
1014
from torchvision import get_video_backend
1115

@@ -83,6 +87,31 @@ def test_uniform_clip_sampler_insufficient_clips(self):
8387
indices = torch.tensor(list(iter(sampler)))
8488
self.assertTrue(indices.equal(torch.tensor([0, 0, 1, 2, 4, 6, 7, 9, 11])))
8589

90+
def test_distributed_sampler_and_uniform_clip_sampler(self):
91+
with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list:
92+
video_clips = VideoClips(video_list, 5, 5)
93+
clip_sampler = UniformClipSampler(video_clips, 3)
94+
95+
distributed_sampler_rank0 = DistributedSampler(
96+
clip_sampler,
97+
num_replicas=2,
98+
rank=0,
99+
group_size=3,
100+
)
101+
indices = torch.tensor(list(iter(distributed_sampler_rank0)))
102+
self.assertEqual(len(distributed_sampler_rank0), 6)
103+
self.assertTrue(indices.equal(torch.tensor([0, 2, 4, 10, 12, 14])))
104+
105+
distributed_sampler_rank1 = DistributedSampler(
106+
clip_sampler,
107+
num_replicas=2,
108+
rank=1,
109+
group_size=3,
110+
)
111+
indices = torch.tensor(list(iter(distributed_sampler_rank1)))
112+
self.assertEqual(len(distributed_sampler_rank1), 6)
113+
self.assertTrue(indices.equal(torch.tensor([5, 7, 9, 0, 2, 4])))
114+
86115

87116
if __name__ == '__main__':
88117
unittest.main()

torchvision/datasets/samplers/clip_sampler.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,32 @@ class DistributedSampler(Sampler):
99
"""
1010
Extension of DistributedSampler, as discussed in
1111
https://github.com/pytorch/pytorch/issues/23430
12+
13+
Example:
14+
dataset: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
15+
num_replicas: 4
16+
shuffle: False
17+
18+
when group_size = 1
19+
RANK | shard_dataset
20+
=========================
21+
rank_0 | [0, 4, 8, 12]
22+
rank_1 | [1, 5, 9, 13]
23+
rank_2 | [2, 6, 10, 0]
24+
rank_3 | [3, 7, 11, 1]
25+
26+
when group_size = 2
27+
28+
RANK | shard_dataset
29+
=========================
30+
rank_0 | [0, 1, 8, 9]
31+
rank_1 | [2, 3, 10, 11]
32+
rank_2 | [4, 5, 12, 13]
33+
rank_3 | [6, 7, 0, 1]
34+
1235
"""
1336

14-
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=False):
37+
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=False, group_size=1):
1538
if num_replicas is None:
1639
if not dist.is_available():
1740
raise RuntimeError("Requires distributed package to be available")
@@ -20,11 +43,20 @@ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=False):
2043
if not dist.is_available():
2144
raise RuntimeError("Requires distributed package to be available")
2245
rank = dist.get_rank()
46+
assert len(dataset) % group_size == 0, (
47+
"dataset length must be a multiplier of group size"
48+
"dataset length: %d, group size: %d" % (len(dataset), group_size)
49+
)
2350
self.dataset = dataset
51+
self.group_size = group_size
2452
self.num_replicas = num_replicas
2553
self.rank = rank
2654
self.epoch = 0
27-
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
55+
dataset_group_length = len(dataset) // group_size
56+
self.num_group_samples = int(
57+
math.ceil(dataset_group_length * 1.0 / self.num_replicas)
58+
)
59+
self.num_samples = self.num_group_samples * group_size
2860
self.total_size = self.num_samples * self.num_replicas
2961
self.shuffle = shuffle
3062

@@ -41,8 +73,14 @@ def __iter__(self):
4173
indices += indices[:(self.total_size - len(indices))]
4274
assert len(indices) == self.total_size
4375

76+
total_group_size = self.total_size // self.group_size
77+
indices = torch.reshape(
78+
torch.LongTensor(indices), (total_group_size, self.group_size)
79+
)
80+
4481
# subsample
45-
indices = indices[self.rank:self.total_size:self.num_replicas]
82+
indices = indices[self.rank:total_group_size:self.num_replicas, :]
83+
indices = torch.reshape(indices, (-1,)).tolist()
4684
assert len(indices) == self.num_samples
4785

4886
if isinstance(self.dataset, Sampler):

0 commit comments

Comments
 (0)