|
41 | 41 | TEST_CUDA, |
42 | 42 | TEST_HPU, |
43 | 43 | TEST_PRIVATEUSE1, |
| 44 | + TEST_WITH_EXTERNAL_MULTIPROCESSING, |
44 | 45 | TEST_XPU, |
45 | 46 | ) |
46 | 47 | from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec |
@@ -372,6 +373,13 @@ def backend(self) -> str: |
372 | 373 | backend = dist.get_default_backend_for_device(DEVICE_TYPE) |
373 | 374 | return backend |
374 | 375 |
|
| 376 | + @property |
| 377 | + def init_method(self): |
| 378 | + if TEST_WITH_EXTERNAL_MULTIPROCESSING: |
| 379 | + return None |
| 380 | + return f"file://{self.file_name}" |
| 381 | + |
| 382 | + |
375 | 383 | def build_device_mesh(self) -> DeviceMesh: |
376 | 384 | return init_device_mesh(self.device_type, (self.world_size,)) |
377 | 385 |
|
@@ -413,7 +421,7 @@ def init_pg(self, eager_init, backend: Optional[str] = None) -> None: |
413 | 421 | backend=backend, |
414 | 422 | world_size=self.world_size, |
415 | 423 | rank=self.rank, # pyre-ignore[16] |
416 | | - init_method=f"file://{self.file_name}", # pyre-ignore[16] |
| 424 | + init_method=self.init_method, # pyre-ignore[16] |
417 | 425 | device_id=device_id, |
418 | 426 | ) |
419 | 427 |
|
@@ -443,7 +451,10 @@ def destroy_pg(self, device_id: Optional[int] = None) -> None: |
443 | 451 |
|
444 | 452 | def setUp(self) -> None: |
445 | 453 | super().setUp() |
446 | | - self._spawn_processes() |
| 454 | + if TEST_WITH_EXTERNAL_MULTIPROCESSING: |
| 455 | + self._run_external_multiprocessing() |
| 456 | + else: |
| 457 | + self._spawn_processes() |
447 | 458 |
|
448 | 459 | def _test_op_on_dtensor(self, op_call, *args, **kwargs) -> None: |
449 | 460 | """ |
|
0 commit comments