Skip to content

Commit 1e97f00

Browse files
[mxfp8 moe training] compute prefix sum of group sizes inside kernel intead of precomputing
stack-info: PR: #3285, branch: danielvegamyhre/stack/82
1 parent f856d36 commit 1e97f00

File tree

4 files changed

+42
-46
lines changed

4 files changed

+42
-46
lines changed

test/prototype/moe_training/test_kernels.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
triton_fp8_per_group_rowwise_scales,
2222
)
2323
from torchao.prototype.moe_training.kernels.mxfp8 import (
24-
compute_blocked_scale_offsets_for_K_groups,
25-
compute_blocked_scale_offsets_for_M_groups,
2624
torch_to_blocked_2d_K_groups,
2725
torch_to_blocked_2d_M_groups,
2826
torch_to_blocked_per_group_3d,
@@ -236,13 +234,9 @@ def test_triton_mx_block_rearrange_2d_M_groups(
236234
)
237235

238236
# triton kernel
239-
_, output_group_offsets = compute_blocked_scale_offsets_for_M_groups(
240-
input_group_offsets
241-
)
242237
triton_out_scales = triton_mx_block_rearrange_2d_M_groups(
243238
e8m0_scales,
244239
input_group_offsets,
245-
output_group_offsets,
246240
)
247241
assert torch.allclose(ref_out_scales, triton_out_scales, atol=0, rtol=0), (
248242
"blocked scales not equal"
@@ -306,16 +300,9 @@ def test_triton_mx_block_rearrange_2d_K_groups(
306300
)
307301

308302
# triton kernel
309-
_, output_group_offsets = compute_blocked_scale_offsets_for_K_groups(
310-
scale_group_offsets
311-
)
312-
assert torch.equal(output_group_offsets, ref_start_cols_after_padding), (
313-
"output scale group start offsets not equal"
314-
)
315303
triton_out_scales = triton_mx_block_rearrange_2d_K_groups(
316304
e8m0_scales,
317305
scale_group_offsets,
318-
output_group_offsets,
319306
)
320307
assert torch.equal(ref_out_scales, triton_out_scales), "blocked scales not equal"
321308

torchao/prototype/moe_training/kernels/mxfp8/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
from torchao.prototype.moe_training.kernels.mxfp8.quant import (
2-
compute_blocked_scale_offsets_for_K_groups, # noqa: F401
3-
compute_blocked_scale_offsets_for_M_groups, # noqa: F401
42
mxfp8_quantize_cuda_3d, # noqa: F401
53
torch_to_blocked_2d_K_groups, # noqa: F401
64
torch_to_blocked_2d_M_groups, # noqa: F401

torchao/prototype/moe_training/kernels/mxfp8/quant.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,6 @@ def compute_blocked_scale_offsets_for_K_groups(
223223
def triton_mx_block_rearrange_2d_M_groups(
224224
scales_tensor: torch.Tensor,
225225
input_group_end_offsets: torch.Tensor,
226-
output_group_start_offsets: torch.Tensor,
227226
) -> torch.Tensor:
228227
"""
229228
Rearranges an E8M0 tensor scale to block-scaled swizzle format,
@@ -275,15 +274,14 @@ def triton_mx_block_rearrange_2d_M_groups(
275274
scales_tensor.stride(1),
276275
rows,
277276
cols,
278-
num_groups,
279277
# Original offsets (to read from)
280278
input_group_end_offsets,
281279
# Output scales tensor and group offsets after padding (to write to)
282280
output.view(torch.uint8),
283281
output.stride(0),
284-
output_group_start_offsets,
285282
output_stride_per_block,
286283
output_stride_per_row_of_blocks,
284+
num_groups=num_groups,
287285
BLOCK_ROWS=BLOCK_ROWS,
288286
BLOCK_COLS=BLOCK_COLS,
289287
)
@@ -297,13 +295,12 @@ def triton_scale_swizzle_M_groups(
297295
scales_stride_dim1,
298296
scale_rows,
299297
scale_cols,
300-
num_groups,
301298
orig_offsets, # (num_groups,)
302299
output_scales_ptr,
303300
output_scales_stride_dim0,
304-
output_scales_group_offsets, # (num_groups,)
305301
output_stride_per_block,
306302
output_stride_per_row_of_blocks,
303+
num_groups: tl.constexpr,
307304
BLOCK_ROWS: tl.constexpr,
308305
BLOCK_COLS: tl.constexpr,
309306
):
@@ -316,10 +313,13 @@ def triton_scale_swizzle_M_groups(
316313
input_group_end_row = tl.load(
317314
orig_offsets + group_pid, mask=group_pid < num_groups, other=0
318315
)
319-
# Output scales start row we will begin writing to
320-
output_group_start_row = tl.load(
321-
output_scales_group_offsets + group_pid, mask=group_pid < num_groups, other=0
316+
317+
# Calculate this group's start row after blocked format padding, by doing a prefix sum
318+
# of each previous group's padded size.
319+
output_group_start_row = _blocked_group_start_idx(
320+
group_pid, orig_offsets, num_groups, 128
322321
)
322+
323323
# Calculate destination indices for each row and col in block swizzled layout.
324324
# We can reuse this swizzle transformation on each block of data we read.
325325
row_offs = tl.arange(0, BLOCK_ROWS)[:, None]
@@ -489,7 +489,6 @@ def triton_scale_swizzle_per_group_3d(
489489
def triton_mx_block_rearrange_2d_K_groups(
490490
scales_tensor: torch.Tensor,
491491
input_group_end_offsets: torch.Tensor,
492-
output_group_start_offsets: torch.Tensor,
493492
) -> torch.Tensor:
494493
"""
495494
Rearranges an E8M0 tensor scale to block-scaled swizzle format on a per group basis,
@@ -538,13 +537,10 @@ def triton_mx_block_rearrange_2d_K_groups(
538537
rows,
539538
cols,
540539
padded_rows,
541-
num_groups,
542-
# Original offsets (to read from)
543540
input_group_end_offsets,
544-
# Output scales tensor and group offsets after padding (to write to)
545541
output.view(torch.uint8),
546-
output_group_start_offsets,
547542
output_stride_per_block,
543+
num_groups=num_groups,
548544
BLOCK_ROWS=BLOCK_ROWS,
549545
BLOCK_COLS=BLOCK_COLS,
550546
DEBUG=False,
@@ -560,11 +556,10 @@ def triton_scale_swizzle_2d_K_groups(
560556
scale_rows,
561557
scale_cols,
562558
padded_rows,
563-
num_groups,
564559
orig_offsets, # (num_groups,)
565560
output_scales_ptr,
566-
output_scales_group_offsets, # (num_groups,)
567561
output_stride_per_block,
562+
num_groups: tl.constexpr,
568563
BLOCK_ROWS: tl.constexpr,
569564
BLOCK_COLS: tl.constexpr,
570565
DEBUG: tl.constexpr = False,
@@ -578,8 +573,11 @@ def triton_scale_swizzle_2d_K_groups(
578573
)
579574
input_group_end_col = tl.load(orig_offsets + group_pid)
580575

581-
# Output scales start row we will begin writing to
582-
output_group_start_col = tl.load(output_scales_group_offsets + group_pid)
576+
# Calculate this group's start row after blocked format padding, by doing a prefix sum
577+
# of each previous group's padded size.
578+
output_group_start_col = _blocked_group_start_idx(
579+
group_pid, orig_offsets, num_groups, 4
580+
)
583581

584582
row_offs = tl.arange(0, BLOCK_ROWS)[:, None]
585583
col_offs = tl.arange(0, BLOCK_COLS)[None, :]
@@ -651,6 +649,31 @@ def _dest_indices_for_block(
651649
return dest_indices_flat
652650

653651

652+
@triton.jit
653+
def _blocked_group_start_idx(
654+
group_pid,
655+
orig_offsets,
656+
num_groups: tl.constexpr,
657+
padding_size: tl.constexpr,
658+
):
659+
"""Prefix sum to compute the start index of a given group."""
660+
offsets = tl.load(orig_offsets + tl.arange(0, num_groups))
661+
prev_offsets = tl.load(
662+
orig_offsets + tl.arange(0, num_groups) - 1,
663+
mask=tl.arange(0, num_groups) > 0,
664+
other=0,
665+
)
666+
group_sizes = tl.where(
667+
tl.arange(0, num_groups) > 0,
668+
offsets - prev_offsets,
669+
offsets,
670+
)
671+
padded_sizes = tl.cdiv(group_sizes, padding_size) * padding_size
672+
prefix_mask = tl.arange(0, num_groups) < group_pid
673+
group_start_idx = tl.sum(tl.where(prefix_mask, padded_sizes, 0))
674+
return group_start_idx
675+
676+
654677
mxfp8_cuda_extension_available = False
655678
if is_sm_at_least_100():
656679
try:

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
triton_fp8_rowwise_3d_transpose_rhs,
1818
)
1919
from torchao.prototype.moe_training.kernels.mxfp8 import (
20-
compute_blocked_scale_offsets_for_K_groups,
21-
compute_blocked_scale_offsets_for_M_groups,
2220
mxfp8_quantize_cuda_3d,
2321
triton_mx_block_rearrange_2d_K_groups,
2422
triton_mx_block_rearrange_2d_M_groups,
@@ -329,13 +327,9 @@ def forward(
329327
)
330328

331329
# Convert scales to blocked format for 2d-3d grouped mm
332-
_, blocked_scales_group_offsets_2d3d = (
333-
compute_blocked_scale_offsets_for_M_groups(offs)
334-
)
335330
A_scales_blocked = triton_mx_block_rearrange_2d_M_groups(
336331
A_scale,
337332
offs,
338-
blocked_scales_group_offsets_2d3d,
339333
)
340334
B_scales_blocked = triton_mx_block_rearrange_per_group_3d(B_scales)
341335

@@ -350,7 +344,7 @@ def forward(
350344
out_dtype=out_dtype,
351345
)
352346

353-
ctx.save_for_backward(A, B_t, offs, blocked_scales_group_offsets_2d3d)
347+
ctx.save_for_backward(A, B_t, offs)
354348
ctx.block_size = block_size
355349
ctx.out_dtype = out_dtype
356350
ctx.emulated = emulated
@@ -359,7 +353,7 @@ def forward(
359353

360354
@staticmethod
361355
def backward(ctx, grad_out: torch.Tensor):
362-
A, B_t, offs, blocked_scales_group_offsets_2d3d = ctx.saved_tensors
356+
A, B_t, offs = ctx.saved_tensors
363357
block_size = ctx.block_size
364358
out_dtype = ctx.out_dtype
365359
use_triton_for_dim0_cast = ctx.use_triton_for_dim0_cast
@@ -390,7 +384,6 @@ def backward(ctx, grad_out: torch.Tensor):
390384
grad_out_scales_blocked = triton_mx_block_rearrange_2d_M_groups(
391385
grad_out_scale,
392386
offs,
393-
blocked_scales_group_offsets_2d3d,
394387
)
395388
B_scales_blocked = triton_mx_block_rearrange_per_group_3d(B_scales)
396389

@@ -436,18 +429,13 @@ def backward(ctx, grad_out: torch.Tensor):
436429

437430
# Convert scales to blocked format for 2d-2d grouped mm
438431
scale_group_offsets = offs // block_size
439-
_, blocked_scale_group_offsets = compute_blocked_scale_offsets_for_K_groups(
440-
scale_group_offsets
441-
)
442432
grad_out_t_scales_blocked = triton_mx_block_rearrange_2d_K_groups(
443433
grad_out_t_scales,
444434
scale_group_offsets,
445-
blocked_scale_group_offsets,
446435
)
447436
A_t_scales_blocked = triton_mx_block_rearrange_2d_K_groups(
448437
A_t_scales,
449438
scale_group_offsets,
450-
blocked_scale_group_offsets,
451439
)
452440

453441
# grad_B_t = scaled grouped mm of (N,total_M) @ (total_M,K) = (E,N,K)

0 commit comments

Comments
 (0)