1+ /* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+ Licensed under the Apache License, Version 2.0 (the "License");
4+ you may not use this file except in compliance with the License.
5+ You may obtain a copy of the License at
6+
7+ https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+ Unless required by applicable law or agreed to in writing, software
10+ distributed under the License is distributed on an "AS IS" BASIS,
11+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ See the License for the specific language governing permissions and
13+ limitations under the License.
14+ ==============================================================================*/
15+
16+ #include " npu_qwen2_vision_encoder_layer_impl.h"
17+
18+ #include < glog/logging.h>
19+ #include < mstx/ms_tools_ext.h>
20+
21+ #include < iostream>
22+ #include < map>
23+
24+ #include " torch_npu/csrc/core/npu/NPUCachingAllocator.h"
25+ #include " torch_npu/csrc/core/npu/NPUException.h"
26+ #include " xllm_kernels/models/qwen3_vl/qwen3_vl_encoder.h"
27+
28+ namespace xllm {
29+ namespace layer {
30+
31+ enum VisionEncoderLayerTensorId : int {
32+ IN_INPUT_NORM_WEIGHT = 0 ,
33+ IN_INPUT_NORM_BIAS,
34+ IN_POST_NORM_WEIGHT,
35+ IN_POST_NORM_BIAS,
36+ IN_QKV_WEIGHT,
37+ IN_QKV_BIAS,
38+ IN_WATTENTION_OUT_WEIGHT,
39+ IN_WATTENTION_OUT_BIAS,
40+ IN_LINEAR_FC1_WEIGHT,
41+ IN_LINEAR_FC1_BIAS,
42+ IN_LINEAR_FC2_WEIGHT,
43+ IN_LINEAR_FC2_BIAS,
44+ IN_VISION_Q_WEIGHT,
45+ IN_VISION_Q_BIAS,
46+ IN_VISION_K_WEIGHT,
47+ IN_VISION_K_BIAS,
48+ IN_VISION_V_WEIGHT,
49+ IN_VISION_V_BIAS
50+ };
51+
52+ const uint64_t WEIGHT_COUNT_PER_LAYER = 18 ;
53+
54+ static std::vector<std::pair<int , std::string>> WEIGHT_MAPPING = {
55+ {IN_INPUT_NORM_WEIGHT, " norm1.weight" },
56+ {IN_INPUT_NORM_BIAS, " norm1.bias" },
57+ {IN_POST_NORM_WEIGHT, " norm2.weight" },
58+ {IN_POST_NORM_BIAS, " norm2.bias" },
59+ {IN_QKV_WEIGHT, " attn.qkv.weight" },
60+ {IN_QKV_BIAS, " attn.qkv.bias" },
61+ {IN_WATTENTION_OUT_WEIGHT, " attn.proj.weight" },
62+ {IN_WATTENTION_OUT_BIAS, " attn.proj.bias" },
63+ {IN_LINEAR_FC1_WEIGHT, " mlp.fc1.weight" },
64+ {IN_LINEAR_FC1_BIAS, " mlp.fc1.bias" },
65+ {IN_LINEAR_FC2_WEIGHT, " mlp.fc2.weight" },
66+ {IN_LINEAR_FC2_BIAS, " mlp.fc2.bias" }};
67+
68+ // {weight,dim}
69+ static std::map<int , int > WEIGHT_SHARD = {
70+ {IN_WATTENTION_OUT_WEIGHT, 1 },
71+ {IN_LINEAR_FC1_WEIGHT, 0 },
72+ {IN_LINEAR_FC1_BIAS, 0 },
73+ {IN_LINEAR_FC2_WEIGHT, 1 },
74+ };
75+
76+ void NpuQwen2VisionEncoderLayerImpl::param_from_args (
77+ atb_speed::qwen::VisionEncoderLayerParam& param,
78+ const ModelArgs& args,
79+ const ParallelArgs& parallel_args) {
80+ param.isBF16 = args.dtype () == " bfloat16" ;
81+ param.rmsNormEps = args.rms_norm_eps ();
82+ param.worldSize = parallel_args.world_size ();
83+ param.numAttentionHeadsPerRank =
84+ args.mm_num_attention_heads () / param.worldSize ;
85+ param.hiddenSizePerAttentionHead =
86+ args.mm_hidden_size () / args.mm_num_attention_heads ();
87+ std::optional<long int > optionalValue = args.mm_num_attention_heads ();
88+ param.numKeyValueHeadsPerRank =
89+ static_cast <int >(optionalValue.value ()) / param.worldSize ;
90+ param.rank = parallel_args.rank ();
91+ param.backend = " lccl" ;
92+ param.enableLogN = false ;
93+ }
94+
95+ NpuQwen2VisionEncoderLayerImpl::NpuQwen2VisionEncoderLayerImpl (
96+ const ModelContext& context)
97+ : NpuBaseLayer(context) {
98+ auto model_args = context.get_model_args ();
99+ auto parallel_args = context.get_parallel_args ();
100+ auto options = context.get_tensor_options ();
101+ param_from_args (encode_param_, model_args, parallel_args);
102+ at_weight_tensors_.resize (WEIGHT_COUNT_PER_LAYER);
103+ atb_weight_tensors_.resize (WEIGHT_COUNT_PER_LAYER);
104+ dtype_ = c10::typeMetaToScalarType (options.dtype ());
105+ device_id_ = options.device ().index ();
106+ placeholder_ = atb_speed::Utils::AtTensor2Tensor (
107+ torch::zeros ({1 }).to (device_).to (dtype_));
108+ at_placeholder_ = torch::zeros ({1 }).to (device_).to (dtype_);
109+ for (int i = 0 ; i < WEIGHT_COUNT_PER_LAYER; ++i) {
110+ at_weight_tensors_[i] = torch::zeros ({1 }).to (options);
111+ }
112+ }
113+
114+ void NpuQwen2VisionEncoderLayerImpl::verify_loaded_weights () const {
115+ for (const auto & [index, name] : WEIGHT_MAPPING) {
116+ CHECK (at_weight_tensors_[index].sizes () != std::vector<int64_t >({1 }))
117+ << " weight is not loaded for " << name;
118+ }
119+ }
120+
121+ void NpuQwen2VisionEncoderLayerImpl::merge_loaded_weights () {
122+ // spilt pack qkv weight when enable tp
123+ get_weights_col_packed_qkv ();
124+ if (encode_param_.worldSize > 1 ) {
125+ // merge qkv weight
126+ auto new_qkv_weight = torch::cat ({at_weight_tensors_[IN_VISION_Q_WEIGHT],
127+ at_weight_tensors_[IN_VISION_K_WEIGHT],
128+ at_weight_tensors_[IN_VISION_V_WEIGHT]},
129+ 0 );
130+ at_weight_tensors_[IN_QKV_WEIGHT] = new_qkv_weight;
131+ at_weight_tensors_[IN_VISION_Q_WEIGHT] = torch::zeros ({1 }).to (device_);
132+ at_weight_tensors_[IN_VISION_K_WEIGHT] = torch::zeros ({1 }).to (device_);
133+ at_weight_tensors_[IN_VISION_V_WEIGHT] = torch::zeros ({1 }).to (device_);
134+
135+ // merge qkv bias
136+ auto new_qkv_bias = torch::cat ({at_weight_tensors_[IN_VISION_Q_BIAS],
137+ at_weight_tensors_[IN_VISION_K_BIAS],
138+ at_weight_tensors_[IN_VISION_V_BIAS]},
139+ 0 );
140+ at_weight_tensors_[IN_QKV_BIAS] = new_qkv_bias;
141+ at_weight_tensors_[IN_VISION_Q_BIAS] = torch::zeros ({1 }).to (device_);
142+ at_weight_tensors_[IN_VISION_K_BIAS] = torch::zeros ({1 }).to (device_);
143+ at_weight_tensors_[IN_VISION_V_BIAS] = torch::zeros ({1 }).to (device_);
144+ }
145+ c10_npu::NPUCachingAllocator::emptyCache ();
146+ for (int i = 0 ; i < WEIGHT_COUNT_PER_LAYER; ++i) {
147+ atb_weight_tensors_[i] =
148+ atb_speed::Utils::AtTensor2Tensor (at_weight_tensors_[i]);
149+ }
150+
151+ init_layer ();
152+ }
153+ // tp spilt weight
154+ void NpuQwen2VisionEncoderLayerImpl::get_weights_col_packed_qkv () {
155+ int rank = encode_param_.rank ;
156+ int worldSize = encode_param_.worldSize ;
157+ // split qkv weight
158+ qkv_weight = torch::chunk (at_weight_tensors_[IN_QKV_WEIGHT], 3 , 0 );
159+ qkv_bias = torch::chunk (at_weight_tensors_[IN_QKV_BIAS], 3 , 0 );
160+ // weight
161+ at_weight_tensors_[IN_VISION_Q_WEIGHT] =
162+ (qkv_weight[0 ].chunk (worldSize, 0 ))[rank];
163+ at_weight_tensors_[IN_VISION_K_WEIGHT] =
164+ (qkv_weight[1 ].chunk (worldSize, 0 ))[rank];
165+ at_weight_tensors_[IN_VISION_V_WEIGHT] =
166+ (qkv_weight[2 ].chunk (worldSize, 0 ))[rank];
167+ // bias
168+ at_weight_tensors_[IN_VISION_Q_BIAS] =
169+ (qkv_bias[0 ].chunk (worldSize, 0 ))[rank];
170+ at_weight_tensors_[IN_VISION_K_BIAS] =
171+ (qkv_bias[1 ].chunk (worldSize, 0 ))[rank];
172+ at_weight_tensors_[IN_VISION_V_BIAS] =
173+ (qkv_bias[2 ].chunk (worldSize, 0 ))[rank];
174+ }
175+
176+ void NpuQwen2VisionEncoderLayerImpl::load_state_dict (
177+ const StateDict& state_dict) {
178+ for (const auto & [index, name] : WEIGHT_MAPPING) {
179+ if (WEIGHT_SHARD.find (index) != WEIGHT_SHARD.end ()) {
180+ set_weight (state_dict, name, index, WEIGHT_SHARD[index]);
181+ } else {
182+ set_weight (state_dict, name, index);
183+ }
184+ }
185+ }
186+
187+ int64_t NpuQwen2VisionEncoderLayerImpl::init_layer () {
188+ name_ = " qwen2_encoder_layer" ;
189+ model_name_ = " qwen2_vl" ;
190+ CHECK_OPERATION_STATUS_RETURN (init_node (encode_node_, encode_param_));
191+ return atb::NO_ERROR;
192+ }
193+
194+ int64_t NpuQwen2VisionEncoderLayerImpl::init_node (
195+ atb_speed::Model::Node& node,
196+ atb_speed::qwen::VisionEncoderLayerParam& param) {
197+ atb::Operation* operation = nullptr ;
198+ atb_speed::qwen::Qwen3VL_EncoderLayer (param, &operation);
199+ node.operation .reset (operation);
200+ if (node.operation == nullptr ) {
201+ LOG (ERROR) << " node.operation is null" ;
202+ return -1 ;
203+ }
204+ if (node.operation ->GetInputNum () < 1 ) {
205+ LOG (ERROR) << " Can not resize number which is smaller than 1" ;
206+ return -1 ;
207+ }
208+ node.inTensors .resize (node.operation ->GetInputNum ());
209+ node.outTensors .resize (1 );
210+ size_t inTensorId = 1 ;
211+
212+ for (size_t weightTensorId = 0 ; weightTensorId < WEIGHT_COUNT_PER_LAYER;
213+ ++weightTensorId) {
214+ node.inTensors .at (weightTensorId) = &atb_weight_tensors_[weightTensorId];
215+ }
216+
217+ node.variantPack .inTensors .reserve (node.inTensors .size ());
218+ node.variantPack .inTensors .resize (node.inTensors .size ());
219+ node.variantPack .outTensors .reserve (1 );
220+ node.variantPack .outTensors .resize (1 );
221+ return atb::NO_ERROR;
222+ }
223+
224+ torch::Tensor NpuQwen2VisionEncoderLayerImpl::forward (
225+ torch::Tensor& x,
226+ torch::Tensor& cos_pos,
227+ torch::Tensor& sin_pos,
228+ torch::Tensor& cu_seqlen,
229+ std::vector<int >& cu_seqlen_vec,
230+ ModelInputParams& input_params,
231+ int node_id,
232+ aclrtEvent* event,
233+ std::atomic<bool >* event_flag) {
234+ atb::Status st;
235+
236+ build_node_variant_pack (encode_node_,
237+ x,
238+ cos_pos,
239+ sin_pos,
240+ cu_seqlen,
241+ cu_seqlen_vec,
242+ input_params,
243+ true );
244+ // mstxRangeEnd(id);
245+ st = execute_node (encode_node_, node_id);
246+ LOG_IF (FATAL, st != 0 ) << model_name_
247+ << " excute encode layer fail, error code: " << st;
248+ return x;
249+ }
250+
251+ void NpuQwen2VisionEncoderLayerImpl::build_node_variant_pack (
252+ atb_speed::Model::Node& node,
253+ torch::Tensor& x,
254+ torch::Tensor& cos_pos,
255+ torch::Tensor& sin_pos,
256+ torch::Tensor& cu_seqlen,
257+ std::vector<int >& cu_seqlen_vec,
258+ ModelInputParams& input_params,
259+ bool is_prefill) {
260+ internal_tensors_ = atb_speed::Utils::AtTensor2Tensor (x);
261+
262+ node.variantPack .inTensors .at (WEIGHT_COUNT_PER_LAYER) = internal_tensors_;
263+ node.variantPack .inTensors .at (WEIGHT_COUNT_PER_LAYER + 1 ) =
264+ atb_speed::Utils::AtTensor2Tensor (cos_pos);
265+ node.variantPack .inTensors .at (WEIGHT_COUNT_PER_LAYER + 2 ) =
266+ atb_speed::Utils::AtTensor2Tensor (sin_pos);
267+ node.variantPack .inTensors .at (WEIGHT_COUNT_PER_LAYER + 3 ) =
268+ atb_speed::Utils::AtTensor2Tensor (cu_seqlen);
269+ node.variantPack .inTensors .at (WEIGHT_COUNT_PER_LAYER + 3 ).hostData =
270+ cu_seqlen_vec.data ();
271+
272+ for (size_t i = 0 ; i < WEIGHT_COUNT_PER_LAYER; ++i) {
273+ CHECK_THROW (node.inTensors .at (i) == nullptr ,
274+ model_name_ << " inTensor " << i << " is NULL" );
275+ node.variantPack .inTensors .at (i) = *node.inTensors .at (i);
276+ // LOG(INFO) << model_name_ << "inTensors[" << i << "]:"
277+ // << atb_speed::TensorUtil::TensorToString(
278+ // node.variantPack.inTensors.at(i));
279+ }
280+
281+ node.variantPack .outTensors .at (0 ) = internal_tensors_;
282+ }
283+
284+ } // namespace layer
285+ } // namespace xllm
0 commit comments