233233from aiter .ops .triton .batched_gemm_afp4wfp4_pre_quant import batched_gemm_afp4wfp4_pre_quant
234234
235235
236- ENABLE_FP4 = True
237-
238236class QueryLenSupport (Enum ):
239237 """Defines the level of query length support for an attention backend's
240238 decode pipeline.
@@ -1221,7 +1219,6 @@ def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
12211219 # Convert from (B, N, L) to (N, B, L)
12221220 # x [num_heads, batch_size, kv_lora_rank]
12231221 x = x .view (- 1 , self .num_heads , self .kv_lora_rank ).transpose (0 , 1 )
1224- print ("[Unified Path]" , "out shape:" , out .shape , "out dtype:" , out .dtype )
12251222
12261223 if self .W_V .dtype == torch .uint8 :
12271224 out = out .view (- 1 , self .num_heads , self .v_head_dim )
@@ -1231,13 +1228,9 @@ def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
12311228 out_buffer = torch .empty (
12321229 x .shape [0 ], # num_heads
12331230 x .shape [1 ], # batchsize
1234- self .W_V .shape [2 ] * 2 , # v
1231+ self .W_V .shape [1 ] , # v
12351232 device = x .device ,
12361233 dtype = torch .bfloat16 )
1237- print ("In _v_up_proj:" )
1238- print ("x.shape:" , x .shape , " self.W_V.shape:" , self .W_V .shape ,
1239- "out_buffer.shape:" , out_buffer .shape , " out.shape:" ,
1240- out .shape )
12411234 batched_gemm_afp4wfp4_pre_quant (
12421235 x ,
12431236 self .W_V ,
@@ -1246,7 +1239,6 @@ def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
12461239 out_buffer
12471240 )
12481241 out_buffer = out_buffer .transpose (0 , 1 ) # [batchsize, num_heads, v]
1249- #out = out.transpose(0, 1) # [num_heads, batch_size, v]
12501242 out .copy_ (out_buffer )
12511243 elif is_rocm_aiter_fp8bmm_enabled () and (not ENABLE_FP4 ):
12521244 out = out .view (- 1 , self .num_heads , self .v_head_dim )
@@ -1580,11 +1572,8 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
15801572 return dequant_weights .T
15811573 return layer .weight
15821574
1583- print ("self.kv_b_proj:" , self .kv_b_proj .weight .shape )
1584- print ("self.qk_nope_head_dim:" , self .qk_nope_head_dim , "self.v_head_dim:" , self .v_head_dim )
1585-
1586- if self .kv_b_proj .weight .dtype == torch .uint8 and ENABLE_FP4 : # mxfp4 elemnts packed in a byte
1587- # self.kv_b_proj [num_heads * (qk_nope_head_dim + v_head_dim), q_lora_rank]
1575+ if self .kv_b_proj .weight .dtype == torch .uint8 : # mxfp4 elemnts packed in a byte
1576+ # kv_b_proj [num_heads * (qk_nope_head_dim + v_head_dim), q_lora_rank]
15881577 kv_b_proj_weight = self .kv_b_proj .weight .T
15891578 kv_b_proj_weight = kv_b_proj_weight .reshape (
15901579 self .kv_lora_rank ,
@@ -1596,24 +1585,25 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
15961585 )
15971586 # W_K [self.kv_lora_rank, num_heads, qk_nope_head_dim // 2] -> [num_heads, kv_lora_rank, qk_nope_head_dim //2]
15981587 self .W_K = W_UK .transpose (0 , 1 )
1599- # W_V [kv_lora_rank, num_heads, v_head_dim // 2] -> [num_heads, v_head_dim // 2, kv_lora_rank]
1600- #self.W_V = W_UV.permute(1, 2, 0)
1601- self .W_V = W_UV .transpose (0 , 1 )
1588+ # W_V [kv_lora_rank, num_heads, v_head_dim // 2] -> [num_heads, v_head_dim, kv_lora_rank // 2]
1589+ # Alway pack at the last dimension, need check acc here.
1590+ self .W_V = W_UV .permute (1 , 2 , 0 )
1591+ self .W_V = self .W_V .reshape (self .num_heads , self .v_head_dim , self .kv_lora_rank // 2 )
16021592
1603- # split w_scale
16041593 kv_b_proj_weight_sc = self .kv_b_proj .weight_scale
1605- print ( " kv_b_proj_weight_sc.shape:" , kv_b_proj_weight_sc . shape )
1606- # Shape should be [num_headsx(qk_nope_head_dim+v_head_dim), kv_lora_rank // 32]
1607- kv_b_proj_weight_sc = self . kv_b_proj . weight_scale . T . reshape (
1608- self .kv_lora_rank ,
1594+ # kv_b_proj_weight_sc: [num_heads x (qk_nope_head_dim+v_head_dim), kv_lora_rank // 32]
1595+
1596+ # Obtain W_V_Scale first
1597+ W_scale = self .kv_b_proj . weight_scale . view (
16091598 self .num_heads ,
1610- self .qk_nope_head_dim // 32 + self .v_head_dim // 32
1611- )
1612- # self.W_K_scale [kv_lora_rank, num_heads, qk_nope_head_dim //32]
1613- self .W_K_scale , self .W_V_scale = kv_b_proj_weight_sc .split (
1614- [self .qk_nope_head_dim // 32 , self .v_head_dim // 32 ], dim = - 1 )
1615- self .W_K_scale = self .W_K_scale .transpose (0 , 1 )
1616- self .W_V_scale = self .W_V_scale .permute (1 , 2 , 0 )
1599+ self .qk_nope_head_dim + self .v_head_dim ,
1600+ self .kv_lora_rank // 32 )
1601+ self .W_K_scale , self .W_V_scale = W_scale .split ([self .qk_nope_head_dim , self .v_head_dim ], dim = 1 )
1602+
1603+ # Obtain W_K_scale
1604+ self .W_K_scale = self .W_K_scale .view (self .num_heads , self .qk_nope_head_dim // 32 , self .kv_lora_rank )
1605+ self .W_K_scale = self .W_K_scale .permute (0 , 2 , 1 )
1606+
16171607 max_batch_size = 1024 # [ToDo] Find the optimal upper limit
16181608 pre_compilation_list = list (range (1 , max_batch_size + 1 ))
16191609 if is_global_first_rank ():
@@ -1623,24 +1613,17 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
16231613 total = max_batch_size ,
16241614 )
16251615 for m in pre_compilation_list :
1626- #print("Pre-Compiling first kernel", flush=True)
1627- #print("self.W_K.shape:", self.W_K.shape, flush=True)
1628- #print("self.W_K_scale.shape:", self.W_K_scale.shape, flush=True)
16291616 # [ num_heads, m, qk_nope_head_dim // 2 * 2]
16301617 x = torch .empty (
16311618 (self .W_K .shape [0 ], m , self .W_K .shape [2 ] * 2 ),
16321619 dtype = torch .bfloat16 ,
16331620 device = self .W_K .device ,
16341621 )
1635- #print("x.shape:", x.shape, flush=True)
16361622 # x shape [ num_heads, m , qk_nope_head_dim //2 * 2]
16371623 # W_K shape [num_heads, kv_lora_ranks, qk_nope_head_dim //2]
16381624 out = torch .empty (
16391625 x .shape [0 ], x .shape [1 ], self .W_K .shape [1 ], device = x .device , dtype = torch .bfloat16
16401626 )
1641- #print("out.shape:", out.shape, flush=True)
1642-
1643- # self.W_K [kv_lora_rank, num_heads, qk_nope_head_dim //32]
16441627
16451628 batched_gemm_afp4wfp4_pre_quant (
16461629 x ,
@@ -1650,31 +1633,26 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
16501633 out
16511634 )
16521635
1653- print ("Pre-Compiling second kernel" , flush = True )
16541636 ## x [ num_heads, m, kv_lora_rank]
16551637 x = torch .empty (
1656- (self .W_V .shape [0 ], m , self .W_V .shape [2 ] * 2 ),
1638+ (self .W_V .shape [0 ], m , self .W_V .shape [2 ] ** 2 ),
16571639 dtype = torch .bfloat16 ,
16581640 device = self .W_V .device ,
16591641 )
1660- print ("x.shape:" , x .shape , flush = True )
16611642 ## [num_heads, m, kv_lora_rank] x [ num_heads, v_head_dim // 2, kv_lora_rank]
16621643 ## [num_heads, m, v_head_dim //2]
16631644 out = torch .empty (
1664- x .shape [0 ], x .shape [1 ], self .W_K .shape [1 ], device = x .device , dtype = torch .bfloat16 )
1665- print ("out.shape:" , out .shape , flush = True )
1666- print ("self.W_V.shape:" , self .W_V .shape , flush = True )
1667- print ("self.W_V_scale.shape:" , self .W_V_scale .shape , flush = True )
1645+ x .shape [0 ], x .shape [1 ], self .W_V .shape [1 ], device = x .device , dtype = torch .bfloat16 )
16681646 batched_gemm_afp4wfp4_pre_quant (
16691647 x ,
16701648 self .W_V ,
16711649 self .W_V_scale ,
16721650 torch .bfloat16 ,
16731651 out
16741652 )
1653+ # Early return, the left is for fp8 scenario.
16751654 return
16761655
1677-
16781656 # we currently do not have quantized bmm's which are needed for
16791657 # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
16801658 # the bmm's in 16-bit, the extra memory overhead of this is fairly low
@@ -2108,7 +2086,7 @@ def forward(
21082086 decode_pe_padded .copy_ (decode_q_pe )
21092087 decode_q_pe = decode_pe_padded
21102088
2111- if self .kv_b_proj .weight .dtype == torch .uint8 and ENABLE_FP4 :
2089+ if self .kv_b_proj .weight .dtype == torch .uint8 :
21122090 decode_ql_nope = torch .empty (decode_q_nope .shape [0 ], decode_q_nope .shape [1 ], self .W_K .shape [1 ], dtype = torch .bfloat16 , device = decode_q_nope .device )
21132091 batched_gemm_afp4wfp4_pre_quant (
21142092 decode_q_nope ,
@@ -2118,7 +2096,6 @@ def forward(
21182096 decode_ql_nope
21192097 )
21202098 decode_ql_nope = decode_ql_nope .transpose (0 , 1 )
2121- print ("[FP4 Path] decode_ql_nope.shape:" , decode_ql_nope .shape , "dtype:" , decode_ql_nope .dtype )
21222099 elif is_rocm_aiter_fp8bmm_enabled ():
21232100 # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
21242101 decode_ql_nope = aiter_triton_fp8_bmm (
@@ -2128,7 +2105,6 @@ def forward(
21282105 group_size = 128 ,
21292106 transpose_bm = True ,
21302107 )
2131- print ("[FP8 Path] decode_ql_nope.shape:" , decode_ql_nope .shape , "dtype:" , decode_ql_nope .dtype )
21322108 else :
21332109 # Pads the head_dim if necessary (for the underlying kernel)
21342110 N , B , P = decode_q_nope .shape
0 commit comments