@@ -234,30 +234,32 @@ def _worker_bench_all_to_all(
234
234
pgi : ProcessGroupInfo ,
235
235
dp_size : int ,
236
236
in_dtype_str : str ,
237
+ out_dtype_str : str ,
237
238
) -> None :
238
239
uid = nvshmem_get_unique_id () if pgi .rank == 0 else nvshmem_alloc_empty_unique_id ()
239
240
torch .distributed .broadcast (uid , src = 0 )
240
241
nvshmem_init (uid , pgi .rank , pgi .world_size )
241
242
242
243
in_dtype = getattr (torch , in_dtype_str )
244
+ out_dtype = getattr (torch , out_dtype_str )
243
245
assert isinstance (in_dtype , torch .dtype )
244
246
configs = [
245
247
# V2-Lite: 64 Experts, 6 Experts per Token, 2048 Hidden Dim
246
- MoEConfig (64 , 6 , 2048 , 1 , in_dtype ),
247
- MoEConfig (64 , 6 , 2048 , 4 , in_dtype ),
248
- MoEConfig (64 , 6 , 2048 , 8 , in_dtype ),
249
- MoEConfig (64 , 6 , 2048 , 16 , in_dtype ),
250
- MoEConfig (64 , 6 , 2048 , 32 , in_dtype ),
251
- MoEConfig (64 , 6 , 2048 , 64 , in_dtype ),
252
- MoEConfig (64 , 6 , 2048 , 128 , in_dtype ),
248
+ MoEConfig (64 , 6 , 2048 , 1 , in_dtype , out_dtype ),
249
+ MoEConfig (64 , 6 , 2048 , 4 , in_dtype , out_dtype ),
250
+ MoEConfig (64 , 6 , 2048 , 8 , in_dtype , out_dtype ),
251
+ MoEConfig (64 , 6 , 2048 , 16 , in_dtype , out_dtype ),
252
+ MoEConfig (64 , 6 , 2048 , 32 , in_dtype , out_dtype ),
253
+ MoEConfig (64 , 6 , 2048 , 64 , in_dtype , out_dtype ),
254
+ MoEConfig (64 , 6 , 2048 , 128 , in_dtype , out_dtype ),
253
255
# R1 : 256 Experts, 8 Experts per Token, 7168 Hidden Dim
254
- MoEConfig (256 , 8 , 7168 , 1 , in_dtype ),
255
- MoEConfig (256 , 8 , 7168 , 4 , in_dtype ),
256
- MoEConfig (256 , 8 , 7168 , 8 , in_dtype ),
257
- MoEConfig (256 , 8 , 7168 , 16 , in_dtype ),
258
- MoEConfig (256 , 8 , 7168 , 32 , in_dtype ),
259
- MoEConfig (256 , 8 , 7168 , 64 , in_dtype ),
260
- MoEConfig (256 , 8 , 7168 , 128 , in_dtype ),
256
+ MoEConfig (256 , 8 , 7168 , 1 , in_dtype , out_dtype ),
257
+ MoEConfig (256 , 8 , 7168 , 4 , in_dtype , out_dtype ),
258
+ MoEConfig (256 , 8 , 7168 , 8 , in_dtype , out_dtype ),
259
+ MoEConfig (256 , 8 , 7168 , 16 , in_dtype , out_dtype ),
260
+ MoEConfig (256 , 8 , 7168 , 32 , in_dtype , out_dtype ),
261
+ MoEConfig (256 , 8 , 7168 , 64 , in_dtype , out_dtype ),
262
+ MoEConfig (256 , 8 , 7168 , 128 , in_dtype , out_dtype ),
261
263
]
262
264
263
265
header = [
@@ -340,18 +342,26 @@ def main() -> None:
340
342
parser .add_argument ("--dp-size" , type = int , default = 1 )
341
343
parser .add_argument (
342
344
"--in-dtype" ,
343
- choices = ["bfloat16" , "float8_e4m3fn" ],
345
+ choices = ["bfloat16" , "float16" , " float8_e4m3fn" ],
344
346
default = "float8_e4m3fn" ,
345
347
)
348
+ parser .add_argument (
349
+ "--out-dtype" ,
350
+ choices = ["bfloat16" , "float16" ],
351
+ default = "bfloat16" ,
352
+ )
346
353
args = parser .parse_args ()
347
354
dp_size = int (args .dp_size )
348
355
in_dtype = str (args .in_dtype )
356
+ out_dtype = str (args .out_dtype )
349
357
350
358
if "MASTER_ADDR" in os .environ :
351
- parallel_launch_from_env (_worker_bench_all_to_all , dp_size , in_dtype )
359
+ parallel_launch_from_env (_worker_bench_all_to_all , dp_size , in_dtype , out_dtype )
352
360
else :
353
361
world_size = torch .cuda .device_count ()
354
- parallel_launch (world_size , _worker_bench_all_to_all , dp_size , in_dtype )
362
+ parallel_launch (
363
+ world_size , _worker_bench_all_to_all , dp_size , in_dtype , out_dtype
364
+ )
355
365
356
366
357
367
if __name__ == "__main__" :
0 commit comments