@@ -1437,20 +1437,33 @@ def quantize_nvfp4_triton_kernel(
14371437 s_ptr ,
14381438 stride_xm ,
14391439 stride_xn ,
1440+ stride_sm ,
1441+ stride_sn ,
14401442 M ,
14411443 N ,
14421444 USE_TENSOR_SCALE : tl .constexpr ,
14431445 MASK_SCALES : tl .constexpr ,
1446+ ROW_TILE_SIZE : tl .constexpr ,
1447+ COL_TILE_SIZE : tl .constexpr ,
14441448 ):
1449+ """
1450+ 1. single block of data is shaped [128, 64] unpacked or [128, 32] packed
1451+ 2. corresponding single unswizzled block of scales is shaped [128, 4]
1452+ 3. corresponding single swizzles block of scales is shaped [32, 16]
1453+ """
1454+
14451455 F4_E2M1_MAX = 6.0
14461456 F8E4M3_MAX = 448.0
14471457 E4M3_EPS = 1.5258789e-05
14481458
1459+ NUM_ROW_INNER_TILES : tl .constexpr = ROW_TILE_SIZE // 128
1460+ NUM_COL_INNER_TILES : tl .constexpr = COL_TILE_SIZE // 64
1461+
14491462 pid_m = tl .program_id (1 )
14501463 pid_n = tl .program_id (0 )
14511464
1452- offs_m = pid_m * 128 + tl .arange (0 , 128 )[:, None ]
1453- offs_n = pid_n * 64 + tl .arange (0 , 64 )[None , :]
1465+ offs_m = pid_m * ROW_TILE_SIZE + tl .arange (0 , ROW_TILE_SIZE )[:, None ]
1466+ offs_n = pid_n * COL_TILE_SIZE + tl .arange (0 , COL_TILE_SIZE )[None , :]
14541467 if MASK_SCALES :
14551468 mask = (offs_m < M ) & (offs_n < N )
14561469 other = 0.0
@@ -1459,11 +1472,11 @@ def quantize_nvfp4_triton_kernel(
14591472 other = None
14601473 x = tl .load (
14611474 x_ptr + offs_m * stride_xm + offs_n * stride_xn , mask = mask , other = other
1462- ) # [128, 64 ]
1463- x_blocks = x .to (tl .float32 ).reshape (128 , 4 , 16 ) # [128 , 4, 16]
1475+ ) # [ROW_TILE_SIZE, COL_TILE_SIZE ]
1476+ x_blocks = x .to (tl .float32 ).reshape (ROW_TILE_SIZE , 4 , 16 ) # [-1 , 4, 16]
14641477
14651478 # Compute block-wise scales
1466- block_amax = tl .max (x_blocks .abs (), axis = 2 ) # [128 , 4]
1479+ block_amax = tl .max (x_blocks .abs (), axis = 2 ) # [-1 , 4]
14671480
14681481 if USE_TENSOR_SCALE :
14691482 # Two-level scaling: quantize block scales with per-tensor scale
@@ -1501,21 +1514,37 @@ def quantize_nvfp4_triton_kernel(
15011514 scales ,
15021515 0.0 ,
15031516 )
1504- packed_scales = scales .reshape (4 , 32 , 4 ).permute (1 , 0 , 2 ).reshape (32 , 16 )
1505- offs_m = tl .arange (0 , 32 )[:, None ]
1506- offs_n = tl .arange (0 , 16 )[None , :]
1517+ # packed_scales = scales.reshape(4, 32, 4).permute(1, 0, 2).reshape(32, 16)
1518+ packed_scales = scales .reshape (NUM_ROW_INNER_TILES , 4 , 32 , 4 ).permute (0 , 2 , 1 , 3 ).reshape (NUM_ROW_INNER_TILES * 32 , 16 )
1519+ scale_offs_m = tl .arange (0 , 32 * NUM_ROW_INNER_TILES )[:, None ]
1520+ scale_offs_n = tl .arange (0 , 16 )[None , :]
1521+ # packed_scales = tl.arange(0, 32 * NUM_ROW_INNER_TILES * 16 * NUM_COL_INNER_TILES).reshape(NUM_ROW_INNER_TILES * 32, NUM_COL_INNER_TILES * 16).to(tl.float32)
1522+
1523+ # TODO write me
1524+ scale_elements_per_outer_tile = (min (ROW_TILE_SIZE , M ) // 128 * 32 ) * 16
1525+
1526+ # scale_offs = (scale_offs_m * 16 + scale_offs_n)
1527+ # TODO(next): debug here, offsets or masks are probably not correct here
1528+ scale_offs = (scale_offs_m * 16 + scale_offs_n )
15071529 tl .store (
15081530 s_ptr
1509- + (pid_m * tl .num_programs (0 ) + pid_n ) * (32 * 16 )
1510- + offs_m * 16
1511- + offs_n ,
1531+ # + (pid_m * tl.num_programs(0) + pid_n) * (NUM_ROW_INNER_TILES * 32 * 16)
1532+ # + (pid_m * tl.num_programs(0) + pid_n) * (NUM_ROW_INNER_TILES * 32 * 16)
1533+ # + (pid_m * tl.num_programs(0) + pid_n) * (1 * 32 * 16)
1534+ + (pid_m * tl .num_programs (0 ) + pid_n ) * scale_elements_per_outer_tile
1535+ + scale_offs ,
15121536 packed_scales ,
1537+ mask = (scale_offs < scale_elements_per_outer_tile ),
15131538 )
15141539
15151540 # Convert to FP4
1516- x_fp4x2 = convert_fp32_to_fp4_packed (x_blocks .reshape (128 , 32 , 2 ).split ())
1517- offs_m = pid_m * 128 + tl .arange (0 , 128 )[:, None ]
1518- offs_n = pid_n * 32 + tl .arange (0 , 32 )[None , :]
1541+ x_fp4x2 = convert_fp32_to_fp4_packed (
1542+ x_blocks .reshape (ROW_TILE_SIZE , 32 , 2 ).split ()
1543+ )
1544+ offs_m = pid_m * ROW_TILE_SIZE + tl .arange (0 , ROW_TILE_SIZE )[:, None ]
1545+ offs_n = (
1546+ pid_n * (COL_TILE_SIZE // 2 ) + tl .arange (0 , COL_TILE_SIZE // 2 )[None , :]
1547+ )
15191548 if MASK_SCALES :
15201549 mask = (offs_m < M ) & (offs_n < N // 2 )
15211550 else :
@@ -1537,7 +1566,7 @@ def triton_quantize_nvfp4(
15371566 Tuple[torch.Tensor, torch.Tensor]: Quantized tensor and scales tensor in swizzled layout.
15381567
15391568 Note:
1540- Since VLLM does not use dyanmo guards we need to make this a custom op
1569+ Since VLLM does not use dynamo guards we need to make this a custom op
15411570 to avoid the triton kernel being invoked w/ the wrong use of `MASK_SCALES`
15421571 """
15431572 # reshape to 2d
@@ -1557,11 +1586,16 @@ def triton_quantize_nvfp4(
15571586
15581587 # mask out scales to 0 if we are not aligned to 128 x 64
15591588 MASK_SCALES = M % 128 != 0 or N % 64 != 0
1589+ # MASK_SCALES = True
15601590
15611591 xq = x .new_empty (M , N // 2 , dtype = torch .uint8 )
15621592 scales = x .new_empty (padded_rows , padded_cols , dtype = torch .float8_e4m3fn )
1593+ # scales.view(torch.uint8).fill_(45)
1594+
1595+ ROW_TILE_SIZE = 128 * 2
1596+ COL_TILE_SIZE = 64
15631597
1564- grid = (triton .cdiv (N , 64 ), triton .cdiv (M , 128 ))
1598+ grid = (triton .cdiv (N , COL_TILE_SIZE ), triton .cdiv (M , ROW_TILE_SIZE ))
15651599
15661600 if per_tensor_scale is None :
15671601 # Don't allocate tensor, we just steal this since it won't be used in kernel
@@ -1578,10 +1612,14 @@ def triton_quantize_nvfp4(
15781612 scales ,
15791613 x .stride (0 ),
15801614 x .stride (1 ),
1615+ scales .stride (0 ),
1616+ scales .stride (1 ),
15811617 M ,
15821618 N ,
15831619 USE_TENSOR_SCALE = use_tensor_scale ,
15841620 MASK_SCALES = MASK_SCALES ,
1621+ ROW_TILE_SIZE = ROW_TILE_SIZE ,
1622+ COL_TILE_SIZE = COL_TILE_SIZE ,
15851623 )
15861624
15871625 # reshape back to original shape
0 commit comments