diff --git a/tests/ut/attention/test_attention_v1.py b/tests/ut/attention/test_attention_v1.py index 54be152100..b286f81fa4 100644 --- a/tests/ut/attention/test_attention_v1.py +++ b/tests/ut/attention/test_attention_v1.py @@ -101,7 +101,7 @@ def test_build_prefill_no_cache(self, mock_is_310p, mock_nd_to_nz_2d, max_query_len=5, decode_token_per_req=torch.tensor([1, 1]), block_table_tensor=torch.zeros((10, 10)), - slot_mapping_cpu=torch.tensor(range(20)), + slot_mapping=torch.tensor(range(20)), actual_seq_lengths_q=torch.tensor([0, 1]), positions=torch.tensor([10, 10]), attn_mask=torch.ones((10, 10)), @@ -134,7 +134,7 @@ def test_build_chunked_prefill(self, mock_ascend_attention_state, max_query_len=6, decode_token_per_req=torch.tensor([1, 1, 1]), block_table_tensor=torch.zeros((10, 10)), - slot_mapping_cpu=torch.tensor(range(20)), + slot_mapping=torch.tensor(range(20)), actual_seq_lengths_q=torch.tensor([0, 1, 2]), positions=torch.tensor([10, 10]), attn_mask=torch.ones((15, 15)), @@ -165,7 +165,7 @@ def test_build_non_310p(self, mock_is_310p, mock_ascend_metadata): max_query_len=6, decode_token_per_req=torch.tensor([1, 1, 1]), block_table_tensor=torch.zeros((10, 10)), - slot_mapping_cpu=torch.tensor(range(20)), + slot_mapping=torch.tensor(range(20)), actual_seq_lengths_q=torch.tensor([0, 1, 2]), positions=torch.tensor([10, 10]), attn_mask=torch.ones((15, 15)), @@ -378,10 +378,12 @@ def test_forward_prefill_cache_hit(self, mock_flash_attention_qlens, mock_flash_attention_qlens.assert_called_once() assert output.shape == (10, 8 * 64) + @patch('vllm_ascend.attention.attention_v1.get_forward_context') @patch('torch_npu._npu_reshape_and_cache') @patch('torch_npu._npu_paged_attention') def test_forward_decode_only(self, mock_paged_attention, - mock_npu_reshape_and_cache): + mock_npu_reshape_and_cache, + mock_get_forward_context): """Test forward pass in DecodeOnly state""" query = torch.randn(10, 8 * 64) key = torch.randn(10, 8 * 64) @@ -395,6 +397,8 @@ def test_forward_decode_only(self, mock_paged_attention, metadata.slot_mapping = torch.zeros(10, dtype=torch.long) layer = self.layer_no_quant + mock_get_forward_context.return_value = MagicMock(capturing=False) + output = self.impl.forward(layer, query, key, @@ -435,12 +439,13 @@ def test_forward_decode_only_swa(self, mock_fused_infer_attention_score, mock_fused_infer_attention_score.assert_called_once() assert output.shape == (10, 8 * 64) + @patch('vllm_ascend.attention.attention_v1.get_forward_context') @patch('torch_npu._npu_reshape_and_cache') @patch('torch_npu._npu_paged_attention') @patch('torch_npu.npu_fused_infer_attention_score') def test_forward_decode_only_swa_seq_len_mismatch( self, mock_fused_infer_attention_score, mock_paged_attention, - mock_npu_reshape_and_cache): + mock_npu_reshape_and_cache, mock_get_forward_context): """Test forward pass in DecodeOnly state when seq)len_mismatch""" query = torch.randn(10, 8 * 64) key = torch.randn(10, 8 * 64) @@ -457,6 +462,8 @@ def test_forward_decode_only_swa_seq_len_mismatch( mock_fused_infer_attention_score.return_value = (torch.ones(10, 8, 64), 1) + mock_get_forward_context.return_value = MagicMock(capturing=False) + output = self.impl_swa.forward(self.layer_no_quant, query, key, diff --git a/tests/ut/torchair/test_torchair_mla.py b/tests/ut/torchair/test_torchair_mla.py index 0e0150c880..ec8ddfdfa7 100644 --- a/tests/ut/torchair/test_torchair_mla.py +++ b/tests/ut/torchair/test_torchair_mla.py @@ -463,7 +463,7 @@ def test_build_decode(self, mock_ascend_config): max_query_len=1, decode_token_per_req=torch.tensor([1, 1, 1]), block_table_tensor=torch.zeros((10, 10)), - slot_mapping_cpu=torch.tensor(range(20)), + slot_mapping=torch.tensor(range(20)), actual_seq_lengths_q=torch.tensor([0, 1, 2]), positions=torch.tensor([1, 1]), attn_mask=torch.ones((15, 15)), diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 10a2f6a416..2e5382d94d 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -31,13 +31,15 @@ is_v1_kv_transfer_group) from vllm.forward_context import ForwardContext, get_forward_context from vllm.utils import cdiv, direct_register_custom_op +from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import AttentionSpec from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.ops.attention import vanilla_chunked_prefill -from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p, - nd_to_nz_2d, nd_to_nz_spec) +from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, + get_graph_params, is_310p, nd_to_nz_2d, + nd_to_nz_spec) def wait_for_kv_layer_from_connector(layer_name: str): @@ -197,6 +199,12 @@ class AscendMetadata: class AscendAttentionMetadataBuilder: + # Does this backend/builder support CUDA Graphs for attention (default: no). + cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + # Does this backend/builder reorder the batch? + # If not, set this to None. Otherwise set it to the query + # length that will be pulled into the front of the batch. reorder_batch_threshold: ClassVar[int] = 1 def __init__( @@ -221,7 +229,7 @@ def build( self, common_prefix_len: int, common_attn_metadata: AscendCommonAttentionMetadata, - model: nn.Module, + model: Optional[nn.Module] = None, ): num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens @@ -231,11 +239,7 @@ def build( block_table = common_attn_metadata.block_table_tensor query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] - slot_mapping = common_attn_metadata.slot_mapping_cpu[: - num_actual_tokens].to( - self.device, - non_blocking= - True) + slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens] attn_mask = common_attn_metadata.attn_mask attn_state = common_attn_metadata.attn_state query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: @@ -268,6 +272,24 @@ def build( is_only_prefill=common_attn_metadata.is_only_prefill) return attn_metadata + def build_for_graph_capture( + self, + common_attn_metadata: AscendCommonAttentionMetadata, + attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly, + ): + if attn_state == AscendAttentionState.DecodeOnly: + attn_metadata = self.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + ) + else: + raise NotImplementedError( + "Currently we only support building dummy metadata for DecodeOnly state" + ) + + attn_metadata.attn_state = attn_state + return attn_metadata + class AscendAttentionBackendImpl(AttentionImpl): @@ -406,16 +428,53 @@ def _forward_decode_only( output = output.view(batch_size, self.num_heads, self.head_size) else: - torch_npu._npu_paged_attention( - query=query, - key_cache=self.key_cache, - value_cache=self.value_cache, - num_kv_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale_value=self.scale, - block_table=attn_metadata.block_tables, - context_lens=attn_metadata.seq_lens, - out=output) + graph_params = get_graph_params() + forward_context: ForwardContext = get_forward_context() + num_tokens = query.shape[0] + if forward_context.capturing: + stream = torch_npu.npu.current_stream() + + event = torch.npu.ExternalEvent() + event.wait(stream) + event.reset(stream) + graph_params.events[num_tokens].append(event) + + graph_params.attn_params[num_tokens].append(( + query, + self.key_cache, + self.value_cache, + self.num_kv_heads, + self.num_heads, + self.scale, + attn_metadata.block_tables, + attn_metadata.seq_lens, + output, + )) + + torch.npu.graph_task_group_begin(stream) + torch_npu._npu_paged_attention( + query=query, + key_cache=self.key_cache, + value_cache=self.value_cache, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + block_table=attn_metadata.block_tables, + context_lens=attn_metadata.seq_lens, + out=output) + handle = torch.npu.graph_task_group_end(stream) + graph_params.handles[num_tokens].append(handle) + else: + torch_npu._npu_paged_attention( + query=query, + key_cache=self.key_cache, + value_cache=self.value_cache, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + block_table=attn_metadata.block_tables, + context_lens=attn_metadata.seq_lens, + out=output) return output def _forward_v1_style( diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index aa2c818b29..bae0ceced3 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -292,11 +292,7 @@ def build( device = self.device block_table = (common_attn_metadata.block_table_tensor[:num_reqs]) - slot_mapping = common_attn_metadata.slot_mapping_cpu[: - num_actual_tokens].to( - device, - non_blocking= - True) + slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens] input_positions = common_attn_metadata.positions[: num_actual_tokens].long( ) diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index 65af109799..01cf4ea6d4 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -41,7 +41,7 @@ class AscendCommonAttentionMetadata: block_table_tensor: torch.Tensor - slot_mapping_cpu: torch.Tensor + slot_mapping: torch.Tensor actual_seq_lengths_q: list[int] diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index cc124480ce..d88ee342fa 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -147,6 +147,7 @@ def __call__(self, *args, **kwargs): patch("torch.npu.empty_cache", lambda: None)) # mind-exploding: carefully manage the reference and memory. + forward_context.capturing = True with torch.npu.graph(aclgraph, pool=self.graph_pool): # `output` is managed by pytorch's aclgraph pool output = self.runnable(*args, **kwargs) diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index efc1c42422..7104621a97 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -179,23 +179,13 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: compilation_config.cudagraph_num_of_warmups = 1 - # TODO: make vllm support oot platform to set `compilation_config.cudagraph_mode` - # if cudagraph_mode is not explicitly set by users, set default value - if compilation_config.level == CompilationLevel.PIECEWISE: - compilation_config.cudagraph_mode = \ - CUDAGraphMode.PIECEWISE - elif compilation_config.level not in [ + if compilation_config.level not in [ CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE ]: logger.warning( "NPU does not support %s compilation level. Setting CUDAGraphMode to NONE", compilation_config.level) compilation_config.cudagraph_mode = CUDAGraphMode.NONE - else: - logger.warning( - "compilation_config.level = CompilationLevel.NO_COMPILATION is set, Setting CUDAGraphMode to NONE" - ) - compilation_config.cudagraph_mode = CUDAGraphMode.NONE # set CUDAGraphMode to None when torchair is enabled, no mather what compilation_config.level is. if ascend_config.torchair_graph_config.enabled: @@ -221,7 +211,12 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if compilation_config.cudagraph_mode == CUDAGraphMode.NONE: compilation_config.level = CompilationLevel.NO_COMPILATION - elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE: + # TODO: Currently MLA does not support FULL_DECODE_ONLY, remove the second condition + # after MLA being supported + elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE or ( + compilation_config.cudagraph_mode + == CUDAGraphMode.FULL_DECODE_ONLY and model_config is not None + and model_config.use_mla): logger.info( "PIECEWISE compilation enabled on NPU. use_inductor not supported - " "using only ACL Graph mode") @@ -233,6 +228,24 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "vllm.unified_ascend_attention_with_output", "vllm.mla_forward" ]) update_aclgraph_sizes(vllm_config) + elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY: + logger.info( + "FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - " + "using only ACL Graph mode") + compilation_config.use_inductor = False + warning_message = """\033[91m + ********************************************************************************** + * WARNING: You have enabled the *full graph* feature. + * This is an early experimental stage and may involve various unknown issues. + * A known problem is that capturing too many batch sizes can lead to OOM + * (Out of Memory) errors or inference hangs. If you encounter such issues, + * consider reducing `gpu_memory_utilization` or manually specifying a smaller + * batch size for graph capture. + * For more details, please refer to: + * https://docs.vllm.ai/en/stable/configuration/conserving_memory.html#reduce-cuda-graphs + **********************************************************************************\033[0m + """ + logger.warning(warning_message) else: logger.info( "%s cudagraph_mode is not support on NPU. falling back to NONE", @@ -379,3 +392,7 @@ def stateless_init_device_torch_dist_pg( @classmethod def support_hybrid_kv_cache(cls) -> bool: return True + + @classmethod + def support_static_graph_mode(cls) -> bool: + return True diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 9fda89a97c..9184bde058 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -347,7 +347,7 @@ def _get_eagle_atten_dict( actual_seq_lengths_q=self.runner.actual_seq_lengths_q, block_table_tensor=self.runner.input_batch.block_table[0]. get_device_tensor(), - slot_mapping_cpu=self.runner.slot_mapping_cpu, + slot_mapping=self.runner.slot_mapping, positions=self.runner.positions, attn_mask=self.runner.attn_mask, spec_attn_mask=self.runner.spec_attn_mask, @@ -434,7 +434,7 @@ def _propose( actual_seq_lengths_q=self.runner.actual_seq_lengths_q, block_table_tensor=self.runner.input_batch.block_table[0]. get_device_tensor(), - slot_mapping_cpu=target_slot_mapping, + slot_mapping=target_slot_mapping, positions=target_positions, attn_mask=self.runner.attn_mask, spec_attn_mask=self.runner.spec_attn_mask, diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index edd80d8bde..800b57f2bb 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -385,7 +385,7 @@ def _propose( actual_seq_lengths_q=self.runner.actual_seq_lengths_q, block_table_tensor=self.runner.input_batch.block_table[0]. get_device_tensor(), - slot_mapping_cpu=target_slot_mapping, + slot_mapping=target_slot_mapping, positions=target_positions, attn_mask=self.runner.attn_mask, spec_attn_mask=self.runner.spec_attn_mask, diff --git a/vllm_ascend/torchair/torchair_attention.py b/vllm_ascend/torchair/torchair_attention.py index ba354fd622..9f1b40e5e1 100644 --- a/vllm_ascend/torchair/torchair_attention.py +++ b/vllm_ascend/torchair/torchair_attention.py @@ -175,7 +175,7 @@ def build( self, common_prefix_len: int, common_attn_metadata: AscendCommonAttentionMetadata, - model: nn.Module, + model: Optional[nn.Module] = None, ): num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens @@ -185,11 +185,7 @@ def build( block_table[:num_reqs]) seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] - slot_mapping = common_attn_metadata.slot_mapping_cpu[: - num_actual_tokens].to( - self.device, - non_blocking= - True) + slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens] attn_mask = common_attn_metadata.attn_mask attn_state = common_attn_metadata.attn_state diff --git a/vllm_ascend/torchair/torchair_mla.py b/vllm_ascend/torchair/torchair_mla.py index 3f54fdb36d..995173a9fe 100644 --- a/vllm_ascend/torchair/torchair_mla.py +++ b/vllm_ascend/torchair/torchair_mla.py @@ -400,11 +400,7 @@ def build( device = self.device block_table = (common_attn_metadata.block_table_tensor[:num_reqs]) - slot_mapping = common_attn_metadata.slot_mapping_cpu[: - num_actual_tokens].to( - device, - non_blocking= - True) + slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens] input_positions = common_attn_metadata.positions[: num_actual_tokens].long( ) diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index 0c715fd245..c92fa4e789 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -121,12 +121,14 @@ def _sync_metadata_across_dp( return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo - def _build_attention_metadata(self, with_prefill, num_reqs, skip_attn): + def _build_attention_metadata(self, with_prefill, num_reqs, num_tokens, + max_query_len, force_attention): # NOTE: If torchair graph mode and not with_prefill, # we can't skip_attn, it will cause graph recompile. if with_prefill or self.enable_shared_expert_dp: attn_metadata = super()._build_attention_metadata( - with_prefill, num_reqs, skip_attn) + with_prefill, num_reqs, num_tokens, max_query_len, + force_attention) else: common_attn_metadata = TorchairCommonAttentionMetadata( num_reqs=num_reqs, diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 06e1a2bb8c..451112d9e0 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -22,6 +22,7 @@ import math import os from contextlib import contextmanager, nullcontext +from dataclasses import dataclass from enum import Enum from threading import Lock from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union @@ -634,3 +635,34 @@ def npu_stream_switch(target_stream: torch.npu.Stream, return nullcontext() assert target_stream is not None return torch.npu.stream(target_stream) + + +@dataclass +class GraphParams: + events: dict[int, list[torch.npu.ExternalEvent]] + workspaces: dict[int, torch.Tensor] + handles: dict[int, list[torch_npu._C._NPUTaskGroupHandle]] + attn_params: dict[int, list[tuple]] + + +_graph_params: Optional[GraphParams] = None + + +def set_graph_params(aclgraph_capture_sizes: set[int]): + global _graph_params + if _graph_params is not None: + raise ValueError("Graph parameters have already been set!") + _graph_params = GraphParams( + {size: [] + for size in aclgraph_capture_sizes}, + {size: None + for size in aclgraph_capture_sizes}, + {size: [] + for size in aclgraph_capture_sizes}, + {size: [] + for size in aclgraph_capture_sizes}, + ) + + +def get_graph_params(): + return _graph_params diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 937c911706..c908cd0e3e 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -69,8 +69,8 @@ LazyLoader, cdiv, get_dtype_size, is_pin_memory_available) from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder -from vllm.v1.attention.backends.utils import \ - reorder_batch_to_split_decodes_and_prefills +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, reorder_batch_to_split_decodes_and_prefills) from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, KVCacheConfig, KVCacheSpec, MambaSpec) @@ -111,8 +111,9 @@ from vllm_ascend.spec_decode.mtp_proposer import MtpProposer from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, AscendSocVersion, ProfileExecuteDuration, - get_ascend_soc_version, is_310p, - lmhead_tp_enable, vllm_version_is) + get_ascend_soc_version, get_graph_params, + is_310p, lmhead_tp_enable, set_graph_params, + vllm_version_is) from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch if TYPE_CHECKING: @@ -337,6 +338,9 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.seq_lens = torch.zeros(self.max_num_reqs, dtype=torch.int32, device=self.device) + self.slot_mapping = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device=self.device) self.uses_mrope = self.model_config.uses_mrope # Only relevant for models using M-RoPE (e.g, Qwen2-VL) @@ -1204,7 +1208,7 @@ def _prepare_inputs( intermediate_tensors: Optional[IntermediateTensors] = None, ) -> tuple[dict[str, Any], torch.Tensor, np.ndarray, int, torch.Tensor, int, torch.Tensor, SpecDecodeMetadata, Optional[torch.Tensor], - Optional[torch.Tensor], Optional[torch.Tensor]]: + Optional[torch.Tensor], Optional[torch.Tensor], int]: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs @@ -1457,11 +1461,10 @@ def _prepare_inputs( blk_table_tensor = blk_table.get_device_tensor() slot_mapping = blk_table.slot_mapping_cpu[: total_num_scheduled_tokens] - self.slot_mapping_cpu[:total_num_scheduled_tokens].copy_( - slot_mapping) - # # Fill unused with -1. Needed for reshape_and_cache in full cuda - # # graph mode. - # blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1) + self.slot_mapping[:total_num_scheduled_tokens].copy_( + slot_mapping[:total_num_scheduled_tokens], + non_blocking=True, + ) # Make AscendCommonAttentionMetadata common_attn_metadata = AscendCommonAttentionMetadata( @@ -1474,7 +1477,7 @@ def _prepare_inputs( actual_seq_lengths_q=self.actual_seq_lengths_q, # TODO: change this to the right block table for linear attn block_table_tensor=blk_table_tensor[:num_reqs], - slot_mapping_cpu=self.slot_mapping_cpu, + slot_mapping=self.slot_mapping, num_computed_tokens_cpu=num_computed_tokens_cpu, positions=self.positions, attn_mask=self.attn_mask, @@ -1531,7 +1534,8 @@ def _prepare_inputs( return (attn_metadata, positions, num_scheduled_tokens, num_input_tokens, num_tokens_across_dp, maybe_padded_num_tokens, logits_indices, spec_decode_metadata, - input_ids, inputs_embeds, intermediate_tensors) + input_ids, inputs_embeds, intermediate_tensors, + max_num_scheduled_tokens) def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill, maybe_padded_num_tokens, @@ -1545,6 +1549,13 @@ def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) + + forward_context = get_forward_context() + if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL: + graph_params = get_graph_params() + self.update_attn_params(graph_params, forward_context, + positions.shape[0]) + if get_forward_context().flashcomm_v1_enabled: hidden_states = tensor_model_parallel_all_gather(hidden_states, 0) pad_size = get_forward_context().pad_size @@ -1552,6 +1563,44 @@ def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill, hidden_states = hidden_states[:-pad_size, :] return hidden_states + def update_attn_params(self, graph_params, forward_context, runtime_shape): + # FIXME: Behold! We are using a temporary hack here to update the args + # for each layer's attention op in the graph. + for key, param, handle, event in zip( + forward_context.attn_metadata, + graph_params.attn_params[runtime_shape], + graph_params.handles[runtime_shape], + graph_params.events[runtime_shape], + ): + ( + query, + key_cache, + value_cache, + num_kv_heads, + num_heads, + scale, + block_table, + seq_lens, + output, + ) = param + # block_table = forward_context.attn_metadata[key].block_tables + seq_lens = forward_context.attn_metadata[key].seq_lens + + with torch.npu.stream(self.update_stream): + torch.npu.graph_task_update_begin(self.update_stream, handle) + torch_npu._npu_paged_attention(query=query, + key_cache=key_cache, + value_cache=value_cache, + num_kv_heads=num_kv_heads, + num_heads=num_heads, + scale_value=scale, + block_table=block_table, + context_lens=seq_lens, + out=output) + torch.npu.graph_task_update_end(self.update_stream) + + event.record(self.update_stream) + def _build_attn_state(self, num_reqs, num_scheduled_tokens, num_valid_tokens): ascend_config = get_ascend_config() @@ -1868,8 +1917,9 @@ def execute_model( (attn_metadata, positions, num_scheduled_tokens_np, num_input_tokens, num_tokens_across_dp, maybe_padded_num_tokens, logits_indices, spec_decode_metadata, input_ids, inputs_embeds, - intermediate_tensors) = (self._prepare_inputs( - scheduler_output, intermediate_tensors)) + intermediate_tensors, + max_query_len) = (self._prepare_inputs(scheduler_output, + intermediate_tensors)) if self.dynamic_eplb: self.eplb_updator.take_update_info_from_eplb_process() @@ -1877,8 +1927,11 @@ def execute_model( moe_comm_method = self._select_moe_comm_method(num_input_tokens, self.with_prefill) + uniform_decode = (max_query_len == self.uniform_decode_query_len) and ( + scheduler_output.total_num_scheduled_tokens + == self.input_batch.num_reqs * max_query_len) batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, - uniform_decode=False) + uniform_decode=uniform_decode) aclgraph_runtime_mode, batch_descriptor = \ self.aclgraph_dispatcher.dispatch(batch_descriptor) @@ -2197,12 +2250,54 @@ def get_finished_kv_transfer( scheduler_output.finished_req_ids) return None, None - def _build_attention_metadata(self, with_prefill, num_reqs, skip_attn): - if skip_attn: - attn_metadata = None - else: - # TODO(zzzzwwjj): when aclgraph and full graph mode, we need build attn_metadata - attn_metadata = None + def _build_attention_metadata(self, create_mixed_batch, num_reqs, + num_tokens, max_query_len, force_attention): + attn_metadata: Optional[dict[str, Any]] = None + + if force_attention: + attn_metadata = {} + + if create_mixed_batch: + raise NotImplementedError( + "force_attention=True is not supported for mixed batches.") + else: + seq_lens = self.model_config.max_model_len + self.seq_lens_np[:num_reqs] = seq_lens + self.seq_lens_np[num_reqs:] = 0 + + num_computed_tokens_cpu = ( + self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]) + + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.kv_cache_config.kv_cache_groups): + block_table_tensor = self.input_batch.block_table[ + kv_cache_group_id].get_device_tensor() + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=self.query_start_loc[:num_reqs + 1], + query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + + 1], + seq_lens_cpu=self.seq_lens_cpu, + seq_lens=self.seq_lens_cpu[:num_reqs], + num_reqs=num_reqs, + num_actual_tokens=num_tokens, + actual_seq_lengths_q=self.actual_seq_lengths_q, + block_table_tensor=block_table_tensor[:num_reqs], + slot_mapping=self.slot_mapping, + num_computed_tokens_cpu=num_computed_tokens_cpu, + max_query_len=max_query_len, + decode_token_per_req=self.decode_token_per_req, + ) + + for attn_group in self.attn_groups[kv_cache_group_id]: + if vllm_version_is("0.10.2"): + builder = attn_group.metadata_builder + else: + builder = attn_group.get_metadata_builder() + attn_metadata_i = builder.build_for_graph_capture( + common_attn_metadata) + for layer_name in kv_cache_group_spec.layer_names: + attn_metadata[layer_name] = attn_metadata_i + return attn_metadata def _generate_dummy_run_hidden_states(self, with_prefill, @@ -2231,12 +2326,8 @@ def _dummy_run( ) -> torch.Tensor: # only support eager mode and piecewise graph now assert aclgraph_runtime_mode in { - CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE + CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL } - if force_attention: - raise RuntimeError( - "Capturing attention in aclgraph is unexpected, because full graph is not supported now" - ) # Padding for DP (num_tokens, num_tokens_across_dp, with_prefill, @@ -2292,9 +2383,13 @@ def _dummy_run( if self.is_kv_producer: with_prefill = True - attn_metadata = self._build_attention_metadata(with_prefill, - num_reqs, - skip_attn=True) + attn_metadata = self._build_attention_metadata( + with_prefill, + num_reqs, + num_tokens, + max_query_len, + force_attention, + ) if not self.in_profile_run and self.dynamic_eplb: self.eplb_updator.forward_before() @@ -2533,6 +2628,14 @@ def load_model(self) -> None: logger.info("Loading model weights took %.4f GB", m.consumed_memory / float(2**30)) + # wrap the model with full graph wrapper if needed. + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): + self.update_stream = torch.npu.Stream() + set_graph_params(self.compilation_config.cudagraph_capture_sizes) + self.model = ACLGraphWrapper(self.model, + self.vllm_config, + runtime_mode=CUDAGraphMode.FULL) + def _convert_torch_format(self, tensor): tensor = torch_npu.npu_format_cast(tensor, ACL_FORMAT) return tensor @@ -3093,9 +3196,78 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: return kv_cache_spec def initialize_aclgraph_capture(self) -> None: - # TODO: Add check of AttentionCGSupport and cudagraph_mode.decode_mode when full graph is supported - # Trigger aclgraph dispatching keys initialization here (after - # initializing attn backends). + min_ag_support = AttentionCGSupport.ALWAYS + min_ag_builder_name = None + + for attn_group in self._attn_group_iterator(): + if vllm_version_is("0.10.2"): + builder = attn_group.metadata_builder + else: + builder = attn_group.get_metadata_builder() + if builder.cudagraph_support.value < min_ag_support.value: + min_ag_support = builder.cudagraph_support + min_ag_builder_name = builder.__class__.__name__ + + # This is an imitation of compilation_config.splitting_ops_contain_attention() + splitting_ops_contain_attention = ( + self.compilation_config.splitting_ops is not None + and all(op in self.compilation_config.splitting_ops for op in [ + "vllm.unified_ascend_attention_with_output", + "vllm.mla_forward", + ])) + + # Flexible resolve the aclgraph mode + aclgraph_mode = self.compilation_config.cudagraph_mode + # check graph for mixed batch is supported + if aclgraph_mode.mixed_mode() == CUDAGraphMode.FULL \ + and min_ag_support != AttentionCGSupport.ALWAYS: + msg = (f"ACLGraphMode.{aclgraph_mode.name} is not supported " + f"with {min_ag_builder_name} backend (support: " + f"{min_ag_support})") + if min_ag_support == AttentionCGSupport.NEVER: + # if not supported any full graphs, just raise it. + msg += "; please try cudagraph_mode=PIECEWISE, and "\ + "make sure compilation level is piecewise" + raise ValueError(msg) + + # attempt to resolve the full graph related mode + if splitting_ops_contain_attention: + msg += "; setting cudagraph_mode=FULL_AND_PIECEWISE" + aclgraph_mode = self.compilation_config.cudagraph_mode = ( + CUDAGraphMode.FULL_AND_PIECEWISE) + else: + msg += "; setting cudagraph_mode=FULL_DECODE_ONLY" + aclgraph_mode = self.compilation_config.cudagraph_mode = ( + CUDAGraphMode.FULL_DECODE_ONLY) + logger.warning(msg) + + # check that if spec-decode + decode full-graphs is supported + if (aclgraph_mode.decode_mode() == CUDAGraphMode.FULL + and self.uniform_decode_query_len > 1 and min_ag_support.value + < AttentionCGSupport.UNIFORM_BATCH.value): + msg = (f"CUDAGraphMode.{aclgraph_mode.name} is not supported" + f" with spec-decode for attention backend " + f"{min_ag_builder_name} (support: {min_ag_support})") + if splitting_ops_contain_attention: + msg += "; setting cudagraph_mode=PIECEWISE" + aclgraph_mode = self.compilation_config.cudagraph_mode = \ + CUDAGraphMode.PIECEWISE + else: + msg += "; setting cudagraph_mode=NONE" + aclgraph_mode = self.compilation_config.cudagraph_mode = \ + CUDAGraphMode.NONE + logger.warning(msg) + + # double check that we can support full graph if they are requested + # even after automatic downgrades + if aclgraph_mode.has_full_cudagraphs() \ + and min_ag_support == AttentionCGSupport.NEVER: + raise ValueError(f"CUDAGraphMode.{aclgraph_mode.name} is not " + f"supported with {min_ag_builder_name} backend (" + f"support:{min_ag_support}) " + "; please try cudagraph_mode=PIECEWISE, " + "and make sure compilation level is piecewise") + self.aclgraph_dispatcher.initialize_cudagraph_keys( self.compilation_config.cudagraph_mode, self.uniform_decode_query_len) @@ -3104,10 +3276,15 @@ def _capture_aclgraphs(self, compilation_cases: list[int], aclgraph_runtime_mode: CUDAGraphMode, uniform_decode: bool): assert aclgraph_runtime_mode != CUDAGraphMode.NONE and \ - aclgraph_runtime_mode in [CUDAGraphMode.PIECEWISE] + aclgraph_runtime_mode in [CUDAGraphMode.FULL, + CUDAGraphMode.PIECEWISE] # Only rank 0 should print progress bar during capture if is_global_first_rank(): + logger.info( + "Starting to capture ACL graphs for cases: %s, " + "mode: %s, uniform_decode: %s", compilation_cases, + aclgraph_runtime_mode.name, uniform_decode) compilation_cases = tqdm( compilation_cases, disable=not self.load_config.use_tqdm_on_load, @@ -3129,6 +3306,7 @@ def _capture_aclgraphs(self, compilation_cases: list[int], uniform_decode=uniform_decode) self._dummy_run(num_tokens, aclgraph_runtime_mode=aclgraph_runtime_mode, + force_attention=force_attention, uniform_decode=uniform_decode) def _capture_model(self): @@ -3155,6 +3333,21 @@ def _capture_model(self): aclgraph_runtime_mode=aclgraph_runtime_mode, uniform_decode=False) + if aclgraph_mode.decode_mode() == CUDAGraphMode.FULL and \ + aclgraph_mode.separate_routine(): + max_num_tokens = self.scheduler_config.max_num_seqs * \ + self.uniform_decode_query_len + decode_cudagraph_batch_sizes = [ + x for x in self.aclgraph_batch_sizes if x <= max_num_tokens + and x >= self.uniform_decode_query_len + ] + compilation_cases_decode = list( + reversed(decode_cudagraph_batch_sizes)) + self._capture_aclgraphs( + compilation_cases=compilation_cases_decode, + aclgraph_runtime_mode=CUDAGraphMode.FULL, + uniform_decode=True) + # Disable aclgraph capturing globally, so any unexpected aclgraph # capturing will be detected and raise an error after here. # Note: We don't put it into graph_capture context manager because