Skip to content

Commit cd21d0e

Browse files
[mxfp8 moe training] fix mxfp8 a2a bench script; set mxfp8 a2a scaling type to RCEIL (#3114)
[mxfp8 moe training] fix mxfp8 a2a bench script
1 parent afd6096 commit cd21d0e

File tree

2 files changed

+57
-26
lines changed

2 files changed

+57
-26
lines changed

benchmarks/prototype/moe_training/mxfp8/bench_all_to_all_v.py

Lines changed: 52 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#
88
# To run these benchmarks, use the following command:
99
#
10-
# torchrun --nproc-per-node=8 --local-ranks-filter=0 benchmarks/prototype/moe_training/mxfp8/bench_all_to_all_v.py
10+
# torchrun --nproc-per-node=4 --local-ranks-filter=0 benchmarks/prototype/moe_training/mxfp8/bench_all_to_all_v.py
1111
#
1212
#######################################################################
1313
import argparse
@@ -24,6 +24,7 @@
2424
all_to_all_single,
2525
all_to_all_single_autograd,
2626
)
27+
from torch.nn import functional as F
2728
from tqdm import tqdm
2829

2930
from benchmarks.utils import profile_fn
@@ -66,33 +67,53 @@ def get_configs() -> List[ExperimentConfig]:
6667
return configs
6768

6869

69-
# Copy/paste a2a impls added in https://github.com/pytorch/torchtitan/pull/1765
70-
def default_a2a_dispatch(
70+
def default_a2a_fwd_bwd(
7171
routed_input: torch.Tensor,
72+
labels: torch.Tensor,
7273
output_splits_list: list[int],
7374
input_splits_list: list[int],
7475
device_mesh: DeviceMesh,
7576
):
76-
"""
77-
Default implementation of all-to-all dispatch. Incurs device-to-host sync.
78-
79-
Returns:
80-
routed_input: the local tokens after all-to-all dispatch
81-
input_splits: the input splits for all-to-all dispatch
82-
output_splits: the output splits for all-to-all dispatch
83-
num_tokens_per_expert_group: the number of tokens per EP rank after all-to-all dispatch
84-
"""
85-
# perform all-to-all
8677
routed_input = all_to_all_single_autograd(
8778
routed_input,
8879
output_splits_list,
8980
input_splits_list,
9081
device_mesh.get_group(),
9182
)
9283
routed_input = torch.ops._c10d_functional.wait_tensor(routed_input)
84+
85+
loss = F.mse_loss(routed_input, labels)
86+
loss.backward()
87+
88+
torch.cuda.synchronize()
9389
return routed_input
9490

9591

92+
def mxfp8_a2a_fwd_bwd(
93+
routed_input: torch.Tensor,
94+
labels: torch.Tensor,
95+
output_splits_list: list[int],
96+
input_splits_list: list[int],
97+
device_mesh: DeviceMesh,
98+
):
99+
routed_input = to_mxfp8_a2a_dequant(
100+
routed_input,
101+
output_splits_list,
102+
input_splits_list,
103+
device_mesh.get_group(),
104+
)
105+
106+
loss = F.mse_loss(routed_input, labels)
107+
loss.backward()
108+
torch.cuda.synchronize()
109+
return routed_input
110+
111+
112+
# Compile target funcs
113+
default_a2a_sync_compiled = torch.compile(default_a2a_fwd_bwd)
114+
mxfp8_a2a_sync_compiled = torch.compile(mxfp8_a2a_fwd_bwd)
115+
116+
96117
def run_experiment(
97118
config: ExperimentConfig, args: argparse.Namespace
98119
) -> ExperimentResult:
@@ -101,8 +122,9 @@ def run_experiment(
101122
(batch_size * seq_len, dim),
102123
dtype=torch.bfloat16,
103124
device=device,
125+
requires_grad=True,
104126
)
105-
ref_x = x.detach().clone()
127+
ref_x = x.detach().clone().requires_grad_(True)
106128

107129
# Set up device mesh
108130
mesh = init_device_mesh("cuda", (dist.get_world_size(),))
@@ -121,24 +143,27 @@ def warmup(func_no_args):
121143
)
122144
input_splits_list, output_splits_list = get_split_lists(input_splits, mesh)
123145

124-
# Compile target funcs
125-
default_a2a_dispatch_c = torch.compile(default_a2a_dispatch)
126-
to_mxfp8_a2a_dequant_c = torch.compile(to_mxfp8_a2a_dequant)
146+
# Generate labels
147+
labels_shape = (sum(output_splits_list), dim)
148+
labels = x.new_ones(*labels_shape)
127149

128150
# Bench default a2a (exclude d2h sync from preparing input splits_list and output_splits_list)
129151
warmup(
130-
lambda: default_a2a_dispatch_c(
131-
ref_x, output_splits_list, input_splits_list, mesh
152+
lambda: default_a2a_sync_compiled(
153+
ref_x, labels, output_splits_list, input_splits_list, mesh
132154
)
133155
)
134156
start_sec = time.perf_counter()
135-
default_a2a_dispatch_c(ref_x, output_splits_list, input_splits_list, mesh)
157+
default_a2a_sync_compiled(
158+
ref_x, labels, output_splits_list, input_splits_list, mesh
159+
)
136160
end_sec = time.perf_counter()
137161
bf16_ms = (end_sec - start_sec) * 1e3
138162
if args.profile:
139163
profile_fn(
140-
default_a2a_dispatch_c,
164+
default_a2a_sync_compiled,
141165
ref_x,
166+
labels,
142167
output_splits_list,
143168
input_splits_list,
144169
mesh,
@@ -148,16 +173,19 @@ def warmup(func_no_args):
148173

149174
# Bench mxfp8 sync a2a (exclude d2h sync from preparing input splits_list and output_splits_list)
150175
warmup(
151-
lambda: to_mxfp8_a2a_dequant_c(x, output_splits_list, input_splits_list, mesh)
176+
lambda: mxfp8_a2a_sync_compiled(
177+
x, labels, output_splits_list, input_splits_list, mesh
178+
)
152179
)
153180
start_sec = time.perf_counter()
154-
to_mxfp8_a2a_dequant_c(x, output_splits_list, input_splits_list, mesh)
181+
mxfp8_a2a_sync_compiled(x, labels, output_splits_list, input_splits_list, mesh)
155182
end_sec = time.perf_counter()
156183
mxfp8_ms = (end_sec - start_sec) * 1e3
157184
if args.profile:
158185
profile_fn(
159-
to_mxfp8_a2a_dequant_c,
186+
mxfp8_a2a_sync_compiled,
160187
x,
188+
labels,
161189
output_splits_list,
162190
input_splits_list,
163191
mesh,

torchao/prototype/moe_training/kernels/mxfp8/comms.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
blockwise_barrier,
1212
sync_threads,
1313
)
14+
from torchao.prototype.mx_formats.config import ScaleCalculationMode
1415
from torchao.prototype.mx_formats.mx_tensor import to_dtype, to_mx
1516

1617

@@ -468,14 +469,15 @@ def forward(
468469
"""
469470
Dynamically quantizes input to mxfp8, performs all-to-all, then dequantizes output back to original precision.
470471
Requires d2h sync to get input_splits and output_splits on host, as required by torch.distributed.all_to_all_single API.
472+
Uses RCEIL scaling mode for quantization.
471473
"""
472-
473474
# Quantize input
474475
block_size = 32
475476
input_scales, input_data = to_mx(
476477
input,
477478
elem_dtype=torch.float8_e4m3fn,
478479
block_size=block_size,
480+
scaling_mode=ScaleCalculationMode.RCEIL,
479481
)
480482

481483
# Dispatch data (async)
@@ -531,6 +533,7 @@ def backward(ctx, grad_output_hp):
531533
grad_output_hp,
532534
elem_dtype=torch.float8_e4m3fn,
533535
block_size=block_size,
536+
scaling_mode=ScaleCalculationMode.RCEIL,
534537
)
535538

536539
# Dispatch data (async)
@@ -550,8 +553,8 @@ def backward(ctx, grad_output_hp):
550553
)
551554

552555
# Explicitly wait since the a2a ops are async
553-
grad_input_scales = torch.ops._c10d_functional.wait_tensor(grad_input_scales)
554556
grad_input_data = torch.ops._c10d_functional.wait_tensor(grad_input_data)
557+
grad_input_scales = torch.ops._c10d_functional.wait_tensor(grad_input_scales)
555558

556559
hp_dtype = grad_output_hp.dtype
557560
lowp_dtype = grad_input_data.dtype

0 commit comments

Comments
 (0)