4343 retry_on_connect_failures ,
4444 run_tests ,
4545 TEST_WITH_DEV_DBG_ASAN ,
46+ TEST_XPU ,
4647 TestCase ,
4748)
4849from torch .utils .checkpoint import checkpoint
6364
6465torch .backends .cuda .matmul .allow_tf32 = False
6566
67+ device_type = acc .type if (acc := torch .accelerator .current_accelerator ()) else "cpu"
68+
6669
6770def gpus_for_rank (world_size ):
6871 """Multigpu tests are designed to simulate the multi nodes with multi
6972 GPUs on each node. Nccl backend requires equal #GPUs in each process.
7073 On a single node, all visible GPUs are evenly
7174 divided to subsets, each process only uses a subset.
7275 """
73- visible_devices = list (range (torch .cuda .device_count ()))
74- gpus_per_process = torch .cuda .device_count () // world_size
76+ device_count = torch .accelerator .device_count ()
77+ visible_devices = list (range (device_count ))
78+ gpus_per_process = device_count // world_size
7579 gpus_for_rank = []
7680 for rank in range (world_size ):
7781 gpus_for_rank .append (
@@ -401,7 +405,7 @@ def _prepare_multi_device_module(
401405 gradient_as_bucket_view = gradient_as_bucket_view ,
402406 )
403407
404- input = torch .randn (global_batch_size , 2 ).cuda (devices [0 ])
408+ input = torch .randn (global_batch_size , 2 ).to (devices [0 ])
405409 target = torch .randn (global_batch_size , 4 )
406410
407411 return model , ddp_model , input , target
@@ -435,10 +439,10 @@ def _test_ddp_checkpointing(
435439 allow_none_grads = False ,
436440 ):
437441 # to reproduce the same training results
438- torch .cuda . set_device (self .rank )
442+ torch .accelerator . set_device_index (self .rank )
439443 torch .manual_seed (31415 )
440- model = copy .deepcopy (input_model ).cuda ( )
441- ddp_model = copy .deepcopy (input_model ).cuda ( )
444+ model = copy .deepcopy (input_model ).to ( device_type )
445+ ddp_model = copy .deepcopy (input_model ).to ( device_type )
442446 ddp_model = nn .parallel .DistributedDataParallel (
443447 ddp_model ,
444448 bucket_cap_mb = 1 ,
@@ -554,8 +558,8 @@ def __init__(self, use_reentrant=True):
554558 def _prepare_dummy_data (self ):
555559 ddp_bs = 16
556560 bs = ddp_bs * self .world_size
557- input = torch .rand ((bs , 20 ), device = "cuda" , requires_grad = True )
558- target = torch .randn ((bs , 20 ), device = "cuda" )
561+ input = torch .rand ((bs , 20 ), device = device_type , requires_grad = True )
562+ target = torch .randn ((bs , 20 ), device = device_type )
559563 offset = self .rank * ddp_bs
560564 ddp_input = input [offset : offset + ddp_bs ]
561565 ddp_target = target [offset : offset + ddp_bs ]
@@ -715,7 +719,7 @@ def test_ddp_checkpointing_weight_sharing(self, use_reentrant):
715719 Test that checkpointing with weight sharing works.
716720 """
717721 process_group = self ._get_process_group ()
718- torch .cuda . set_device (self .rank )
722+ torch .accelerator . set_device_index (self .rank )
719723 for use_bucket_view , static_graph in product ((False , True ), (False , True )):
720724 torch .manual_seed (31415 )
721725 l1 = nn .Linear (20 , 20 )
@@ -738,7 +742,7 @@ def test_ddp_checkpointing_twice_weight_sharing(self):
738742 same layer twice and having weights shared across layers.
739743 """
740744 process_group = self ._get_process_group ()
741- torch .cuda . set_device (self .rank )
745+ torch .accelerator . set_device_index (self .rank )
742746 for use_bucket_view in (True , False ):
743747 self ._test_ddp_checkpointing (
744748 self .CheckpointTwiceModuleWeightSharing (),
@@ -1162,7 +1166,7 @@ def _test_sequence_num_incremented(self, process_group, ranks):
11621166
11631167 # Verify sequence numbers are appropriately incremented
11641168 for i in range (10 ):
1165- t = torch .ones (1 , device = torch . cuda . current_device () )
1169+ t = torch .ones (1 , device = device_type )
11661170 dist .all_reduce (t , group = process_group )
11671171 if not c10d ._rank_not_in_group (process_group ):
11681172 seq_num = self ._verify_sequence_number_across_pg (
@@ -1193,7 +1197,7 @@ def _test_sequence_num_incremented(self, process_group, ranks):
11931197 self .assertEqual (rank_to_seq_num [0 ] + 1 , rank_to_seq_num [1 ])
11941198
11951199 def _test_sequence_num_incremented_default_group (self , backend_name ):
1196- torch .cuda . set_device (self .rank )
1200+ torch .accelerator . set_device_index (self .rank )
11971201 store = dist .FileStore (self .file_name , self .world_size )
11981202 dist .init_process_group (
11991203 backend_name ,
@@ -1207,7 +1211,7 @@ def _test_sequence_num_incremented_default_group(self, backend_name):
12071211 )
12081212
12091213 def _test_sequence_num_incremented_subgroup (self , backend_name ):
1210- torch .cuda . set_device (self .rank )
1214+ torch .accelerator . set_device_index (self .rank )
12111215 store = dist .FileStore (self .file_name , self .world_size )
12121216 dist .init_process_group (
12131217 backend_name ,
@@ -1262,8 +1266,8 @@ def _test_warn_not_in_group(self, backend):
12621266 in_group_ranks = list (filter (lambda x : x % 2 == 0 , range (self .world_size )))
12631267 group = dist .new_group (in_group_ranks )
12641268
1265- x = torch .zeros (2 , 2 ).cuda (self .rank )
1266- xs = [torch .zeros (2 , 2 ).cuda (self .rank ) for _ in range (len (in_group_ranks ))]
1269+ x = torch .zeros (2 , 2 ).to (self .rank )
1270+ xs = [torch .zeros (2 , 2 ).to (self .rank ) for _ in range (len (in_group_ranks ))]
12671271 if self .rank not in in_group_ranks :
12681272 msg = ".*{}.*does not belong to.*"
12691273 with self .assertWarnsOnceRegex (UserWarning , msg .format ("all_gather" )):
@@ -1392,7 +1396,7 @@ def _test_bool_tensors(self, backend):
13921396 rank = self .rank ,
13931397 store = store ,
13941398 )
1395- device = "cuda" if backend == "nccl" else "cpu"
1399+ device = "cuda" if backend == "nccl" else "xpu" if backend == "xccl" else " cpu"
13961400 # test alltoall_base
13971401 tensor = torch .tensor ([1 , 0 , 0 , 1 ], dtype = torch .bool , device = device )
13981402 zeros = torch .tensor ([0 , 0 , 0 , 0 ], dtype = torch .bool , device = device )
@@ -1574,8 +1578,8 @@ def test_debug_level(self):
15741578
15751579class DummyWork (dist ._Work ):
15761580 def wait (self , timeout = 5.0 ):
1577- if torch .cuda .is_available ():
1578- torch .cuda .current_stream ().synchronize ()
1581+ if torch .accelerator .is_available ():
1582+ torch .accelerator .current_stream ().synchronize ()
15791583 return True
15801584
15811585
@@ -1790,6 +1794,18 @@ def test_backend_config(self):
17901794 ("cpu:gloo,cuda:nccl" , "cpu:gloo,cuda:nccl" ),
17911795 ]
17921796
1797+ if TEST_XPU :
1798+ # Override backend_config_strings_and_expected_values for Intel GPU.
1799+ backend_config_strings_and_expected_values [4 :10 ] = [
1800+ (dist .Backend .DUMMY , "cpu:dummy,cuda:dummy,xpu:dummy" ),
1801+ ("DUMMY" , "cpu:dummy,cuda:dummy,xpu:dummy" ),
1802+ ("dummy" , "cpu:dummy,cuda:dummy,xpu:dummy" ),
1803+ ("cpu:dummy,xpu:dummy" , "cpu:dummy,xpu:dummy" ),
1804+ ("cpu:dummy,xpu:xccl" , "cpu:dummy,xpu:xccl" ),
1805+ ("cpu:gloo,xpu:dummy" , "cpu:gloo,xpu:dummy" ),
1806+ ("cpu:gloo,xpu:xccl" , "cpu:gloo,xpu:xccl" ),
1807+ ]
1808+
17931809 for config_str , expected_value in backend_config_strings_and_expected_values :
17941810 with self .subTest (config_str ):
17951811 # ensures these configs strings are valid and no ValueError is raised
@@ -1800,6 +1816,8 @@ def test_backend_config(self):
18001816 invalid_backend_config_strings = [
18011817 "cpu:gloo,cuda:nccl," , # trailing comma
18021818 "cpu:gloo,cuda:nccl,cpu:dummy" , # duplicate device
1819+ "cpu:gloo,xpu:xccl," , # trailing comma
1820+ "cpu:gloo,xpu:xccl,cpu:dummy" , # duplicate device
18031821 ]
18041822 for config_str in invalid_backend_config_strings :
18051823 with self .subTest (config_str ):
@@ -1814,7 +1832,7 @@ def test_init_process_group_with_multiple_backends(self):
18141832 os .environ ["MASTER_ADDR" ] = "localhost"
18151833 os .environ ["MASTER_PORT" ] = "6789"
18161834 dist .init_process_group (
1817- "cpu:dummy,cuda:dummy" , rank = self .rank , world_size = self .world_size
1835+ "cpu:dummy,cuda:dummy,xpu:dummy " , rank = self .rank , world_size = self .world_size
18181836 )
18191837
18201838 # test all_gather
@@ -2053,7 +2071,7 @@ def _call_collective_with_varying_tensors(self, backend, collective, *args):
20532071 # correctly dispatched
20542072
20552073 # TODO: this will be updated in the future to not be backend specific
2056- device = "cuda" if backend == "nccl" else "cpu"
2074+ device = "cuda" if backend == "nccl" else "xpu" if backend == "xccl" else " cpu"
20572075 # ensure supported devices (cpu, cuda) succeeds during dispatch call
20582076 tensor = torch .zeros (2 , 2 , device = torch .device (device ))
20592077 # multi tensor collectives
@@ -2119,7 +2137,7 @@ def _test_all_to_all_single(self, backend):
21192137 rank = self .rank ,
21202138 store = store ,
21212139 )
2122- device = "cuda" if backend == "nccl" else "cpu"
2140+ device = "cuda" if backend == "nccl" else "xpu" if backend == "xccl" else " cpu"
21232141 # test alltoall_base
21242142 input_tensor = torch .ones (2 , 2 , device = torch .device (device ))
21252143 output_tensor = torch .zeros (2 , 2 , device = torch .device (device ))
@@ -2251,8 +2269,9 @@ def testNodeLocalRank(self):
22512269
22522270
22532271if __name__ == "__main__" :
2254- assert not torch .cuda ._initialized , (
2255- "test_distributed must not have initialized CUDA context on main process"
2256- )
2272+ if device_type != "cpu" :
2273+ assert not torch .get_device_module ()._initialized , (
2274+ "test_distributed must not have initialized {device_type} context on main process"
2275+ )
22572276
22582277 run_tests ()
0 commit comments