@@ -118,7 +118,7 @@ class AscendMetadata:
118118    query_start_loc : torch .Tensor 
119119    query_lens : torch .Tensor 
120120    seq_lens : torch .Tensor 
121-     seq_lens_list : list 
121+     seq_lens_list : Optional [ list [ int ]] 
122122    # Maximum query length in the batch. None for decoding. 
123123    max_query_len : Optional [int ] =  None 
124124    # (num_tokens,). The indices of the token slots that input tokens will be 
@@ -168,8 +168,9 @@ def build(self,
168168        seq_lens  =  common_attn_metadata .seq_lens 
169169        # TODO: Refactor these two param to common metadata in runners, 
170170        # preparing for the hybrid KV groups feature 
171-         query_lens  =  common_attn_metadata .query_lens  if  common_attn_metadata .query_lens  is  not None  else  self .runner .query_lens 
172-         seq_lens_list  =  common_attn_metadata .seq_lens_list  if  common_attn_metadata .seq_lens_list  is  not None  else  self .runner .seq_lens_list 
171+         query_lens  =  common_attn_metadata .query_lens  or  self .runner .query_lens 
172+         # Since FIA for GQA is not active now, we temporarily silence it 
173+         seq_lens_list  =  common_attn_metadata .seq_lens_list 
173174
174175        slot_mapping  =  self .runner .slot_mapping [:num_actual_tokens ]
175176        attn_mask  =  self .runner .attn_mask 
@@ -193,8 +194,8 @@ def build_dummy_metadata(self, num_actual_tokens, num_reqs,
193194                             num_scheduled_tokens , attn_state ):
194195        if  attn_state  ==  AscendAttentionState .DecodeOnly :
195196            # NOTE: We only need to pay attention to seq_lens_list and block_table here 
196-             common_attn_metadata  =  CommonAttentionMetadata (seq_lens_list = [ 2 ]  * 
197-                                                             num_reqs )
197+             common_attn_metadata  =  CommonAttentionMetadata (
198+                 seq_lens = torch . empty_like ( self . runner . seq_lens_cpu ). fill_ ( 2 ) )
198199
199200            block_table  =  self .runner .input_batch .block_table [0 ].block_table 
200201            block_table [:num_reqs , 0 ] =  torch .arange (1 ,
@@ -349,82 +350,42 @@ def forward(
349350                    scale_value = self .scale ,
350351                    out = output )
351352            elif  attn_metadata .attn_state  ==  AscendAttentionState .DecodeOnly :
352-                 if  self .full_graph :
353-                     graph_params  =  get_graph_params ()
354-                     q  =  query .view (num_tokens , - 1 , self .hidden_size )
355-                     k  =  self .key_cache .view (  # type: ignore 
356-                         - 1 , self .block_size ,
357-                         self .num_kv_heads  *  self .head_size )
358-                     v  =  self .value_cache .view (  # type: ignore 
359-                         - 1 , self .block_size ,
360-                         self .num_kv_heads  *  self .head_size )
361-                     actual_seq_lens  =  attn_metadata .seq_lens_list 
362-                     attn_args  =  {
363-                         "query" : q ,
364-                         "key" : k ,
365-                         "value" : v ,
366-                         "actual_seq_lengths_kv" : actual_seq_lens ,
367-                         "block_table" : attn_metadata .block_tables ,
368-                         "num_heads" : self .num_heads ,
369-                         "scale" : self .scale ,
370-                         "input_layout" : "BSH" ,
371-                         "num_key_value_heads" : self .num_kv_heads ,
372-                         "block_size" : self .block_size ,
373-                     }
374- 
375-                     # Prepare tensors for attention output 
376-                     # TODO: Refactor this to step-level instead of layer-level 
377-                     attn_output  =  torch .empty (num_tokens ,
378-                                               1 ,
379-                                               self .hidden_size ,
380-                                               dtype = output .dtype ,
381-                                               device = output .device )
382-                     softmax_lse  =  torch .empty (num_tokens ,
383-                                               dtype = output .dtype ,
384-                                               device = output .device )
385- 
386-                     # Get workspace from cache or calculate it if not present. 
387-                     workspace  =  graph_params .workspaces .get (num_tokens )
388-                     if  workspace  is  None :
389-                         workspace  =  torch_npu ._npu_fused_infer_attention_score_get_max_workspace (
390-                             ** attn_args )
391-                         graph_params .workspaces [num_tokens ] =  workspace 
392- 
393-                     forward_context  =  get_forward_context ()
394-                     if  not  forward_context .capturing :
395-                         # Execute attention kernel directly in non-capturing mode 
396-                         torch .ops .npu .npu_fused_infer_attention_score .out (
397-                             workspace = workspace ,
398-                             out = [attn_output , softmax_lse ],
399-                             ** attn_args )
400-                     else :
401-                         # Handle graph capturing mode 
402-                         stream  =  torch_npu .npu .current_stream ()
403- 
404-                         event  =  torch .npu .ExternalEvent ()
405-                         event .wait (stream )
406-                         event .reset (stream )
407-                         graph_params .events [num_tokens ].append (event )
408- 
409-                         graph_params .attn_params [num_tokens ].append (
410-                             (q , k , v , actual_seq_lens ,
411-                              attn_metadata .block_tables , self .num_heads ,
412-                              self .scale , self .num_kv_heads , attn_output ,
413-                              softmax_lse ))
414- 
415-                         torch .npu .graph_task_group_begin (stream )
416-                         torch .ops .npu .npu_fused_infer_attention_score .out (
417-                             workspace = workspace ,
418-                             out = [attn_output , softmax_lse ],
419-                             ** attn_args )
420-                         handle  =  torch .npu .graph_task_group_end (stream )
421-                         graph_params .handles [num_tokens ].append (handle )
422- 
423-                     # Reshape output to match the expected format 
424-                     output .copy_ (
425-                         attn_output .view (num_tokens , self .num_heads ,
426-                                          self .head_size ))
353+                 graph_params  =  get_graph_params ()
354+ 
355+                 forward_context  =  get_forward_context ()
356+                 if  not  forward_context .capturing :
357+                     torch_npu ._npu_paged_attention (
358+                         query = query ,
359+                         key_cache = self .key_cache ,
360+                         value_cache = self .value_cache ,
361+                         num_kv_heads = self .num_kv_heads ,
362+                         num_heads = self .num_heads ,
363+                         scale_value = self .scale ,
364+                         block_table = attn_metadata .block_tables ,
365+                         context_lens = attn_metadata .seq_lens ,
366+                         out = output )
427367                else :
368+                     # Handle graph capturing mode 
369+                     stream  =  torch_npu .npu .current_stream ()
370+ 
371+                     event  =  torch .npu .ExternalEvent ()
372+                     event .wait (stream )
373+                     event .reset (stream )
374+                     graph_params .events [num_tokens ].append (event )
375+ 
376+                     graph_params .attn_params [num_tokens ].append ((
377+                         query ,
378+                         self .key_cache ,
379+                         self .value_cache ,
380+                         self .num_kv_heads ,
381+                         self .num_heads ,
382+                         self .scale ,
383+                         attn_metadata .block_tables ,
384+                         attn_metadata .seq_lens ,
385+                         output ,
386+                     ))
387+ 
388+                     torch .npu .graph_task_group_begin (stream )
428389                    torch_npu ._npu_paged_attention (
429390                        query = query ,
430391                        key_cache = self .key_cache ,
@@ -435,6 +396,8 @@ def forward(
435396                        block_table = attn_metadata .block_tables ,
436397                        context_lens = attn_metadata .seq_lens ,
437398                        out = output )
399+                     handle  =  torch .npu .graph_task_group_end (stream )
400+                     graph_params .handles [num_tokens ].append (handle )
438401            # Normal V1 situation. 
439402            else :
440403                # use chunked prefill for head size 192 scenario, like deepseek 
0 commit comments