38
38
)
39
39
40
40
41
+ def _get_number_of_gpu_sm () -> int :
42
+ if not torch .cuda .is_available ():
43
+ raise RuntimeError ("CUDA is not available" )
44
+ device_props = torch .cuda .get_device_properties (0 )
45
+ return device_props .multi_processor_count
46
+
47
+
41
48
def _str_1d_tensor (t : torch .Tensor ) -> str :
42
49
sl = [f"{ x :7.4f} " for x in t .tolist ()]
43
50
if len (sl ) > 5 :
@@ -48,6 +55,7 @@ def _str_1d_tensor(t: torch.Tensor) -> str:
48
55
def _do_test_all_to_all (
49
56
pgi : ProcessGroupInfo ,
50
57
dp_size : int ,
58
+ max_sm_count : int ,
51
59
moe : MoEConfig ,
52
60
internode : bool ,
53
61
) -> None :
@@ -79,6 +87,7 @@ def _do_test_all_to_all(
79
87
* torch .float32 .itemsize
80
88
)
81
89
),
90
+ max_sm_count = max_sm_count ,
82
91
)
83
92
else :
84
93
ata = AllToAll .intranode (
@@ -99,6 +108,7 @@ def _do_test_all_to_all(
99
108
* torch .float32 .itemsize
100
109
)
101
110
),
111
+ max_sm_count = max_sm_count ,
102
112
)
103
113
104
114
# Generate the same test data on all ranks
@@ -283,6 +293,7 @@ def _worker_test_all_to_all(
283
293
dp_size : int ,
284
294
in_dtype : str ,
285
295
out_dtype : str ,
296
+ max_sm_count : int ,
286
297
moe_config : MoEConfig ,
287
298
internode : bool ,
288
299
) -> None :
@@ -295,16 +306,17 @@ def _worker_test_all_to_all(
295
306
in_dtype = getattr (torch , in_dtype ),
296
307
out_dtype = getattr (torch , out_dtype ),
297
308
)
298
- _do_test_all_to_all (pgi , dp_size , moe_config , internode )
309
+ _do_test_all_to_all (pgi , dp_size , max_sm_count , moe_config , internode )
299
310
300
311
nvshmem_finalize ()
301
312
302
313
303
314
@pytest .mark .skipif (torch .cuda .device_count () < 4 , reason = "Requires at least 4 GPUs" )
304
315
@pytest .mark .parametrize ("in_dtype" , ["bfloat16" , "float8_e4m3fn" , "float16" ])
305
316
@pytest .mark .parametrize ("out_dtype" , ["float16" , "bfloat16" ])
317
+ @pytest .mark .parametrize ("max_sm_count" , [_get_number_of_gpu_sm (),_get_number_of_gpu_sm ()// 2 ])
306
318
@pytest .mark .parametrize ("internode" , [True , False ])
307
- def test_all_to_all_4_gpu (in_dtype : str , out_dtype : str , internode : bool ) -> None :
319
+ def test_all_to_all_4_gpu (in_dtype : str , out_dtype : str , max_sm_count : int , internode : bool ) -> None :
308
320
world_size = 4
309
321
dp_size = 2
310
322
parallel_launch (
@@ -313,6 +325,7 @@ def test_all_to_all_4_gpu(in_dtype: str, out_dtype: str, internode: bool) -> Non
313
325
dp_size ,
314
326
in_dtype ,
315
327
out_dtype ,
328
+ max_sm_count ,
316
329
small_moe ,
317
330
internode ,
318
331
)
@@ -322,13 +335,15 @@ def _worker_test_all_to_all_multi_node(
322
335
pgi : ProcessGroupInfo ,
323
336
in_dtype : str ,
324
337
out_dtype : str ,
338
+ max_sm_count : int ,
325
339
) -> None :
326
340
dp_size = 4
327
341
_worker_test_all_to_all (
328
342
pgi ,
329
343
dp_size ,
330
344
in_dtype ,
331
345
out_dtype ,
346
+ max_sm_count ,
332
347
medium_moe ,
333
348
True ,
334
349
)
@@ -338,4 +353,5 @@ def _worker_test_all_to_all_multi_node(
338
353
@pytest .mark .parametrize ("in_dtype" , ["bfloat16" , "float8_e4m3fn" , "float16" ])
339
354
@pytest .mark .parametrize ("out_dtype" , ["float16" , "bfloat16" ])
340
355
def test_all_to_all_multi_node (in_dtype : str , out_dtype : str ) -> None :
341
- parallel_launch_from_env (_worker_test_all_to_all_multi_node , in_dtype , out_dtype )
356
+ max_sm_count = _get_number_of_gpu_sm ()
357
+ parallel_launch_from_env (_worker_test_all_to_all_multi_node , in_dtype , out_dtype , max_sm_count )
0 commit comments