Skip to content

Commit e99fa00

Browse files
wz337pytorchmergebot
authored andcommitted
Back out "[DeviceMesh] Add support for nD slicing (pytorch#119752)" (pytorch#121763)
Summary: Original commit changeset: e52b8809c8d8 Original Phabricator Diff: D54778906 We have to backout this diff. D54778906 seems to be causing test failures for APF blocking trunk health and hence release. Just starting to look at the issue. T182209248 Test Plan: Sandcastle Reviewed By: satgera Differential Revision: D54825114 Pull Request resolved: pytorch#121763 Approved by: https://github.com/osalpekar
1 parent be33d31 commit e99fa00

File tree

2 files changed

+45
-160
lines changed

2 files changed

+45
-160
lines changed

test/distributed/test_device_mesh.py

Lines changed: 21 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
)
2929
from torch.testing._internal.distributed._tensor.common_dtensor import (
3030
DTensorTestBase,
31-
skip_if_lt_x_gpu,
3231
skip_unless_torch_gpu,
3332
with_comms,
3433
)
@@ -79,10 +78,7 @@ def test_assert_invalid_mesh_tensor(self):
7978

8079
@with_comms
8180
@run_with_both_funcol_impls
82-
@skip_unless_torch_gpu
8381
def test_get_group(self):
84-
# TODO: `test_get_group` still periodically timeout on cpu
85-
# remove `@skip_unless_torch_gpu` after the problem is fixed.
8682
mesh_shape = (2, self.world_size // 2)
8783
mesh_2d = init_device_mesh(
8884
self.device_type, mesh_shape, mesh_dim_names=("dp", "tp")
@@ -103,10 +99,7 @@ def test_get_group(self):
10399

104100
@with_comms
105101
@run_with_both_funcol_impls
106-
@skip_unless_torch_gpu
107102
def test_get_local_rank_raises_exception(self):
108-
# TODO: `test_get_local_rank_raises_exception` still periodically timeout on cpu
109-
# remove `@skip_unless_torch_gpu` after the problem is fixed.
110103
mesh_shape = (2, self.world_size // 2)
111104
mesh_2d = init_device_mesh(
112105
self.device_type, mesh_shape, mesh_dim_names=("dp", "tp")
@@ -120,10 +113,7 @@ def test_get_local_rank_raises_exception(self):
120113

121114
@with_comms
122115
@run_with_both_funcol_impls
123-
@skip_unless_torch_gpu
124116
def test_get_local_rank(self):
125-
# TODO: `test_get_local_rank_raises_exception` still periodically timeout on cpu
126-
# remove `@skip_unless_torch_gpu` after the problem is fixed.
127117
mesh_shape = (2, self.world_size // 2)
128118
mesh_2d = init_device_mesh(
129119
self.device_type, mesh_shape, mesh_dim_names=("dp", "tp")
@@ -276,71 +266,47 @@ def world_size(self):
276266

277267
@with_comms
278268
@run_with_both_funcol_impls
279-
def test_raises_invalid_mesh_dim_names(self):
280-
error_msg = "Invalid mesh_dim_name"
281-
# Case 1: the DeviceMesh does not have a mesh_dim_names attribute
282-
with self.assertRaisesRegex(
283-
RuntimeError, "Cannot slice a DeviceMesh without mesh_dim_names."
284-
):
269+
def test_raises_no_mesh_dim_found(self):
270+
with self.assertRaisesRegex(KeyError, "No `mesh_dim_names` found."):
285271
mesh = init_device_mesh(self.device_type, (2, 4))
286272
child_mesh = mesh["DP"]
287273

288-
child_mesh_dim_names = "PP"
289-
with self.assertRaisesRegex(ValueError, error_msg):
290-
mesh_dim_names = ("DP", "TP")
291-
mesh = init_device_mesh(
292-
self.device_type, (2, 4), mesh_dim_names=mesh_dim_names
293-
)
294-
child_mesh = mesh[child_mesh_dim_names]
295-
296-
# Case 2
297-
child_mesh_dim_names = ["PP", "CP"]
298-
with self.assertRaisesRegex(ValueError, error_msg):
299-
mesh_dim_names = ("DP", "TP")
300-
mesh = init_device_mesh(
301-
self.device_type, (2, 4), mesh_dim_names=mesh_dim_names
302-
)
303-
child_mesh = mesh[child_mesh_dim_names]
304-
305-
# Case 3: a given child_mesh_dim_name is not a contiguous subset of the parent mesh's mesh_dim_names.
306-
child_mesh_dim_names = ("TP", "DP")
307-
with self.assertRaisesRegex(ValueError, error_msg):
274+
@with_comms
275+
@run_with_both_funcol_impls
276+
def test_raises_invalid_mesh_dim_name(self):
277+
child_mesh_dim_name = "PP"
278+
with self.assertRaisesRegex(
279+
KeyError, f"Mesh dimension '{child_mesh_dim_name}' does not exist."
280+
):
308281
mesh_dim_names = ("DP", "TP")
309282
mesh = init_device_mesh(
310283
self.device_type, (2, 4), mesh_dim_names=mesh_dim_names
311284
)
312-
child_mesh = mesh[child_mesh_dim_names]
313-
314-
# Case 3
315-
child_mesh_dim_names = ("PP", "TP")
316-
with self.assertRaisesRegex(ValueError, error_msg):
317-
mesh_dim_names = ("PP", "DP", "TP")
318-
mesh = init_device_mesh(
319-
self.device_type, (2, 2, 2), mesh_dim_names=mesh_dim_names
320-
)
321-
child_mesh = mesh[child_mesh_dim_names]
285+
child_mesh = mesh[child_mesh_dim_name]
322286

323287
@with_comms
324288
@run_with_both_funcol_impls
325-
@skip_if_lt_x_gpu(8)
326-
def test_get_item_2d(self):
327-
# TODO: `test_get_item_2d` still periodically timeout on cpu
328-
# remove `@skip_if_lt_x_gpu` after the problem is fixed.
289+
def test_get_item(self):
329290
mesh_shape = (2, 4)
330291
mesh_dim_names = ("DP", "TP")
331292
mesh_2d = init_device_mesh(
332293
self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
333294
)
334295

296+
pg_ranks_by_dim_name = {}
297+
for mesh_dim_name in mesh_dim_names:
298+
mesh_dim = mesh_dim_names.index(mesh_dim_name)
299+
pg_ranks_by_dim_name[mesh_dim_name] = mesh_2d.mesh.swapdims(
300+
-1, mesh_dim
301+
).reshape(-1, mesh_2d.mesh.size(mesh_dim))
302+
335303
tp_mesh = mesh_2d["TP"]
336-
tp_group = [[0, 1, 2, 3], [4, 5, 6, 7]]
337304
tp_group_idx = self.rank // 4
338-
self.assertEqual(tp_mesh.mesh.tolist(), tp_group[tp_group_idx])
305+
self.assertEqual(tp_mesh.mesh, pg_ranks_by_dim_name["TP"][tp_group_idx])
339306

340307
dp_mesh = mesh_2d["DP"]
341-
dp_group = [[0, 4], [1, 5], [2, 6], [3, 7]]
342308
dp_group_idx = self.rank % 4
343-
self.assertEqual(dp_mesh.mesh.tolist(), dp_group[dp_group_idx])
309+
self.assertEqual(mesh_2d["DP"].mesh, pg_ranks_by_dim_name["DP"][dp_group_idx])
344310

345311
@with_comms
346312
@run_with_both_funcol_impls
@@ -351,50 +317,14 @@ def test_get_item_1d(self):
351317
dp_mesh = mesh["dp"]
352318
self.assertEqual(dp_mesh, mesh)
353319

354-
with self.assertRaisesRegex(ValueError, "Invalid mesh_dim_name"):
320+
with self.assertRaisesRegex(RuntimeError, "Invalid mesh_dim_name"):
355321
dp_mesh = mesh["dim0"]
356322

357-
@with_comms
358-
@skip_if_lt_x_gpu(8)
359-
def test_get_item_3d(self):
360-
# TODO: `test_get_item_3d` still periodically timeout on cpu
361-
# remove `@skip_if_lt_x_gpu` after the problem is fixed.
362-
mesh_shape = (2, 2, 2)
363-
mesh_dim_names = ("Replicate", "Shard", "TP")
364-
mesh_3d = init_device_mesh(
365-
self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
366-
)
367-
368-
tp_group = [[0, 1], [2, 3], [4, 5], [6, 7]]
369-
tp_group_idx = int(self.rank / 2)
370-
self.assertEqual(mesh_3d["TP"].mesh.tolist(), tp_group[tp_group_idx])
371-
372-
shard_group = [[0, 2], [1, 3], [4, 6], [5, 7]]
373-
shard_group_idx = self.rank % 2 + self.rank // 4 * 2
374-
self.assertEqual(mesh_3d["Shard"].mesh.tolist(), shard_group[shard_group_idx])
375-
376-
replicate_group = [[0, 4], [1, 5], [2, 6], [3, 7]]
377-
replicate_group_idx = self.rank % 4
378-
self.assertEqual(
379-
mesh_3d["Replicate"].mesh.tolist(), replicate_group[replicate_group_idx]
380-
)
381-
382-
# We support both UX for nD slicing.
383-
# mesh_3d[["Replicate", "Shard"]] or mesh_3d["Replicate", "Shard"]
384-
hsdp_mesh_1 = mesh_3d[["Replicate", "Shard"]]
385-
hsdp_mesh_2 = mesh_3d["Replicate", "Shard"]
386-
hsdp_group = [[[0, 2], [4, 6]], [[1, 3], [5, 7]]]
387-
hsdp_group_idx = self.rank % 2
388-
self.assertEqual(hsdp_mesh_1.mesh.tolist(), hsdp_group[hsdp_group_idx])
389-
self.assertEqual(hsdp_mesh_2.mesh.tolist(), hsdp_group[hsdp_group_idx])
390-
self.assertEqual(hsdp_mesh_1, hsdp_mesh_2)
391-
392323

393324
@instantiate_parametrized_tests
394325
class TestMeshEnv(DTensorTestBase):
395326
@with_comms
396327
@run_with_both_funcol_impls
397-
@skip_unless_torch_gpu
398328
def test_get_parent_mesh(self):
399329
mesh_shape = (2, self.world_size // 2)
400330
mesh_dim_names = ("DP", "TP")
@@ -415,7 +345,6 @@ def test_get_parent_mesh(self):
415345

416346
@with_comms
417347
@run_with_both_funcol_impls
418-
@skip_unless_torch_gpu
419348
def test_get_parent_mesh_dim_exist(self):
420349
mesh_shape = (2, self.world_size // 2)
421350
mesh_dim_names = ("DP", "TP")
@@ -428,7 +357,6 @@ def test_get_parent_mesh_dim_exist(self):
428357

429358
@with_comms
430359
@run_with_both_funcol_impls
431-
@skip_unless_torch_gpu
432360
def test_get_parent_mesh_dim_not_exist(self):
433361
mesh_shape = (self.world_size,)
434362
mesh = init_device_mesh(self.device_type, mesh_shape)
@@ -437,7 +365,6 @@ def test_get_parent_mesh_dim_not_exist(self):
437365

438366
@with_comms
439367
@run_with_both_funcol_impls
440-
@skip_unless_torch_gpu
441368
def test_get_mesh_dim_by_name(self):
442369
mesh_shape = (2, self.world_size // 2)
443370
mesh_dim_names = ("DP", "TP")

torch/distributed/device_mesh.py

Lines changed: 24 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -67,43 +67,25 @@ def get_current_mesh(self) -> "DeviceMesh":
6767
return self.mesh_stack[-1]
6868

6969
def create_child_mesh(
70-
self,
71-
device_mesh: "DeviceMesh",
72-
mesh_dim_names: Tuple[str],
70+
self, device_mesh: "DeviceMesh", mesh_dim: int, mesh_dim_name: str
7371
) -> "DeviceMesh":
7472
# swap the current dim to the last dim then reshape to flatten out other
7573
# dims, so we can just extract the list of ranks which contains cur_rank.
76-
mesh_dims = [
77-
not_none(device_mesh.mesh_dim_names).index(mesh_dim_name)
78-
for mesh_dim_name in mesh_dim_names
79-
]
8074
cur_rank = device_mesh.get_rank()
81-
mesh = device_mesh.mesh
82-
all_mesh_dims = list(range(mesh.ndim))
83-
for mesh_dim in mesh_dims:
84-
# remove not pop b/c we want the value of the ind removed not it's position in the list
85-
# because this list dynamically changes.
86-
all_mesh_dims.remove(mesh_dim)
87-
88-
mesh_sizes = [device_mesh.mesh.size(mesh_dim) for mesh_dim in mesh_dims]
89-
90-
pg_ranks_by_dim = device_mesh.mesh.permute(
91-
*all_mesh_dims, *mesh_dims
92-
).reshape(-1, *mesh_sizes)
93-
94-
for mesh_nd in pg_ranks_by_dim:
95-
if cur_rank in mesh_nd:
96-
sub_mesh = DeviceMesh(
97-
device_mesh.device_type,
98-
mesh_nd,
99-
mesh_dim_names=mesh_dim_names,
100-
)
101-
res_sub_mesh = sub_mesh
75+
pg_ranks_by_dim = device_mesh.mesh.swapdims(-1, mesh_dim).reshape(
76+
-1, device_mesh.mesh.size(mesh_dim)
77+
)
10278

103-
res_sub_mesh._dim_group_infos = [ # type: ignore[possibly-undefined]
104-
device_mesh._dim_group_infos[mesh_dim] for mesh_dim in mesh_dims
105-
]
79+
for mesh_1d in pg_ranks_by_dim:
80+
sub_mesh = DeviceMesh(
81+
device_mesh.device_type,
82+
mesh_1d,
83+
mesh_dim_names=(mesh_dim_name,),
84+
)
85+
if cur_rank in mesh_1d:
86+
res_sub_mesh = sub_mesh
10687

88+
res_sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[mesh_dim]] # type: ignore[possibly-undefined]
10789
# Assign the current DeviceMesh as the parent of the child DeviceMesh.
10890
self.child_to_parent_mapping[res_sub_mesh] = device_mesh
10991
return res_sub_mesh
@@ -228,7 +210,6 @@ def __init__(
228210
# private field to pre-generate DeviceMesh's hash
229211
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
230212
self._hash = hash((self._flatten_mesh_list, self.mesh.shape, id(self)))
231-
self._parent_mesh = _mesh_resources.get_parent_mesh(self)
232213

233214
# Skip process group initialization if xla device.
234215
# TODO(yeounoh) implement DeviceMesh backend and register XLA backend.
@@ -237,8 +218,7 @@ def __init__(
237218
# already. The world pg is used for device mesh identity (rank) on each
238219
# process (we need to know if the current global rank is in the mesh or not).
239220
self._get_or_create_default_group()
240-
if not self._parent_mesh:
241-
self._init_process_groups()
221+
self._init_process_groups()
242222

243223
def _get_or_create_default_group(self):
244224
default_initialized = is_initialized()
@@ -360,7 +340,7 @@ def __eq__(self, other: object) -> bool:
360340
and self._flatten_mesh_list == other._flatten_mesh_list
361341
)
362342

363-
def __getitem__(self, mesh_dim_names: Union[str, Tuple[str]]) -> "DeviceMesh":
343+
def __getitem__(self, mesh_dim_name: str) -> "DeviceMesh":
364344
"""
365345
Slice the current DeviceMesh based on the mesh_dim_name given to create a child
366346
DeviceMesh.
@@ -388,39 +368,17 @@ def __getitem__(self, mesh_dim_names: Union[str, Tuple[str]]) -> "DeviceMesh":
388368
>>> # of cross-host(dim 0), and within-host (dim 1).
389369
>>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]])
390370
"""
391-
if not self.mesh_dim_names:
392-
raise RuntimeError("Cannot slice a DeviceMesh without mesh_dim_names.")
393-
394-
mesh_dim_names = (
395-
(mesh_dim_names,) if isinstance(mesh_dim_names, str) else mesh_dim_names
396-
)
397-
398-
error_msg = (
399-
f"Invalid mesh_dim_name {mesh_dim_names} specified. "
400-
f"Valid mesh_dim_names should be a contiguous subsequence of {self.mesh_dim_names}."
401-
)
371+
if self.mesh.ndim == 1:
372+
if self.mesh_dim_names and mesh_dim_name == self.mesh_dim_names[0]:
373+
return self
374+
else:
375+
raise RuntimeError(
376+
f"Invalid mesh_dim_name {mesh_dim_name} specified."
377+
)
402378

403-
# When the dimension slicing out is equal to the mesh dimensions of the current DeviceMesh,
404-
# we simply return self if the given slicing is valid.
405-
if mesh_dim_names == self.mesh_dim_names:
406-
return self
407-
# Check if the user-provided slicing is a valid contiguous subsequence of the mesh_dim_names
408-
# of the current DeviceMesh.
409-
elif len(mesh_dim_names) < len(self.mesh_dim_names):
410-
outermost_dim_name = mesh_dim_names[0]
411-
if outermost_dim_name not in self.mesh_dim_names:
412-
raise ValueError(error_msg)
413-
outermost_dim_idx = self.mesh_dim_names.index(outermost_dim_name)
414-
for i, j in zip(
415-
mesh_dim_names,
416-
self.mesh_dim_names[outermost_dim_idx : len(mesh_dim_names)],
417-
):
418-
if i != j:
419-
raise ValueError(error_msg)
420-
else:
421-
raise ValueError(error_msg)
379+
mesh_dim = _mesh_resources.get_mesh_dim_by_name(self, mesh_dim_name)
380+
submesh = _mesh_resources.create_child_mesh(self, mesh_dim, mesh_dim_name)
422381

423-
submesh = _mesh_resources.create_child_mesh(self, mesh_dim_names)
424382
return submesh
425383

426384
def get_group(

0 commit comments

Comments
 (0)