diff --git a/benchmarks/prototype/moe_training/mxfp8/bench_all_to_all_v.py b/benchmarks/prototype/moe_training/mxfp8/bench_all_to_all_v.py index 237a26c3e1..28a5bb87a2 100644 --- a/benchmarks/prototype/moe_training/mxfp8/bench_all_to_all_v.py +++ b/benchmarks/prototype/moe_training/mxfp8/bench_all_to_all_v.py @@ -7,7 +7,7 @@ # # To run these benchmarks, use the following command: # -# torchrun --nproc-per-node=8 --local-ranks-filter=0 benchmarks/prototype/moe_training/mxfp8/bench_all_to_all_v.py +# torchrun --nproc-per-node=4 --local-ranks-filter=0 benchmarks/prototype/moe_training/mxfp8/bench_all_to_all_v.py # ####################################################################### import argparse @@ -24,6 +24,7 @@ all_to_all_single, all_to_all_single_autograd, ) +from torch.nn import functional as F from tqdm import tqdm from benchmarks.utils import profile_fn @@ -66,23 +67,13 @@ def get_configs() -> List[ExperimentConfig]: return configs -# Copy/paste a2a impls added in https://github.com/pytorch/torchtitan/pull/1765 -def default_a2a_dispatch( +def default_a2a_fwd_bwd( routed_input: torch.Tensor, + labels: torch.Tensor, output_splits_list: list[int], input_splits_list: list[int], device_mesh: DeviceMesh, ): - """ - Default implementation of all-to-all dispatch. Incurs device-to-host sync. - - Returns: - routed_input: the local tokens after all-to-all dispatch - input_splits: the input splits for all-to-all dispatch - output_splits: the output splits for all-to-all dispatch - num_tokens_per_expert_group: the number of tokens per EP rank after all-to-all dispatch - """ - # perform all-to-all routed_input = all_to_all_single_autograd( routed_input, output_splits_list, @@ -90,9 +81,39 @@ def default_a2a_dispatch( device_mesh.get_group(), ) routed_input = torch.ops._c10d_functional.wait_tensor(routed_input) + + loss = F.mse_loss(routed_input, labels) + loss.backward() + + torch.cuda.synchronize() return routed_input +def mxfp8_a2a_fwd_bwd( + routed_input: torch.Tensor, + labels: torch.Tensor, + output_splits_list: list[int], + input_splits_list: list[int], + device_mesh: DeviceMesh, +): + routed_input = to_mxfp8_a2a_dequant( + routed_input, + output_splits_list, + input_splits_list, + device_mesh.get_group(), + ) + + loss = F.mse_loss(routed_input, labels) + loss.backward() + torch.cuda.synchronize() + return routed_input + + +# Compile target funcs +default_a2a_sync_compiled = torch.compile(default_a2a_fwd_bwd) +mxfp8_a2a_sync_compiled = torch.compile(mxfp8_a2a_fwd_bwd) + + def run_experiment( config: ExperimentConfig, args: argparse.Namespace ) -> ExperimentResult: @@ -101,8 +122,9 @@ def run_experiment( (batch_size * seq_len, dim), dtype=torch.bfloat16, device=device, + requires_grad=True, ) - ref_x = x.detach().clone() + ref_x = x.detach().clone().requires_grad_(True) # Set up device mesh mesh = init_device_mesh("cuda", (dist.get_world_size(),)) @@ -121,24 +143,27 @@ def warmup(func_no_args): ) input_splits_list, output_splits_list = get_split_lists(input_splits, mesh) - # Compile target funcs - default_a2a_dispatch_c = torch.compile(default_a2a_dispatch) - to_mxfp8_a2a_dequant_c = torch.compile(to_mxfp8_a2a_dequant) + # Generate labels + labels_shape = (sum(output_splits_list), dim) + labels = x.new_ones(*labels_shape) # Bench default a2a (exclude d2h sync from preparing input splits_list and output_splits_list) warmup( - lambda: default_a2a_dispatch_c( - ref_x, output_splits_list, input_splits_list, mesh + lambda: default_a2a_sync_compiled( + ref_x, labels, output_splits_list, input_splits_list, mesh ) ) start_sec = time.perf_counter() - default_a2a_dispatch_c(ref_x, output_splits_list, input_splits_list, mesh) + default_a2a_sync_compiled( + ref_x, labels, output_splits_list, input_splits_list, mesh + ) end_sec = time.perf_counter() bf16_ms = (end_sec - start_sec) * 1e3 if args.profile: profile_fn( - default_a2a_dispatch_c, + default_a2a_sync_compiled, ref_x, + labels, output_splits_list, input_splits_list, mesh, @@ -148,16 +173,19 @@ def warmup(func_no_args): # Bench mxfp8 sync a2a (exclude d2h sync from preparing input splits_list and output_splits_list) warmup( - lambda: to_mxfp8_a2a_dequant_c(x, output_splits_list, input_splits_list, mesh) + lambda: mxfp8_a2a_sync_compiled( + x, labels, output_splits_list, input_splits_list, mesh + ) ) start_sec = time.perf_counter() - to_mxfp8_a2a_dequant_c(x, output_splits_list, input_splits_list, mesh) + mxfp8_a2a_sync_compiled(x, labels, output_splits_list, input_splits_list, mesh) end_sec = time.perf_counter() mxfp8_ms = (end_sec - start_sec) * 1e3 if args.profile: profile_fn( - to_mxfp8_a2a_dequant_c, + mxfp8_a2a_sync_compiled, x, + labels, output_splits_list, input_splits_list, mesh, diff --git a/torchao/prototype/moe_training/kernels/mxfp8/comms.py b/torchao/prototype/moe_training/kernels/mxfp8/comms.py index 7430010d3a..7c6999fbf1 100644 --- a/torchao/prototype/moe_training/kernels/mxfp8/comms.py +++ b/torchao/prototype/moe_training/kernels/mxfp8/comms.py @@ -11,6 +11,7 @@ blockwise_barrier, sync_threads, ) +from torchao.prototype.mx_formats.config import ScaleCalculationMode from torchao.prototype.mx_formats.mx_tensor import to_dtype, to_mx @@ -468,14 +469,15 @@ def forward( """ Dynamically quantizes input to mxfp8, performs all-to-all, then dequantizes output back to original precision. Requires d2h sync to get input_splits and output_splits on host, as required by torch.distributed.all_to_all_single API. + Uses RCEIL scaling mode for quantization. """ - # Quantize input block_size = 32 input_scales, input_data = to_mx( input, elem_dtype=torch.float8_e4m3fn, block_size=block_size, + scaling_mode=ScaleCalculationMode.RCEIL, ) # Dispatch data (async) @@ -531,6 +533,7 @@ def backward(ctx, grad_output_hp): grad_output_hp, elem_dtype=torch.float8_e4m3fn, block_size=block_size, + scaling_mode=ScaleCalculationMode.RCEIL, ) # Dispatch data (async) @@ -550,8 +553,8 @@ def backward(ctx, grad_output_hp): ) # Explicitly wait since the a2a ops are async - grad_input_scales = torch.ops._c10d_functional.wait_tensor(grad_input_scales) grad_input_data = torch.ops._c10d_functional.wait_tensor(grad_input_data) + grad_input_scales = torch.ops._c10d_functional.wait_tensor(grad_input_scales) hp_dtype = grad_output_hp.dtype lowp_dtype = grad_input_data.dtype