Skip to content

[Feature Request] trtllm_batch_decode_with_kv_cache_mla Batch Invariant support #2107

@yewentao256

Description

@yewentao256

Hi community, wondering if we are gre going to / already support batch invariant for trtllm_batch_decode_with_kv_cache_mla?

If it is supported, upstream vllm could call the function without a for loop like

    num = q.shape[0]
    outs = []
    for i in range(num):
        qi = q[i : i + 1]
        bt_i = attn_metadata.decode.block_table[i : i + 1]
        sl_i = attn_metadata.decode.seq_lens[i : i + 1]
        oi = trtllm_batch_decode_with_kv_cache_mla(
            query=qi,
            kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
            workspace_buffer=self._workspace_buffer,
            qk_nope_head_dim=self.qk_nope_head_dim,
            kv_lora_rank=self.kv_lora_rank,
            qk_rope_head_dim=self.qk_rope_head_dim,
            block_tables=bt_i,
            seq_lens=sl_i,
            max_seq_len=attn_metadata.max_seq_len,
            bmm1_scale=self.bmm1_scale,
            bmm2_scale=self.bmm2_scale,
        )
        outs.append(oi)
    o = torch.cat(outs, dim=0)

Thanks for any answers / help!

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions