233233from aiter .ops .triton .batched_gemm_afp4wfp4_pre_quant import batched_gemm_afp4wfp4_pre_quant
234234
235235
236- ENABLE_FP4 = True
237-
238236class QueryLenSupport (Enum ):
239237 """Defines the level of query length support for an attention backend's
240238 decode pipeline.
@@ -1199,7 +1197,6 @@ def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
11991197 # Convert from (B, N, L) to (N, B, L)
12001198 # x [num_heads, batch_size, kv_lora_rank]
12011199 x = x .view (- 1 , self .num_heads , self .kv_lora_rank ).transpose (0 , 1 )
1202- print ("[Unified Path]" , "out shape:" , out .shape , "out dtype:" , out .dtype )
12031200
12041201 if self .W_V .dtype == torch .uint8 :
12051202 out = out .view (- 1 , self .num_heads , self .v_head_dim )
@@ -1209,13 +1206,9 @@ def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
12091206 out_buffer = torch .empty (
12101207 x .shape [0 ], # num_heads
12111208 x .shape [1 ], # batchsize
1212- self .W_V .shape [2 ] * 2 , # v
1209+ self .W_V .shape [1 ] , # v
12131210 device = x .device ,
12141211 dtype = torch .bfloat16 )
1215- print ("In _v_up_proj:" )
1216- print ("x.shape:" , x .shape , " self.W_V.shape:" , self .W_V .shape ,
1217- "out_buffer.shape:" , out_buffer .shape , " out.shape:" ,
1218- out .shape )
12191212 batched_gemm_afp4wfp4_pre_quant (
12201213 x ,
12211214 self .W_V ,
@@ -1224,7 +1217,6 @@ def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
12241217 out_buffer
12251218 )
12261219 out_buffer = out_buffer .transpose (0 , 1 ) # [batchsize, num_heads, v]
1227- #out = out.transpose(0, 1) # [num_heads, batch_size, v]
12281220 out .copy_ (out_buffer )
12291221 elif is_rocm_aiter_fp8bmm_enabled () and (not ENABLE_FP4 ):
12301222 out = out .view (- 1 , self .num_heads , self .v_head_dim )
@@ -1474,11 +1466,8 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
14741466 return dequant_weights .T
14751467 return layer .weight
14761468
1477- print ("self.kv_b_proj:" , self .kv_b_proj .weight .shape )
1478- print ("self.qk_nope_head_dim:" , self .qk_nope_head_dim , "self.v_head_dim:" , self .v_head_dim )
1479-
1480- if self .kv_b_proj .weight .dtype == torch .uint8 and ENABLE_FP4 : # mxfp4 elemnts packed in a byte
1481- # self.kv_b_proj [num_heads * (qk_nope_head_dim + v_head_dim), q_lora_rank]
1469+ if self .kv_b_proj .weight .dtype == torch .uint8 : # mxfp4 elemnts packed in a byte
1470+ # kv_b_proj [num_heads * (qk_nope_head_dim + v_head_dim), q_lora_rank]
14821471 kv_b_proj_weight = self .kv_b_proj .weight .T
14831472 kv_b_proj_weight = kv_b_proj_weight .reshape (
14841473 self .kv_lora_rank ,
@@ -1490,24 +1479,25 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
14901479 )
14911480 # W_K [self.kv_lora_rank, num_heads, qk_nope_head_dim // 2] -> [num_heads, kv_lora_rank, qk_nope_head_dim //2]
14921481 self .W_K = W_UK .transpose (0 , 1 )
1493- # W_V [kv_lora_rank, num_heads, v_head_dim // 2] -> [num_heads, v_head_dim // 2, kv_lora_rank]
1494- #self.W_V = W_UV.permute(1, 2, 0)
1495- self .W_V = W_UV .transpose (0 , 1 )
1482+ # W_V [kv_lora_rank, num_heads, v_head_dim // 2] -> [num_heads, v_head_dim, kv_lora_rank // 2]
1483+ # Alway pack at the last dimension, need check acc here.
1484+ self .W_V = W_UV .permute (1 , 2 , 0 )
1485+ self .W_V = self .W_V .reshape (self .num_heads , self .v_head_dim , self .kv_lora_rank // 2 )
14961486
1497- # split w_scale
14981487 kv_b_proj_weight_sc = self .kv_b_proj .weight_scale
1499- print ( " kv_b_proj_weight_sc.shape:" , kv_b_proj_weight_sc . shape )
1500- # Shape should be [num_headsx(qk_nope_head_dim+v_head_dim), kv_lora_rank // 32]
1501- kv_b_proj_weight_sc = self . kv_b_proj . weight_scale . T . reshape (
1502- self .kv_lora_rank ,
1488+ # kv_b_proj_weight_sc: [num_heads x (qk_nope_head_dim+v_head_dim), kv_lora_rank // 32]
1489+
1490+ # Obtain W_V_Scale first
1491+ W_scale = self .kv_b_proj . weight_scale . view (
15031492 self .num_heads ,
1504- self .qk_nope_head_dim // 32 + self .v_head_dim // 32
1505- )
1506- # self.W_K_scale [kv_lora_rank, num_heads, qk_nope_head_dim //32]
1507- self .W_K_scale , self .W_V_scale = kv_b_proj_weight_sc .split (
1508- [self .qk_nope_head_dim // 32 , self .v_head_dim // 32 ], dim = - 1 )
1509- self .W_K_scale = self .W_K_scale .transpose (0 , 1 )
1510- self .W_V_scale = self .W_V_scale .permute (1 , 2 , 0 )
1493+ self .qk_nope_head_dim + self .v_head_dim ,
1494+ self .kv_lora_rank // 32 )
1495+ self .W_K_scale , self .W_V_scale = W_scale .split ([self .qk_nope_head_dim , self .v_head_dim ], dim = 1 )
1496+
1497+ # Obtain W_K_scale
1498+ self .W_K_scale = self .W_K_scale .view (self .num_heads , self .qk_nope_head_dim // 32 , self .kv_lora_rank )
1499+ self .W_K_scale = self .W_K_scale .permute (0 , 2 , 1 )
1500+
15111501 max_batch_size = 1024 # [ToDo] Find the optimal upper limit
15121502 pre_compilation_list = list (range (1 , max_batch_size + 1 ))
15131503 if is_global_first_rank ():
@@ -1517,24 +1507,17 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
15171507 total = max_batch_size ,
15181508 )
15191509 for m in pre_compilation_list :
1520- #print("Pre-Compiling first kernel", flush=True)
1521- #print("self.W_K.shape:", self.W_K.shape, flush=True)
1522- #print("self.W_K_scale.shape:", self.W_K_scale.shape, flush=True)
15231510 # [ num_heads, m, qk_nope_head_dim // 2 * 2]
15241511 x = torch .empty (
15251512 (self .W_K .shape [0 ], m , self .W_K .shape [2 ] * 2 ),
15261513 dtype = torch .bfloat16 ,
15271514 device = self .W_K .device ,
15281515 )
1529- #print("x.shape:", x.shape, flush=True)
15301516 # x shape [ num_heads, m , qk_nope_head_dim //2 * 2]
15311517 # W_K shape [num_heads, kv_lora_ranks, qk_nope_head_dim //2]
15321518 out = torch .empty (
15331519 x .shape [0 ], x .shape [1 ], self .W_K .shape [1 ], device = x .device , dtype = torch .bfloat16
15341520 )
1535- #print("out.shape:", out.shape, flush=True)
1536-
1537- # self.W_K [kv_lora_rank, num_heads, qk_nope_head_dim //32]
15381521
15391522 batched_gemm_afp4wfp4_pre_quant (
15401523 x ,
@@ -1544,31 +1527,26 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
15441527 out
15451528 )
15461529
1547- print ("Pre-Compiling second kernel" , flush = True )
15481530 ## x [ num_heads, m, kv_lora_rank]
15491531 x = torch .empty (
1550- (self .W_V .shape [0 ], m , self .W_V .shape [2 ] * 2 ),
1532+ (self .W_V .shape [0 ], m , self .W_V .shape [2 ] ** 2 ),
15511533 dtype = torch .bfloat16 ,
15521534 device = self .W_V .device ,
15531535 )
1554- print ("x.shape:" , x .shape , flush = True )
15551536 ## [num_heads, m, kv_lora_rank] x [ num_heads, v_head_dim // 2, kv_lora_rank]
15561537 ## [num_heads, m, v_head_dim //2]
15571538 out = torch .empty (
1558- x .shape [0 ], x .shape [1 ], self .W_K .shape [1 ], device = x .device , dtype = torch .bfloat16 )
1559- print ("out.shape:" , out .shape , flush = True )
1560- print ("self.W_V.shape:" , self .W_V .shape , flush = True )
1561- print ("self.W_V_scale.shape:" , self .W_V_scale .shape , flush = True )
1539+ x .shape [0 ], x .shape [1 ], self .W_V .shape [1 ], device = x .device , dtype = torch .bfloat16 )
15621540 batched_gemm_afp4wfp4_pre_quant (
15631541 x ,
15641542 self .W_V ,
15651543 self .W_V_scale ,
15661544 torch .bfloat16 ,
15671545 out
15681546 )
1547+ # Early return, the left is for fp8 scenario.
15691548 return
15701549
1571-
15721550 # we currently do not have quantized bmm's which are needed for
15731551 # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
15741552 # the bmm's in 16-bit, the extra memory overhead of this is fairly low
@@ -2002,7 +1980,7 @@ def forward(
20021980 decode_pe_padded .copy_ (decode_q_pe )
20031981 decode_q_pe = decode_pe_padded
20041982
2005- if self .kv_b_proj .weight .dtype == torch .uint8 and ENABLE_FP4 :
1983+ if self .kv_b_proj .weight .dtype == torch .uint8 :
20061984 decode_ql_nope = torch .empty (decode_q_nope .shape [0 ], decode_q_nope .shape [1 ], self .W_K .shape [1 ], dtype = torch .bfloat16 , device = decode_q_nope .device )
20071985 batched_gemm_afp4wfp4_pre_quant (
20081986 decode_q_nope ,
@@ -2012,7 +1990,6 @@ def forward(
20121990 decode_ql_nope
20131991 )
20141992 decode_ql_nope = decode_ql_nope .transpose (0 , 1 )
2015- print ("[FP4 Path] decode_ql_nope.shape:" , decode_ql_nope .shape , "dtype:" , decode_ql_nope .dtype )
20161993 elif is_rocm_aiter_fp8bmm_enabled ():
20171994 # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
20181995 decode_ql_nope = aiter_triton_fp8_bmm (
@@ -2022,7 +1999,6 @@ def forward(
20221999 group_size = 128 ,
20232000 transpose_bm = True ,
20242001 )
2025- print ("[FP8 Path] decode_ql_nope.shape:" , decode_ql_nope .shape , "dtype:" , decode_ql_nope .dtype )
20262002 else :
20272003 # Pads the head_dim if necessary (for the underlying kernel)
20282004 N , B , P = decode_q_nope .shape
0 commit comments