Skip to content

Commit 3fb129a

Browse files
committed
Add dtensor support
1 parent 44c5a33 commit 3fb129a

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

torch/testing/_internal/distributed/_tensor/common_dtensor.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
TEST_CUDA,
4242
TEST_HPU,
4343
TEST_PRIVATEUSE1,
44+
TEST_WITH_EXTERNAL_MULTIPROCESSING,
4445
TEST_XPU,
4546
)
4647
from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
@@ -372,6 +373,13 @@ def backend(self) -> str:
372373
backend = dist.get_default_backend_for_device(DEVICE_TYPE)
373374
return backend
374375

376+
@property
377+
def init_method(self):
378+
if TEST_WITH_EXTERNAL_MULTIPROCESSING:
379+
return None
380+
return f"file://{self.file_name}"
381+
382+
375383
def build_device_mesh(self) -> DeviceMesh:
376384
return init_device_mesh(self.device_type, (self.world_size,))
377385

@@ -413,7 +421,7 @@ def init_pg(self, eager_init, backend: Optional[str] = None) -> None:
413421
backend=backend,
414422
world_size=self.world_size,
415423
rank=self.rank, # pyre-ignore[16]
416-
init_method=f"file://{self.file_name}", # pyre-ignore[16]
424+
init_method=self.init_method, # pyre-ignore[16]
417425
device_id=device_id,
418426
)
419427

@@ -443,7 +451,10 @@ def destroy_pg(self, device_id: Optional[int] = None) -> None:
443451

444452
def setUp(self) -> None:
445453
super().setUp()
446-
self._spawn_processes()
454+
if TEST_WITH_EXTERNAL_MULTIPROCESSING:
455+
self._run_external_multiprocessing()
456+
else:
457+
self._spawn_processes()
447458

448459
def _test_op_on_dtensor(self, op_call, *args, **kwargs) -> None:
449460
"""

0 commit comments

Comments
 (0)