Skip to content

Commit 494df3f

Browse files
[mxfp8 moe training] update benchmarks and tests; simplify per group blocked swizzle ref function
stack-info: PR: #3286, branch: danielvegamyhre/stack/83
1 parent 01374eb commit 494df3f

File tree

5 files changed

+22
-27
lines changed

5 files changed

+22
-27
lines changed

benchmarks/prototype/moe_training/bench_2d_3d_grouped_gemm.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -216,12 +216,12 @@ def bench_fp8_rowwise_grouped_mm(A, B_t, offs) -> float:
216216

217217

218218
def bench_mxfp8_grouped_mm(A, B_t, offs, block_size=32) -> float:
219-
# A_mx shape: (M, K)
220-
# A_scale shape: (M, K//block_size)
219+
# A_fp8 shape: (M, K)
220+
# A_scales shape: (M, K//block_size)
221221
A_scales, A_fp8 = to_mx(A, elem_dtype=torch.float8_e4m3fn, block_size=block_size)
222222

223-
# B_mx shape: (E, N, K)
224-
# B_scale shape: (E, N, K//block_size)
223+
# B_fp8 shape: (E, N, K)
224+
# B_scales shape: (E, N, K//block_size)
225225
B_scales, B_fp8 = to_mx(
226226
B_t.transpose(-2, -1),
227227
elem_dtype=torch.float8_e4m3fn,
@@ -230,26 +230,19 @@ def bench_mxfp8_grouped_mm(A, B_t, offs, block_size=32) -> float:
230230

231231
# Convert scales for each group to blocked format.
232232
Mg, K = A_fp8.shape
233-
A_scales_blocked, starting_row_after_padding = torch_to_blocked_2d_M_groups(
234-
A_scales, offs, K
233+
A_scales_blocked, _ = torch_to_blocked_2d_M_groups(
234+
A_scales, offs, block_size=block_size
235235
)
236236
B_scales_blocked = torch_to_blocked_per_group_3d(B_scales)
237237

238-
# From this, we compute `group_sizes` and `starting_row_after_padding`:
239-
# group_sizes = [32, 32, 64]
240-
# starting_row_after_padding = [0, 32, 64, 128]
241-
zero = torch.tensor([0], dtype=offs.dtype, device=offs.device)
242-
group_sizes = torch.diff(offs, prepend=zero).to(torch.int64)
243-
244238
# Run the grouped mm
245239
mxfp8_us = benchmark_cuda_function_in_microseconds(
246-
torch.ops.fbgemm.mx8mx8bf16_grouped_stacked,
240+
torch._scaled_grouped_mm,
247241
A_fp8,
248-
B_fp8,
242+
B_fp8.transpose(-2, -1),
249243
A_scales_blocked,
250244
B_scales_blocked,
251-
group_sizes,
252-
starting_row_after_padding=starting_row_after_padding,
245+
offs=offs,
253246
)
254247
return mxfp8_us
255248

benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_2d_M_groups.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def get_configs() -> List[ExperimentConfig]:
5151
block_size = 32
5252
input_shapes = [
5353
(16640, 5120 // block_size),
54+
(131072, 5120 // block_size),
5455
]
5556
num_groups = [16]
5657
configs = []
@@ -78,18 +79,21 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
7879
)
7980

8081
Mg, K = input_shape
81-
input_group_offsets = generate_jagged_offs(num_groups, Mg, multiple_of=32)
82+
block_size = 32
83+
input_group_offsets = generate_jagged_offs(num_groups, Mg, multiple_of=block_size)
8284

8385
# bench torch
8486
compiled_run_torch = torch.compile(torch_to_blocked_2d_M_groups)
8587
torch_out_scales, torch_group_offs = compiled_run_torch(
86-
input_tensor, input_group_offsets, K
88+
input_tensor,
89+
input_group_offsets,
90+
block_size=block_size,
8791
)
8892
torch_time_us = benchmark_cuda_function_in_microseconds(
8993
compiled_run_torch,
9094
input_tensor,
9195
input_group_offsets,
92-
K,
96+
block_size=block_size,
9397
)
9498

9599
# bench triton

test/prototype/moe_training/test_kernels.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def test_triton_mx_block_rearrange_2d_M_groups(
230230

231231
# torch reference
232232
ref_out_scales, _ = torch_to_blocked_2d_M_groups(
233-
e8m0_scales, input_group_offsets, k, block_size=block_size
233+
e8m0_scales, input_group_offsets, block_size=block_size
234234
)
235235

236236
# triton kernel

torchao/prototype/moe_training/kernels/mxfp8/quant.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616

1717
def torch_to_blocked_2d_M_groups(
18-
x_scales: Tensor, group_offs: Tensor, K: int, block_size: int = 32
18+
x_scales: Tensor, group_offs: Tensor, block_size: int = 32
1919
) -> Tuple[Tensor, Tensor]:
2020
"""
2121
Convert scales to blocked format for a 2D tensor (input activations / token groups),
@@ -34,14 +34,14 @@ def torch_to_blocked_2d_M_groups(
3434

3535
assert x_scales.ndim == 2, "x_scales must be 2D"
3636
assert block_size == 32, "Only block_size=32 is supported for now"
37-
total_M, _ = x_scales.shape
37+
total_M, scale_cols = x_scales.shape
3838
num_groups = group_offs.shape[0]
3939

4040
# Each group will require a variable amount of padding, so to avoid d2h sync causing by iterating over each group,
4141
# the Triton kernenl will use an upper bound of adding 128 padding rows to each group.
4242
# (This torch impl is used as a reference for correctness, so we must match the triton kernel's impl).
4343
total_M_padded = total_M + num_groups * 128
44-
blocked_scales = x_scales.new_zeros(total_M_padded, K // block_size)
44+
blocked_scales = x_scales.new_zeros(total_M_padded, scale_cols)
4545
start_row_after_padding_list = [0]
4646
group_start_idx = 0
4747
for i, group_end_idx in enumerate(group_offs.tolist()):
@@ -56,8 +56,7 @@ def torch_to_blocked_2d_M_groups(
5656
group_scales_blocked = to_blocked(group_scales)
5757

5858
# Calculate the start row after padding
59-
scaling_groups_per_row = K // block_size
60-
rows_for_group = group_scales_blocked.numel() // scaling_groups_per_row
59+
rows_for_group = group_scales_blocked.numel() // scale_cols
6160
new_start_row = prev_start_row_after_padding + rows_for_group
6261
start_row_after_padding_list.append(new_start_row)
6362

@@ -67,7 +66,7 @@ def torch_to_blocked_2d_M_groups(
6766
prev_start_row_after_padding : prev_start_row_after_padding
6867
+ group_rows_padded,
6968
:,
70-
] = group_scales_blocked.reshape(-1, K // block_size)
69+
] = group_scales_blocked.reshape(-1, scale_cols)
7170

7271
# Update next group start index
7372
group_start_idx = group_end_idx

torchao/prototype/mx_formats/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ def to_blocked(input_matrix, use_triton_kernel: bool = False) -> Tensor:
6868
# Rearrange the blocks
6969
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
7070
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
71-
7271
return rearranged.flatten()
7372

7473

0 commit comments

Comments
 (0)