Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/singlecard/test_aclgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
reason="aclgraph only support on v1")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("full_graph", [False])
@pytest.mark.parametrize("max_tokens", [12])
@pytest.mark.parametrize("full_graph", [True, False])
def test_models(
model: str,
max_tokens: int,
Expand Down
123 changes: 43 additions & 80 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class AscendMetadata:
query_start_loc: torch.Tensor
query_lens: torch.Tensor
seq_lens: torch.Tensor
seq_lens_list: list
seq_lens_list: Optional[list[int]]
# Maximum query length in the batch. None for decoding.
max_query_len: Optional[int] = None
# (num_tokens,). The indices of the token slots that input tokens will be
Expand Down Expand Up @@ -168,8 +168,9 @@ def build(self,
seq_lens = common_attn_metadata.seq_lens
# TODO: Refactor these two param to common metadata in runners,
# preparing for the hybrid KV groups feature
query_lens = common_attn_metadata.query_lens if common_attn_metadata.query_lens is not None else self.runner.query_lens
seq_lens_list = common_attn_metadata.seq_lens_list if common_attn_metadata.seq_lens_list is not None else self.runner.seq_lens_list
query_lens = common_attn_metadata.query_lens or self.runner.query_lens
# Since FIA for GQA is not active now, we temporarily silence it
seq_lens_list = common_attn_metadata.seq_lens_list

slot_mapping = self.runner.slot_mapping[:num_actual_tokens]
attn_mask = self.runner.attn_mask
Expand All @@ -193,8 +194,8 @@ def build_dummy_metadata(self, num_actual_tokens, num_reqs,
num_scheduled_tokens, attn_state):
if attn_state == AscendAttentionState.DecodeOnly:
# NOTE: We only need to pay attention to seq_lens_list and block_table here
common_attn_metadata = CommonAttentionMetadata(seq_lens_list=[2] *
num_reqs)
common_attn_metadata = CommonAttentionMetadata(
seq_lens=torch.empty_like(self.runner.seq_lens_cpu).fill_(2))

block_table = self.runner.input_batch.block_table[0].block_table
block_table[:num_reqs, 0] = torch.arange(1,
Expand Down Expand Up @@ -349,82 +350,42 @@ def forward(
scale_value=self.scale,
out=output)
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
if self.full_graph:
graph_params = get_graph_params()
q = query.view(num_tokens, -1, self.hidden_size)
k = self.key_cache.view( # type: ignore
-1, self.block_size,
self.num_kv_heads * self.head_size)
v = self.value_cache.view( # type: ignore
-1, self.block_size,
self.num_kv_heads * self.head_size)
actual_seq_lens = attn_metadata.seq_lens_list
attn_args = {
"query": q,
"key": k,
"value": v,
"actual_seq_lengths_kv": actual_seq_lens,
"block_table": attn_metadata.block_tables,
"num_heads": self.num_heads,
"scale": self.scale,
"input_layout": "BSH",
"num_key_value_heads": self.num_kv_heads,
"block_size": self.block_size,
}

# Prepare tensors for attention output
# TODO: Refactor this to step-level instead of layer-level
attn_output = torch.empty(num_tokens,
1,
self.hidden_size,
dtype=output.dtype,
device=output.device)
softmax_lse = torch.empty(num_tokens,
dtype=output.dtype,
device=output.device)

# Get workspace from cache or calculate it if not present.
workspace = graph_params.workspaces.get(num_tokens)
if workspace is None:
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
**attn_args)
graph_params.workspaces[num_tokens] = workspace

forward_context = get_forward_context()
if not forward_context.capturing:
# Execute attention kernel directly in non-capturing mode
torch.ops.npu.npu_fused_infer_attention_score.out(
workspace=workspace,
out=[attn_output, softmax_lse],
**attn_args)
else:
# Handle graph capturing mode
stream = torch_npu.npu.current_stream()

event = torch.npu.ExternalEvent()
event.wait(stream)
event.reset(stream)
graph_params.events[num_tokens].append(event)

graph_params.attn_params[num_tokens].append(
(q, k, v, actual_seq_lens,
attn_metadata.block_tables, self.num_heads,
self.scale, self.num_kv_heads, attn_output,
softmax_lse))

torch.npu.graph_task_group_begin(stream)
torch.ops.npu.npu_fused_infer_attention_score.out(
workspace=workspace,
out=[attn_output, softmax_lse],
**attn_args)
handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle)

# Reshape output to match the expected format
output.copy_(
attn_output.view(num_tokens, self.num_heads,
self.head_size))
graph_params = get_graph_params()

forward_context = get_forward_context()
if not forward_context.capturing:
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
else:
# Handle graph capturing mode
stream = torch_npu.npu.current_stream()

event = torch.npu.ExternalEvent()
event.wait(stream)
event.reset(stream)
graph_params.events[num_tokens].append(event)

graph_params.attn_params[num_tokens].append((
query,
self.key_cache,
self.value_cache,
self.num_kv_heads,
self.num_heads,
self.scale,
attn_metadata.block_tables,
attn_metadata.seq_lens,
output,
))

torch.npu.graph_task_group_begin(stream)
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
Expand All @@ -435,6 +396,8 @@ def forward(
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle)
# Normal V1 situation.
else:
# use chunked prefill for head size 192 scenario, like deepseek
Expand Down
39 changes: 22 additions & 17 deletions vllm_ascend/compilation/piecewise_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import torch
import torch.fx as fx
import torch_npu
import vllm.envs as envs
from vllm.compilation.backends import VllmBackend
from vllm.compilation.counter import compilation_counter
Expand Down Expand Up @@ -126,29 +127,33 @@ def check_for_ending_compilation(self):

def update_attn_params(self, graph_params, forward_context, runtime_shape):
for layer_idx in range(len(graph_params.handles[runtime_shape])):
query, key, value, actual_seq_lens, block_table, num_heads, scale, num_kv_heads, output, softmax_lse = graph_params.attn_params[
runtime_shape][layer_idx]
(
query,
key_cache,
value_cache,
num_kv_heads,
num_heads,
scale,
block_table,
seq_lens,
output,
) = graph_params.attn_params[runtime_shape][layer_idx]
block_table = forward_context.attn_metadata.block_tables
actual_seq_lens = forward_context.attn_metadata.seq_lens_list
seq_lens = forward_context.attn_metadata.seq_lens

with torch.npu.stream(self.update_stream):
torch.npu.graph_task_update_begin(
self.update_stream,
graph_params.handles[runtime_shape][layer_idx])
torch.ops.npu.npu_fused_infer_attention_score.out(
query,
key,
value,
workspace=graph_params.workspaces[runtime_shape],
actual_seq_lengths_kv=actual_seq_lens,
block_table=block_table,
num_heads=num_heads,
scale=scale,
input_layout="BSH",
num_key_value_heads=num_kv_heads,
block_size=128,
out=[output, softmax_lse],
)
torch_npu._npu_paged_attention(query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output)
torch.npu.graph_task_update_end(self.update_stream)

graph_params.events[runtime_shape][layer_idx].record(
Expand Down