Skip to content

Commit 6ea5637

Browse files
committed
clean code
1 parent 677c715 commit 6ea5637

File tree

1 file changed

+23
-47
lines changed

1 file changed

+23
-47
lines changed

vllm/v1/attention/backends/mla/common.py

Lines changed: 23 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -233,8 +233,6 @@
233233
from aiter.ops.triton.batched_gemm_afp4wfp4_pre_quant import batched_gemm_afp4wfp4_pre_quant
234234

235235

236-
ENABLE_FP4=True
237-
238236
class QueryLenSupport(Enum):
239237
"""Defines the level of query length support for an attention backend's
240238
decode pipeline.
@@ -1199,7 +1197,6 @@ def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
11991197
# Convert from (B, N, L) to (N, B, L)
12001198
# x [num_heads, batch_size, kv_lora_rank]
12011199
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
1202-
print("[Unified Path]", "out shape:", out.shape, "out dtype:", out.dtype)
12031200

12041201
if self.W_V.dtype == torch.uint8:
12051202
out = out.view(-1, self.num_heads, self.v_head_dim)
@@ -1209,13 +1206,9 @@ def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
12091206
out_buffer = torch.empty(
12101207
x.shape[0], # num_heads
12111208
x.shape[1], # batchsize
1212-
self.W_V.shape[2] * 2, # v
1209+
self.W_V.shape[1], # v
12131210
device=x.device,
12141211
dtype=torch.bfloat16)
1215-
print("In _v_up_proj:")
1216-
print("x.shape:", x.shape, " self.W_V.shape:", self.W_V.shape,
1217-
"out_buffer.shape:", out_buffer.shape, " out.shape:",
1218-
out.shape)
12191212
batched_gemm_afp4wfp4_pre_quant(
12201213
x,
12211214
self.W_V,
@@ -1224,7 +1217,6 @@ def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
12241217
out_buffer
12251218
)
12261219
out_buffer = out_buffer.transpose(0, 1) # [batchsize, num_heads, v]
1227-
#out = out.transpose(0, 1) # [num_heads, batch_size, v]
12281220
out.copy_(out_buffer)
12291221
elif is_rocm_aiter_fp8bmm_enabled() and (not ENABLE_FP4):
12301222
out = out.view(-1, self.num_heads, self.v_head_dim)
@@ -1474,11 +1466,8 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
14741466
return dequant_weights.T
14751467
return layer.weight
14761468

1477-
print("self.kv_b_proj:", self.kv_b_proj.weight.shape)
1478-
print("self.qk_nope_head_dim:", self.qk_nope_head_dim, "self.v_head_dim:", self.v_head_dim)
1479-
1480-
if self.kv_b_proj.weight.dtype == torch.uint8 and ENABLE_FP4: # mxfp4 elemnts packed in a byte
1481-
# self.kv_b_proj [num_heads * (qk_nope_head_dim + v_head_dim), q_lora_rank]
1469+
if self.kv_b_proj.weight.dtype == torch.uint8: # mxfp4 elemnts packed in a byte
1470+
# kv_b_proj [num_heads * (qk_nope_head_dim + v_head_dim), q_lora_rank]
14821471
kv_b_proj_weight = self.kv_b_proj.weight.T
14831472
kv_b_proj_weight = kv_b_proj_weight.reshape(
14841473
self.kv_lora_rank,
@@ -1490,24 +1479,25 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
14901479
)
14911480
# W_K [self.kv_lora_rank, num_heads, qk_nope_head_dim // 2] -> [num_heads, kv_lora_rank, qk_nope_head_dim //2]
14921481
self.W_K = W_UK.transpose(0, 1)
1493-
# W_V [kv_lora_rank, num_heads, v_head_dim // 2] -> [num_heads, v_head_dim // 2, kv_lora_rank]
1494-
#self.W_V = W_UV.permute(1, 2, 0)
1495-
self.W_V = W_UV.transpose(0, 1)
1482+
# W_V [kv_lora_rank, num_heads, v_head_dim // 2] -> [num_heads, v_head_dim, kv_lora_rank // 2]
1483+
# Alway pack at the last dimension, need check acc here.
1484+
self.W_V = W_UV.permute(1, 2, 0)
1485+
self.W_V = self.W_V.reshape(self.num_heads, self.v_head_dim, self.kv_lora_rank // 2)
14961486

1497-
# split w_scale
14981487
kv_b_proj_weight_sc = self.kv_b_proj.weight_scale
1499-
print("kv_b_proj_weight_sc.shape:", kv_b_proj_weight_sc.shape)
1500-
# Shape should be [num_headsx(qk_nope_head_dim+v_head_dim), kv_lora_rank // 32]
1501-
kv_b_proj_weight_sc = self.kv_b_proj.weight_scale.T.reshape(
1502-
self.kv_lora_rank,
1488+
# kv_b_proj_weight_sc: [num_heads x (qk_nope_head_dim+v_head_dim), kv_lora_rank // 32]
1489+
1490+
# Obtain W_V_Scale first
1491+
W_scale = self.kv_b_proj.weight_scale.view(
15031492
self.num_heads,
1504-
self.qk_nope_head_dim // 32 + self.v_head_dim // 32
1505-
)
1506-
# self.W_K_scale [kv_lora_rank, num_heads, qk_nope_head_dim //32]
1507-
self.W_K_scale, self.W_V_scale = kv_b_proj_weight_sc.split(
1508-
[self.qk_nope_head_dim // 32, self.v_head_dim // 32], dim=-1)
1509-
self.W_K_scale = self.W_K_scale.transpose(0, 1)
1510-
self.W_V_scale = self.W_V_scale.permute(1, 2, 0)
1493+
self.qk_nope_head_dim + self.v_head_dim,
1494+
self.kv_lora_rank // 32)
1495+
self.W_K_scale, self.W_V_scale = W_scale.split([self.qk_nope_head_dim, self.v_head_dim], dim=1)
1496+
1497+
# Obtain W_K_scale
1498+
self.W_K_scale = self.W_K_scale.view(self.num_heads, self.qk_nope_head_dim//32, self.kv_lora_rank)
1499+
self.W_K_scale = self.W_K_scale.permute(0, 2, 1)
1500+
15111501
max_batch_size = 1024 # [ToDo] Find the optimal upper limit
15121502
pre_compilation_list = list(range(1, max_batch_size + 1))
15131503
if is_global_first_rank():
@@ -1517,24 +1507,17 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
15171507
total=max_batch_size,
15181508
)
15191509
for m in pre_compilation_list:
1520-
#print("Pre-Compiling first kernel", flush=True)
1521-
#print("self.W_K.shape:", self.W_K.shape, flush=True)
1522-
#print("self.W_K_scale.shape:", self.W_K_scale.shape, flush=True)
15231510
# [ num_heads, m, qk_nope_head_dim // 2 * 2]
15241511
x = torch.empty(
15251512
(self.W_K.shape[0], m, self.W_K.shape[2] * 2),
15261513
dtype=torch.bfloat16,
15271514
device=self.W_K.device,
15281515
)
1529-
#print("x.shape:", x.shape, flush=True)
15301516
# x shape [ num_heads, m , qk_nope_head_dim //2 * 2]
15311517
# W_K shape [num_heads, kv_lora_ranks, qk_nope_head_dim //2]
15321518
out = torch.empty(
15331519
x.shape[0], x.shape[1], self.W_K.shape[1], device=x.device, dtype=torch.bfloat16
15341520
)
1535-
#print("out.shape:", out.shape, flush=True)
1536-
1537-
# self.W_K [kv_lora_rank, num_heads, qk_nope_head_dim //32]
15381521

15391522
batched_gemm_afp4wfp4_pre_quant(
15401523
x,
@@ -1544,31 +1527,26 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
15441527
out
15451528
)
15461529

1547-
print("Pre-Compiling second kernel", flush=True)
15481530
## x [ num_heads, m, kv_lora_rank]
15491531
x = torch.empty(
1550-
(self.W_V.shape[0], m, self.W_V.shape[2] * 2),
1532+
(self.W_V.shape[0], m, self.W_V.shape[2] ** 2),
15511533
dtype=torch.bfloat16,
15521534
device=self.W_V.device,
15531535
)
1554-
print("x.shape:", x.shape, flush=True)
15551536
## [num_heads, m, kv_lora_rank] x [ num_heads, v_head_dim // 2, kv_lora_rank]
15561537
## [num_heads, m, v_head_dim //2]
15571538
out = torch.empty(
1558-
x.shape[0], x.shape[1], self.W_K.shape[1], device=x.device, dtype=torch.bfloat16)
1559-
print("out.shape:", out.shape, flush=True)
1560-
print("self.W_V.shape:", self.W_V.shape, flush=True)
1561-
print("self.W_V_scale.shape:", self.W_V_scale.shape, flush=True)
1539+
x.shape[0], x.shape[1], self.W_V.shape[1], device=x.device, dtype=torch.bfloat16)
15621540
batched_gemm_afp4wfp4_pre_quant(
15631541
x,
15641542
self.W_V,
15651543
self.W_V_scale,
15661544
torch.bfloat16,
15671545
out
15681546
)
1547+
# Early return, the left is for fp8 scenario.
15691548
return
15701549

1571-
15721550
# we currently do not have quantized bmm's which are needed for
15731551
# `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
15741552
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
@@ -2002,7 +1980,7 @@ def forward(
20021980
decode_pe_padded.copy_(decode_q_pe)
20031981
decode_q_pe = decode_pe_padded
20041982

2005-
if self.kv_b_proj.weight.dtype == torch.uint8 and ENABLE_FP4:
1983+
if self.kv_b_proj.weight.dtype == torch.uint8:
20061984
decode_ql_nope = torch.empty(decode_q_nope.shape[0], decode_q_nope.shape[1], self.W_K.shape[1], dtype=torch.bfloat16, device=decode_q_nope.device)
20071985
batched_gemm_afp4wfp4_pre_quant(
20081986
decode_q_nope,
@@ -2012,7 +1990,6 @@ def forward(
20121990
decode_ql_nope
20131991
)
20141992
decode_ql_nope = decode_ql_nope.transpose(0, 1)
2015-
print("[FP4 Path] decode_ql_nope.shape:", decode_ql_nope.shape, "dtype:", decode_ql_nope.dtype)
20161993
elif is_rocm_aiter_fp8bmm_enabled():
20171994
# Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
20181995
decode_ql_nope = aiter_triton_fp8_bmm(
@@ -2022,7 +1999,6 @@ def forward(
20221999
group_size=128,
20232000
transpose_bm=True,
20242001
)
2025-
print("[FP8 Path] decode_ql_nope.shape:", decode_ql_nope.shape, "dtype:", decode_ql_nope.dtype)
20262002
else:
20272003
# Pads the head_dim if necessary (for the underlying kernel)
20282004
N, B, P = decode_q_nope.shape

0 commit comments

Comments
 (0)