Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 9 additions & 16 deletions benchmarks/prototype/moe_training/bench_2d_3d_grouped_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,12 +216,12 @@ def bench_fp8_rowwise_grouped_mm(A, B_t, offs) -> float:


def bench_mxfp8_grouped_mm(A, B_t, offs, block_size=32) -> float:
# A_mx shape: (M, K)
# A_scale shape: (M, K//block_size)
# A_fp8 shape: (M, K)
# A_scales shape: (M, K//block_size)
A_scales, A_fp8 = to_mx(A, elem_dtype=torch.float8_e4m3fn, block_size=block_size)

# B_mx shape: (E, N, K)
# B_scale shape: (E, N, K//block_size)
# B_fp8 shape: (E, N, K)
# B_scales shape: (E, N, K//block_size)
B_scales, B_fp8 = to_mx(
B_t.transpose(-2, -1),
elem_dtype=torch.float8_e4m3fn,
Expand All @@ -230,26 +230,19 @@ def bench_mxfp8_grouped_mm(A, B_t, offs, block_size=32) -> float:

# Convert scales for each group to blocked format.
Mg, K = A_fp8.shape
A_scales_blocked, starting_row_after_padding = torch_to_blocked_2d_M_groups(
A_scales, offs, K
A_scales_blocked, _ = torch_to_blocked_2d_M_groups(
A_scales, offs, block_size=block_size
)
B_scales_blocked = torch_to_blocked_per_group_3d(B_scales)

# From this, we compute `group_sizes` and `starting_row_after_padding`:
# group_sizes = [32, 32, 64]
# starting_row_after_padding = [0, 32, 64, 128]
zero = torch.tensor([0], dtype=offs.dtype, device=offs.device)
group_sizes = torch.diff(offs, prepend=zero).to(torch.int64)

# Run the grouped mm
mxfp8_us = benchmark_cuda_function_in_microseconds(
torch.ops.fbgemm.mx8mx8bf16_grouped_stacked,
torch._scaled_grouped_mm,
A_fp8,
B_fp8,
B_fp8.transpose(-2, -1),
A_scales_blocked,
B_scales_blocked,
group_sizes,
starting_row_after_padding=starting_row_after_padding,
offs=offs,
)
return mxfp8_us

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def get_configs() -> List[ExperimentConfig]:
block_size = 32
input_shapes = [
(16640, 5120 // block_size),
(131072, 5120 // block_size),
]
num_groups = [16]
configs = []
Expand Down Expand Up @@ -78,18 +79,21 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
)

Mg, K = input_shape
input_group_offsets = generate_jagged_offs(num_groups, Mg, multiple_of=32)
block_size = 32
input_group_offsets = generate_jagged_offs(num_groups, Mg, multiple_of=block_size)

# bench torch
compiled_run_torch = torch.compile(torch_to_blocked_2d_M_groups)
torch_out_scales, torch_group_offs = compiled_run_torch(
input_tensor, input_group_offsets, K
input_tensor,
input_group_offsets,
block_size=block_size,
)
torch_time_us = benchmark_cuda_function_in_microseconds(
compiled_run_torch,
input_tensor,
input_group_offsets,
K,
block_size=block_size,
)

# bench triton
Expand Down
2 changes: 1 addition & 1 deletion test/prototype/moe_training/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def test_triton_mx_block_rearrange_2d_M_groups(

# torch reference
ref_out_scales, _ = torch_to_blocked_2d_M_groups(
e8m0_scales, input_group_offsets, k, block_size=block_size
e8m0_scales, input_group_offsets, block_size=block_size
)

# triton kernel
Expand Down
11 changes: 5 additions & 6 deletions torchao/prototype/moe_training/kernels/mxfp8/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


def torch_to_blocked_2d_M_groups(
x_scales: Tensor, group_offs: Tensor, K: int, block_size: int = 32
x_scales: Tensor, group_offs: Tensor, block_size: int = 32
) -> Tuple[Tensor, Tensor]:
"""
Convert scales to blocked format for a 2D tensor (input activations / token groups),
Expand All @@ -34,14 +34,14 @@ def torch_to_blocked_2d_M_groups(

assert x_scales.ndim == 2, "x_scales must be 2D"
assert block_size == 32, "Only block_size=32 is supported for now"
total_M, _ = x_scales.shape
total_M, scale_cols = x_scales.shape
num_groups = group_offs.shape[0]

# Each group will require a variable amount of padding, so to avoid d2h sync causing by iterating over each group,
# the Triton kernenl will use an upper bound of adding 128 padding rows to each group.
# (This torch impl is used as a reference for correctness, so we must match the triton kernel's impl).
total_M_padded = total_M + num_groups * 128
blocked_scales = x_scales.new_zeros(total_M_padded, K // block_size)
blocked_scales = x_scales.new_zeros(total_M_padded, scale_cols)
start_row_after_padding_list = [0]
group_start_idx = 0
for i, group_end_idx in enumerate(group_offs.tolist()):
Expand All @@ -56,8 +56,7 @@ def torch_to_blocked_2d_M_groups(
group_scales_blocked = to_blocked(group_scales)

# Calculate the start row after padding
scaling_groups_per_row = K // block_size
rows_for_group = group_scales_blocked.numel() // scaling_groups_per_row
rows_for_group = group_scales_blocked.numel() // scale_cols
new_start_row = prev_start_row_after_padding + rows_for_group
start_row_after_padding_list.append(new_start_row)

Expand All @@ -67,7 +66,7 @@ def torch_to_blocked_2d_M_groups(
prev_start_row_after_padding : prev_start_row_after_padding
+ group_rows_padded,
:,
] = group_scales_blocked.reshape(-1, K // block_size)
] = group_scales_blocked.reshape(-1, scale_cols)

# Update next group start index
group_start_idx = group_end_idx
Expand Down
1 change: 0 additions & 1 deletion torchao/prototype/mx_formats/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def to_blocked(input_matrix, use_triton_kernel: bool = False) -> Tensor:
# Rearrange the blocks
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)

return rearranged.flatten()


Expand Down
Loading