@@ -233,6 +233,8 @@ NpuQwen3MoeDecoderLayerImpl::NpuQwen3MoeDecoderLayerImpl(
233233 CHECK_EQ (parallel_args.world_size (), dp_size_ * dp_local_tp_size_);
234234 dp_local_tp_rank_ = parallel_args.rank () % dp_local_tp_size_;
235235
236+ n_kv_heads_ = static_cast <int32_t >(model_args.n_kv_heads ().value ());
237+
236238 param_from_args (prefill_param_, model_args, parallel_args, true );
237239 param_from_args (decode_param_, model_args, parallel_args, false );
238240 initialize_tensors (options);
@@ -345,8 +347,8 @@ void NpuQwen3MoeDecoderLayerImpl::initialize_basic_parameters(
345347 param.rmsnormQKNorm = true ;
346348 param.hiddenSizePerAttentionHead = args.head_dim ();
347349 std::optional<long int > optionalValue = args.n_kv_heads ();
348- param.numKeyValueHeadsPerRank =
349- static_cast <int >(optionalValue.value ()) / parallel_args.world_size ();
350+ param.numKeyValueHeadsPerRank = std::max (
351+ 1 , static_cast <int >(optionalValue.value ()) / parallel_args.world_size () );
350352 param.numAttentionHeadsPerRank = args.n_heads () / dp_local_tp_size_;
351353
352354 param.attnLinearTransposeType = {1 , -1 , -1 , 1 , -1 , -1 };
@@ -390,8 +392,16 @@ void NpuQwen3MoeDecoderLayerImpl::initialize_mlp_parameters(
390392void NpuQwen3MoeDecoderLayerImpl::initialize_parallel_parameters (
391393 atb_speed::qwen::MoeDecoderLayerParam& param,
392394 const ParallelArgs& parallel_args) {
393- param.lmHeadLocalTp = 0 ;
395+ param.lmHeadLocalTp = dp_local_tp_size_ ;
394396 param.mapping = parallel_args.mapping ();
397+ param.tensorParallelInfo = {parallel_args.rank (),
398+ parallel_args.world_size (),
399+ FLAGS_communication_backend,
400+ FLAGS_rank_tablefile,
401+ nullptr ,
402+ " " };
403+
404+ param.PrintParam ();
395405 param.maxDecodeDpTokenSize = 0 ; // TODO
396406}
397407
@@ -543,13 +553,22 @@ void NpuQwen3MoeDecoderLayerImpl::process_general_weights(
543553 const int index = get_mapped_index (name, weight_mapping);
544554 const bool is_sharded = shard_map.count (index);
545555 torch::Tensor tmp_tensor;
556+ int32_t tp_rank = dp_local_tp_rank_;
557+ int32_t tp_size = dp_local_tp_size_;
558+
559+ if (index == IN_QKV_WEIGHT_1 || index == IN_QKV_WEIGHT_2 ||
560+ index == IN_QKV_BIAS_1 || index == IN_QKV_BIAS_2 ||
561+ index == IN_QKV_DESCALE_1 || index == IN_QKV_DESCALE_2) {
562+ if (n_kv_heads_ < dp_local_tp_size_) {
563+ int32_t repeat_times = (dp_local_tp_size_ / n_kv_heads_);
546564
565+ tp_rank = tp_rank / repeat_times;
566+ tp_size = n_kv_heads_;
567+ }
568+ }
547569 if (is_sharded) {
548- tmp_tensor = get_sharded_tensor (state_dict,
549- name,
550- shard_map.at (index),
551- dp_local_tp_rank_,
552- dp_local_tp_size_)
570+ tmp_tensor = get_sharded_tensor (
571+ state_dict, name, shard_map.at (index), tp_rank, tp_size)
553572 .to (device_);
554573 } else {
555574 tmp_tensor = tensor.to (device_);
0 commit comments