77#
88# To run these benchmarks, use the following command:
99#
10- # torchrun --nproc-per-node=8 --local-ranks-filter=0 benchmarks/prototype/moe_training/mxfp8/bench_all_to_all_v.py
10+ # torchrun --nproc-per-node=4 --local-ranks-filter=0 benchmarks/prototype/moe_training/mxfp8/bench_all_to_all_v.py
1111#
1212#######################################################################
1313import argparse
2424 all_to_all_single ,
2525 all_to_all_single_autograd ,
2626)
27+ from torch .nn import functional as F
2728from tqdm import tqdm
2829
2930from benchmarks .utils import profile_fn
@@ -66,33 +67,53 @@ def get_configs() -> List[ExperimentConfig]:
6667 return configs
6768
6869
69- # Copy/paste a2a impls added in https://github.com/pytorch/torchtitan/pull/1765
70- def default_a2a_dispatch (
70+ def default_a2a_fwd_bwd (
7171 routed_input : torch .Tensor ,
72+ labels : torch .Tensor ,
7273 output_splits_list : list [int ],
7374 input_splits_list : list [int ],
7475 device_mesh : DeviceMesh ,
7576):
76- """
77- Default implementation of all-to-all dispatch. Incurs device-to-host sync.
78-
79- Returns:
80- routed_input: the local tokens after all-to-all dispatch
81- input_splits: the input splits for all-to-all dispatch
82- output_splits: the output splits for all-to-all dispatch
83- num_tokens_per_expert_group: the number of tokens per EP rank after all-to-all dispatch
84- """
85- # perform all-to-all
8677 routed_input = all_to_all_single_autograd (
8778 routed_input ,
8879 output_splits_list ,
8980 input_splits_list ,
9081 device_mesh .get_group (),
9182 )
9283 routed_input = torch .ops ._c10d_functional .wait_tensor (routed_input )
84+
85+ loss = F .mse_loss (routed_input , labels )
86+ loss .backward ()
87+
88+ torch .cuda .synchronize ()
9389 return routed_input
9490
9591
92+ def mxfp8_a2a_fwd_bwd (
93+ routed_input : torch .Tensor ,
94+ labels : torch .Tensor ,
95+ output_splits_list : list [int ],
96+ input_splits_list : list [int ],
97+ device_mesh : DeviceMesh ,
98+ ):
99+ routed_input = to_mxfp8_a2a_dequant (
100+ routed_input ,
101+ output_splits_list ,
102+ input_splits_list ,
103+ device_mesh .get_group (),
104+ )
105+
106+ loss = F .mse_loss (routed_input , labels )
107+ loss .backward ()
108+ torch .cuda .synchronize ()
109+ return routed_input
110+
111+
112+ # Compile target funcs
113+ default_a2a_sync_compiled = torch .compile (default_a2a_fwd_bwd )
114+ mxfp8_a2a_sync_compiled = torch .compile (mxfp8_a2a_fwd_bwd )
115+
116+
96117def run_experiment (
97118 config : ExperimentConfig , args : argparse .Namespace
98119) -> ExperimentResult :
@@ -101,8 +122,9 @@ def run_experiment(
101122 (batch_size * seq_len , dim ),
102123 dtype = torch .bfloat16 ,
103124 device = device ,
125+ requires_grad = True ,
104126 )
105- ref_x = x .detach ().clone ()
127+ ref_x = x .detach ().clone (). requires_grad_ ( True )
106128
107129 # Set up device mesh
108130 mesh = init_device_mesh ("cuda" , (dist .get_world_size (),))
@@ -121,24 +143,27 @@ def warmup(func_no_args):
121143 )
122144 input_splits_list , output_splits_list = get_split_lists (input_splits , mesh )
123145
124- # Compile target funcs
125- default_a2a_dispatch_c = torch . compile ( default_a2a_dispatch )
126- to_mxfp8_a2a_dequant_c = torch . compile ( to_mxfp8_a2a_dequant )
146+ # Generate labels
147+ labels_shape = ( sum ( output_splits_list ), dim )
148+ labels = x . new_ones ( * labels_shape )
127149
128150 # Bench default a2a (exclude d2h sync from preparing input splits_list and output_splits_list)
129151 warmup (
130- lambda : default_a2a_dispatch_c (
131- ref_x , output_splits_list , input_splits_list , mesh
152+ lambda : default_a2a_sync_compiled (
153+ ref_x , labels , output_splits_list , input_splits_list , mesh
132154 )
133155 )
134156 start_sec = time .perf_counter ()
135- default_a2a_dispatch_c (ref_x , output_splits_list , input_splits_list , mesh )
157+ default_a2a_sync_compiled (
158+ ref_x , labels , output_splits_list , input_splits_list , mesh
159+ )
136160 end_sec = time .perf_counter ()
137161 bf16_ms = (end_sec - start_sec ) * 1e3
138162 if args .profile :
139163 profile_fn (
140- default_a2a_dispatch_c ,
164+ default_a2a_sync_compiled ,
141165 ref_x ,
166+ labels ,
142167 output_splits_list ,
143168 input_splits_list ,
144169 mesh ,
@@ -148,16 +173,19 @@ def warmup(func_no_args):
148173
149174 # Bench mxfp8 sync a2a (exclude d2h sync from preparing input splits_list and output_splits_list)
150175 warmup (
151- lambda : to_mxfp8_a2a_dequant_c (x , output_splits_list , input_splits_list , mesh )
176+ lambda : mxfp8_a2a_sync_compiled (
177+ x , labels , output_splits_list , input_splits_list , mesh
178+ )
152179 )
153180 start_sec = time .perf_counter ()
154- to_mxfp8_a2a_dequant_c ( x , output_splits_list , input_splits_list , mesh )
181+ mxfp8_a2a_sync_compiled ( x , labels , output_splits_list , input_splits_list , mesh )
155182 end_sec = time .perf_counter ()
156183 mxfp8_ms = (end_sec - start_sec ) * 1e3
157184 if args .profile :
158185 profile_fn (
159- to_mxfp8_a2a_dequant_c ,
186+ mxfp8_a2a_sync_compiled ,
160187 x ,
188+ labels ,
161189 output_splits_list ,
162190 input_splits_list ,
163191 mesh ,
0 commit comments