Skip to content

Commit 66d603b

Browse files
committed
update
1 parent 8028475 commit 66d603b

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

torch/testing/_internal/common_fsdp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1215,7 +1215,7 @@ def _run(cls, rank, test_name, file_name, pipe, **kwargs):
12151215
device_ids = None
12161216
device_id = self.rank % DEVICE_COUNT
12171217
if TEST_CUDA or TEST_XPU:
1218-
torch.accelerator.set_device_idx(device_id)
1218+
torch.accelerator.set_device_index(device_id)
12191219
device_ids = [device_id]
12201220

12211221
# Execute barrier prior to running test to ensure that every process

torch/testing/_internal/distributed/_tensor/common_dtensor.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from torch.testing._internal.common_utils import (
3636
TEST_HPU,
3737
TEST_CUDA,
38+
TEST_XPU,
3839
)
3940
from torch.testing._internal.common_distributed import (
4041
MultiProcessTestCase,
@@ -55,6 +56,10 @@
5556
DEVICE_TYPE = "hpu"
5657
PG_BACKEND = "hccl"
5758
DEVICE_COUNT = _get_device_module("hpu").device_count()
59+
elif TEST_HPU:
60+
DEVICE_TYPE = "xpu"
61+
PG_BACKEND = "xccl"
62+
DEVICE_COUNT = _get_device_module("xpu").device_count()
5863
else:
5964
DEVICE_TYPE = "cpu"
6065
PG_BACKEND = "gloo"
@@ -328,6 +333,8 @@ def backend(self) -> str:
328333
backend = "nccl"
329334
elif TEST_HPU:
330335
backend = "hccl"
336+
elif TEST_XPU:
337+
backend = "xccl"
331338
else:
332339
backend = "gloo"
333340
return backend
@@ -345,7 +352,7 @@ def init_pg(self, eager_init) -> None:
345352
device_id = None
346353
if "nccl" or "xccl" in self.backend:
347354
# set device for nccl pg for collectives
348-
torch.accelerator.set_device_idx(self.rank)
355+
torch.accelerator.set_device_index(self.rank)
349356
# we only need to set device_id for nccl backend with eager init
350357
device_id = torch.device(f"{self.device_type}:{self.rank}") if eager_init else None
351358
# For nccl backend, bind the device to the process if device_id is not None

0 commit comments

Comments
 (0)