Skip to content

Commit c9485f8

Browse files
daisydenpytorchmergebot
authored andcommitted
[Reland][2/N]Port several test files under test/distributed to Intel GPU (pytorch#159473)
For pytorch#114850, we will port distributed tests to Intel GPU. This PR will work on some test files under test/distributed. We could enable Intel GPU with following methods and try the best to keep the original code styles: - instantiate_device_type_tests() - use "torch.accelerator.current_accelerator()" to determine the accelerator backend - use requires_accelerator_dist_backend to allow both nccl and xccl test - enabled XPU for some test path - Change the hardcoded world_size according to device_count. - Unify some common code under torch/testing/_internal for multiple backend, for example: Added xpu for Backend.backend_capability and dist.Backend.register_backend() Pull Request resolved: pytorch#159473 Approved by: https://github.com/guangyey, https://github.com/d4l3k
1 parent 71b272e commit c9485f8

File tree

11 files changed

+345
-233
lines changed

11 files changed

+345
-233
lines changed

test/distributed/test_c10d_common.py

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
retry_on_connect_failures,
4444
run_tests,
4545
TEST_WITH_DEV_DBG_ASAN,
46+
TEST_XPU,
4647
TestCase,
4748
)
4849
from torch.utils.checkpoint import checkpoint
@@ -63,15 +64,18 @@
6364

6465
torch.backends.cuda.matmul.allow_tf32 = False
6566

67+
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
68+
6669

6770
def gpus_for_rank(world_size):
6871
"""Multigpu tests are designed to simulate the multi nodes with multi
6972
GPUs on each node. Nccl backend requires equal #GPUs in each process.
7073
On a single node, all visible GPUs are evenly
7174
divided to subsets, each process only uses a subset.
7275
"""
73-
visible_devices = list(range(torch.cuda.device_count()))
74-
gpus_per_process = torch.cuda.device_count() // world_size
76+
device_count = torch.accelerator.device_count()
77+
visible_devices = list(range(device_count))
78+
gpus_per_process = device_count // world_size
7579
gpus_for_rank = []
7680
for rank in range(world_size):
7781
gpus_for_rank.append(
@@ -401,7 +405,7 @@ def _prepare_multi_device_module(
401405
gradient_as_bucket_view=gradient_as_bucket_view,
402406
)
403407

404-
input = torch.randn(global_batch_size, 2).cuda(devices[0])
408+
input = torch.randn(global_batch_size, 2).to(devices[0])
405409
target = torch.randn(global_batch_size, 4)
406410

407411
return model, ddp_model, input, target
@@ -435,10 +439,10 @@ def _test_ddp_checkpointing(
435439
allow_none_grads=False,
436440
):
437441
# to reproduce the same training results
438-
torch.cuda.set_device(self.rank)
442+
torch.accelerator.set_device_index(self.rank)
439443
torch.manual_seed(31415)
440-
model = copy.deepcopy(input_model).cuda()
441-
ddp_model = copy.deepcopy(input_model).cuda()
444+
model = copy.deepcopy(input_model).to(device_type)
445+
ddp_model = copy.deepcopy(input_model).to(device_type)
442446
ddp_model = nn.parallel.DistributedDataParallel(
443447
ddp_model,
444448
bucket_cap_mb=1,
@@ -554,8 +558,8 @@ def __init__(self, use_reentrant=True):
554558
def _prepare_dummy_data(self):
555559
ddp_bs = 16
556560
bs = ddp_bs * self.world_size
557-
input = torch.rand((bs, 20), device="cuda", requires_grad=True)
558-
target = torch.randn((bs, 20), device="cuda")
561+
input = torch.rand((bs, 20), device=device_type, requires_grad=True)
562+
target = torch.randn((bs, 20), device=device_type)
559563
offset = self.rank * ddp_bs
560564
ddp_input = input[offset : offset + ddp_bs]
561565
ddp_target = target[offset : offset + ddp_bs]
@@ -715,7 +719,7 @@ def test_ddp_checkpointing_weight_sharing(self, use_reentrant):
715719
Test that checkpointing with weight sharing works.
716720
"""
717721
process_group = self._get_process_group()
718-
torch.cuda.set_device(self.rank)
722+
torch.accelerator.set_device_index(self.rank)
719723
for use_bucket_view, static_graph in product((False, True), (False, True)):
720724
torch.manual_seed(31415)
721725
l1 = nn.Linear(20, 20)
@@ -738,7 +742,7 @@ def test_ddp_checkpointing_twice_weight_sharing(self):
738742
same layer twice and having weights shared across layers.
739743
"""
740744
process_group = self._get_process_group()
741-
torch.cuda.set_device(self.rank)
745+
torch.accelerator.set_device_index(self.rank)
742746
for use_bucket_view in (True, False):
743747
self._test_ddp_checkpointing(
744748
self.CheckpointTwiceModuleWeightSharing(),
@@ -1162,7 +1166,7 @@ def _test_sequence_num_incremented(self, process_group, ranks):
11621166

11631167
# Verify sequence numbers are appropriately incremented
11641168
for i in range(10):
1165-
t = torch.ones(1, device=torch.cuda.current_device())
1169+
t = torch.ones(1, device=device_type)
11661170
dist.all_reduce(t, group=process_group)
11671171
if not c10d._rank_not_in_group(process_group):
11681172
seq_num = self._verify_sequence_number_across_pg(
@@ -1193,7 +1197,7 @@ def _test_sequence_num_incremented(self, process_group, ranks):
11931197
self.assertEqual(rank_to_seq_num[0] + 1, rank_to_seq_num[1])
11941198

11951199
def _test_sequence_num_incremented_default_group(self, backend_name):
1196-
torch.cuda.set_device(self.rank)
1200+
torch.accelerator.set_device_index(self.rank)
11971201
store = dist.FileStore(self.file_name, self.world_size)
11981202
dist.init_process_group(
11991203
backend_name,
@@ -1207,7 +1211,7 @@ def _test_sequence_num_incremented_default_group(self, backend_name):
12071211
)
12081212

12091213
def _test_sequence_num_incremented_subgroup(self, backend_name):
1210-
torch.cuda.set_device(self.rank)
1214+
torch.accelerator.set_device_index(self.rank)
12111215
store = dist.FileStore(self.file_name, self.world_size)
12121216
dist.init_process_group(
12131217
backend_name,
@@ -1262,8 +1266,8 @@ def _test_warn_not_in_group(self, backend):
12621266
in_group_ranks = list(filter(lambda x: x % 2 == 0, range(self.world_size)))
12631267
group = dist.new_group(in_group_ranks)
12641268

1265-
x = torch.zeros(2, 2).cuda(self.rank)
1266-
xs = [torch.zeros(2, 2).cuda(self.rank) for _ in range(len(in_group_ranks))]
1269+
x = torch.zeros(2, 2).to(self.rank)
1270+
xs = [torch.zeros(2, 2).to(self.rank) for _ in range(len(in_group_ranks))]
12671271
if self.rank not in in_group_ranks:
12681272
msg = ".*{}.*does not belong to.*"
12691273
with self.assertWarnsOnceRegex(UserWarning, msg.format("all_gather")):
@@ -1392,7 +1396,7 @@ def _test_bool_tensors(self, backend):
13921396
rank=self.rank,
13931397
store=store,
13941398
)
1395-
device = "cuda" if backend == "nccl" else "cpu"
1399+
device = "cuda" if backend == "nccl" else "xpu" if backend == "xccl" else "cpu"
13961400
# test alltoall_base
13971401
tensor = torch.tensor([1, 0, 0, 1], dtype=torch.bool, device=device)
13981402
zeros = torch.tensor([0, 0, 0, 0], dtype=torch.bool, device=device)
@@ -1574,8 +1578,8 @@ def test_debug_level(self):
15741578

15751579
class DummyWork(dist._Work):
15761580
def wait(self, timeout=5.0):
1577-
if torch.cuda.is_available():
1578-
torch.cuda.current_stream().synchronize()
1581+
if torch.accelerator.is_available():
1582+
torch.accelerator.current_stream().synchronize()
15791583
return True
15801584

15811585

@@ -1790,6 +1794,18 @@ def test_backend_config(self):
17901794
("cpu:gloo,cuda:nccl", "cpu:gloo,cuda:nccl"),
17911795
]
17921796

1797+
if TEST_XPU:
1798+
# Override backend_config_strings_and_expected_values for Intel GPU.
1799+
backend_config_strings_and_expected_values[4:10] = [
1800+
(dist.Backend.DUMMY, "cpu:dummy,cuda:dummy,xpu:dummy"),
1801+
("DUMMY", "cpu:dummy,cuda:dummy,xpu:dummy"),
1802+
("dummy", "cpu:dummy,cuda:dummy,xpu:dummy"),
1803+
("cpu:dummy,xpu:dummy", "cpu:dummy,xpu:dummy"),
1804+
("cpu:dummy,xpu:xccl", "cpu:dummy,xpu:xccl"),
1805+
("cpu:gloo,xpu:dummy", "cpu:gloo,xpu:dummy"),
1806+
("cpu:gloo,xpu:xccl", "cpu:gloo,xpu:xccl"),
1807+
]
1808+
17931809
for config_str, expected_value in backend_config_strings_and_expected_values:
17941810
with self.subTest(config_str):
17951811
# ensures these configs strings are valid and no ValueError is raised
@@ -1800,6 +1816,8 @@ def test_backend_config(self):
18001816
invalid_backend_config_strings = [
18011817
"cpu:gloo,cuda:nccl,", # trailing comma
18021818
"cpu:gloo,cuda:nccl,cpu:dummy", # duplicate device
1819+
"cpu:gloo,xpu:xccl,", # trailing comma
1820+
"cpu:gloo,xpu:xccl,cpu:dummy", # duplicate device
18031821
]
18041822
for config_str in invalid_backend_config_strings:
18051823
with self.subTest(config_str):
@@ -1814,7 +1832,7 @@ def test_init_process_group_with_multiple_backends(self):
18141832
os.environ["MASTER_ADDR"] = "localhost"
18151833
os.environ["MASTER_PORT"] = "6789"
18161834
dist.init_process_group(
1817-
"cpu:dummy,cuda:dummy", rank=self.rank, world_size=self.world_size
1835+
"cpu:dummy,cuda:dummy,xpu:dummy", rank=self.rank, world_size=self.world_size
18181836
)
18191837

18201838
# test all_gather
@@ -2053,7 +2071,7 @@ def _call_collective_with_varying_tensors(self, backend, collective, *args):
20532071
# correctly dispatched
20542072

20552073
# TODO: this will be updated in the future to not be backend specific
2056-
device = "cuda" if backend == "nccl" else "cpu"
2074+
device = "cuda" if backend == "nccl" else "xpu" if backend == "xccl" else "cpu"
20572075
# ensure supported devices (cpu, cuda) succeeds during dispatch call
20582076
tensor = torch.zeros(2, 2, device=torch.device(device))
20592077
# multi tensor collectives
@@ -2119,7 +2137,7 @@ def _test_all_to_all_single(self, backend):
21192137
rank=self.rank,
21202138
store=store,
21212139
)
2122-
device = "cuda" if backend == "nccl" else "cpu"
2140+
device = "cuda" if backend == "nccl" else "xpu" if backend == "xccl" else "cpu"
21232141
# test alltoall_base
21242142
input_tensor = torch.ones(2, 2, device=torch.device(device))
21252143
output_tensor = torch.zeros(2, 2, device=torch.device(device))
@@ -2251,8 +2269,9 @@ def testNodeLocalRank(self):
22512269

22522270

22532271
if __name__ == "__main__":
2254-
assert not torch.cuda._initialized, (
2255-
"test_distributed must not have initialized CUDA context on main process"
2256-
)
2272+
if device_type != "cpu":
2273+
assert not torch.get_device_module()._initialized, (
2274+
"test_distributed must not have initialized {device_type} context on main process"
2275+
)
22572276

22582277
run_tests()

0 commit comments

Comments
 (0)