diff --git a/tests/singlecard/test_aclgraph.py b/tests/singlecard/test_aclgraph.py index fb02555956..f36e15473d 100644 --- a/tests/singlecard/test_aclgraph.py +++ b/tests/singlecard/test_aclgraph.py @@ -35,8 +35,8 @@ @pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0", reason="aclgraph only support on v1") @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("max_tokens", [32]) -@pytest.mark.parametrize("full_graph", [False]) +@pytest.mark.parametrize("max_tokens", [12]) +@pytest.mark.parametrize("full_graph", [True, False]) def test_models( model: str, max_tokens: int, diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 3417bb87fb..f03d9f88e6 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -118,7 +118,7 @@ class AscendMetadata: query_start_loc: torch.Tensor query_lens: torch.Tensor seq_lens: torch.Tensor - seq_lens_list: list + seq_lens_list: Optional[list[int]] # Maximum query length in the batch. None for decoding. max_query_len: Optional[int] = None # (num_tokens,). The indices of the token slots that input tokens will be @@ -168,8 +168,9 @@ def build(self, seq_lens = common_attn_metadata.seq_lens # TODO: Refactor these two param to common metadata in runners, # preparing for the hybrid KV groups feature - query_lens = common_attn_metadata.query_lens if common_attn_metadata.query_lens is not None else self.runner.query_lens - seq_lens_list = common_attn_metadata.seq_lens_list if common_attn_metadata.seq_lens_list is not None else self.runner.seq_lens_list + query_lens = common_attn_metadata.query_lens or self.runner.query_lens + # Since FIA for GQA is not active now, we temporarily silence it + seq_lens_list = common_attn_metadata.seq_lens_list slot_mapping = self.runner.slot_mapping[:num_actual_tokens] attn_mask = self.runner.attn_mask @@ -193,8 +194,8 @@ def build_dummy_metadata(self, num_actual_tokens, num_reqs, num_scheduled_tokens, attn_state): if attn_state == AscendAttentionState.DecodeOnly: # NOTE: We only need to pay attention to seq_lens_list and block_table here - common_attn_metadata = CommonAttentionMetadata(seq_lens_list=[2] * - num_reqs) + common_attn_metadata = CommonAttentionMetadata( + seq_lens=torch.empty_like(self.runner.seq_lens_cpu).fill_(2)) block_table = self.runner.input_batch.block_table[0].block_table block_table[:num_reqs, 0] = torch.arange(1, @@ -349,82 +350,42 @@ def forward( scale_value=self.scale, out=output) elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: - if self.full_graph: - graph_params = get_graph_params() - q = query.view(num_tokens, -1, self.hidden_size) - k = self.key_cache.view( # type: ignore - -1, self.block_size, - self.num_kv_heads * self.head_size) - v = self.value_cache.view( # type: ignore - -1, self.block_size, - self.num_kv_heads * self.head_size) - actual_seq_lens = attn_metadata.seq_lens_list - attn_args = { - "query": q, - "key": k, - "value": v, - "actual_seq_lengths_kv": actual_seq_lens, - "block_table": attn_metadata.block_tables, - "num_heads": self.num_heads, - "scale": self.scale, - "input_layout": "BSH", - "num_key_value_heads": self.num_kv_heads, - "block_size": self.block_size, - } - - # Prepare tensors for attention output - # TODO: Refactor this to step-level instead of layer-level - attn_output = torch.empty(num_tokens, - 1, - self.hidden_size, - dtype=output.dtype, - device=output.device) - softmax_lse = torch.empty(num_tokens, - dtype=output.dtype, - device=output.device) - - # Get workspace from cache or calculate it if not present. - workspace = graph_params.workspaces.get(num_tokens) - if workspace is None: - workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( - **attn_args) - graph_params.workspaces[num_tokens] = workspace - - forward_context = get_forward_context() - if not forward_context.capturing: - # Execute attention kernel directly in non-capturing mode - torch.ops.npu.npu_fused_infer_attention_score.out( - workspace=workspace, - out=[attn_output, softmax_lse], - **attn_args) - else: - # Handle graph capturing mode - stream = torch_npu.npu.current_stream() - - event = torch.npu.ExternalEvent() - event.wait(stream) - event.reset(stream) - graph_params.events[num_tokens].append(event) - - graph_params.attn_params[num_tokens].append( - (q, k, v, actual_seq_lens, - attn_metadata.block_tables, self.num_heads, - self.scale, self.num_kv_heads, attn_output, - softmax_lse)) - - torch.npu.graph_task_group_begin(stream) - torch.ops.npu.npu_fused_infer_attention_score.out( - workspace=workspace, - out=[attn_output, softmax_lse], - **attn_args) - handle = torch.npu.graph_task_group_end(stream) - graph_params.handles[num_tokens].append(handle) - - # Reshape output to match the expected format - output.copy_( - attn_output.view(num_tokens, self.num_heads, - self.head_size)) + graph_params = get_graph_params() + + forward_context = get_forward_context() + if not forward_context.capturing: + torch_npu._npu_paged_attention( + query=query, + key_cache=self.key_cache, + value_cache=self.value_cache, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + block_table=attn_metadata.block_tables, + context_lens=attn_metadata.seq_lens, + out=output) else: + # Handle graph capturing mode + stream = torch_npu.npu.current_stream() + + event = torch.npu.ExternalEvent() + event.wait(stream) + event.reset(stream) + graph_params.events[num_tokens].append(event) + + graph_params.attn_params[num_tokens].append(( + query, + self.key_cache, + self.value_cache, + self.num_kv_heads, + self.num_heads, + self.scale, + attn_metadata.block_tables, + attn_metadata.seq_lens, + output, + )) + + torch.npu.graph_task_group_begin(stream) torch_npu._npu_paged_attention( query=query, key_cache=self.key_cache, @@ -435,6 +396,8 @@ def forward( block_table=attn_metadata.block_tables, context_lens=attn_metadata.seq_lens, out=output) + handle = torch.npu.graph_task_group_end(stream) + graph_params.handles[num_tokens].append(handle) # Normal V1 situation. else: # use chunked prefill for head size 192 scenario, like deepseek diff --git a/vllm_ascend/compilation/piecewise_backend.py b/vllm_ascend/compilation/piecewise_backend.py index aafe639373..ca42554fed 100644 --- a/vllm_ascend/compilation/piecewise_backend.py +++ b/vllm_ascend/compilation/piecewise_backend.py @@ -23,6 +23,7 @@ import torch import torch.fx as fx +import torch_npu import vllm.envs as envs from vllm.compilation.backends import VllmBackend from vllm.compilation.counter import compilation_counter @@ -126,29 +127,33 @@ def check_for_ending_compilation(self): def update_attn_params(self, graph_params, forward_context, runtime_shape): for layer_idx in range(len(graph_params.handles[runtime_shape])): - query, key, value, actual_seq_lens, block_table, num_heads, scale, num_kv_heads, output, softmax_lse = graph_params.attn_params[ - runtime_shape][layer_idx] + ( + query, + key_cache, + value_cache, + num_kv_heads, + num_heads, + scale, + block_table, + seq_lens, + output, + ) = graph_params.attn_params[runtime_shape][layer_idx] block_table = forward_context.attn_metadata.block_tables - actual_seq_lens = forward_context.attn_metadata.seq_lens_list + seq_lens = forward_context.attn_metadata.seq_lens with torch.npu.stream(self.update_stream): torch.npu.graph_task_update_begin( self.update_stream, graph_params.handles[runtime_shape][layer_idx]) - torch.ops.npu.npu_fused_infer_attention_score.out( - query, - key, - value, - workspace=graph_params.workspaces[runtime_shape], - actual_seq_lengths_kv=actual_seq_lens, - block_table=block_table, - num_heads=num_heads, - scale=scale, - input_layout="BSH", - num_key_value_heads=num_kv_heads, - block_size=128, - out=[output, softmax_lse], - ) + torch_npu._npu_paged_attention(query=query, + key_cache=key_cache, + value_cache=value_cache, + num_kv_heads=num_kv_heads, + num_heads=num_heads, + scale_value=scale, + block_table=block_table, + context_lens=seq_lens, + out=output) torch.npu.graph_task_update_end(self.update_stream) graph_params.events[runtime_shape][layer_idx].record(