@@ -1441,6 +1441,8 @@ def quantize_nvfp4_triton_kernel(
14411441 N ,
14421442 USE_TENSOR_SCALE : tl .constexpr ,
14431443 MASK_SCALES : tl .constexpr ,
1444+ ROW_TILE_SIZE : tl .constexpr ,
1445+ COL_TILE_SIZE : tl .constexpr ,
14441446 ):
14451447 F4_E2M1_MAX = 6.0
14461448 F8E4M3_MAX = 448.0
@@ -1449,8 +1451,8 @@ def quantize_nvfp4_triton_kernel(
14491451 pid_m = tl .program_id (1 )
14501452 pid_n = tl .program_id (0 )
14511453
1452- offs_m = pid_m * 128 + tl .arange (0 , 128 )[:, None ]
1453- offs_n = pid_n * 64 + tl .arange (0 , 64 )[None , :]
1454+ offs_m = pid_m * ROW_TILE_SIZE + tl .arange (0 , ROW_TILE_SIZE )[:, None ]
1455+ offs_n = pid_n * COL_TILE_SIZE + tl .arange (0 , COL_TILE_SIZE )[None , :]
14541456 if MASK_SCALES :
14551457 mask = (offs_m < M ) & (offs_n < N )
14561458 other = 0.0
@@ -1460,10 +1462,10 @@ def quantize_nvfp4_triton_kernel(
14601462 x = tl .load (
14611463 x_ptr + offs_m * stride_xm + offs_n * stride_xn , mask = mask , other = other
14621464 ) # [128, 64]
1463- x_blocks = x .to (tl .float32 ).reshape (128 , 4 , 16 ) # [128 , 4, 16]
1465+ x_blocks = x .to (tl .float32 ).reshape (ROW_TILE_SIZE , 4 , 16 ) # [-1 , 4, 16]
14641466
14651467 # Compute block-wise scales
1466- block_amax = tl .max (x_blocks .abs (), axis = 2 ) # [128 , 4]
1468+ block_amax = tl .max (x_blocks .abs (), axis = 2 ) # [-1 , 4]
14671469
14681470 if USE_TENSOR_SCALE :
14691471 # Two-level scaling: quantize block scales with per-tensor scale
@@ -1513,9 +1515,13 @@ def quantize_nvfp4_triton_kernel(
15131515 )
15141516
15151517 # Convert to FP4
1516- x_fp4x2 = convert_fp32_to_fp4_packed (x_blocks .reshape (128 , 32 , 2 ).split ())
1517- offs_m = pid_m * 128 + tl .arange (0 , 128 )[:, None ]
1518- offs_n = pid_n * 32 + tl .arange (0 , 32 )[None , :]
1518+ x_fp4x2 = convert_fp32_to_fp4_packed (
1519+ x_blocks .reshape (ROW_TILE_SIZE , 32 , 2 ).split ()
1520+ )
1521+ offs_m = pid_m * ROW_TILE_SIZE + tl .arange (0 , ROW_TILE_SIZE )[:, None ]
1522+ offs_n = (
1523+ pid_n * (COL_TILE_SIZE // 2 ) + tl .arange (0 , COL_TILE_SIZE // 2 )[None , :]
1524+ )
15191525 if MASK_SCALES :
15201526 mask = (offs_m < M ) & (offs_n < N // 2 )
15211527 else :
@@ -1537,7 +1543,7 @@ def triton_quantize_nvfp4(
15371543 Tuple[torch.Tensor, torch.Tensor]: Quantized tensor and scales tensor in swizzled layout.
15381544
15391545 Note:
1540- Since VLLM does not use dyanmo guards we need to make this a custom op
1546+ Since VLLM does not use dynamo guards we need to make this a custom op
15411547 to avoid the triton kernel being invoked w/ the wrong use of `MASK_SCALES`
15421548 """
15431549 # reshape to 2d
@@ -1571,6 +1577,8 @@ def triton_quantize_nvfp4(
15711577 tensor_scale_ptr = per_tensor_scale
15721578 use_tensor_scale = True
15731579
1580+ ROW_TILE_SIZE = 128
1581+ COL_TILE_SIZE = 64
15741582 quantize_nvfp4_triton_kernel [grid ](
15751583 x ,
15761584 tensor_scale_ptr ,
@@ -1582,6 +1590,8 @@ def triton_quantize_nvfp4(
15821590 N ,
15831591 USE_TENSOR_SCALE = use_tensor_scale ,
15841592 MASK_SCALES = MASK_SCALES ,
1593+ ROW_TILE_SIZE = ROW_TILE_SIZE ,
1594+ COL_TILE_SIZE = COL_TILE_SIZE ,
15851595 )
15861596
15871597 # reshape back to original shape
0 commit comments