Skip to content

Commit 0778f7b

Browse files
committed
feat: enhance Qwen3-MoE to support TP settings beyond 4.
1 parent 6b52dd4 commit 0778f7b

File tree

2 files changed

+28
-8
lines changed

2 files changed

+28
-8
lines changed

xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,8 @@ NpuQwen3MoeDecoderLayerImpl::NpuQwen3MoeDecoderLayerImpl(
233233
CHECK_EQ(parallel_args.world_size(), dp_size_ * dp_local_tp_size_);
234234
dp_local_tp_rank_ = parallel_args.rank() % dp_local_tp_size_;
235235

236+
n_kv_heads_ = static_cast<int32_t>(model_args.n_kv_heads().value());
237+
236238
param_from_args(prefill_param_, model_args, parallel_args, true);
237239
param_from_args(decode_param_, model_args, parallel_args, false);
238240
initialize_tensors(options);
@@ -345,8 +347,8 @@ void NpuQwen3MoeDecoderLayerImpl::initialize_basic_parameters(
345347
param.rmsnormQKNorm = true;
346348
param.hiddenSizePerAttentionHead = args.head_dim();
347349
std::optional<long int> optionalValue = args.n_kv_heads();
348-
param.numKeyValueHeadsPerRank =
349-
static_cast<int>(optionalValue.value()) / parallel_args.world_size();
350+
param.numKeyValueHeadsPerRank = std::max(
351+
1, static_cast<int>(optionalValue.value()) / parallel_args.world_size());
350352
param.numAttentionHeadsPerRank = args.n_heads() / dp_local_tp_size_;
351353

352354
param.attnLinearTransposeType = {1, -1, -1, 1, -1, -1};
@@ -390,8 +392,16 @@ void NpuQwen3MoeDecoderLayerImpl::initialize_mlp_parameters(
390392
void NpuQwen3MoeDecoderLayerImpl::initialize_parallel_parameters(
391393
atb_speed::qwen::MoeDecoderLayerParam& param,
392394
const ParallelArgs& parallel_args) {
393-
param.lmHeadLocalTp = 0;
395+
param.lmHeadLocalTp = dp_local_tp_size_;
394396
param.mapping = parallel_args.mapping();
397+
param.tensorParallelInfo = {parallel_args.rank(),
398+
parallel_args.world_size(),
399+
FLAGS_communication_backend,
400+
FLAGS_rank_tablefile,
401+
nullptr,
402+
""};
403+
404+
param.PrintParam();
395405
param.maxDecodeDpTokenSize = 0; // TODO
396406
}
397407

@@ -543,13 +553,22 @@ void NpuQwen3MoeDecoderLayerImpl::process_general_weights(
543553
const int index = get_mapped_index(name, weight_mapping);
544554
const bool is_sharded = shard_map.count(index);
545555
torch::Tensor tmp_tensor;
556+
int32_t tp_rank = dp_local_tp_rank_;
557+
int32_t tp_size = dp_local_tp_size_;
558+
559+
if (index == IN_QKV_WEIGHT_1 || index == IN_QKV_WEIGHT_2 ||
560+
index == IN_QKV_BIAS_1 || index == IN_QKV_BIAS_2 ||
561+
index == IN_QKV_DESCALE_1 || index == IN_QKV_DESCALE_2) {
562+
if (n_kv_heads_ < dp_local_tp_size_) {
563+
int32_t repeat_times = (dp_local_tp_size_ / n_kv_heads_);
546564

565+
tp_rank = tp_rank / repeat_times;
566+
tp_size = n_kv_heads_;
567+
}
568+
}
547569
if (is_sharded) {
548-
tmp_tensor = get_sharded_tensor(state_dict,
549-
name,
550-
shard_map.at(index),
551-
dp_local_tp_rank_,
552-
dp_local_tp_size_)
570+
tmp_tensor = get_sharded_tensor(
571+
state_dict, name, shard_map.at(index), tp_rank, tp_size)
553572
.to(device_);
554573
} else {
555574
tmp_tensor = tensor.to(device_);

xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ class NpuQwen3MoeDecoderLayerImpl : public NpuBaseLayer {
190190
int32_t start_expert_id_;
191191
int32_t end_expert_id_;
192192
int32_t ep_rank_;
193+
int32_t n_kv_heads_;
193194

194195
int32_t dp_size_;
195196
int32_t dp_local_tp_size_;

0 commit comments

Comments
 (0)