Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 52 additions & 24 deletions benchmarks/prototype/moe_training/mxfp8/bench_all_to_all_v.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#
# To run these benchmarks, use the following command:
#
# torchrun --nproc-per-node=8 --local-ranks-filter=0 benchmarks/prototype/moe_training/mxfp8/bench_all_to_all_v.py
# torchrun --nproc-per-node=4 --local-ranks-filter=0 benchmarks/prototype/moe_training/mxfp8/bench_all_to_all_v.py
#
#######################################################################
import argparse
Expand All @@ -24,6 +24,7 @@
all_to_all_single,
all_to_all_single_autograd,
)
from torch.nn import functional as F
from tqdm import tqdm

from benchmarks.utils import profile_fn
Expand Down Expand Up @@ -66,33 +67,53 @@ def get_configs() -> List[ExperimentConfig]:
return configs


# Copy/paste a2a impls added in https://github.com/pytorch/torchtitan/pull/1765
def default_a2a_dispatch(
def default_a2a_fwd_bwd(
routed_input: torch.Tensor,
labels: torch.Tensor,
output_splits_list: list[int],
input_splits_list: list[int],
device_mesh: DeviceMesh,
):
"""
Default implementation of all-to-all dispatch. Incurs device-to-host sync.

Returns:
routed_input: the local tokens after all-to-all dispatch
input_splits: the input splits for all-to-all dispatch
output_splits: the output splits for all-to-all dispatch
num_tokens_per_expert_group: the number of tokens per EP rank after all-to-all dispatch
"""
# perform all-to-all
routed_input = all_to_all_single_autograd(
routed_input,
output_splits_list,
input_splits_list,
device_mesh.get_group(),
)
routed_input = torch.ops._c10d_functional.wait_tensor(routed_input)

loss = F.mse_loss(routed_input, labels)
loss.backward()

torch.cuda.synchronize()
return routed_input


def mxfp8_a2a_fwd_bwd(
routed_input: torch.Tensor,
labels: torch.Tensor,
output_splits_list: list[int],
input_splits_list: list[int],
device_mesh: DeviceMesh,
):
routed_input = to_mxfp8_a2a_dequant(
routed_input,
output_splits_list,
input_splits_list,
device_mesh.get_group(),
)

loss = F.mse_loss(routed_input, labels)
loss.backward()
torch.cuda.synchronize()
return routed_input


# Compile target funcs
default_a2a_sync_compiled = torch.compile(default_a2a_fwd_bwd)
mxfp8_a2a_sync_compiled = torch.compile(mxfp8_a2a_fwd_bwd)


def run_experiment(
config: ExperimentConfig, args: argparse.Namespace
) -> ExperimentResult:
Expand All @@ -101,8 +122,9 @@ def run_experiment(
(batch_size * seq_len, dim),
dtype=torch.bfloat16,
device=device,
requires_grad=True,
)
ref_x = x.detach().clone()
ref_x = x.detach().clone().requires_grad_(True)

# Set up device mesh
mesh = init_device_mesh("cuda", (dist.get_world_size(),))
Expand All @@ -121,24 +143,27 @@ def warmup(func_no_args):
)
input_splits_list, output_splits_list = get_split_lists(input_splits, mesh)

# Compile target funcs
default_a2a_dispatch_c = torch.compile(default_a2a_dispatch)
to_mxfp8_a2a_dequant_c = torch.compile(to_mxfp8_a2a_dequant)
# Generate labels
labels_shape = (sum(output_splits_list), dim)
labels = x.new_ones(*labels_shape)

# Bench default a2a (exclude d2h sync from preparing input splits_list and output_splits_list)
warmup(
lambda: default_a2a_dispatch_c(
ref_x, output_splits_list, input_splits_list, mesh
lambda: default_a2a_sync_compiled(
ref_x, labels, output_splits_list, input_splits_list, mesh
)
)
start_sec = time.perf_counter()
default_a2a_dispatch_c(ref_x, output_splits_list, input_splits_list, mesh)
default_a2a_sync_compiled(
ref_x, labels, output_splits_list, input_splits_list, mesh
)
end_sec = time.perf_counter()
bf16_ms = (end_sec - start_sec) * 1e3
if args.profile:
profile_fn(
default_a2a_dispatch_c,
default_a2a_sync_compiled,
ref_x,
labels,
output_splits_list,
input_splits_list,
mesh,
Expand All @@ -148,16 +173,19 @@ def warmup(func_no_args):

# Bench mxfp8 sync a2a (exclude d2h sync from preparing input splits_list and output_splits_list)
warmup(
lambda: to_mxfp8_a2a_dequant_c(x, output_splits_list, input_splits_list, mesh)
lambda: mxfp8_a2a_sync_compiled(
x, labels, output_splits_list, input_splits_list, mesh
)
)
start_sec = time.perf_counter()
to_mxfp8_a2a_dequant_c(x, output_splits_list, input_splits_list, mesh)
mxfp8_a2a_sync_compiled(x, labels, output_splits_list, input_splits_list, mesh)
end_sec = time.perf_counter()
mxfp8_ms = (end_sec - start_sec) * 1e3
if args.profile:
profile_fn(
to_mxfp8_a2a_dequant_c,
mxfp8_a2a_sync_compiled,
x,
labels,
output_splits_list,
input_splits_list,
mesh,
Expand Down
7 changes: 5 additions & 2 deletions torchao/prototype/moe_training/kernels/mxfp8/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
blockwise_barrier,
sync_threads,
)
from torchao.prototype.mx_formats.config import ScaleCalculationMode
from torchao.prototype.mx_formats.mx_tensor import to_dtype, to_mx


Expand Down Expand Up @@ -468,14 +469,15 @@ def forward(
"""
Dynamically quantizes input to mxfp8, performs all-to-all, then dequantizes output back to original precision.
Requires d2h sync to get input_splits and output_splits on host, as required by torch.distributed.all_to_all_single API.
Uses RCEIL scaling mode for quantization.
"""

# Quantize input
block_size = 32
input_scales, input_data = to_mx(
input,
elem_dtype=torch.float8_e4m3fn,
block_size=block_size,
scaling_mode=ScaleCalculationMode.RCEIL,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: put scaling mode in the docblock

)

# Dispatch data (async)
Expand Down Expand Up @@ -531,6 +533,7 @@ def backward(ctx, grad_output_hp):
grad_output_hp,
elem_dtype=torch.float8_e4m3fn,
block_size=block_size,
scaling_mode=ScaleCalculationMode.RCEIL,
)

# Dispatch data (async)
Expand All @@ -550,8 +553,8 @@ def backward(ctx, grad_output_hp):
)

# Explicitly wait since the a2a ops are async
grad_input_scales = torch.ops._c10d_functional.wait_tensor(grad_input_scales)
grad_input_data = torch.ops._c10d_functional.wait_tensor(grad_input_data)
grad_input_scales = torch.ops._c10d_functional.wait_tensor(grad_input_scales)

hp_dtype = grad_output_hp.dtype
lowp_dtype = grad_input_data.dtype
Expand Down
Loading