Skip to content

Commit b89bedf

Browse files
committed
add more changes for XPU
1 parent 88ed5d2 commit b89bedf

File tree

7 files changed

+30
-9
lines changed

7 files changed

+30
-9
lines changed

test/distributed/fsdp/test_fsdp_comm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ def forward(self, x: torch.Tensor):
382382
model.module.mlps._wait_unshard_streams_on_current_stream()
383383

384384

385-
devices = ("cuda", "hpu")
385+
devices = ("cuda", "hpu", "xpu")
386386
instantiate_device_type_tests(TestCommunication, globals(), only_for=devices)
387387
instantiate_device_type_tests(TestExplicitUnshard, globals(), only_for=devices)
388388
if __name__ == "__main__":

test/distributed/tensor/test_dtensor_compile.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
skipIfTorchDynamo,
4444
TEST_CUDA,
4545
TEST_HPU,
46+
TEST_XPU,
4647
)
4748
from torch.testing._internal.distributed._tensor.common_dtensor import (
4849
DTensorTestBase,
@@ -108,7 +109,14 @@ def tearDown(self):
108109

109110
@property
110111
def device_type(self) -> str:
111-
return "cuda" if TEST_CUDA else "hpu" if TEST_HPU else "cpu"
112+
if TEST_CUDA:
113+
return "cuda"
114+
elif TEST_HPU:
115+
return "hpu"
116+
elif TEST_XPU:
117+
return "xpu"
118+
else:
119+
return "xpu"
112120

113121
@property
114122
def world_size(self) -> int:

test/distributed/tensor/test_random_ops.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,20 @@
1919
)
2020
from torch.distributed.tensor.debug import CommDebugMode
2121
from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module
22-
from torch.testing._internal.common_utils import run_tests, TEST_HPU
22+
from torch.testing._internal.common_utils import run_tests, TEST_HPU, TEST_XPU
2323
from torch.testing._internal.distributed._tensor.common_dtensor import (
2424
DTensorTestBase,
2525
skip_if_lt_x_gpu,
2626
skip_unless_torch_gpu,
2727
with_comms,
2828
)
2929

30-
31-
TYPE_DEVICE = "hpu" if TEST_HPU else "cuda"
30+
if TEST_XPU:
31+
TYPE_DEVICE = "xpu"
32+
elif TEST_HPU:
33+
TYPE_DEVICE = "hpu"
34+
else:
35+
TYPE_DEVICE = "cuda"
3236

3337

3438
class DistTensorRandomInitTest(DTensorTestBase):

test/distributed/tensor/test_redistribute.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from torch.distributed.device_mesh import init_device_mesh
1010
from torch.distributed.tensor._collective_utils import shard_dim_alltoall
1111
from torch.distributed.tensor.debug import CommDebugMode
12-
from torch.testing._internal.common_utils import run_tests, TEST_CUDA, TEST_HPU
12+
from torch.testing._internal.common_utils import run_tests, TEST_CUDA, TEST_HPU, TEST_XPU
1313
from torch.testing._internal.distributed._tensor.common_dtensor import (
1414
DTensorTestBase,
1515
with_comms,
@@ -366,7 +366,7 @@ def test_redistribute_shard_dim_change(self):
366366
local_out_dt = out_dt.to_local()
367367
local_expected_dt = expected_dt.to_local()
368368
self.assertEqual(out_dt.to_local(), expected_dt.to_local())
369-
if TEST_HPU or TEST_CUDA:
369+
if TEST_HPU or TEST_CUDA or TEST_XPU:
370370
self.assertEqual(
371371
comm_mode.get_comm_counts()[
372372
torch.ops._dtensor.shard_dim_alltoall

test/distributed/test_backends.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def test_create_pg(self, device) -> None:
4444
dist.destroy_process_group()
4545

4646

47-
devices = ["cpu", "cuda", "hpu"]
47+
devices = ["cpu", "cuda", "hpu", "xpu"]
4848
instantiate_device_type_tests(TestMiscCollectiveUtils, globals(), only_for=devices)
4949

5050
if __name__ == "__main__":

test/distributed/test_functional_api.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
skipIfHpu,
3535
TEST_CUDA,
3636
TEST_HPU,
37+
TEST_XPU,
3738
TestCase,
3839
)
3940

@@ -66,6 +67,9 @@
6667
DEVICE = "hpu"
6768
elif TEST_CUDA:
6869
devices.append("cuda")
70+
elif TEST_XPU:
71+
devices.append("xpu")
72+
DEVICE = "xpu"
6973

7074

7175
def new_subgroups(group_size: int, pg_tag=None):
@@ -474,6 +478,8 @@ def allred_mesh_dim(input):
474478
# And then set the BACKEND variable appropriately.
475479
if TEST_HPU:
476480
BACKEND = dist.Backend.HCCL
481+
elif TEST_XPU:
482+
BACKEND = dist.Backend.XCCL
477483

478484

479485
# allows you to check for multiple accelerator irrespective of device type
@@ -486,6 +492,9 @@ def exit_if_lt_x_accelerators(x):
486492
elif TEST_HPU:
487493
if torch.hpu.device_count() < x:
488494
sys.exit(TEST_SKIPS[f"multi-hpu-{x}"].exit_code)
495+
elif TEST_XPU:
496+
if torch.xpu.device_count() < x:
497+
sys.exit(TEST_SKIPS[f"multi-hpu-{x}"].exit_code)
489498

490499

491500
def with_comms(func=None):

torch/testing/_internal/common_fsdp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1440,7 +1440,7 @@ def _test_fsdp_parity(
14401440
self.assertRaisesRegex(
14411441
RuntimeError,
14421442
"An FSDP-managed module with parameter CPU offloading enabled "
1443-
"has parameters on cuda",
1443+
"has parameters on xpu", #zl_debug: refine for xpu
14441444
)
14451445
if expects_device_error
14461446
else nullcontext()

0 commit comments

Comments
 (0)