Skip to content

Commit 5844f90

Browse files
committed
[wip] speed up nvfp4 triton kernel
Summary: Test Plan: ```bash pytest test/prototype/mx_formats/test_nvfp4_tensor.py -s -x -k test_triton_nvfp4_quantize_equivalence ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 35f9d1e ghstack-comment-id: 3416381534 Pull-Request: #3202
1 parent ca99c1c commit 5844f90

File tree

3 files changed

+81
-24
lines changed

3 files changed

+81
-24
lines changed

benchmarks/mx_formats/cast_bench.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,12 +257,12 @@ def run(
257257

258258
elif mode == "dim0_nvfp4":
259259
to_nvfp4_reference_c = torch.compile(to_nvfp4_reference)
260-
y_d0, s_d0 = to_nvfp4_reference_c(x, use_triton_kernel=False)
260+
y_d0, s_d0 = to_nvfp4_reference_c(x)
261261

262262
for _ in range(2):
263-
__ = to_nvfp4_reference_c(x, use_triton_kernel=False)
263+
__ = to_nvfp4_reference_c(x)
264264
time_us = benchmark_cuda_function_in_microseconds(
265-
lambda x: to_nvfp4_reference_c(x, use_triton_kernel=False),
265+
lambda x: to_nvfp4_reference_c(x),
266266
x,
267267
)
268268
assert y_d0.dtype == torch.uint8

test/prototype/mx_formats/test_nvfp4_tensor.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -359,13 +359,17 @@ def test_nvfp4_swizzled_scales_get_scales_method():
359359

360360
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
361361
@pytest.mark.parametrize(
362-
"M", [128, 256, 512, 1024, 100, 200, 384], ids=lambda m: f"M{m}"
362+
# "M", [128, 256, 512, 1024, 100, 200, 384], ids=lambda m: f"M{m}"
363+
"M", [256, ], ids=lambda m: f"M{m}"
363364
)
364-
@pytest.mark.parametrize("N", [64, 128, 256, 512, 32, 96, 160], ids=lambda n: f"N{n}")
365+
# @pytest.mark.parametrize("N", [64, 128, 256, 512, 32, 96, 160], ids=lambda n: f"N{n}")
366+
@pytest.mark.parametrize("N", [128], ids=lambda n: f"N{n}")
365367
@pytest.mark.parametrize(
366-
"use_per_tensor_scale", [False, True], ids=["block_scale", "tensor_scale"]
368+
# "use_per_tensor_scale", [False, True], ids=["block_scale", "tensor_scale"]
369+
"use_per_tensor_scale", [False, ], ids=["block_scale"]
367370
)
368-
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["fp32", "bf16"])
371+
# @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["fp32", "bf16"])
372+
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"])
369373
@pytest.mark.skipif(
370374
not is_sm_at_least_100(), reason="requires sm100+ for raw intrinsics"
371375
)
@@ -394,7 +398,20 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
394398
use_triton_kernel=True,
395399
)
396400

397-
torch.testing.assert_close(nvfp4_pt.scale.flatten(), nvfp4_triton.scale.flatten())
401+
# print(nvfp4_triton.scale)
402+
403+
s00 = nvfp4_pt.scale.reshape(2, -1, 16)[0].float()
404+
s01 = nvfp4_pt.scale.reshape(2, -1, 16)[1].float()
405+
s10 = nvfp4_triton.scale.reshape(2, -1, 16)[0].float()
406+
s11 = nvfp4_triton.scale.reshape(2, -1, 16)[1].float()
407+
# print(s00.sum(), s01.sum(), s10.sum(), s11.sum())
408+
409+
s0 = nvfp4_pt.scale.reshape(-1, 32 * 16).float().sum(dim=1)
410+
s1 = nvfp4_triton.scale.reshape(-1, 32 * 16).float().sum(dim=1)
411+
print('\n', s0)
412+
print(s1)
413+
414+
# breakpoint()
398415
pt_unpacked = unpack_uint4(nvfp4_pt.qdata)
399416
triton_unpacked = unpack_uint4(nvfp4_triton.qdata)
400417
torch.testing.assert_close(
@@ -404,6 +421,8 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
404421
rtol=0,
405422
)
406423

424+
torch.testing.assert_close(nvfp4_pt.scale.flatten(), nvfp4_triton.scale.flatten())
425+
407426
x_pt_dequant = nvfp4_pt.dequantize(dtype)
408427
x_triton_dequant = nvfp4_triton.dequantize(dtype)
409428

torchao/prototype/mx_formats/kernels.py

Lines changed: 54 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1437,20 +1437,33 @@ def quantize_nvfp4_triton_kernel(
14371437
s_ptr,
14381438
stride_xm,
14391439
stride_xn,
1440+
stride_sm,
1441+
stride_sn,
14401442
M,
14411443
N,
14421444
USE_TENSOR_SCALE: tl.constexpr,
14431445
MASK_SCALES: tl.constexpr,
1446+
ROW_TILE_SIZE: tl.constexpr,
1447+
COL_TILE_SIZE: tl.constexpr,
14441448
):
1449+
"""
1450+
1. single block of data is shaped [128, 64] unpacked or [128, 32] packed
1451+
2. corresponding single unswizzled block of scales is shaped [128, 4]
1452+
3. corresponding single swizzles block of scales is shaped [32, 16]
1453+
"""
1454+
14451455
F4_E2M1_MAX = 6.0
14461456
F8E4M3_MAX = 448.0
14471457
E4M3_EPS = 1.5258789e-05
14481458

1459+
NUM_ROW_INNER_TILES: tl.constexpr = ROW_TILE_SIZE // 128
1460+
NUM_COL_INNER_TILES: tl.constexpr = COL_TILE_SIZE // 64
1461+
14491462
pid_m = tl.program_id(1)
14501463
pid_n = tl.program_id(0)
14511464

1452-
offs_m = pid_m * 128 + tl.arange(0, 128)[:, None]
1453-
offs_n = pid_n * 64 + tl.arange(0, 64)[None, :]
1465+
offs_m = pid_m * ROW_TILE_SIZE + tl.arange(0, ROW_TILE_SIZE)[:, None]
1466+
offs_n = pid_n * COL_TILE_SIZE + tl.arange(0, COL_TILE_SIZE)[None, :]
14541467
if MASK_SCALES:
14551468
mask = (offs_m < M) & (offs_n < N)
14561469
other = 0.0
@@ -1459,11 +1472,11 @@ def quantize_nvfp4_triton_kernel(
14591472
other = None
14601473
x = tl.load(
14611474
x_ptr + offs_m * stride_xm + offs_n * stride_xn, mask=mask, other=other
1462-
) # [128, 64]
1463-
x_blocks = x.to(tl.float32).reshape(128, 4, 16) # [128, 4, 16]
1475+
) # [ROW_TILE_SIZE, COL_TILE_SIZE]
1476+
x_blocks = x.to(tl.float32).reshape(ROW_TILE_SIZE, 4, 16) # [-1, 4, 16]
14641477

14651478
# Compute block-wise scales
1466-
block_amax = tl.max(x_blocks.abs(), axis=2) # [128, 4]
1479+
block_amax = tl.max(x_blocks.abs(), axis=2) # [-1, 4]
14671480

14681481
if USE_TENSOR_SCALE:
14691482
# Two-level scaling: quantize block scales with per-tensor scale
@@ -1501,21 +1514,37 @@ def quantize_nvfp4_triton_kernel(
15011514
scales,
15021515
0.0,
15031516
)
1504-
packed_scales = scales.reshape(4, 32, 4).permute(1, 0, 2).reshape(32, 16)
1505-
offs_m = tl.arange(0, 32)[:, None]
1506-
offs_n = tl.arange(0, 16)[None, :]
1517+
# packed_scales = scales.reshape(4, 32, 4).permute(1, 0, 2).reshape(32, 16)
1518+
packed_scales = scales.reshape(NUM_ROW_INNER_TILES, 4, 32, 4).permute(0, 2, 1, 3).reshape(NUM_ROW_INNER_TILES * 32, 16)
1519+
scale_offs_m = tl.arange(0, 32 * NUM_ROW_INNER_TILES)[:, None]
1520+
scale_offs_n = tl.arange(0, 16)[None, :]
1521+
# packed_scales = tl.arange(0, 32 * NUM_ROW_INNER_TILES * 16 * NUM_COL_INNER_TILES).reshape(NUM_ROW_INNER_TILES * 32, NUM_COL_INNER_TILES * 16).to(tl.float32)
1522+
1523+
# TODO write me
1524+
scale_elements_per_outer_tile = (min(ROW_TILE_SIZE, M) // 128 * 32) * 16
1525+
1526+
# scale_offs = (scale_offs_m * 16 + scale_offs_n)
1527+
# TODO(next): debug here, offsets or masks are probably not correct here
1528+
scale_offs = (scale_offs_m * 16 + scale_offs_n)
15071529
tl.store(
15081530
s_ptr
1509-
+ (pid_m * tl.num_programs(0) + pid_n) * (32 * 16)
1510-
+ offs_m * 16
1511-
+ offs_n,
1531+
# + (pid_m * tl.num_programs(0) + pid_n) * (NUM_ROW_INNER_TILES * 32 * 16)
1532+
# + (pid_m * tl.num_programs(0) + pid_n) * (NUM_ROW_INNER_TILES * 32 * 16)
1533+
# + (pid_m * tl.num_programs(0) + pid_n) * (1 * 32 * 16)
1534+
+ (pid_m * tl.num_programs(0) + pid_n) * scale_elements_per_outer_tile
1535+
+ scale_offs,
15121536
packed_scales,
1537+
mask=(scale_offs < scale_elements_per_outer_tile),
15131538
)
15141539

15151540
# Convert to FP4
1516-
x_fp4x2 = convert_fp32_to_fp4_packed(x_blocks.reshape(128, 32, 2).split())
1517-
offs_m = pid_m * 128 + tl.arange(0, 128)[:, None]
1518-
offs_n = pid_n * 32 + tl.arange(0, 32)[None, :]
1541+
x_fp4x2 = convert_fp32_to_fp4_packed(
1542+
x_blocks.reshape(ROW_TILE_SIZE, 32, 2).split()
1543+
)
1544+
offs_m = pid_m * ROW_TILE_SIZE + tl.arange(0, ROW_TILE_SIZE)[:, None]
1545+
offs_n = (
1546+
pid_n * (COL_TILE_SIZE // 2) + tl.arange(0, COL_TILE_SIZE // 2)[None, :]
1547+
)
15191548
if MASK_SCALES:
15201549
mask = (offs_m < M) & (offs_n < N // 2)
15211550
else:
@@ -1537,7 +1566,7 @@ def triton_quantize_nvfp4(
15371566
Tuple[torch.Tensor, torch.Tensor]: Quantized tensor and scales tensor in swizzled layout.
15381567
15391568
Note:
1540-
Since VLLM does not use dyanmo guards we need to make this a custom op
1569+
Since VLLM does not use dynamo guards we need to make this a custom op
15411570
to avoid the triton kernel being invoked w/ the wrong use of `MASK_SCALES`
15421571
"""
15431572
# reshape to 2d
@@ -1557,11 +1586,16 @@ def triton_quantize_nvfp4(
15571586

15581587
# mask out scales to 0 if we are not aligned to 128 x 64
15591588
MASK_SCALES = M % 128 != 0 or N % 64 != 0
1589+
# MASK_SCALES = True
15601590

15611591
xq = x.new_empty(M, N // 2, dtype=torch.uint8)
15621592
scales = x.new_empty(padded_rows, padded_cols, dtype=torch.float8_e4m3fn)
1593+
# scales.view(torch.uint8).fill_(45)
1594+
1595+
ROW_TILE_SIZE = 128 * 2
1596+
COL_TILE_SIZE = 64
15631597

1564-
grid = (triton.cdiv(N, 64), triton.cdiv(M, 128))
1598+
grid = (triton.cdiv(N, COL_TILE_SIZE), triton.cdiv(M, ROW_TILE_SIZE))
15651599

15661600
if per_tensor_scale is None:
15671601
# Don't allocate tensor, we just steal this since it won't be used in kernel
@@ -1578,10 +1612,14 @@ def triton_quantize_nvfp4(
15781612
scales,
15791613
x.stride(0),
15801614
x.stride(1),
1615+
scales.stride(0),
1616+
scales.stride(1),
15811617
M,
15821618
N,
15831619
USE_TENSOR_SCALE=use_tensor_scale,
15841620
MASK_SCALES=MASK_SCALES,
1621+
ROW_TILE_SIZE=ROW_TILE_SIZE,
1622+
COL_TILE_SIZE=COL_TILE_SIZE,
15851623
)
15861624

15871625
# reshape back to original shape

0 commit comments

Comments
 (0)