diff --git a/vllm/attention/backends/dual_chunk_flash_attn.py b/vllm/attention/backends/dual_chunk_flash_attn.py index e108646e7ffb..fa6f3f1b39cc 100644 --- a/vllm/attention/backends/dual_chunk_flash_attn.py +++ b/vllm/attention/backends/dual_chunk_flash_attn.py @@ -1055,7 +1055,6 @@ def _dual_chunk_flash_attn_prefill_func( v_states_intra, softmax_scale=softmax_scale, causal=True, - block_table=block_table, stage="intra", vertical_indices=vertical_buffer, slash_indices=slash_buffer, @@ -1070,7 +1069,6 @@ def _dual_chunk_flash_attn_prefill_func( v_states_intra, softmax_scale=softmax_scale, causal=True, - block_table=block_table, stage="intra", vertical_indices=intra_vertical_indices, slash_indices=intra_slash_indices, @@ -1085,7 +1083,6 @@ def _dual_chunk_flash_attn_prefill_func( v_states_succ, softmax_scale=softmax_scale, causal=False, - block_table=block_table, stage="succ", vertical_indices=succ_vertical_buffer, slash_indices=succ_slash_buffer, @@ -1100,7 +1097,6 @@ def _dual_chunk_flash_attn_prefill_func( v_states_succ, softmax_scale=softmax_scale, causal=False, - block_table=block_table, stage="succ", vertical_indices=succ_vertical_indices, slash_indices=succ_slash_indices, @@ -1115,7 +1111,6 @@ def _dual_chunk_flash_attn_prefill_func( v_states_inter, softmax_scale=softmax_scale, causal=False, - block_table=block_table, stage="inter", vertical_indices=inter_vertical_buffer, slash_indices=inter_slash_buffer, @@ -1130,7 +1125,6 @@ def _dual_chunk_flash_attn_prefill_func( v_states_inter, softmax_scale=softmax_scale, causal=False, - block_table=block_table, stage="inter", vertical_indices=inter_vertical_indices, slash_indices=inter_slash_indices, @@ -1151,7 +1145,6 @@ def _do_flash_attn( value_states: torch.Tensor, softmax_scale: float, causal: bool = True, - block_table: torch.Tensor = None, max_seqlen_k: Optional[int] = None, stage: str = "intra", vertical_indices: Optional[torch.Tensor] = None, @@ -1230,7 +1223,6 @@ def _do_flash_attn( device=query_states.device), max_seqlen_k=max_seqlen_k, causal=causal, - block_table=block_table.unsqueeze(0), return_softmax_lse=True, ) softmax_lse = softmax_lse.view(q_len, q_heads, 1).transpose(0,