28
28
)
29
29
from torch .testing ._internal .distributed ._tensor .common_dtensor import (
30
30
DTensorTestBase ,
31
- skip_if_lt_x_gpu ,
32
31
skip_unless_torch_gpu ,
33
32
with_comms ,
34
33
)
@@ -79,10 +78,7 @@ def test_assert_invalid_mesh_tensor(self):
79
78
80
79
@with_comms
81
80
@run_with_both_funcol_impls
82
- @skip_unless_torch_gpu
83
81
def test_get_group (self ):
84
- # TODO: `test_get_group` still periodically timeout on cpu
85
- # remove `@skip_unless_torch_gpu` after the problem is fixed.
86
82
mesh_shape = (2 , self .world_size // 2 )
87
83
mesh_2d = init_device_mesh (
88
84
self .device_type , mesh_shape , mesh_dim_names = ("dp" , "tp" )
@@ -103,10 +99,7 @@ def test_get_group(self):
103
99
104
100
@with_comms
105
101
@run_with_both_funcol_impls
106
- @skip_unless_torch_gpu
107
102
def test_get_local_rank_raises_exception (self ):
108
- # TODO: `test_get_local_rank_raises_exception` still periodically timeout on cpu
109
- # remove `@skip_unless_torch_gpu` after the problem is fixed.
110
103
mesh_shape = (2 , self .world_size // 2 )
111
104
mesh_2d = init_device_mesh (
112
105
self .device_type , mesh_shape , mesh_dim_names = ("dp" , "tp" )
@@ -120,10 +113,7 @@ def test_get_local_rank_raises_exception(self):
120
113
121
114
@with_comms
122
115
@run_with_both_funcol_impls
123
- @skip_unless_torch_gpu
124
116
def test_get_local_rank (self ):
125
- # TODO: `test_get_local_rank_raises_exception` still periodically timeout on cpu
126
- # remove `@skip_unless_torch_gpu` after the problem is fixed.
127
117
mesh_shape = (2 , self .world_size // 2 )
128
118
mesh_2d = init_device_mesh (
129
119
self .device_type , mesh_shape , mesh_dim_names = ("dp" , "tp" )
@@ -276,71 +266,47 @@ def world_size(self):
276
266
277
267
@with_comms
278
268
@run_with_both_funcol_impls
279
- def test_raises_invalid_mesh_dim_names (self ):
280
- error_msg = "Invalid mesh_dim_name"
281
- # Case 1: the DeviceMesh does not have a mesh_dim_names attribute
282
- with self .assertRaisesRegex (
283
- RuntimeError , "Cannot slice a DeviceMesh without mesh_dim_names."
284
- ):
269
+ def test_raises_no_mesh_dim_found (self ):
270
+ with self .assertRaisesRegex (KeyError , "No `mesh_dim_names` found." ):
285
271
mesh = init_device_mesh (self .device_type , (2 , 4 ))
286
272
child_mesh = mesh ["DP" ]
287
273
288
- child_mesh_dim_names = "PP"
289
- with self .assertRaisesRegex (ValueError , error_msg ):
290
- mesh_dim_names = ("DP" , "TP" )
291
- mesh = init_device_mesh (
292
- self .device_type , (2 , 4 ), mesh_dim_names = mesh_dim_names
293
- )
294
- child_mesh = mesh [child_mesh_dim_names ]
295
-
296
- # Case 2
297
- child_mesh_dim_names = ["PP" , "CP" ]
298
- with self .assertRaisesRegex (ValueError , error_msg ):
299
- mesh_dim_names = ("DP" , "TP" )
300
- mesh = init_device_mesh (
301
- self .device_type , (2 , 4 ), mesh_dim_names = mesh_dim_names
302
- )
303
- child_mesh = mesh [child_mesh_dim_names ]
304
-
305
- # Case 3: a given child_mesh_dim_name is not a contiguous subset of the parent mesh's mesh_dim_names.
306
- child_mesh_dim_names = ("TP" , "DP" )
307
- with self .assertRaisesRegex (ValueError , error_msg ):
274
+ @with_comms
275
+ @run_with_both_funcol_impls
276
+ def test_raises_invalid_mesh_dim_name (self ):
277
+ child_mesh_dim_name = "PP"
278
+ with self .assertRaisesRegex (
279
+ KeyError , f"Mesh dimension '{ child_mesh_dim_name } ' does not exist."
280
+ ):
308
281
mesh_dim_names = ("DP" , "TP" )
309
282
mesh = init_device_mesh (
310
283
self .device_type , (2 , 4 ), mesh_dim_names = mesh_dim_names
311
284
)
312
- child_mesh = mesh [child_mesh_dim_names ]
313
-
314
- # Case 3
315
- child_mesh_dim_names = ("PP" , "TP" )
316
- with self .assertRaisesRegex (ValueError , error_msg ):
317
- mesh_dim_names = ("PP" , "DP" , "TP" )
318
- mesh = init_device_mesh (
319
- self .device_type , (2 , 2 , 2 ), mesh_dim_names = mesh_dim_names
320
- )
321
- child_mesh = mesh [child_mesh_dim_names ]
285
+ child_mesh = mesh [child_mesh_dim_name ]
322
286
323
287
@with_comms
324
288
@run_with_both_funcol_impls
325
- @skip_if_lt_x_gpu (8 )
326
- def test_get_item_2d (self ):
327
- # TODO: `test_get_item_2d` still periodically timeout on cpu
328
- # remove `@skip_if_lt_x_gpu` after the problem is fixed.
289
+ def test_get_item (self ):
329
290
mesh_shape = (2 , 4 )
330
291
mesh_dim_names = ("DP" , "TP" )
331
292
mesh_2d = init_device_mesh (
332
293
self .device_type , mesh_shape , mesh_dim_names = mesh_dim_names
333
294
)
334
295
296
+ pg_ranks_by_dim_name = {}
297
+ for mesh_dim_name in mesh_dim_names :
298
+ mesh_dim = mesh_dim_names .index (mesh_dim_name )
299
+ pg_ranks_by_dim_name [mesh_dim_name ] = mesh_2d .mesh .swapdims (
300
+ - 1 , mesh_dim
301
+ ).reshape (- 1 , mesh_2d .mesh .size (mesh_dim ))
302
+
335
303
tp_mesh = mesh_2d ["TP" ]
336
- tp_group = [[0 , 1 , 2 , 3 ], [4 , 5 , 6 , 7 ]]
337
304
tp_group_idx = self .rank // 4
338
- self .assertEqual (tp_mesh .mesh . tolist (), tp_group [tp_group_idx ])
305
+ self .assertEqual (tp_mesh .mesh , pg_ranks_by_dim_name [ "TP" ] [tp_group_idx ])
339
306
340
307
dp_mesh = mesh_2d ["DP" ]
341
- dp_group = [[0 , 4 ], [1 , 5 ], [2 , 6 ], [3 , 7 ]]
342
308
dp_group_idx = self .rank % 4
343
- self .assertEqual (dp_mesh .mesh . tolist (), dp_group [dp_group_idx ])
309
+ self .assertEqual (mesh_2d [ "DP" ] .mesh , pg_ranks_by_dim_name [ "DP" ] [dp_group_idx ])
344
310
345
311
@with_comms
346
312
@run_with_both_funcol_impls
@@ -351,50 +317,14 @@ def test_get_item_1d(self):
351
317
dp_mesh = mesh ["dp" ]
352
318
self .assertEqual (dp_mesh , mesh )
353
319
354
- with self .assertRaisesRegex (ValueError , "Invalid mesh_dim_name" ):
320
+ with self .assertRaisesRegex (RuntimeError , "Invalid mesh_dim_name" ):
355
321
dp_mesh = mesh ["dim0" ]
356
322
357
- @with_comms
358
- @skip_if_lt_x_gpu (8 )
359
- def test_get_item_3d (self ):
360
- # TODO: `test_get_item_3d` still periodically timeout on cpu
361
- # remove `@skip_if_lt_x_gpu` after the problem is fixed.
362
- mesh_shape = (2 , 2 , 2 )
363
- mesh_dim_names = ("Replicate" , "Shard" , "TP" )
364
- mesh_3d = init_device_mesh (
365
- self .device_type , mesh_shape , mesh_dim_names = mesh_dim_names
366
- )
367
-
368
- tp_group = [[0 , 1 ], [2 , 3 ], [4 , 5 ], [6 , 7 ]]
369
- tp_group_idx = int (self .rank / 2 )
370
- self .assertEqual (mesh_3d ["TP" ].mesh .tolist (), tp_group [tp_group_idx ])
371
-
372
- shard_group = [[0 , 2 ], [1 , 3 ], [4 , 6 ], [5 , 7 ]]
373
- shard_group_idx = self .rank % 2 + self .rank // 4 * 2
374
- self .assertEqual (mesh_3d ["Shard" ].mesh .tolist (), shard_group [shard_group_idx ])
375
-
376
- replicate_group = [[0 , 4 ], [1 , 5 ], [2 , 6 ], [3 , 7 ]]
377
- replicate_group_idx = self .rank % 4
378
- self .assertEqual (
379
- mesh_3d ["Replicate" ].mesh .tolist (), replicate_group [replicate_group_idx ]
380
- )
381
-
382
- # We support both UX for nD slicing.
383
- # mesh_3d[["Replicate", "Shard"]] or mesh_3d["Replicate", "Shard"]
384
- hsdp_mesh_1 = mesh_3d [["Replicate" , "Shard" ]]
385
- hsdp_mesh_2 = mesh_3d ["Replicate" , "Shard" ]
386
- hsdp_group = [[[0 , 2 ], [4 , 6 ]], [[1 , 3 ], [5 , 7 ]]]
387
- hsdp_group_idx = self .rank % 2
388
- self .assertEqual (hsdp_mesh_1 .mesh .tolist (), hsdp_group [hsdp_group_idx ])
389
- self .assertEqual (hsdp_mesh_2 .mesh .tolist (), hsdp_group [hsdp_group_idx ])
390
- self .assertEqual (hsdp_mesh_1 , hsdp_mesh_2 )
391
-
392
323
393
324
@instantiate_parametrized_tests
394
325
class TestMeshEnv (DTensorTestBase ):
395
326
@with_comms
396
327
@run_with_both_funcol_impls
397
- @skip_unless_torch_gpu
398
328
def test_get_parent_mesh (self ):
399
329
mesh_shape = (2 , self .world_size // 2 )
400
330
mesh_dim_names = ("DP" , "TP" )
@@ -415,7 +345,6 @@ def test_get_parent_mesh(self):
415
345
416
346
@with_comms
417
347
@run_with_both_funcol_impls
418
- @skip_unless_torch_gpu
419
348
def test_get_parent_mesh_dim_exist (self ):
420
349
mesh_shape = (2 , self .world_size // 2 )
421
350
mesh_dim_names = ("DP" , "TP" )
@@ -428,7 +357,6 @@ def test_get_parent_mesh_dim_exist(self):
428
357
429
358
@with_comms
430
359
@run_with_both_funcol_impls
431
- @skip_unless_torch_gpu
432
360
def test_get_parent_mesh_dim_not_exist (self ):
433
361
mesh_shape = (self .world_size ,)
434
362
mesh = init_device_mesh (self .device_type , mesh_shape )
@@ -437,7 +365,6 @@ def test_get_parent_mesh_dim_not_exist(self):
437
365
438
366
@with_comms
439
367
@run_with_both_funcol_impls
440
- @skip_unless_torch_gpu
441
368
def test_get_mesh_dim_by_name (self ):
442
369
mesh_shape = (2 , self .world_size // 2 )
443
370
mesh_dim_names = ("DP" , "TP" )
0 commit comments