38
38
)
39
39
40
40
41
+ def _get_number_of_gpu_sm () -> int :
42
+ if not torch .cuda .is_available ():
43
+ raise RuntimeError ("CUDA is not available" )
44
+ device_props = torch .cuda .get_device_properties (0 )
45
+ return device_props .multi_processor_count
46
+
47
+
41
48
def _str_1d_tensor (t : torch .Tensor ) -> str :
42
49
sl = [f"{ x :7.4f} " for x in t .tolist ()]
43
50
if len (sl ) > 5 :
@@ -48,6 +55,7 @@ def _str_1d_tensor(t: torch.Tensor) -> str:
48
55
def _do_test_all_to_all (
49
56
pgi : ProcessGroupInfo ,
50
57
dp_size : int ,
58
+ max_sm_count : int ,
51
59
moe : MoEConfig ,
52
60
internode : bool ,
53
61
use_compile : bool ,
@@ -80,6 +88,7 @@ def _do_test_all_to_all(
80
88
* torch .float32 .itemsize
81
89
)
82
90
),
91
+ max_sm_count = max_sm_count ,
83
92
)
84
93
else :
85
94
ata = AllToAll .intranode (
@@ -100,6 +109,7 @@ def _do_test_all_to_all(
100
109
* torch .float32 .itemsize
101
110
)
102
111
),
112
+ max_sm_count = max_sm_count ,
103
113
)
104
114
105
115
# Generate the same test data on all ranks
@@ -291,6 +301,7 @@ def _worker_test_all_to_all(
291
301
dp_size : int ,
292
302
in_dtype : str ,
293
303
out_dtype : str ,
304
+ max_sm_count : int ,
294
305
moe_config : MoEConfig ,
295
306
internode : bool ,
296
307
use_compile : bool = False ,
@@ -305,18 +316,21 @@ def _worker_test_all_to_all(
305
316
out_dtype = getattr (torch , out_dtype ),
306
317
)
307
318
308
- _do_test_all_to_all (pgi , dp_size , moe_config , internode , use_compile )
319
+ _do_test_all_to_all (pgi , dp_size , max_sm_count , moe_config , internode , use_compile )
309
320
310
321
nvshmem_finalize ()
311
322
312
323
313
324
@pytest .mark .skipif (torch .cuda .device_count () < 4 , reason = "Requires at least 4 GPUs" )
314
325
@pytest .mark .parametrize ("in_dtype" , ["bfloat16" , "float8_e4m3fn" , "float16" ])
315
326
@pytest .mark .parametrize ("out_dtype" , ["float16" , "bfloat16" ])
327
+ @pytest .mark .parametrize (
328
+ "max_sm_count" , [_get_number_of_gpu_sm (), _get_number_of_gpu_sm () // 2 ]
329
+ )
316
330
@pytest .mark .parametrize ("internode" , [True , False ])
317
331
@pytest .mark .parametrize ("use_compile" , [False , True ])
318
332
def test_all_to_all_4_gpu (
319
- in_dtype : str , out_dtype : str , internode : bool , use_compile : bool
333
+ in_dtype : str , out_dtype : str , max_sm_count : int , internode : bool , use_compile : bool
320
334
) -> None :
321
335
world_size = 4
322
336
dp_size = 2
@@ -326,6 +340,7 @@ def test_all_to_all_4_gpu(
326
340
dp_size ,
327
341
in_dtype ,
328
342
out_dtype ,
343
+ max_sm_count ,
329
344
small_moe ,
330
345
internode ,
331
346
use_compile ,
@@ -336,13 +351,15 @@ def _worker_test_all_to_all_multi_node(
336
351
pgi : ProcessGroupInfo ,
337
352
in_dtype : str ,
338
353
out_dtype : str ,
354
+ max_sm_count : int ,
339
355
) -> None :
340
356
dp_size = 4
341
357
_worker_test_all_to_all (
342
358
pgi ,
343
359
dp_size ,
344
360
in_dtype ,
345
361
out_dtype ,
362
+ max_sm_count ,
346
363
medium_moe ,
347
364
True ,
348
365
)
@@ -352,4 +369,7 @@ def _worker_test_all_to_all_multi_node(
352
369
@pytest .mark .parametrize ("in_dtype" , ["bfloat16" , "float8_e4m3fn" , "float16" ])
353
370
@pytest .mark .parametrize ("out_dtype" , ["float16" , "bfloat16" ])
354
371
def test_all_to_all_multi_node (in_dtype : str , out_dtype : str ) -> None :
355
- parallel_launch_from_env (_worker_test_all_to_all_multi_node , in_dtype , out_dtype )
372
+ max_sm_count = _get_number_of_gpu_sm ()
373
+ parallel_launch_from_env (
374
+ _worker_test_all_to_all_multi_node , in_dtype , out_dtype , max_sm_count
375
+ )
0 commit comments