Skip to content

Commit 6a43eb5

Browse files
[mxfp8 moe training] use mxfp8 dim0 kernel in mxfp8 all2all; simplify mxfp8 dim0 kernel
1 parent 168d4b7 commit 6a43eb5

File tree

2 files changed

+28
-39
lines changed

2 files changed

+28
-39
lines changed

torchao/prototype/moe_training/kernels/mxfp8/comms.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
blockwise_barrier,
1212
sync_threads,
1313
)
14-
from torchao.prototype.mx_formats.config import ScaleCalculationMode
14+
from torchao.prototype.mx_formats.kernels import triton_to_mxfp8_dim0
1515
from torchao.prototype.mx_formats.mx_tensor import to_dtype, to_mx
1616

1717

@@ -473,11 +473,9 @@ def forward(
473473
"""
474474
# Quantize input
475475
block_size = 32
476-
input_scales, input_data = to_mx(
476+
input_data, input_scales = triton_to_mxfp8_dim0(
477477
input,
478-
elem_dtype=torch.float8_e4m3fn,
479-
block_size=block_size,
480-
scaling_mode=ScaleCalculationMode.RCEIL,
478+
inner_block_size=block_size,
481479
)
482480

483481
# Dispatch data (async)
@@ -529,11 +527,9 @@ def backward(ctx, grad_output_hp):
529527

530528
# Quantize grad_output
531529
block_size = 32
532-
grad_out_scales, grad_out_data = to_mx(
530+
grad_out_data, grad_out_scales = triton_to_mxfp8_dim0(
533531
grad_output_hp,
534-
elem_dtype=torch.float8_e4m3fn,
535-
block_size=block_size,
536-
scaling_mode=ScaleCalculationMode.RCEIL,
532+
inner_block_size=block_size,
537533
)
538534

539535
# Dispatch data (async)

torchao/prototype/mx_formats/kernels.py

Lines changed: 23 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,7 +1040,7 @@ def to_mxfp8_dim1_kernel(
10401040

10411041
@triton.autotune(
10421042
configs=_get_mxfp8_dim1_kernel_autotune_configs(),
1043-
key=["n_rows", "n_cols", "INNER_BLOCK_SIZE"],
1043+
key=["n_cols", "INNER_BLOCK_SIZE"],
10441044
)
10451045
@triton.jit
10461046
def to_mxfp8_dim0_kernel(
@@ -1118,33 +1118,31 @@ def to_mxfp8_dim0_kernel(
11181118
# Store the row-normalized result in row-major format
11191119
tl.store(output_ptr + row_major_offsets, row_normalized, mask=mask)
11201120

1121-
# reshape row_scale_e8m0_r for proper storage
1122-
# shape: (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,)
1123-
row_scale_e8m0 = row_scale_e8m0_r.reshape(ROW_TILE_SIZE * BLOCKS_PER_COL_TILE)
1121+
# For rowwise quantization, scale tensor has shape (n_rows, n_cols // INNER_BLOCK_SIZE)
1122+
# Calculate base offset for this tile's scales
1123+
scales_per_row = n_cols // INNER_BLOCK_SIZE
11241124

1125-
row_scale_start_offsets = (
1126-
(pid_row * ROW_TILE_SIZE * (n_cols // COL_TILE_SIZE))
1127-
* BLOCKS_PER_COL_TILE # number of blocks seen so far
1128-
+ pid_col * BLOCKS_PER_COL_TILE # increment BLOCKS_PER_COL_TILE
1125+
# Create row and column indices for scale storage
1126+
scale_row_indices = tl.arange(0, ROW_TILE_SIZE)[:, None] + (
1127+
pid_row * ROW_TILE_SIZE
1128+
)
1129+
scale_col_indices = tl.arange(0, BLOCKS_PER_COL_TILE)[None, :] + (
1130+
pid_col * BLOCKS_PER_COL_TILE
11291131
)
11301132

1131-
row_scale_start_ptr = row_scale_ptr + row_scale_start_offsets
1132-
1133-
# calculate row_scale_indices
1134-
row_scale_indices = tl.arange(0, ROW_TILE_SIZE * BLOCKS_PER_COL_TILE)
1133+
# Calculate linear indices into scale tensor
1134+
scale_offsets = scale_row_indices * scales_per_row + scale_col_indices
11351135

1136-
# How many values are in all the other rows for this col_pid, need to jump
1137-
# over them for every BLOCKS_PER_COL_TILE values
1138-
jump_vals_per_row = (n_cols - COL_TILE_SIZE) // INNER_BLOCK_SIZE
1136+
# Create masks for valid scale indices
1137+
scale_row_mask = scale_row_indices < n_rows
1138+
scale_col_mask = scale_col_indices < scales_per_row
1139+
scale_mask = scale_row_mask & scale_col_mask
11391140

1140-
# example transformation (specifics depend on tile sizes):
1141-
# [0, 1, 2, 3, 4, 5, 6, 7] -> [0, 1, 4, 5, 8, 9, 12, 13]
1142-
row_scale_indices = row_scale_indices + (
1143-
(row_scale_indices // BLOCKS_PER_COL_TILE) * jump_vals_per_row
1144-
)
1141+
# Reshape scale values and masks to match the flattened layout
1142+
row_scale_e8m0_2d = row_scale_e8m0_r.reshape(ROW_TILE_SIZE, BLOCKS_PER_COL_TILE)
11451143

1146-
# Store the scales
1147-
tl.store(row_scale_start_ptr + row_scale_indices, row_scale_e8m0)
1144+
# Store the scales with proper masking
1145+
tl.store(row_scale_ptr + scale_offsets, row_scale_e8m0_2d, mask=scale_mask)
11481146

11491147
@triton_op("torchao::triton_to_mxfp8_dim0", mutates_args={})
11501148
def triton_to_mxfp8_dim0(
@@ -1167,14 +1165,9 @@ def triton_to_mxfp8_dim0(
11671165
x = x.reshape(-1, x.shape[-1])
11681166
n_rows, n_cols = x.shape
11691167

1170-
# Masking of loads and stores is not well tested yet, so for now enforce
1171-
# shapes which do not need masking. Note that this condition depends on max values of
1172-
# ROW_TILE_SIZE and COL_TILE_SIZE, which are autotuned above.
1173-
# TODO(future): implement and test masking and remove this restriction
1174-
max_row_tile_size = 128
1175-
max_col_tile_size = 128
1176-
assert n_rows % max_row_tile_size == 0, "unsupported"
1177-
assert n_cols % max_col_tile_size == 0, "unsupported"
1168+
assert n_cols % inner_block_size == 0, (
1169+
"columns must be divisible by inner block size"
1170+
)
11781171

11791172
# Create output tensors
11801173
output = torch.empty(

0 commit comments

Comments
 (0)