Skip to content

Commit cbf6277

Browse files
[mxfp8 moe training] add triton kernel for mxfp8 quantization along dim0
stack-info: PR: #3128, branch: danielvegamyhre/stack/75
1 parent cd21d0e commit cbf6277

File tree

3 files changed

+229
-2
lines changed

3 files changed

+229
-2
lines changed

benchmarks/mx_formats/cast_bench.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from torchao.prototype.mx_formats.config import ScaleCalculationMode
1515
from torchao.prototype.mx_formats.kernels import (
16+
triton_to_mxfp8_dim0,
1617
triton_to_mxfp8_dim1,
1718
)
1819
from torchao.prototype.mx_formats.mx_tensor import to_mx
@@ -97,6 +98,7 @@ def run(
9798
"dim0_mxfp8_floor",
9899
"dim0_mxfp4_floor",
99100
"dim0_mxfp8_rceil",
101+
"dim0_mxfp8_triton_floor",
100102
"dim1_mxfp8_floor",
101103
"dim1_mxfp8_rceil",
102104
"dim1_mxfp8_triton_floor",
@@ -222,6 +224,22 @@ def run(
222224
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
223225
bps = (bytes_r + bytes_w) / (time_us / 1e6)
224226

227+
elif mode == "dim0_mxfp8_triton_floor":
228+
y_d0, s_d0 = triton_to_mxfp8_dim0(x, inner_block_size=BLOCK_SIZE)
229+
230+
for _ in range(2):
231+
__ = triton_to_mxfp8_dim0(x, inner_block_size=BLOCK_SIZE)
232+
time_us = benchmark_cuda_function_in_microseconds(
233+
lambda x, b: triton_to_mxfp8_dim0(x, inner_block_size=BLOCK_SIZE),
234+
x,
235+
BLOCK_SIZE,
236+
)
237+
assert y_d0.dtype == torch.float8_e4m3fn
238+
assert s_d0.dtype == torch.float8_e8m0fnu
239+
bytes_r = x.numel() * bytes_per_el_bf16
240+
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
241+
bps = (bytes_r + bytes_w) / (time_us / 1e6)
242+
225243
elif mode == "dim1_mxfp8_floor":
226244
to_mx_dim1_reference_c = torch.compile(to_mx_dim1_reference)
227245
y_d1, s_d1 = to_mx_dim1_reference_c(x, BLOCK_SIZE)

test/prototype/mx_formats/test_kernels.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
pack_uint6,
3838
triton_f6_e2m3_to_bf16,
3939
triton_f6_e3m2_to_bf16,
40+
triton_to_mxfp8_dim0,
4041
triton_to_mxfp8_dim1,
4142
triton_to_mxfp8_dim1_reference,
4243
unpack_uint4,
@@ -431,6 +432,23 @@ def test_fp6_e3m2_pack_unpack():
431432
assert torch.all(orig_vals_f6_packed_unpacked == orig_vals)
432433

433434

435+
def triton_to_mxfp8_dim0_reference(
436+
x_hp: torch.Tensor, block_size
437+
) -> tuple[torch.Tensor, torch.Tensor]:
438+
"""
439+
A reference version of `triton_to_mxfp8_dim0` for rowwise quantization.
440+
"""
441+
from torchao.prototype.mx_formats.mx_tensor import to_mx
442+
443+
# cast across dim0 (rowwise) - no transpose needed
444+
scale_e8m0_dim0, x_hp_d0_normalized = to_mx(x_hp, torch.float8_e4m3fn, block_size)
445+
scale_e8m0_dim0 = scale_e8m0_dim0.view(torch.float8_e8m0fnu)
446+
return (
447+
x_hp_d0_normalized,
448+
scale_e8m0_dim0,
449+
)
450+
451+
434452
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
435453
@pytest.mark.skipif(
436454
not is_sm_at_least_89(),
@@ -446,6 +464,21 @@ def test_triton_mxfp8_dim1_randn(M, K):
446464
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)
447465

448466

467+
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
468+
@pytest.mark.skipif(
469+
not is_sm_at_least_89(),
470+
reason="float8 in triton requires CUDA capability 8.9 or greater",
471+
)
472+
@pytest.mark.parametrize("M", (256, 2048, 131072))
473+
@pytest.mark.parametrize("K", (256, 5120, 7168))
474+
def test_triton_mxfp8_dim0_randn(M, K):
475+
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
476+
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(x, block_size=32)
477+
x_mx_t, x_s_t = triton_to_mxfp8_dim0(x, inner_block_size=32)
478+
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
479+
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)
480+
481+
449482
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
450483
@pytest.mark.parametrize(
451484
"shape",

torchao/prototype/mx_formats/kernels.py

Lines changed: 178 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,8 @@ def _(uint8_data):
829829
import triton.language as tl
830830
from torch.library import triton_op, wrap_triton
831831

832+
print("importing triton ops")
833+
832834
@triton.jit
833835
def _triton_calculate_scale(x, axis):
834836
# There is no good support for accessing globals from a jit'ed triton
@@ -891,13 +893,13 @@ def _get_mxfp8_dim1_kernel_autotune_configs():
891893

892894
@triton.autotune(
893895
configs=_get_mxfp8_dim1_kernel_autotune_configs(),
894-
key=["n_rows", "n_cols", "INNER_BLOCK_SIZE"],
896+
key=["n_cols", "INNER_BLOCK_SIZE"],
895897
)
896898
@triton.jit
897899
def to_mxfp8_dim1_kernel(
898900
x_ptr, # pointer to input tensor
899901
output_col_major_ptr, # pointer to column-major output tensor (column-normalized)
900-
col_scale_ptr, # pointer to store column-wise maximum absolute values
902+
col_scale_ptr, # pointer to store scales
901903
n_rows, # number of rows in the tensor
902904
n_cols, # number of columns in the tensor
903905
ROW_TILE_SIZE: tl.constexpr,
@@ -1038,6 +1040,174 @@ def to_mxfp8_dim1_kernel(
10381040
# TODO(future): mask this store
10391041
tl.store(col_scale_start_ptr + col_scale_indices, col_scale_e8m0)
10401042

1043+
@triton.autotune(
1044+
configs=_get_mxfp8_dim1_kernel_autotune_configs(),
1045+
key=["n_cols", "INNER_BLOCK_SIZE"],
1046+
)
1047+
@triton.jit
1048+
def to_mxfp8_dim0_kernel(
1049+
x_ptr, # pointer to input tensor
1050+
output_ptr, # pointer to output tensor (row-normalized)
1051+
row_scale_ptr, # pointer to store row-wise maximum absolute values
1052+
n_rows, # number of rows in the tensor
1053+
n_cols, # number of columns in the tensor
1054+
ROW_TILE_SIZE: tl.constexpr,
1055+
COL_TILE_SIZE: tl.constexpr,
1056+
INNER_BLOCK_SIZE: tl.constexpr, # should be 32 for MX
1057+
):
1058+
"""
1059+
Quantizes a high precision tensor to mxfp8 rowwise (1x32 scaling granularity).
1060+
1061+
This is the counterpart to to_mxfp8_dim1_kernel which does columnwise quantization.
1062+
Instead of transposing and scaling across columns, this kernel scales across rows.
1063+
"""
1064+
1065+
BLOCKS_PER_COL_TILE: tl.constexpr = COL_TILE_SIZE // INNER_BLOCK_SIZE
1066+
1067+
# Get program ID
1068+
pid_row = tl.program_id(0)
1069+
pid_col = tl.program_id(1)
1070+
1071+
# Calculate starting row and column for this tile
1072+
start_row = pid_row * ROW_TILE_SIZE
1073+
start_col = pid_col * COL_TILE_SIZE
1074+
1075+
# Create offsets for the block
1076+
row_offsets = tl.arange(0, ROW_TILE_SIZE)
1077+
col_offsets = tl.arange(0, COL_TILE_SIZE)
1078+
1079+
# Compute global row/col positions
1080+
rows = start_row + row_offsets[:, None]
1081+
cols = start_col + col_offsets[None, :]
1082+
1083+
# Create masks for out-of-bounds accesses
1084+
row_mask = rows < n_rows
1085+
col_mask = cols < n_cols
1086+
mask = row_mask & col_mask
1087+
1088+
# Compute memory offsets for row-major layout (rows, cols)
1089+
row_major_offsets = (rows * n_cols + cols).to(tl.int32)
1090+
1091+
# Load the entire block in a single operation
1092+
# shape: (ROW_TILE_SIZE, COL_TILE_SIZE)
1093+
x_block = tl.load(x_ptr + row_major_offsets, mask=mask)
1094+
1095+
# Reshape to inner tile size for rowwise scaling
1096+
# shape: (ROW_TILE_SIZE, COL_TILE_SIZE) -> (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, INNER_BLOCK_SIZE)
1097+
x_block_r = x_block.reshape(
1098+
ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, INNER_BLOCK_SIZE
1099+
)
1100+
1101+
# Calculate the absolute values of elements in the block
1102+
x_block_abs_r = tl.abs(x_block_r)
1103+
1104+
# Find the maximum absolute value for each row (across columns)
1105+
# shape: (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,)
1106+
row_scale_r, row_scale_e8m0_r = _triton_calculate_scale(x_block_abs_r, axis=1)
1107+
1108+
# Divide each row by scale
1109+
# Broadcasting row_scale to match x_block's shape
1110+
# x_block_r shape (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, INNER_BLOCK_SIZE)
1111+
# row_scale shape (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,) -> (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, 1)
1112+
row_normalized_r = x_block_r / row_scale_r[:, None]
1113+
1114+
# Reshape back to original tile size
1115+
row_normalized = tl.reshape(row_normalized_r, ROW_TILE_SIZE, COL_TILE_SIZE)
1116+
1117+
# Quantize to float8
1118+
row_normalized = row_normalized.to(tl.float8e4nv)
1119+
1120+
# Store the row-normalized result in row-major format
1121+
tl.store(output_ptr + row_major_offsets, row_normalized, mask=mask)
1122+
1123+
# For rowwise quantization, scale tensor has shape (n_rows, n_cols // INNER_BLOCK_SIZE)
1124+
# Calculate base offset for this tile's scales
1125+
scales_per_row = n_cols // INNER_BLOCK_SIZE
1126+
1127+
# Create row and column indices for scale storage
1128+
scale_row_indices = tl.arange(0, ROW_TILE_SIZE)[:, None] + (
1129+
pid_row * ROW_TILE_SIZE
1130+
)
1131+
scale_col_indices = tl.arange(0, BLOCKS_PER_COL_TILE)[None, :] + (
1132+
pid_col * BLOCKS_PER_COL_TILE
1133+
)
1134+
1135+
# Calculate linear indices into scale tensor
1136+
scale_offsets = scale_row_indices * scales_per_row + scale_col_indices
1137+
1138+
# Create masks for valid scale indices
1139+
scale_row_mask = scale_row_indices < n_rows
1140+
scale_col_mask = scale_col_indices < scales_per_row
1141+
scale_mask = scale_row_mask & scale_col_mask
1142+
1143+
# Reshape scale values and masks to match the flattened layout
1144+
row_scale_e8m0_2d = row_scale_e8m0_r.reshape(ROW_TILE_SIZE, BLOCKS_PER_COL_TILE)
1145+
1146+
# Store the scales with proper masking
1147+
tl.store(row_scale_ptr + scale_offsets, row_scale_e8m0_2d, mask=scale_mask)
1148+
1149+
@triton_op("torchao::triton_to_mxfp8_dim0", mutates_args={})
1150+
def triton_to_mxfp8_dim0(
1151+
x: torch.Tensor, inner_block_size: int = 32
1152+
) -> Tuple[torch.Tensor, torch.Tensor]:
1153+
"""
1154+
Input:
1155+
* `x` - input tensor, in row major memory layout
1156+
* `inner_block_size` - size of tiles to scale across, default is 32 for MX recipes
1157+
1158+
Output:
1159+
* `output`: the `float8_e4m3fn` values of `x` cast to mxfp8 across dim0 (rowwise)
1160+
* `row_scale`: the `e8m0` values of `x_scale` used to cast `x` to mxfp8 across dim0
1161+
"""
1162+
assert x.is_contiguous(), "`x` must be contiguous"
1163+
assert inner_block_size <= 32
1164+
1165+
# Reshape tensor to 2d if necessary and get shape
1166+
x_orig_shape = x.shape
1167+
x = x.reshape(-1, x.shape[-1])
1168+
n_rows, n_cols = x.shape
1169+
1170+
assert n_cols % inner_block_size == 0, (
1171+
"columns must be divisible by inner block size"
1172+
)
1173+
1174+
# Create output tensors
1175+
output = torch.empty(
1176+
(n_rows, n_cols), dtype=torch.float8_e4m3fn, device=x.device
1177+
)
1178+
1179+
# Create scale tensors for rowwise scaling
1180+
row_scale = torch.empty(
1181+
(n_rows, n_cols // inner_block_size),
1182+
dtype=torch.uint8,
1183+
device=x.device,
1184+
)
1185+
1186+
# Calculate grid dimensions based on tile size
1187+
grid = lambda META: (
1188+
triton.cdiv(n_rows, META["ROW_TILE_SIZE"]),
1189+
triton.cdiv(n_cols, META["COL_TILE_SIZE"]),
1190+
)
1191+
1192+
# Launch the kernel
1193+
wrap_triton(to_mxfp8_dim0_kernel)[grid](
1194+
x_ptr=x,
1195+
output_ptr=output,
1196+
row_scale_ptr=row_scale,
1197+
n_rows=n_rows,
1198+
n_cols=n_cols,
1199+
INNER_BLOCK_SIZE=inner_block_size,
1200+
)
1201+
1202+
# Reshape output back to original shape
1203+
output = output.reshape(x_orig_shape)
1204+
row_scale = row_scale.reshape(*x_orig_shape[:-1], row_scale.shape[-1])
1205+
1206+
return (
1207+
output,
1208+
row_scale.view(torch.float8_e8m0fnu),
1209+
)
1210+
10411211
@triton_op("torchao::triton_to_mxfp8_dim1", mutates_args={})
10421212
def triton_to_mxfp8_dim1(
10431213
x: torch.Tensor, inner_block_size: int = 32
@@ -1459,6 +1629,12 @@ def _(scale_tensor):
14591629
return scale_tensor.new_empty((padded_rows, padded_cols))
14601630
else:
14611631

1632+
def triton_to_mxfp8_dim0(
1633+
x: torch.Tensor,
1634+
inner_block_size=32,
1635+
) -> Tuple[torch.Tensor, torch.Tensor]:
1636+
raise AssertionError("needs torch version 2.8+ and triton")
1637+
14621638
def triton_to_mxfp8_dim1(
14631639
x, inner_block_size=32
14641640
) -> Tuple[torch.Tensor, torch.Tensor]:

0 commit comments

Comments
 (0)