-
Notifications
You must be signed in to change notification settings - Fork 594
Open
Labels
Description
Hi community, wondering if we are gre going to / already support batch invariant for trtllm_batch_decode_with_kv_cache_mla?
If it is supported, upstream vllm could call the function without a for loop like
num = q.shape[0]
outs = []
for i in range(num):
qi = q[i : i + 1]
bt_i = attn_metadata.decode.block_table[i : i + 1]
sl_i = attn_metadata.decode.seq_lens[i : i + 1]
oi = trtllm_batch_decode_with_kv_cache_mla(
query=qi,
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
workspace_buffer=self._workspace_buffer,
qk_nope_head_dim=self.qk_nope_head_dim,
kv_lora_rank=self.kv_lora_rank,
qk_rope_head_dim=self.qk_rope_head_dim,
block_tables=bt_i,
seq_lens=sl_i,
max_seq_len=attn_metadata.max_seq_len,
bmm1_scale=self.bmm1_scale,
bmm2_scale=self.bmm2_scale,
)
outs.append(oi)
o = torch.cat(outs, dim=0)Thanks for any answers / help!