Skip to content

Commit 00eaea9

Browse files
committed
[wip] speed up nvfp4 triton kernel
Summary: Test Plan: ```bash pytest test/prototype/mx_formats/test_nvfp4_tensor.py -s -x -k test_triton_nvfp4_quantize_equivalence ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: bfe1b82 ghstack-comment-id: 3416381534 Pull-Request: #3202
1 parent 16da39f commit 00eaea9

File tree

2 files changed

+21
-11
lines changed

2 files changed

+21
-11
lines changed

benchmarks/mx_formats/cast_bench.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,12 +257,12 @@ def run(
257257

258258
elif mode == "dim0_nvfp4":
259259
to_nvfp4_reference_c = torch.compile(to_nvfp4_reference)
260-
y_d0, s_d0 = to_nvfp4_reference_c(x, use_triton_kernel=False)
260+
y_d0, s_d0 = to_nvfp4_reference_c(x)
261261

262262
for _ in range(2):
263-
__ = to_nvfp4_reference_c(x, use_triton_kernel=False)
263+
__ = to_nvfp4_reference_c(x)
264264
time_us = benchmark_cuda_function_in_microseconds(
265-
lambda x: to_nvfp4_reference_c(x, use_triton_kernel=False),
265+
lambda x: to_nvfp4_reference_c(x),
266266
x,
267267
)
268268
assert y_d0.dtype == torch.uint8

torchao/prototype/mx_formats/kernels.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1441,6 +1441,8 @@ def quantize_nvfp4_triton_kernel(
14411441
N,
14421442
USE_TENSOR_SCALE: tl.constexpr,
14431443
MASK_SCALES: tl.constexpr,
1444+
ROW_TILE_SIZE: tl.constexpr,
1445+
COL_TILE_SIZE: tl.constexpr,
14441446
):
14451447
F4_E2M1_MAX = 6.0
14461448
F8E4M3_MAX = 448.0
@@ -1449,8 +1451,8 @@ def quantize_nvfp4_triton_kernel(
14491451
pid_m = tl.program_id(1)
14501452
pid_n = tl.program_id(0)
14511453

1452-
offs_m = pid_m * 128 + tl.arange(0, 128)[:, None]
1453-
offs_n = pid_n * 64 + tl.arange(0, 64)[None, :]
1454+
offs_m = pid_m * ROW_TILE_SIZE + tl.arange(0, ROW_TILE_SIZE)[:, None]
1455+
offs_n = pid_n * COL_TILE_SIZE + tl.arange(0, COL_TILE_SIZE)[None, :]
14541456
if MASK_SCALES:
14551457
mask = (offs_m < M) & (offs_n < N)
14561458
other = 0.0
@@ -1460,10 +1462,10 @@ def quantize_nvfp4_triton_kernel(
14601462
x = tl.load(
14611463
x_ptr + offs_m * stride_xm + offs_n * stride_xn, mask=mask, other=other
14621464
) # [128, 64]
1463-
x_blocks = x.to(tl.float32).reshape(128, 4, 16) # [128, 4, 16]
1465+
x_blocks = x.to(tl.float32).reshape(ROW_TILE_SIZE, 4, 16) # [-1, 4, 16]
14641466

14651467
# Compute block-wise scales
1466-
block_amax = tl.max(x_blocks.abs(), axis=2) # [128, 4]
1468+
block_amax = tl.max(x_blocks.abs(), axis=2) # [-1, 4]
14671469

14681470
if USE_TENSOR_SCALE:
14691471
# Two-level scaling: quantize block scales with per-tensor scale
@@ -1513,9 +1515,13 @@ def quantize_nvfp4_triton_kernel(
15131515
)
15141516

15151517
# Convert to FP4
1516-
x_fp4x2 = convert_fp32_to_fp4_packed(x_blocks.reshape(128, 32, 2).split())
1517-
offs_m = pid_m * 128 + tl.arange(0, 128)[:, None]
1518-
offs_n = pid_n * 32 + tl.arange(0, 32)[None, :]
1518+
x_fp4x2 = convert_fp32_to_fp4_packed(
1519+
x_blocks.reshape(ROW_TILE_SIZE, 32, 2).split()
1520+
)
1521+
offs_m = pid_m * ROW_TILE_SIZE + tl.arange(0, ROW_TILE_SIZE)[:, None]
1522+
offs_n = (
1523+
pid_n * (COL_TILE_SIZE // 2) + tl.arange(0, COL_TILE_SIZE // 2)[None, :]
1524+
)
15191525
if MASK_SCALES:
15201526
mask = (offs_m < M) & (offs_n < N // 2)
15211527
else:
@@ -1537,7 +1543,7 @@ def triton_quantize_nvfp4(
15371543
Tuple[torch.Tensor, torch.Tensor]: Quantized tensor and scales tensor in swizzled layout.
15381544
15391545
Note:
1540-
Since VLLM does not use dyanmo guards we need to make this a custom op
1546+
Since VLLM does not use dynamo guards we need to make this a custom op
15411547
to avoid the triton kernel being invoked w/ the wrong use of `MASK_SCALES`
15421548
"""
15431549
# reshape to 2d
@@ -1571,6 +1577,8 @@ def triton_quantize_nvfp4(
15711577
tensor_scale_ptr = per_tensor_scale
15721578
use_tensor_scale = True
15731579

1580+
ROW_TILE_SIZE = 128
1581+
COL_TILE_SIZE = 64
15741582
quantize_nvfp4_triton_kernel[grid](
15751583
x,
15761584
tensor_scale_ptr,
@@ -1582,6 +1590,8 @@ def triton_quantize_nvfp4(
15821590
N,
15831591
USE_TENSOR_SCALE=use_tensor_scale,
15841592
MASK_SCALES=MASK_SCALES,
1593+
ROW_TILE_SIZE=ROW_TILE_SIZE,
1594+
COL_TILE_SIZE=COL_TILE_SIZE,
15851595
)
15861596

15871597
# reshape back to original shape

0 commit comments

Comments
 (0)