Skip to content

Commit 664124a

Browse files
[mxfp8 moe training] add triton kernel for mxfp8 quantization along dim0 (#3128)
stack-info: PR: #3128, branch: danielvegamyhre/stack/75
1 parent fdd8fe7 commit 664124a

File tree

3 files changed

+227
-2
lines changed

3 files changed

+227
-2
lines changed

benchmarks/mx_formats/cast_bench.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from torchao.prototype.mx_formats.config import ScaleCalculationMode
1515
from torchao.prototype.mx_formats.kernels import (
16+
triton_to_mxfp8_dim0,
1617
triton_to_mxfp8_dim1,
1718
)
1819
from torchao.prototype.mx_formats.mx_tensor import to_mx
@@ -97,6 +98,7 @@ def run(
9798
"dim0_mxfp8_floor",
9899
"dim0_mxfp4_floor",
99100
"dim0_mxfp8_rceil",
101+
"dim0_mxfp8_triton_floor",
100102
"dim1_mxfp8_floor",
101103
"dim1_mxfp8_rceil",
102104
"dim1_mxfp8_triton_floor",
@@ -222,6 +224,22 @@ def run(
222224
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
223225
bps = (bytes_r + bytes_w) / (time_us / 1e6)
224226

227+
elif mode == "dim0_mxfp8_triton_floor":
228+
y_d0, s_d0 = triton_to_mxfp8_dim0(x, inner_block_size=BLOCK_SIZE)
229+
230+
for _ in range(2):
231+
__ = triton_to_mxfp8_dim0(x, inner_block_size=BLOCK_SIZE)
232+
time_us = benchmark_cuda_function_in_microseconds(
233+
lambda x, b: triton_to_mxfp8_dim0(x, inner_block_size=BLOCK_SIZE),
234+
x,
235+
BLOCK_SIZE,
236+
)
237+
assert y_d0.dtype == torch.float8_e4m3fn
238+
assert s_d0.dtype == torch.float8_e8m0fnu
239+
bytes_r = x.numel() * bytes_per_el_bf16
240+
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
241+
bps = (bytes_r + bytes_w) / (time_us / 1e6)
242+
225243
elif mode == "dim1_mxfp8_floor":
226244
to_mx_dim1_reference_c = torch.compile(to_mx_dim1_reference)
227245
y_d1, s_d1 = to_mx_dim1_reference_c(x, BLOCK_SIZE)

test/prototype/mx_formats/test_kernels.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
pack_uint6,
3838
triton_f6_e2m3_to_bf16,
3939
triton_f6_e3m2_to_bf16,
40+
triton_to_mxfp8_dim0,
4041
triton_to_mxfp8_dim1,
4142
triton_to_mxfp8_dim1_reference,
4243
unpack_uint4,
@@ -451,6 +452,23 @@ def test_fp6_e3m2_pack_unpack():
451452
assert torch.all(orig_vals_f6_packed_unpacked == orig_vals)
452453

453454

455+
def triton_to_mxfp8_dim0_reference(
456+
x_hp: torch.Tensor, block_size
457+
) -> tuple[torch.Tensor, torch.Tensor]:
458+
"""
459+
A reference version of `triton_to_mxfp8_dim0` for rowwise quantization.
460+
"""
461+
from torchao.prototype.mx_formats.mx_tensor import to_mx
462+
463+
# cast across dim0 (rowwise) - no transpose needed
464+
scale_e8m0_dim0, x_hp_d0_normalized = to_mx(x_hp, torch.float8_e4m3fn, block_size)
465+
scale_e8m0_dim0 = scale_e8m0_dim0.view(torch.float8_e8m0fnu)
466+
return (
467+
x_hp_d0_normalized,
468+
scale_e8m0_dim0,
469+
)
470+
471+
454472
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
455473
@pytest.mark.skipif(
456474
not is_sm_at_least_89(),
@@ -466,6 +484,21 @@ def test_triton_mxfp8_dim1_randn(M, K):
466484
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)
467485

468486

487+
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
488+
@pytest.mark.skipif(
489+
not is_sm_at_least_89(),
490+
reason="float8 in triton requires CUDA capability 8.9 or greater",
491+
)
492+
@pytest.mark.parametrize("M", (256, 2048, 131072))
493+
@pytest.mark.parametrize("K", (256, 5120, 7168))
494+
def test_triton_mxfp8_dim0_randn(M, K):
495+
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
496+
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(x, block_size=32)
497+
x_mx_t, x_s_t = triton_to_mxfp8_dim0(x, inner_block_size=32)
498+
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
499+
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)
500+
501+
469502
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
470503
@pytest.mark.parametrize(
471504
"shape",

torchao/prototype/mx_formats/kernels.py

Lines changed: 176 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -891,13 +891,13 @@ def _get_mxfp8_dim1_kernel_autotune_configs():
891891

892892
@triton.autotune(
893893
configs=_get_mxfp8_dim1_kernel_autotune_configs(),
894-
key=["n_rows", "n_cols", "INNER_BLOCK_SIZE"],
894+
key=["n_cols", "INNER_BLOCK_SIZE"],
895895
)
896896
@triton.jit
897897
def to_mxfp8_dim1_kernel(
898898
x_ptr, # pointer to input tensor
899899
output_col_major_ptr, # pointer to column-major output tensor (column-normalized)
900-
col_scale_ptr, # pointer to store column-wise maximum absolute values
900+
col_scale_ptr, # pointer to store scales
901901
n_rows, # number of rows in the tensor
902902
n_cols, # number of columns in the tensor
903903
ROW_TILE_SIZE: tl.constexpr,
@@ -1038,6 +1038,174 @@ def to_mxfp8_dim1_kernel(
10381038
# TODO(future): mask this store
10391039
tl.store(col_scale_start_ptr + col_scale_indices, col_scale_e8m0)
10401040

1041+
@triton.autotune(
1042+
configs=_get_mxfp8_dim1_kernel_autotune_configs(),
1043+
key=["n_cols", "INNER_BLOCK_SIZE"],
1044+
)
1045+
@triton.jit
1046+
def to_mxfp8_dim0_kernel(
1047+
x_ptr, # pointer to input tensor
1048+
output_ptr, # pointer to output tensor (row-normalized)
1049+
row_scale_ptr, # pointer to store row-wise maximum absolute values
1050+
n_rows, # number of rows in the tensor
1051+
n_cols, # number of columns in the tensor
1052+
ROW_TILE_SIZE: tl.constexpr,
1053+
COL_TILE_SIZE: tl.constexpr,
1054+
INNER_BLOCK_SIZE: tl.constexpr, # should be 32 for MX
1055+
):
1056+
"""
1057+
Quantizes a high precision tensor to mxfp8 rowwise (1x32 scaling granularity).
1058+
1059+
This is the counterpart to to_mxfp8_dim1_kernel which does columnwise quantization.
1060+
Instead of transposing and scaling across columns, this kernel scales across rows.
1061+
"""
1062+
1063+
BLOCKS_PER_COL_TILE: tl.constexpr = COL_TILE_SIZE // INNER_BLOCK_SIZE
1064+
1065+
# Get program ID
1066+
pid_row = tl.program_id(0)
1067+
pid_col = tl.program_id(1)
1068+
1069+
# Calculate starting row and column for this tile
1070+
start_row = pid_row * ROW_TILE_SIZE
1071+
start_col = pid_col * COL_TILE_SIZE
1072+
1073+
# Create offsets for the block
1074+
row_offsets = tl.arange(0, ROW_TILE_SIZE)
1075+
col_offsets = tl.arange(0, COL_TILE_SIZE)
1076+
1077+
# Compute global row/col positions
1078+
rows = start_row + row_offsets[:, None]
1079+
cols = start_col + col_offsets[None, :]
1080+
1081+
# Create masks for out-of-bounds accesses
1082+
row_mask = rows < n_rows
1083+
col_mask = cols < n_cols
1084+
mask = row_mask & col_mask
1085+
1086+
# Compute memory offsets for row-major layout (rows, cols)
1087+
row_major_offsets = (rows * n_cols + cols).to(tl.int32)
1088+
1089+
# Load the entire block in a single operation
1090+
# shape: (ROW_TILE_SIZE, COL_TILE_SIZE)
1091+
x_block = tl.load(x_ptr + row_major_offsets, mask=mask)
1092+
1093+
# Reshape to inner tile size for rowwise scaling
1094+
# shape: (ROW_TILE_SIZE, COL_TILE_SIZE) -> (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, INNER_BLOCK_SIZE)
1095+
x_block_r = x_block.reshape(
1096+
ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, INNER_BLOCK_SIZE
1097+
)
1098+
1099+
# Calculate the absolute values of elements in the block
1100+
x_block_abs_r = tl.abs(x_block_r)
1101+
1102+
# Find the maximum absolute value for each row (across columns)
1103+
# shape: (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,)
1104+
row_scale_r, row_scale_e8m0_r = _triton_calculate_scale(x_block_abs_r, axis=1)
1105+
1106+
# Divide each row by scale
1107+
# Broadcasting row_scale to match x_block's shape
1108+
# x_block_r shape (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, INNER_BLOCK_SIZE)
1109+
# row_scale shape (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,) -> (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, 1)
1110+
row_normalized_r = x_block_r / row_scale_r[:, None]
1111+
1112+
# Reshape back to original tile size
1113+
row_normalized = tl.reshape(row_normalized_r, ROW_TILE_SIZE, COL_TILE_SIZE)
1114+
1115+
# Quantize to float8
1116+
row_normalized = row_normalized.to(tl.float8e4nv)
1117+
1118+
# Store the row-normalized result in row-major format
1119+
tl.store(output_ptr + row_major_offsets, row_normalized, mask=mask)
1120+
1121+
# For rowwise quantization, scale tensor has shape (n_rows, n_cols // INNER_BLOCK_SIZE)
1122+
# Calculate base offset for this tile's scales
1123+
scales_per_row = n_cols // INNER_BLOCK_SIZE
1124+
1125+
# Create row and column indices for scale storage
1126+
scale_row_indices = tl.arange(0, ROW_TILE_SIZE)[:, None] + (
1127+
pid_row * ROW_TILE_SIZE
1128+
)
1129+
scale_col_indices = tl.arange(0, BLOCKS_PER_COL_TILE)[None, :] + (
1130+
pid_col * BLOCKS_PER_COL_TILE
1131+
)
1132+
1133+
# Calculate linear indices into scale tensor
1134+
scale_offsets = scale_row_indices * scales_per_row + scale_col_indices
1135+
1136+
# Create masks for valid scale indices
1137+
scale_row_mask = scale_row_indices < n_rows
1138+
scale_col_mask = scale_col_indices < scales_per_row
1139+
scale_mask = scale_row_mask & scale_col_mask
1140+
1141+
# Reshape scale values and masks to match the flattened layout
1142+
row_scale_e8m0_2d = row_scale_e8m0_r.reshape(ROW_TILE_SIZE, BLOCKS_PER_COL_TILE)
1143+
1144+
# Store the scales with proper masking
1145+
tl.store(row_scale_ptr + scale_offsets, row_scale_e8m0_2d, mask=scale_mask)
1146+
1147+
@triton_op("torchao::triton_to_mxfp8_dim0", mutates_args={})
1148+
def triton_to_mxfp8_dim0(
1149+
x: torch.Tensor, inner_block_size: int = 32
1150+
) -> Tuple[torch.Tensor, torch.Tensor]:
1151+
"""
1152+
Input:
1153+
* `x` - input tensor, in row major memory layout
1154+
* `inner_block_size` - size of tiles to scale across, default is 32 for MX recipes
1155+
1156+
Output:
1157+
* `output`: the `float8_e4m3fn` values of `x` cast to mxfp8 across dim0 (rowwise)
1158+
* `row_scale`: the `e8m0` values of `x_scale` used to cast `x` to mxfp8 across dim0
1159+
"""
1160+
assert x.is_contiguous(), "`x` must be contiguous"
1161+
assert inner_block_size <= 32
1162+
1163+
# Reshape tensor to 2d if necessary and get shape
1164+
x_orig_shape = x.shape
1165+
x = x.reshape(-1, x.shape[-1])
1166+
n_rows, n_cols = x.shape
1167+
1168+
assert n_cols % inner_block_size == 0, (
1169+
"columns must be divisible by inner block size"
1170+
)
1171+
1172+
# Create output tensors
1173+
output = torch.empty(
1174+
(n_rows, n_cols), dtype=torch.float8_e4m3fn, device=x.device
1175+
)
1176+
1177+
# Create scale tensors for rowwise scaling
1178+
row_scale = torch.empty(
1179+
(n_rows, n_cols // inner_block_size),
1180+
dtype=torch.uint8,
1181+
device=x.device,
1182+
)
1183+
1184+
# Calculate grid dimensions based on tile size
1185+
grid = lambda META: (
1186+
triton.cdiv(n_rows, META["ROW_TILE_SIZE"]),
1187+
triton.cdiv(n_cols, META["COL_TILE_SIZE"]),
1188+
)
1189+
1190+
# Launch the kernel
1191+
wrap_triton(to_mxfp8_dim0_kernel)[grid](
1192+
x_ptr=x,
1193+
output_ptr=output,
1194+
row_scale_ptr=row_scale,
1195+
n_rows=n_rows,
1196+
n_cols=n_cols,
1197+
INNER_BLOCK_SIZE=inner_block_size,
1198+
)
1199+
1200+
# Reshape output back to original shape
1201+
output = output.reshape(x_orig_shape)
1202+
row_scale = row_scale.reshape(*x_orig_shape[:-1], row_scale.shape[-1])
1203+
1204+
return (
1205+
output,
1206+
row_scale.view(torch.float8_e8m0fnu),
1207+
)
1208+
10411209
@triton_op("torchao::triton_to_mxfp8_dim1", mutates_args={})
10421210
def triton_to_mxfp8_dim1(
10431211
x: torch.Tensor, inner_block_size: int = 32
@@ -1467,6 +1635,12 @@ def _(scale_tensor):
14671635
return scale_tensor.new_empty((padded_rows, padded_cols))
14681636
else:
14691637

1638+
def triton_to_mxfp8_dim0(
1639+
x: torch.Tensor,
1640+
inner_block_size=32,
1641+
) -> Tuple[torch.Tensor, torch.Tensor]:
1642+
raise AssertionError("needs torch version 2.8+ and triton")
1643+
14701644
def triton_to_mxfp8_dim1(
14711645
x, inner_block_size=32
14721646
) -> Tuple[torch.Tensor, torch.Tensor]:

0 commit comments

Comments
 (0)