Skip to content

Commit 95308fd

Browse files
ZhiweiYan-96zejunchen-zejun
authored andcommitted
clean code
1 parent e484e3c commit 95308fd

File tree

1 file changed

+23
-47
lines changed

1 file changed

+23
-47
lines changed

vllm/v1/attention/backends/mla/common.py

Lines changed: 23 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -233,8 +233,6 @@
233233
from aiter.ops.triton.batched_gemm_afp4wfp4_pre_quant import batched_gemm_afp4wfp4_pre_quant
234234

235235

236-
ENABLE_FP4=True
237-
238236
class QueryLenSupport(Enum):
239237
"""Defines the level of query length support for an attention backend's
240238
decode pipeline.
@@ -1221,7 +1219,6 @@ def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
12211219
# Convert from (B, N, L) to (N, B, L)
12221220
# x [num_heads, batch_size, kv_lora_rank]
12231221
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
1224-
print("[Unified Path]", "out shape:", out.shape, "out dtype:", out.dtype)
12251222

12261223
if self.W_V.dtype == torch.uint8:
12271224
out = out.view(-1, self.num_heads, self.v_head_dim)
@@ -1231,13 +1228,9 @@ def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
12311228
out_buffer = torch.empty(
12321229
x.shape[0], # num_heads
12331230
x.shape[1], # batchsize
1234-
self.W_V.shape[2] * 2, # v
1231+
self.W_V.shape[1], # v
12351232
device=x.device,
12361233
dtype=torch.bfloat16)
1237-
print("In _v_up_proj:")
1238-
print("x.shape:", x.shape, " self.W_V.shape:", self.W_V.shape,
1239-
"out_buffer.shape:", out_buffer.shape, " out.shape:",
1240-
out.shape)
12411234
batched_gemm_afp4wfp4_pre_quant(
12421235
x,
12431236
self.W_V,
@@ -1246,7 +1239,6 @@ def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
12461239
out_buffer
12471240
)
12481241
out_buffer = out_buffer.transpose(0, 1) # [batchsize, num_heads, v]
1249-
#out = out.transpose(0, 1) # [num_heads, batch_size, v]
12501242
out.copy_(out_buffer)
12511243
elif is_rocm_aiter_fp8bmm_enabled() and (not ENABLE_FP4):
12521244
out = out.view(-1, self.num_heads, self.v_head_dim)
@@ -1580,11 +1572,8 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
15801572
return dequant_weights.T
15811573
return layer.weight
15821574

1583-
print("self.kv_b_proj:", self.kv_b_proj.weight.shape)
1584-
print("self.qk_nope_head_dim:", self.qk_nope_head_dim, "self.v_head_dim:", self.v_head_dim)
1585-
1586-
if self.kv_b_proj.weight.dtype == torch.uint8 and ENABLE_FP4: # mxfp4 elemnts packed in a byte
1587-
# self.kv_b_proj [num_heads * (qk_nope_head_dim + v_head_dim), q_lora_rank]
1575+
if self.kv_b_proj.weight.dtype == torch.uint8: # mxfp4 elemnts packed in a byte
1576+
# kv_b_proj [num_heads * (qk_nope_head_dim + v_head_dim), q_lora_rank]
15881577
kv_b_proj_weight = self.kv_b_proj.weight.T
15891578
kv_b_proj_weight = kv_b_proj_weight.reshape(
15901579
self.kv_lora_rank,
@@ -1596,24 +1585,25 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
15961585
)
15971586
# W_K [self.kv_lora_rank, num_heads, qk_nope_head_dim // 2] -> [num_heads, kv_lora_rank, qk_nope_head_dim //2]
15981587
self.W_K = W_UK.transpose(0, 1)
1599-
# W_V [kv_lora_rank, num_heads, v_head_dim // 2] -> [num_heads, v_head_dim // 2, kv_lora_rank]
1600-
#self.W_V = W_UV.permute(1, 2, 0)
1601-
self.W_V = W_UV.transpose(0, 1)
1588+
# W_V [kv_lora_rank, num_heads, v_head_dim // 2] -> [num_heads, v_head_dim, kv_lora_rank // 2]
1589+
# Alway pack at the last dimension, need check acc here.
1590+
self.W_V = W_UV.permute(1, 2, 0)
1591+
self.W_V = self.W_V.reshape(self.num_heads, self.v_head_dim, self.kv_lora_rank // 2)
16021592

1603-
# split w_scale
16041593
kv_b_proj_weight_sc = self.kv_b_proj.weight_scale
1605-
print("kv_b_proj_weight_sc.shape:", kv_b_proj_weight_sc.shape)
1606-
# Shape should be [num_headsx(qk_nope_head_dim+v_head_dim), kv_lora_rank // 32]
1607-
kv_b_proj_weight_sc = self.kv_b_proj.weight_scale.T.reshape(
1608-
self.kv_lora_rank,
1594+
# kv_b_proj_weight_sc: [num_heads x (qk_nope_head_dim+v_head_dim), kv_lora_rank // 32]
1595+
1596+
# Obtain W_V_Scale first
1597+
W_scale = self.kv_b_proj.weight_scale.view(
16091598
self.num_heads,
1610-
self.qk_nope_head_dim // 32 + self.v_head_dim // 32
1611-
)
1612-
# self.W_K_scale [kv_lora_rank, num_heads, qk_nope_head_dim //32]
1613-
self.W_K_scale, self.W_V_scale = kv_b_proj_weight_sc.split(
1614-
[self.qk_nope_head_dim // 32, self.v_head_dim // 32], dim=-1)
1615-
self.W_K_scale = self.W_K_scale.transpose(0, 1)
1616-
self.W_V_scale = self.W_V_scale.permute(1, 2, 0)
1599+
self.qk_nope_head_dim + self.v_head_dim,
1600+
self.kv_lora_rank // 32)
1601+
self.W_K_scale, self.W_V_scale = W_scale.split([self.qk_nope_head_dim, self.v_head_dim], dim=1)
1602+
1603+
# Obtain W_K_scale
1604+
self.W_K_scale = self.W_K_scale.view(self.num_heads, self.qk_nope_head_dim//32, self.kv_lora_rank)
1605+
self.W_K_scale = self.W_K_scale.permute(0, 2, 1)
1606+
16171607
max_batch_size = 1024 # [ToDo] Find the optimal upper limit
16181608
pre_compilation_list = list(range(1, max_batch_size + 1))
16191609
if is_global_first_rank():
@@ -1623,24 +1613,17 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
16231613
total=max_batch_size,
16241614
)
16251615
for m in pre_compilation_list:
1626-
#print("Pre-Compiling first kernel", flush=True)
1627-
#print("self.W_K.shape:", self.W_K.shape, flush=True)
1628-
#print("self.W_K_scale.shape:", self.W_K_scale.shape, flush=True)
16291616
# [ num_heads, m, qk_nope_head_dim // 2 * 2]
16301617
x = torch.empty(
16311618
(self.W_K.shape[0], m, self.W_K.shape[2] * 2),
16321619
dtype=torch.bfloat16,
16331620
device=self.W_K.device,
16341621
)
1635-
#print("x.shape:", x.shape, flush=True)
16361622
# x shape [ num_heads, m , qk_nope_head_dim //2 * 2]
16371623
# W_K shape [num_heads, kv_lora_ranks, qk_nope_head_dim //2]
16381624
out = torch.empty(
16391625
x.shape[0], x.shape[1], self.W_K.shape[1], device=x.device, dtype=torch.bfloat16
16401626
)
1641-
#print("out.shape:", out.shape, flush=True)
1642-
1643-
# self.W_K [kv_lora_rank, num_heads, qk_nope_head_dim //32]
16441627

16451628
batched_gemm_afp4wfp4_pre_quant(
16461629
x,
@@ -1650,31 +1633,26 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
16501633
out
16511634
)
16521635

1653-
print("Pre-Compiling second kernel", flush=True)
16541636
## x [ num_heads, m, kv_lora_rank]
16551637
x = torch.empty(
1656-
(self.W_V.shape[0], m, self.W_V.shape[2] * 2),
1638+
(self.W_V.shape[0], m, self.W_V.shape[2] ** 2),
16571639
dtype=torch.bfloat16,
16581640
device=self.W_V.device,
16591641
)
1660-
print("x.shape:", x.shape, flush=True)
16611642
## [num_heads, m, kv_lora_rank] x [ num_heads, v_head_dim // 2, kv_lora_rank]
16621643
## [num_heads, m, v_head_dim //2]
16631644
out = torch.empty(
1664-
x.shape[0], x.shape[1], self.W_K.shape[1], device=x.device, dtype=torch.bfloat16)
1665-
print("out.shape:", out.shape, flush=True)
1666-
print("self.W_V.shape:", self.W_V.shape, flush=True)
1667-
print("self.W_V_scale.shape:", self.W_V_scale.shape, flush=True)
1645+
x.shape[0], x.shape[1], self.W_V.shape[1], device=x.device, dtype=torch.bfloat16)
16681646
batched_gemm_afp4wfp4_pre_quant(
16691647
x,
16701648
self.W_V,
16711649
self.W_V_scale,
16721650
torch.bfloat16,
16731651
out
16741652
)
1653+
# Early return, the left is for fp8 scenario.
16751654
return
16761655

1677-
16781656
# we currently do not have quantized bmm's which are needed for
16791657
# `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
16801658
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
@@ -2108,7 +2086,7 @@ def forward(
21082086
decode_pe_padded.copy_(decode_q_pe)
21092087
decode_q_pe = decode_pe_padded
21102088

2111-
if self.kv_b_proj.weight.dtype == torch.uint8 and ENABLE_FP4:
2089+
if self.kv_b_proj.weight.dtype == torch.uint8:
21122090
decode_ql_nope = torch.empty(decode_q_nope.shape[0], decode_q_nope.shape[1], self.W_K.shape[1], dtype=torch.bfloat16, device=decode_q_nope.device)
21132091
batched_gemm_afp4wfp4_pre_quant(
21142092
decode_q_nope,
@@ -2118,7 +2096,6 @@ def forward(
21182096
decode_ql_nope
21192097
)
21202098
decode_ql_nope = decode_ql_nope.transpose(0, 1)
2121-
print("[FP4 Path] decode_ql_nope.shape:", decode_ql_nope.shape, "dtype:", decode_ql_nope.dtype)
21222099
elif is_rocm_aiter_fp8bmm_enabled():
21232100
# Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
21242101
decode_ql_nope = aiter_triton_fp8_bmm(
@@ -2128,7 +2105,6 @@ def forward(
21282105
group_size=128,
21292106
transpose_bm=True,
21302107
)
2131-
print("[FP8 Path] decode_ql_nope.shape:", decode_ql_nope.shape, "dtype:", decode_ql_nope.dtype)
21322108
else:
21332109
# Pads the head_dim if necessary (for the underlying kernel)
21342110
N, B, P = decode_q_nope.shape

0 commit comments

Comments
 (0)