Skip to content

Commit 7cd1a74

Browse files
[mxfp8 moe training] compute prefix sum of group sizes inside kernel intead of precomputing
stack-info: PR: #3285, branch: danielvegamyhre/stack/82
1 parent f856d36 commit 7cd1a74

File tree

2 files changed

+42
-31
lines changed

2 files changed

+42
-31
lines changed

torchao/prototype/moe_training/kernels/mxfp8/quant.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,6 @@ def compute_blocked_scale_offsets_for_K_groups(
223223
def triton_mx_block_rearrange_2d_M_groups(
224224
scales_tensor: torch.Tensor,
225225
input_group_end_offsets: torch.Tensor,
226-
output_group_start_offsets: torch.Tensor,
227226
) -> torch.Tensor:
228227
"""
229228
Rearranges an E8M0 tensor scale to block-scaled swizzle format,
@@ -275,35 +274,58 @@ def triton_mx_block_rearrange_2d_M_groups(
275274
scales_tensor.stride(1),
276275
rows,
277276
cols,
278-
num_groups,
279277
# Original offsets (to read from)
280278
input_group_end_offsets,
281279
# Output scales tensor and group offsets after padding (to write to)
282280
output.view(torch.uint8),
283281
output.stride(0),
284-
output_group_start_offsets,
285282
output_stride_per_block,
286283
output_stride_per_row_of_blocks,
284+
num_groups=num_groups,
287285
BLOCK_ROWS=BLOCK_ROWS,
288286
BLOCK_COLS=BLOCK_COLS,
289287
)
290288
return output
291289

292290

291+
@triton.jit
292+
def _blocked_group_start_idx(
293+
group_pid,
294+
orig_offsets,
295+
num_groups: tl.constexpr,
296+
padding_size: tl.constexpr,
297+
):
298+
"""Prefix sum to compute the start index of a given group."""
299+
offsets = tl.load(orig_offsets + tl.arange(0, num_groups))
300+
group_sizes = tl.where(
301+
tl.arange(0, num_groups) > 0,
302+
offsets
303+
- tl.load(
304+
orig_offsets + tl.arange(0, num_groups) - 1,
305+
mask=tl.arange(0, num_groups) > 0,
306+
other=0,
307+
),
308+
offsets,
309+
)
310+
padded_sizes = tl.cdiv(group_sizes, padding_size) * padding_size
311+
prefix_mask = tl.arange(0, num_groups) < group_pid
312+
group_start_idx = tl.sum(tl.where(prefix_mask, padded_sizes, 0))
313+
return group_start_idx
314+
315+
293316
@triton.jit
294317
def triton_scale_swizzle_M_groups(
295318
scales_ptr, # (M, K//block_size)
296319
scales_stride_dim0,
297320
scales_stride_dim1,
298321
scale_rows,
299322
scale_cols,
300-
num_groups,
301323
orig_offsets, # (num_groups,)
302324
output_scales_ptr,
303325
output_scales_stride_dim0,
304-
output_scales_group_offsets, # (num_groups,)
305326
output_stride_per_block,
306327
output_stride_per_row_of_blocks,
328+
num_groups: tl.constexpr,
307329
BLOCK_ROWS: tl.constexpr,
308330
BLOCK_COLS: tl.constexpr,
309331
):
@@ -316,10 +338,13 @@ def triton_scale_swizzle_M_groups(
316338
input_group_end_row = tl.load(
317339
orig_offsets + group_pid, mask=group_pid < num_groups, other=0
318340
)
319-
# Output scales start row we will begin writing to
320-
output_group_start_row = tl.load(
321-
output_scales_group_offsets + group_pid, mask=group_pid < num_groups, other=0
341+
342+
# Calculate this group's start row after blocked format padding, by doing a prefix sum
343+
# of each previous group's padded size.
344+
output_group_start_row = _blocked_group_start_idx(
345+
group_pid, orig_offsets, num_groups, 128
322346
)
347+
323348
# Calculate destination indices for each row and col in block swizzled layout.
324349
# We can reuse this swizzle transformation on each block of data we read.
325350
row_offs = tl.arange(0, BLOCK_ROWS)[:, None]
@@ -489,7 +514,6 @@ def triton_scale_swizzle_per_group_3d(
489514
def triton_mx_block_rearrange_2d_K_groups(
490515
scales_tensor: torch.Tensor,
491516
input_group_end_offsets: torch.Tensor,
492-
output_group_start_offsets: torch.Tensor,
493517
) -> torch.Tensor:
494518
"""
495519
Rearranges an E8M0 tensor scale to block-scaled swizzle format on a per group basis,
@@ -538,13 +562,10 @@ def triton_mx_block_rearrange_2d_K_groups(
538562
rows,
539563
cols,
540564
padded_rows,
541-
num_groups,
542-
# Original offsets (to read from)
543565
input_group_end_offsets,
544-
# Output scales tensor and group offsets after padding (to write to)
545566
output.view(torch.uint8),
546-
output_group_start_offsets,
547567
output_stride_per_block,
568+
num_groups=num_groups,
548569
BLOCK_ROWS=BLOCK_ROWS,
549570
BLOCK_COLS=BLOCK_COLS,
550571
DEBUG=False,
@@ -560,11 +581,10 @@ def triton_scale_swizzle_2d_K_groups(
560581
scale_rows,
561582
scale_cols,
562583
padded_rows,
563-
num_groups,
564584
orig_offsets, # (num_groups,)
565585
output_scales_ptr,
566-
output_scales_group_offsets, # (num_groups,)
567586
output_stride_per_block,
587+
num_groups: tl.constexpr,
568588
BLOCK_ROWS: tl.constexpr,
569589
BLOCK_COLS: tl.constexpr,
570590
DEBUG: tl.constexpr = False,
@@ -578,8 +598,11 @@ def triton_scale_swizzle_2d_K_groups(
578598
)
579599
input_group_end_col = tl.load(orig_offsets + group_pid)
580600

581-
# Output scales start row we will begin writing to
582-
output_group_start_col = tl.load(output_scales_group_offsets + group_pid)
601+
# Calculate this group's start row after blocked format padding, by doing a prefix sum
602+
# of each previous group's padded size.
603+
output_group_start_col = _blocked_group_start_idx(
604+
group_pid, orig_offsets, num_groups, 4
605+
)
583606

584607
row_offs = tl.arange(0, BLOCK_ROWS)[:, None]
585608
col_offs = tl.arange(0, BLOCK_COLS)[None, :]

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
triton_fp8_rowwise_3d_transpose_rhs,
1818
)
1919
from torchao.prototype.moe_training.kernels.mxfp8 import (
20-
compute_blocked_scale_offsets_for_K_groups,
21-
compute_blocked_scale_offsets_for_M_groups,
2220
mxfp8_quantize_cuda_3d,
2321
triton_mx_block_rearrange_2d_K_groups,
2422
triton_mx_block_rearrange_2d_M_groups,
@@ -329,13 +327,9 @@ def forward(
329327
)
330328

331329
# Convert scales to blocked format for 2d-3d grouped mm
332-
_, blocked_scales_group_offsets_2d3d = (
333-
compute_blocked_scale_offsets_for_M_groups(offs)
334-
)
335330
A_scales_blocked = triton_mx_block_rearrange_2d_M_groups(
336331
A_scale,
337332
offs,
338-
blocked_scales_group_offsets_2d3d,
339333
)
340334
B_scales_blocked = triton_mx_block_rearrange_per_group_3d(B_scales)
341335

@@ -350,7 +344,7 @@ def forward(
350344
out_dtype=out_dtype,
351345
)
352346

353-
ctx.save_for_backward(A, B_t, offs, blocked_scales_group_offsets_2d3d)
347+
ctx.save_for_backward(A, B_t, offs)
354348
ctx.block_size = block_size
355349
ctx.out_dtype = out_dtype
356350
ctx.emulated = emulated
@@ -359,7 +353,7 @@ def forward(
359353

360354
@staticmethod
361355
def backward(ctx, grad_out: torch.Tensor):
362-
A, B_t, offs, blocked_scales_group_offsets_2d3d = ctx.saved_tensors
356+
A, B_t, offs = ctx.saved_tensors
363357
block_size = ctx.block_size
364358
out_dtype = ctx.out_dtype
365359
use_triton_for_dim0_cast = ctx.use_triton_for_dim0_cast
@@ -390,7 +384,6 @@ def backward(ctx, grad_out: torch.Tensor):
390384
grad_out_scales_blocked = triton_mx_block_rearrange_2d_M_groups(
391385
grad_out_scale,
392386
offs,
393-
blocked_scales_group_offsets_2d3d,
394387
)
395388
B_scales_blocked = triton_mx_block_rearrange_per_group_3d(B_scales)
396389

@@ -436,18 +429,13 @@ def backward(ctx, grad_out: torch.Tensor):
436429

437430
# Convert scales to blocked format for 2d-2d grouped mm
438431
scale_group_offsets = offs // block_size
439-
_, blocked_scale_group_offsets = compute_blocked_scale_offsets_for_K_groups(
440-
scale_group_offsets
441-
)
442432
grad_out_t_scales_blocked = triton_mx_block_rearrange_2d_K_groups(
443433
grad_out_t_scales,
444434
scale_group_offsets,
445-
blocked_scale_group_offsets,
446435
)
447436
A_t_scales_blocked = triton_mx_block_rearrange_2d_K_groups(
448437
A_t_scales,
449438
scale_group_offsets,
450-
blocked_scale_group_offsets,
451439
)
452440

453441
# grad_B_t = scaled grouped mm of (N,total_M) @ (total_M,K) = (E,N,K)

0 commit comments

Comments
 (0)