38
38
)
39
39
40
40
41
+ def _get_number_of_gpu_sm () -> int :
42
+ if not torch .cuda .is_available ():
43
+ raise RuntimeError ("CUDA is not available" )
44
+ device_props = torch .cuda .get_device_properties (0 )
45
+ return device_props .multi_processor_count
46
+
47
+
41
48
def _str_1d_tensor (t : torch .Tensor ) -> str :
42
49
sl = [f"{ x :7.4f} " for x in t .tolist ()]
43
50
if len (sl ) > 5 :
@@ -48,6 +55,7 @@ def _str_1d_tensor(t: torch.Tensor) -> str:
48
55
def _do_test_all_to_all (
49
56
pgi : ProcessGroupInfo ,
50
57
dp_size : int ,
58
+ max_sm_count : int ,
51
59
moe : MoEConfig ,
52
60
internode : bool ,
53
61
) -> None :
@@ -79,6 +87,7 @@ def _do_test_all_to_all(
79
87
* torch .float32 .itemsize
80
88
)
81
89
),
90
+ max_sm_count = max_sm_count ,
82
91
)
83
92
else :
84
93
ata = AllToAll .intranode (
@@ -99,6 +108,7 @@ def _do_test_all_to_all(
99
108
* torch .float32 .itemsize
100
109
)
101
110
),
111
+ max_sm_count = max_sm_count ,
102
112
)
103
113
104
114
# Generate the same test data on all ranks
@@ -283,6 +293,7 @@ def _worker_test_all_to_all(
283
293
dp_size : int ,
284
294
in_dtype : str ,
285
295
out_dtype : str ,
296
+ max_sm_count : int ,
286
297
moe_config : MoEConfig ,
287
298
internode : bool ,
288
299
) -> None :
@@ -295,16 +306,21 @@ def _worker_test_all_to_all(
295
306
in_dtype = getattr (torch , in_dtype ),
296
307
out_dtype = getattr (torch , out_dtype ),
297
308
)
298
- _do_test_all_to_all (pgi , dp_size , moe_config , internode )
309
+ _do_test_all_to_all (pgi , dp_size , max_sm_count , moe_config , internode )
299
310
300
311
nvshmem_finalize ()
301
312
302
313
303
314
@pytest .mark .skipif (torch .cuda .device_count () < 4 , reason = "Requires at least 4 GPUs" )
304
315
@pytest .mark .parametrize ("in_dtype" , ["bfloat16" , "float8_e4m3fn" , "float16" ])
305
316
@pytest .mark .parametrize ("out_dtype" , ["float16" , "bfloat16" ])
317
+ @pytest .mark .parametrize (
318
+ "max_sm_count" , [_get_number_of_gpu_sm (), _get_number_of_gpu_sm () // 2 ]
319
+ )
306
320
@pytest .mark .parametrize ("internode" , [True , False ])
307
- def test_all_to_all_4_gpu (in_dtype : str , out_dtype : str , internode : bool ) -> None :
321
+ def test_all_to_all_4_gpu (
322
+ in_dtype : str , out_dtype : str , max_sm_count : int , internode : bool
323
+ ) -> None :
308
324
world_size = 4
309
325
dp_size = 2
310
326
parallel_launch (
@@ -313,6 +329,7 @@ def test_all_to_all_4_gpu(in_dtype: str, out_dtype: str, internode: bool) -> Non
313
329
dp_size ,
314
330
in_dtype ,
315
331
out_dtype ,
332
+ max_sm_count ,
316
333
small_moe ,
317
334
internode ,
318
335
)
@@ -322,13 +339,15 @@ def _worker_test_all_to_all_multi_node(
322
339
pgi : ProcessGroupInfo ,
323
340
in_dtype : str ,
324
341
out_dtype : str ,
342
+ max_sm_count : int ,
325
343
) -> None :
326
344
dp_size = 4
327
345
_worker_test_all_to_all (
328
346
pgi ,
329
347
dp_size ,
330
348
in_dtype ,
331
349
out_dtype ,
350
+ max_sm_count ,
332
351
medium_moe ,
333
352
True ,
334
353
)
@@ -338,4 +357,7 @@ def _worker_test_all_to_all_multi_node(
338
357
@pytest .mark .parametrize ("in_dtype" , ["bfloat16" , "float8_e4m3fn" , "float16" ])
339
358
@pytest .mark .parametrize ("out_dtype" , ["float16" , "bfloat16" ])
340
359
def test_all_to_all_multi_node (in_dtype : str , out_dtype : str ) -> None :
341
- parallel_launch_from_env (_worker_test_all_to_all_multi_node , in_dtype , out_dtype )
360
+ max_sm_count = _get_number_of_gpu_sm ()
361
+ parallel_launch_from_env (
362
+ _worker_test_all_to_all_multi_node , in_dtype , out_dtype , max_sm_count
363
+ )
0 commit comments