Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions tests/ut/attention/test_attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def test_build_prefill_no_cache(self, mock_is_310p, mock_nd_to_nz_2d,
max_query_len=5,
decode_token_per_req=torch.tensor([1, 1]),
block_table_tensor=torch.zeros((10, 10)),
slot_mapping_cpu=torch.tensor(range(20)),
slot_mapping=torch.tensor(range(20)),
actual_seq_lengths_q=torch.tensor([0, 1]),
positions=torch.tensor([10, 10]),
attn_mask=torch.ones((10, 10)),
Expand Down Expand Up @@ -134,7 +134,7 @@ def test_build_chunked_prefill(self, mock_ascend_attention_state,
max_query_len=6,
decode_token_per_req=torch.tensor([1, 1, 1]),
block_table_tensor=torch.zeros((10, 10)),
slot_mapping_cpu=torch.tensor(range(20)),
slot_mapping=torch.tensor(range(20)),
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
positions=torch.tensor([10, 10]),
attn_mask=torch.ones((15, 15)),
Expand Down Expand Up @@ -165,7 +165,7 @@ def test_build_non_310p(self, mock_is_310p, mock_ascend_metadata):
max_query_len=6,
decode_token_per_req=torch.tensor([1, 1, 1]),
block_table_tensor=torch.zeros((10, 10)),
slot_mapping_cpu=torch.tensor(range(20)),
slot_mapping=torch.tensor(range(20)),
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
positions=torch.tensor([10, 10]),
attn_mask=torch.ones((15, 15)),
Expand Down Expand Up @@ -378,10 +378,12 @@ def test_forward_prefill_cache_hit(self, mock_flash_attention_qlens,
mock_flash_attention_qlens.assert_called_once()
assert output.shape == (10, 8 * 64)

@patch('vllm_ascend.attention.attention_v1.get_forward_context')
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu._npu_paged_attention')
def test_forward_decode_only(self, mock_paged_attention,
mock_npu_reshape_and_cache):
mock_npu_reshape_and_cache,
mock_get_forward_context):
"""Test forward pass in DecodeOnly state"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
Expand All @@ -395,6 +397,8 @@ def test_forward_decode_only(self, mock_paged_attention,
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
layer = self.layer_no_quant

mock_get_forward_context.return_value = MagicMock(capturing=False)

output = self.impl.forward(layer,
query,
key,
Expand Down Expand Up @@ -435,12 +439,13 @@ def test_forward_decode_only_swa(self, mock_fused_infer_attention_score,
mock_fused_infer_attention_score.assert_called_once()
assert output.shape == (10, 8 * 64)

@patch('vllm_ascend.attention.attention_v1.get_forward_context')
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu._npu_paged_attention')
@patch('torch_npu.npu_fused_infer_attention_score')
def test_forward_decode_only_swa_seq_len_mismatch(
self, mock_fused_infer_attention_score, mock_paged_attention,
mock_npu_reshape_and_cache):
mock_npu_reshape_and_cache, mock_get_forward_context):
"""Test forward pass in DecodeOnly state when seq)len_mismatch"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
Expand All @@ -457,6 +462,8 @@ def test_forward_decode_only_swa_seq_len_mismatch(
mock_fused_infer_attention_score.return_value = (torch.ones(10, 8,
64), 1)

mock_get_forward_context.return_value = MagicMock(capturing=False)

output = self.impl_swa.forward(self.layer_no_quant,
query,
key,
Expand Down
2 changes: 1 addition & 1 deletion tests/ut/torchair/test_torchair_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def test_build_decode(self, mock_ascend_config):
max_query_len=1,
decode_token_per_req=torch.tensor([1, 1, 1]),
block_table_tensor=torch.zeros((10, 10)),
slot_mapping_cpu=torch.tensor(range(20)),
slot_mapping=torch.tensor(range(20)),
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
positions=torch.tensor([1, 1]),
attn_mask=torch.ones((15, 15)),
Expand Down
95 changes: 77 additions & 18 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,15 @@
is_v1_kv_transfer_group)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.utils import cdiv, direct_register_custom_op
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec

from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.ops.attention import vanilla_chunked_prefill
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
nd_to_nz_2d, nd_to_nz_spec)
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16,
get_graph_params, is_310p, nd_to_nz_2d,
nd_to_nz_spec)


def wait_for_kv_layer_from_connector(layer_name: str):
Expand Down Expand Up @@ -197,6 +199,12 @@ class AscendMetadata:


class AscendAttentionMetadataBuilder:
# Does this backend/builder support CUDA Graphs for attention (default: no).
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
# Does this backend/builder reorder the batch?
# If not, set this to None. Otherwise set it to the query
# length that will be pulled into the front of the batch.
reorder_batch_threshold: ClassVar[int] = 1

def __init__(
Expand All @@ -221,7 +229,7 @@ def build(
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
model: Optional[nn.Module] = None,
):
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
Expand All @@ -231,11 +239,7 @@ def build(
block_table = common_attn_metadata.block_table_tensor
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
slot_mapping = common_attn_metadata.slot_mapping_cpu[:
num_actual_tokens].to(
self.device,
non_blocking=
True)
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
attn_mask = common_attn_metadata.attn_mask
attn_state = common_attn_metadata.attn_state
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
Expand Down Expand Up @@ -268,6 +272,24 @@ def build(
is_only_prefill=common_attn_metadata.is_only_prefill)
return attn_metadata

def build_for_graph_capture(
self,
common_attn_metadata: AscendCommonAttentionMetadata,
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
):
if attn_state == AscendAttentionState.DecodeOnly:
attn_metadata = self.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
else:
raise NotImplementedError(
"Currently we only support building dummy metadata for DecodeOnly state"
)

attn_metadata.attn_state = attn_state
return attn_metadata


class AscendAttentionBackendImpl(AttentionImpl):

Expand Down Expand Up @@ -406,16 +428,53 @@ def _forward_decode_only(

output = output.view(batch_size, self.num_heads, self.head_size)
else:
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
graph_params = get_graph_params()
forward_context: ForwardContext = get_forward_context()
num_tokens = query.shape[0]
if forward_context.capturing:
stream = torch_npu.npu.current_stream()

event = torch.npu.ExternalEvent()
event.wait(stream)
event.reset(stream)
graph_params.events[num_tokens].append(event)

graph_params.attn_params[num_tokens].append((
query,
self.key_cache,
self.value_cache,
self.num_kv_heads,
self.num_heads,
self.scale,
attn_metadata.block_tables,
attn_metadata.seq_lens,
output,
))

torch.npu.graph_task_group_begin(stream)
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle)
else:
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
return output

def _forward_v1_style(
Expand Down
6 changes: 1 addition & 5 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,11 +292,7 @@ def build(
device = self.device

block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
slot_mapping = common_attn_metadata.slot_mapping_cpu[:
num_actual_tokens].to(
device,
non_blocking=
True)
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
input_positions = common_attn_metadata.positions[:
num_actual_tokens].long(
)
Expand Down
2 changes: 1 addition & 1 deletion vllm_ascend/attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class AscendCommonAttentionMetadata:

block_table_tensor: torch.Tensor

slot_mapping_cpu: torch.Tensor
slot_mapping: torch.Tensor

actual_seq_lengths_q: list[int]

Expand Down
1 change: 1 addition & 0 deletions vllm_ascend/compilation/acl_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def __call__(self, *args, **kwargs):
patch("torch.npu.empty_cache", lambda: None))

# mind-exploding: carefully manage the reference and memory.
forward_context.capturing = True
with torch.npu.graph(aclgraph, pool=self.graph_pool):
# `output` is managed by pytorch's aclgraph pool
output = self.runnable(*args, **kwargs)
Expand Down
41 changes: 29 additions & 12 deletions vllm_ascend/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,23 +179,13 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:

compilation_config.cudagraph_num_of_warmups = 1

# TODO: make vllm support oot platform to set `compilation_config.cudagraph_mode`
# if cudagraph_mode is not explicitly set by users, set default value
if compilation_config.level == CompilationLevel.PIECEWISE:
compilation_config.cudagraph_mode = \
CUDAGraphMode.PIECEWISE
elif compilation_config.level not in [
if compilation_config.level not in [
CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE
]:
logger.warning(
"NPU does not support %s compilation level. Setting CUDAGraphMode to NONE",
compilation_config.level)
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
else:
logger.warning(
"compilation_config.level = CompilationLevel.NO_COMPILATION is set, Setting CUDAGraphMode to NONE"
)
compilation_config.cudagraph_mode = CUDAGraphMode.NONE

# set CUDAGraphMode to None when torchair is enabled, no mather what compilation_config.level is.
if ascend_config.torchair_graph_config.enabled:
Expand All @@ -221,7 +211,12 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:

if compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
compilation_config.level = CompilationLevel.NO_COMPILATION
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE:
# TODO: Currently MLA does not support FULL_DECODE_ONLY, remove the second condition
# after MLA being supported
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE or (
compilation_config.cudagraph_mode
== CUDAGraphMode.FULL_DECODE_ONLY and model_config is not None
and model_config.use_mla):
logger.info(
"PIECEWISE compilation enabled on NPU. use_inductor not supported - "
"using only ACL Graph mode")
Expand All @@ -233,6 +228,24 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
"vllm.unified_ascend_attention_with_output", "vllm.mla_forward"
])
update_aclgraph_sizes(vllm_config)
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
logger.info(
"FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - "
"using only ACL Graph mode")
compilation_config.use_inductor = False
warning_message = """\033[91m
**********************************************************************************
* WARNING: You have enabled the *full graph* feature.
* This is an early experimental stage and may involve various unknown issues.
* A known problem is that capturing too many batch sizes can lead to OOM
* (Out of Memory) errors or inference hangs. If you encounter such issues,
* consider reducing `gpu_memory_utilization` or manually specifying a smaller
* batch size for graph capture.
* For more details, please refer to:
* https://docs.vllm.ai/en/stable/configuration/conserving_memory.html#reduce-cuda-graphs
**********************************************************************************\033[0m
"""
logger.warning(warning_message)
else:
logger.info(
"%s cudagraph_mode is not support on NPU. falling back to NONE",
Expand Down Expand Up @@ -379,3 +392,7 @@ def stateless_init_device_torch_dist_pg(
@classmethod
def support_hybrid_kv_cache(cls) -> bool:
return True

@classmethod
def support_static_graph_mode(cls) -> bool:
return True
4 changes: 2 additions & 2 deletions vllm_ascend/spec_decode/eagle_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def _get_eagle_atten_dict(
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
block_table_tensor=self.runner.input_batch.block_table[0].
get_device_tensor(),
slot_mapping_cpu=self.runner.slot_mapping_cpu,
slot_mapping=self.runner.slot_mapping,
positions=self.runner.positions,
attn_mask=self.runner.attn_mask,
spec_attn_mask=self.runner.spec_attn_mask,
Expand Down Expand Up @@ -434,7 +434,7 @@ def _propose(
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
block_table_tensor=self.runner.input_batch.block_table[0].
get_device_tensor(),
slot_mapping_cpu=target_slot_mapping,
slot_mapping=target_slot_mapping,
positions=target_positions,
attn_mask=self.runner.attn_mask,
spec_attn_mask=self.runner.spec_attn_mask,
Expand Down
2 changes: 1 addition & 1 deletion vllm_ascend/spec_decode/mtp_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def _propose(
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
block_table_tensor=self.runner.input_batch.block_table[0].
get_device_tensor(),
slot_mapping_cpu=target_slot_mapping,
slot_mapping=target_slot_mapping,
positions=target_positions,
attn_mask=self.runner.attn_mask,
spec_attn_mask=self.runner.spec_attn_mask,
Expand Down
8 changes: 2 additions & 6 deletions vllm_ascend/torchair/torchair_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def build(
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
model: Optional[nn.Module] = None,
):
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
Expand All @@ -185,11 +185,7 @@ def build(
block_table[:num_reqs])

seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
slot_mapping = common_attn_metadata.slot_mapping_cpu[:
num_actual_tokens].to(
self.device,
non_blocking=
True)
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
attn_mask = common_attn_metadata.attn_mask

attn_state = common_attn_metadata.attn_state
Expand Down
6 changes: 1 addition & 5 deletions vllm_ascend/torchair/torchair_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,11 +400,7 @@ def build(
device = self.device

block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
slot_mapping = common_attn_metadata.slot_mapping_cpu[:
num_actual_tokens].to(
device,
non_blocking=
True)
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
input_positions = common_attn_metadata.positions[:
num_actual_tokens].long(
)
Expand Down
Loading
Loading