2424from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
2525 AttentionLayer , AttentionType )
2626from vllm .attention .backends .utils import CommonAttentionState
27+ from vllm .config import get_current_vllm_config
2728from vllm .forward_context import ForwardContext , get_forward_context
2829from vllm .utils import direct_register_custom_op
2930from vllm .v1 .core .sched .output import SchedulerOutput
3031
32+ from vllm_ascend .attention .utils import \
33+ AscendCommonAttentionMetadata as CommonAttentionMetadata
3134from vllm_ascend .ops .attention import vanilla_chunked_prefill
32- from vllm_ascend .utils import (ACL_FORMAT_FRACTAL_NZ , aligned_16 , is_310p ,
35+ from vllm_ascend .utils import (ACL_FORMAT_FRACTAL_NZ , aligned_16 , get_graph_params , is_310p ,
3336 nd_to_nz_2d , nd_to_nz_spec )
3437from vllm_ascend .worker .npu_input_batch import InputBatch
3538
@@ -132,7 +135,7 @@ class AscendMetadata:
132135 # tokens + new tokens (is None if it is a decoding).
133136 # (batch_size,)
134137 seq_lens : torch .Tensor = None
135-
138+ seq_lens_list : list
136139 query_start_loc : torch .Tensor = None
137140 query_lens : torch .Tensor = None
138141 # Maximum query length in the batch (None for decoding).
@@ -167,6 +170,7 @@ def build(self,
167170 num_reqs ,
168171 num_actual_tokens ,
169172 max_query_len ,
173+ common_attn_metadata : CommonAttentionMetadata ,
170174 enable_dbo_across_dp : bool = False ,
171175 is_only_prefill : bool = False ):
172176
@@ -175,15 +179,16 @@ def build(self,
175179 block_table [:num_reqs , :self .runner .max_num_blocks_per_req ] = (
176180 block_table [:num_reqs ])
177181
178- query_lens = self .runner .query_lens
179- seq_lens = self .runner .seq_lens_cpu [:num_reqs ]
180- slot_mapping = self .runner .slot_mapping_cpu [:num_actual_tokens ].to (
181- self .runner .device , non_blocking = True )
182+ query_start_loc = common_attn_metadata .query_start_loc
183+ seq_lens = common_attn_metadata .seq_lens
184+ # TODO: Refactor these two param to common metadata in runners,
185+ # preparing for the hybrid KV groups feature
186+ query_lens = common_attn_metadata .query_lens if common_attn_metadata .query_lens is not None else self .runner .query_lens
187+ seq_lens_list = common_attn_metadata .seq_lens_list if common_attn_metadata .seq_lens_list is not None else self .runner .seq_lens_list
188+
189+ slot_mapping = self .runner .slot_mapping [:num_actual_tokens ]
182190 attn_mask = self .runner .attn_mask
183191 attn_state = self .runner .attn_state
184- query_start_loc_cpu = self .runner .query_start_loc_cpu [:num_reqs + 1 ]
185- query_start_loc = query_start_loc_cpu .to (self .runner .device ,
186- non_blocking = True )
187192
188193 if is_310p ():
189194 if attn_state == AscendAttentionState .PrefillNoCache :
@@ -201,6 +206,7 @@ def build(self,
201206 query_start_loc = query_start_loc ,
202207 query_lens = query_lens ,
203208 seq_lens = seq_lens ,
209+ seq_lens_list = seq_lens_list ,
204210 max_query_len = max_query_len ,
205211 slot_mapping = slot_mapping ,
206212 attn_mask = attn_mask ,
@@ -209,6 +215,34 @@ def build(self,
209215 is_only_prefill = is_only_prefill )
210216 return attn_metadata
211217
218+ def build_dummy_metadata (self , num_actual_tokens , num_reqs ,
219+ num_scheduled_tokens , attn_state ):
220+ if attn_state == AscendAttentionState .DecodeOnly :
221+ # NOTE: We only need to pay attention to seq_lens_list and block_table here
222+ common_attn_metadata = CommonAttentionMetadata (seq_lens_list = [2 ] *
223+ num_reqs )
224+
225+ block_table = self .runner .input_batch .block_table [0 ].block_table
226+ block_table [:num_reqs , 0 ] = torch .arange (1 ,
227+ num_reqs + 1 ,
228+ device = block_table .device ,
229+ dtype = block_table .dtype )
230+
231+ attn_metadata = self .build (
232+ num_reqs = num_reqs ,
233+ num_actual_tokens = num_actual_tokens ,
234+ max_query_len = num_scheduled_tokens .max (),
235+ common_prefix_len = 0 ,
236+ common_attn_metadata = common_attn_metadata ,
237+ )
238+ else :
239+ raise NotImplementedError (
240+ "Currently we only support building dummy metadata for DecodeOnly state"
241+ )
242+
243+ attn_metadata .attn_state = attn_state
244+ return attn_metadata
245+
212246
213247class AscendAttentionBackendImpl (AttentionImpl ):
214248
@@ -245,6 +279,10 @@ def __init__(
245279 self .key_cache = None
246280 self .value_cache = None
247281
282+ vllm_config = get_current_vllm_config ()
283+ self .full_graph = vllm_config .compilation_config .full_cuda_graph
284+ self .block_size = vllm_config .cache_config .block_size
285+
248286 def forward (
249287 self ,
250288 layer : AttentionLayer ,
@@ -369,20 +407,96 @@ def forward(
369407 scale_value = self .scale ,
370408 out = output )
371409 elif attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
372- if is_310p ():
373- # # seq_lens_tensor needs to be transferred to the device for 310P
374- attn_metadata .seq_lens = \
375- attn_metadata .seq_lens .to (device = query .device )
376- torch_npu ._npu_paged_attention (
377- query = query ,
378- key_cache = self .key_cache ,
379- value_cache = self .value_cache ,
380- num_kv_heads = self .num_kv_heads ,
381- num_heads = self .num_heads ,
382- scale_value = self .scale ,
383- block_table = attn_metadata .block_tables ,
384- context_lens = attn_metadata .seq_lens ,
385- out = output )
410+ if self .full_graph :
411+ graph_params = get_graph_params ()
412+ q = query .view (num_tokens , - 1 , self .hidden_size )
413+ k = self .key_cache .view ( # type: ignore
414+ - 1 , self .block_size ,
415+ self .num_kv_heads * self .head_size )
416+ v = self .value_cache .view ( # type: ignore
417+ - 1 , self .block_size ,
418+ self .num_kv_heads * self .head_size )
419+ actual_seq_lens = attn_metadata .seq_lens_list
420+ attn_args = {
421+ "query" : q ,
422+ "key" : k ,
423+ "value" : v ,
424+ "actual_seq_lengths_kv" : actual_seq_lens ,
425+ "block_table" : attn_metadata .block_tables ,
426+ "num_heads" : self .num_heads ,
427+ "scale" : self .scale ,
428+ "input_layout" : "BSH" ,
429+ "num_key_value_heads" : self .num_kv_heads ,
430+ "block_size" : self .block_size ,
431+ }
432+
433+ # Prepare tensors for attention output
434+ # TODO: Refactor this to step-level instead of layer-level
435+ attn_output = torch .empty (num_tokens ,
436+ 1 ,
437+ self .hidden_size ,
438+ dtype = output .dtype ,
439+ device = output .device )
440+ softmax_lse = torch .empty (num_tokens ,
441+ dtype = output .dtype ,
442+ device = output .device )
443+
444+ # Get workspace from cache or calculate it if not present.
445+ workspace = graph_params .workspaces .get (num_tokens )
446+ if workspace is None :
447+ workspace = torch_npu ._npu_fused_infer_attention_score_get_max_workspace (
448+ ** attn_args )
449+ graph_params .workspaces [num_tokens ] = workspace
450+
451+ forward_context = get_forward_context ()
452+ if not forward_context .capturing :
453+ # Execute attention kernel directly in non-capturing mode
454+ torch .ops .npu .npu_fused_infer_attention_score .out (
455+ workspace = workspace ,
456+ out = [attn_output , softmax_lse ],
457+ ** attn_args )
458+ else :
459+ # Handle graph capturing mode
460+ stream = torch_npu .npu .current_stream ()
461+
462+ event = torch .npu .ExternalEvent ()
463+ event .wait (stream )
464+ event .reset (stream )
465+ graph_params .events [num_tokens ].append (event )
466+
467+ graph_params .attn_params [num_tokens ].append (
468+ (q , k , v , actual_seq_lens ,
469+ attn_metadata .block_tables , self .num_heads ,
470+ self .scale , self .num_kv_heads , attn_output ,
471+ softmax_lse ))
472+
473+ torch .npu .graph_task_group_begin (stream )
474+ torch .ops .npu .npu_fused_infer_attention_score .out (
475+ workspace = workspace ,
476+ out = [attn_output , softmax_lse ],
477+ ** attn_args )
478+ handle = torch .npu .graph_task_group_end (stream )
479+ graph_params .handles [num_tokens ].append (handle )
480+
481+ # Reshape output to match the expected format
482+ output .copy_ (
483+ attn_output .view (num_tokens , self .num_heads ,
484+ self .head_size ))
485+ else :
486+ if is_310p ():
487+ # seq_lens_tensor needs to be transferred to the device for 310P
488+ attn_metadata .seq_lens = \
489+ attn_metadata .seq_lens .to (device = query .device )
490+ torch_npu ._npu_paged_attention (
491+ query = query ,
492+ key_cache = self .key_cache ,
493+ value_cache = self .value_cache ,
494+ num_kv_heads = self .num_kv_heads ,
495+ num_heads = self .num_heads ,
496+ scale_value = self .scale ,
497+ block_table = attn_metadata .block_tables ,
498+ context_lens = attn_metadata .seq_lens ,
499+ out = output )
386500 # Normal V1 situation.
387501 else :
388502 # use chunked prefill for head size 192 scenario, like deepseek
0 commit comments