1+ #
2+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+ # Adapted from vllm/model_executor/models/qwen2_5_vl.py
4+ # Copyright 2023 The vLLM team.
5+ #
6+ # This file is a part of the vllm-ascend project.
7+ #
8+ # Licensed under the Apache License, Version 2.0 (the "License");
9+ # you may not use this file except in compliance with the License.
10+ # You may obtain a copy of the License at
11+ #
12+ # http://www.apache.org/licenses/LICENSE-2.0
13+ #
14+ # Unless required by applicable law or agreed to in writing, software
15+ # distributed under the License is distributed on an "AS IS" BASIS,
16+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+ # See the License for the specific language governing permissions and
18+ # limitations under the License.
19+
20+ from functools import partial
21+ from typing import Callable , Iterable , Optional , Set , Tuple
22+
23+ import torch
24+ import torch .nn as nn
25+ import torch .nn .functional as F
26+ import torch_npu
27+ from einops import rearrange
28+ from transformers .models .qwen2_5_vl .configuration_qwen2_5_vl import (
29+ Qwen2_5_VLConfig , Qwen2_5_VLVisionConfig )
30+ from vllm .config import VllmConfig
31+ from vllm .distributed import parallel_state
32+ from vllm .distributed import utils as dist_utils
33+ from vllm .model_executor .layers .activation import _ACTIVATION_REGISTRY
34+ from vllm .model_executor .layers .layernorm import RMSNorm
35+ from vllm .model_executor .layers .quantization import QuantizationConfig
36+ from vllm .model_executor .model_loader .weight_utils import default_weight_loader
37+ from vllm .model_executor .models .qwen2_5_vl import (
38+ Qwen2_5_VisionAttention , Qwen2_5_VisionBlock , Qwen2_5_VisionPatchEmbed ,
39+ Qwen2_5_VisionTransformer , Qwen2_5_VLDummyInputsBuilder ,
40+ Qwen2_5_VLForConditionalGeneration , Qwen2_5_VLMultiModalProcessor ,
41+ Qwen2_5_VLProcessingInfo )
42+ from vllm .model_executor .models .utils import maybe_prefix
43+ from vllm .multimodal import MULTIMODAL_REGISTRY
44+
45+ MIN_PAD_SIZE = 64
46+ MAX_PAD_SIZE = 128
47+
48+
49+ class AscendQwen2_5_VisionAttention (Qwen2_5_VisionAttention ):
50+
51+ def __init__ (
52+ self ,
53+ embed_dim : int ,
54+ num_heads : int ,
55+ projection_size : int ,
56+ quant_config : Optional [QuantizationConfig ] = None ,
57+ prefix : str = "" ,
58+ ) -> None :
59+ super ().__init__ (
60+ embed_dim ,
61+ num_heads ,
62+ projection_size ,
63+ quant_config ,
64+ prefix ,
65+ )
66+ self .embed_dim = embed_dim
67+ self .hidden_size_per_attention_head = dist_utils .divide (
68+ projection_size , num_heads )
69+ if self .hidden_size_per_attention_head > MIN_PAD_SIZE and self .hidden_size_per_attention_head < MAX_PAD_SIZE :
70+ self .hidden_size_per_attention_head = MAX_PAD_SIZE
71+
72+ def forward (
73+ self ,
74+ x : torch .Tensor ,
75+ cu_seqlens : torch .Tensor ,
76+ cos : torch .Tensor ,
77+ sin : torch .Tensor ,
78+ ) -> torch .Tensor :
79+ # [s, b, c] --> [s, b, head * 3 * head_dim]
80+ x , _ = self .qkv (x )
81+
82+ # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
83+ q , k , v = self .split_qkv (x )
84+ batch_size = q .shape [1 ]
85+
86+ q , k , v = (rearrange (x , "s b ... -> b s ..." ).contiguous ()
87+ for x in (q , k , v ))
88+ q = torch_npu .npu_rotary_mul (q , cos , sin )
89+ k = torch_npu .npu_rotary_mul (k , cos , sin )
90+
91+ q , k , v = [
92+ rearrange (x , "b s h d -> (b s) h d" ).contiguous ()
93+ for x in (q , k , v )
94+ ]
95+
96+ context_layer = torch .torch .empty_like (q )
97+
98+ # operator requires pta version >= 2.5.1.dev20250226
99+ torch_npu ._npu_flash_attention_unpad (
100+ query = q ,
101+ key = k ,
102+ value = v ,
103+ seq_len = cu_seqlens ,
104+ scale_value = self .hidden_size_per_attention_head ** - 0.5 ,
105+ num_heads = self .num_attention_heads_per_partition ,
106+ num_kv_heads = self .num_attention_heads_per_partition ,
107+ out = context_layer )
108+
109+ context_layer = rearrange (context_layer ,
110+ "(b s) h d -> s b (h d)" ,
111+ b = batch_size ).contiguous ()
112+
113+ output , _ = self .proj (context_layer )
114+ return output
115+
116+
117+ class AscendQwen2_5_VisionBlock (Qwen2_5_VisionBlock ):
118+
119+ def __init__ (
120+ self ,
121+ dim : int ,
122+ num_heads : int ,
123+ mlp_hidden_dim : int ,
124+ act_fn : Callable [[torch .Tensor ], torch .Tensor ] = F .silu ,
125+ norm_layer : Optional [Callable [[int ], nn .Module ]] = None ,
126+ quant_config : Optional [QuantizationConfig ] = None ,
127+ prefix : str = "" ,
128+ ) -> None :
129+ super ().__init__ (dim , num_heads , mlp_hidden_dim , act_fn , norm_layer ,
130+ quant_config , prefix )
131+ self .attn = AscendQwen2_5_VisionAttention (embed_dim = dim ,
132+ num_heads = num_heads ,
133+ projection_size = dim ,
134+ quant_config = quant_config ,
135+ prefix = f"{ prefix } .attn" )
136+
137+ def forward (self , x : torch .Tensor , cu_seqlens : torch .Tensor ,
138+ cos : torch .Tensor , sin : torch .Tensor ) -> torch .Tensor :
139+ x = x + self .attn (
140+ self .norm1 (x ), cu_seqlens = cu_seqlens , cos = cos , sin = sin )
141+
142+ x = x + self .mlp (self .norm2 (x ))
143+ return x
144+
145+
146+ class AscendQwen2_5_VisionPatchEmbed (Qwen2_5_VisionPatchEmbed ):
147+
148+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
149+ x = x .matmul (
150+ self .proj .weight .data .view (self .hidden_size , - 1 ).transpose (0 , 1 ))
151+ return x
152+
153+
154+ class AscendQwen2_5_VisionTransformer (Qwen2_5_VisionTransformer ):
155+
156+ def __init__ (
157+ self ,
158+ vision_config : Qwen2_5_VLVisionConfig ,
159+ norm_eps : float = 1e-6 ,
160+ quant_config : Optional [QuantizationConfig ] = None ,
161+ prefix : str = "" ,
162+ interleaved = False ,
163+ ) -> None :
164+ super ().__init__ (vision_config , norm_eps , quant_config , prefix )
165+ norm_layer = partial (RMSNorm , eps = norm_eps )
166+ self .interleaved = interleaved
167+ self .patch_embed = AscendQwen2_5_VisionPatchEmbed (
168+ patch_size = vision_config .patch_size ,
169+ temporal_patch_size = vision_config .temporal_patch_size ,
170+ in_channels = vision_config .in_channels ,
171+ hidden_size = self .hidden_size ,
172+ )
173+ self .blocks = nn .ModuleList ([
174+ AscendQwen2_5_VisionBlock (
175+ dim = self .hidden_size ,
176+ num_heads = self .num_heads ,
177+ mlp_hidden_dim = vision_config .intermediate_size ,
178+ act_fn = _ACTIVATION_REGISTRY [vision_config .hidden_act ],
179+ norm_layer = norm_layer ,
180+ quant_config = quant_config ,
181+ prefix = f"{ prefix } .blocks.{ layer_idx } " )
182+ for layer_idx in range (vision_config .depth )
183+ ])
184+ self .tp_size = parallel_state .get_tensor_model_parallel_world_size ()
185+ self .tp_rank = parallel_state .get_tensor_model_parallel_rank ()
186+ self .hidden_size_per_attention_head = dist_utils .divide (
187+ self .hidden_size , self .num_heads )
188+
189+ if self .hidden_size_per_attention_head > MIN_PAD_SIZE and self .hidden_size_per_attention_head < MAX_PAD_SIZE :
190+ self .origin_hidden_size_per_attention_head = self .hidden_size_per_attention_head
191+ self .half_origin_hidden_size_per_attention_head = self .hidden_size_per_attention_head // 2
192+ self .half_pad_hidden_size_per_attention_head = (
193+ MAX_PAD_SIZE - self .hidden_size_per_attention_head ) // 2
194+ self .hidden_size_per_attention_head = MAX_PAD_SIZE
195+
196+ def cal_cos_sin (self , rotary_pos_emb ):
197+ cos = rotary_pos_emb .cos () # [seqlen, rotary_dim / 2]
198+ sin = rotary_pos_emb .sin ()
199+ cos = torch .nn .functional .pad (
200+ cos , (0 , self .half_pad_hidden_size_per_attention_head ))
201+ sin = torch .nn .functional .pad (
202+ sin , (0 , self .half_pad_hidden_size_per_attention_head ))
203+
204+ if not self .interleaved :
205+ cos_new = torch .cat ((cos , cos ), dim = - 1 )
206+ sin_new = torch .cat ((sin , sin ), dim = - 1 )
207+ else :
208+ cos_new = rearrange (torch .stack ((cos , cos ), dim = - 1 ),
209+ "... d two -> ...(d two)" ,
210+ two = 2 )
211+ sin_new = rearrange (torch .stack ((sin , sin ), dim = - 1 ),
212+ "... d two -> ...(d two)" ,
213+ two = 2 )
214+ cos_new = cos_new .reshape (1 , - 1 , 1 ,
215+ self .hidden_size_per_attention_head )
216+ sin_new = sin_new .reshape (1 , - 1 , 1 ,
217+ self .hidden_size_per_attention_head )
218+ return cos_new , sin_new
219+
220+ def pad_qkv_bias (self , bias ):
221+ first_half = bias .reshape (
222+ - 1 , 3 , self .origin_hidden_size_per_attention_head
223+ )[:, :, :self .half_origin_hidden_size_per_attention_head ]
224+ second_half = bias .reshape (
225+ - 1 , 3 , self .origin_hidden_size_per_attention_head
226+ )[:, :, self .half_origin_hidden_size_per_attention_head :]
227+ first_half_padded = torch .nn .functional .pad (
228+ first_half , (0 , self .half_pad_hidden_size_per_attention_head ))
229+ second_half_padded = torch .nn .functional .pad (
230+ second_half , (0 , self .half_pad_hidden_size_per_attention_head ))
231+ bias_padded = torch .cat ([first_half_padded , second_half_padded ], dim = 2 )
232+ bias_final = bias_padded .reshape (- 1 )
233+ return bias_final
234+
235+ def pad_qkv_weight (self , data ):
236+ qkv_weight_first_half = data .reshape (
237+ - 1 , 3 , self .origin_hidden_size_per_attention_head , self .hidden_size
238+ )[:, :, :self .half_origin_hidden_size_per_attention_head , :]
239+ qkv_weight_second_half = data .reshape (
240+ - 1 , 3 , self .origin_hidden_size_per_attention_head , self .hidden_size
241+ )[:, :, self .half_origin_hidden_size_per_attention_head :, :]
242+
243+ qkv_weight_first_half_padded = torch .nn .functional .pad (
244+ qkv_weight_first_half ,
245+ (0 , 0 , 0 , self .half_pad_hidden_size_per_attention_head ))
246+ qkv_weight_second_half_padded = torch .nn .functional .pad (
247+ qkv_weight_second_half ,
248+ (0 , 0 , 0 , self .half_pad_hidden_size_per_attention_head ))
249+ qkv_weight_padded = torch .cat (
250+ [qkv_weight_first_half_padded , qkv_weight_second_half_padded ],
251+ dim = 2 )
252+ qkv_weight_final = qkv_weight_padded .reshape (- 1 , self .hidden_size )
253+ return qkv_weight_final
254+
255+ def pad_proj_weight (self , data ):
256+ out_weight = torch .nn .functional .pad (
257+ data .reshape (self .hidden_size , - 1 ,
258+ self .half_origin_hidden_size_per_attention_head ),
259+ (0 , self .half_pad_hidden_size_per_attention_head , 0 , 0 )).reshape (
260+ self .hidden_size , - 1 )
261+ return out_weight
262+
263+ def load_weights (self , weights : Iterable [Tuple [str ,
264+ torch .Tensor ]]) -> Set [str ]:
265+ stacked_params_mapping = [
266+ # (param_name, shard_name, shard_id)
267+ ("qkv_proj" , "q_proj" , "q" ),
268+ ("qkv_proj" , "k_proj" , "k" ),
269+ ("qkv_proj" , "v_proj" , "v" ),
270+ ]
271+ params_dict = dict (self .named_parameters (remove_duplicate = False ))
272+ loaded_params : Set [str ] = set ()
273+ for name , loaded_weight in weights :
274+ for (param_name , weight_name , shard_id ) in stacked_params_mapping :
275+ if weight_name not in name :
276+ continue
277+ name = name .replace (weight_name , param_name )
278+
279+ param = params_dict [name ]
280+ weight_loader = param .weight_loader
281+ weight_loader (param , loaded_weight , shard_id )
282+ break
283+ else :
284+ param = params_dict [name ]
285+ weight_loader = getattr (param , "weight_loader" ,
286+ default_weight_loader )
287+ weight_loader (param , loaded_weight )
288+ if ("attn.proj.weight" in name ):
289+ param .data = self .pad_proj_weight (param .data )
290+ if ("attn.qkv.weight" in name ):
291+ param .data = self .pad_qkv_weight (param .data )
292+ if ("attn.qkv.bias" in name ):
293+ param .data = self .pad_qkv_bias (param .data )
294+ loaded_params .add (name )
295+ return loaded_params
296+
297+ def forward (
298+ self ,
299+ x : torch .Tensor ,
300+ grid_thw : torch .Tensor ,
301+ ) -> torch .Tensor :
302+ # compute cu_seqlens
303+ cu_seqlens = torch .repeat_interleave (grid_thw [:, 1 ] * grid_thw [:, 2 ],
304+ grid_thw [:,
305+ 0 ]).cpu ().to (torch .int32 )
306+
307+ # patchify
308+ x = self .patch_embed (x )
309+
310+ # compute position embedding
311+ rotary_pos_emb = self .rot_pos_emb (grid_thw )
312+
313+ # windows attention
314+ window_index , cu_window_seqlens = self .get_window_index (grid_thw )
315+ cu_window_seqlens = torch .tensor (
316+ cu_window_seqlens ,
317+ device = x .device ,
318+ dtype = grid_thw .dtype if torch .jit .is_tracing () else torch .int32 )
319+ cu_window_seqlens = torch .unique_consecutive (cu_window_seqlens )
320+ cu_window_seqlens = torch .diff (cu_window_seqlens ).cpu ().to (torch .int32 )
321+ seq_len , _ = x .size ()
322+ x = x .reshape (seq_len // self .spatial_merge_unit ,
323+ self .spatial_merge_unit , - 1 )
324+ x = x [window_index , :, :]
325+ x = x .reshape (seq_len , - 1 )
326+ rotary_pos_emb = rotary_pos_emb .reshape (
327+ seq_len // self .spatial_merge_unit , self .spatial_merge_unit , - 1 )
328+ rotary_pos_emb = rotary_pos_emb [window_index , :, :]
329+ rotary_pos_emb = rotary_pos_emb .reshape (seq_len , - 1 )
330+
331+ cos , sin = self .cal_cos_sin (rotary_pos_emb )
332+
333+ # transformers
334+ x = x .unsqueeze (1 )
335+ for layer_num , blk in enumerate (self .blocks ):
336+ if layer_num in self .fullatt_block_indexes :
337+ cu_seqlens_now = cu_seqlens
338+ else :
339+ cu_seqlens_now = cu_window_seqlens
340+ x = blk (x , cu_seqlens = cu_seqlens_now , cos = cos , sin = sin )
341+
342+ # adapter
343+ x = self .merger (x )
344+ reverse_indices = torch .argsort (window_index )
345+ x = x [reverse_indices , :]
346+ return x
347+
348+
349+ @MULTIMODAL_REGISTRY .register_processor (
350+ Qwen2_5_VLMultiModalProcessor ,
351+ info = Qwen2_5_VLProcessingInfo ,
352+ dummy_inputs = Qwen2_5_VLDummyInputsBuilder )
353+ class AscendQwen2_5_VLForConditionalGeneration (
354+ Qwen2_5_VLForConditionalGeneration ):
355+
356+ def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
357+ super ().__init__ (vllm_config = vllm_config , prefix = prefix )
358+ config : Qwen2_5_VLConfig = vllm_config .model_config .hf_config
359+ quant_config = vllm_config .quant_config
360+ self .visual = AscendQwen2_5_VisionTransformer (
361+ vision_config = config .vision_config ,
362+ norm_eps = getattr (config , "rms_norm_eps" , 1e-6 ),
363+ quant_config = self ._maybe_ignore_quant_config (quant_config ),
364+ prefix = maybe_prefix (prefix , "visual" ),
365+ )
0 commit comments