Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 38 additions & 9 deletions xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ limitations under the License.

#include <gflags/gflags.h>

#include <unordered_set>

#include "common/global_flags.h"

namespace xllm {
Expand Down Expand Up @@ -233,6 +235,8 @@ NpuQwen3MoeDecoderLayerImpl::NpuQwen3MoeDecoderLayerImpl(
CHECK_EQ(parallel_args.world_size(), dp_size_ * dp_local_tp_size_);
dp_local_tp_rank_ = parallel_args.rank() % dp_local_tp_size_;

n_kv_heads_ = static_cast<int32_t>(model_args.n_kv_heads().value());

param_from_args(prefill_param_, model_args, parallel_args, true);
param_from_args(decode_param_, model_args, parallel_args, false);
initialize_tensors(options);
Expand Down Expand Up @@ -345,8 +349,8 @@ void NpuQwen3MoeDecoderLayerImpl::initialize_basic_parameters(
param.rmsnormQKNorm = true;
param.hiddenSizePerAttentionHead = args.head_dim();
std::optional<long int> optionalValue = args.n_kv_heads();
param.numKeyValueHeadsPerRank =
static_cast<int>(optionalValue.value()) / parallel_args.world_size();
param.numKeyValueHeadsPerRank = std::max(
1, static_cast<int>(optionalValue.value()) / parallel_args.world_size());
param.numAttentionHeadsPerRank = args.n_heads() / dp_local_tp_size_;

param.attnLinearTransposeType = {1, -1, -1, 1, -1, -1};
Expand Down Expand Up @@ -390,8 +394,15 @@ void NpuQwen3MoeDecoderLayerImpl::initialize_mlp_parameters(
void NpuQwen3MoeDecoderLayerImpl::initialize_parallel_parameters(
atb_speed::qwen::MoeDecoderLayerParam& param,
const ParallelArgs& parallel_args) {
param.lmHeadLocalTp = 0;
param.lmHeadLocalTp = dp_local_tp_size_;
param.mapping = parallel_args.mapping();
param.tensorParallelInfo = {parallel_args.rank(),
parallel_args.world_size(),
FLAGS_communication_backend,
FLAGS_rank_tablefile,
nullptr,
""};

param.maxDecodeDpTokenSize = 0; // TODO
}

Expand Down Expand Up @@ -543,13 +554,31 @@ void NpuQwen3MoeDecoderLayerImpl::process_general_weights(
const int index = get_mapped_index(name, weight_mapping);
const bool is_sharded = shard_map.count(index);
torch::Tensor tmp_tensor;

int32_t tp_rank = dp_local_tp_rank_;
int32_t tp_size = dp_local_tp_size_;

static const std::unordered_set<int> qkv_tensor_indices = {IN_QKV_WEIGHT_1,
IN_QKV_WEIGHT_2,
IN_QKV_BIAS_1,
IN_QKV_BIAS_2,
IN_QKV_DESCALE_1,
IN_QKV_DESCALE_2,
IN_QKV_OFFSET_1,
IN_QKV_OFFSET_2,
IN_QKV_SCALE_1,
IN_QKV_SCALE_2};

if (qkv_tensor_indices.count(index) > 0) {
if (n_kv_heads_ < dp_local_tp_size_) {
int32_t repeat_times = (dp_local_tp_size_ / n_kv_heads_);

tp_rank = tp_rank / repeat_times;
tp_size = n_kv_heads_;
}
}
if (is_sharded) {
tmp_tensor = get_sharded_tensor(state_dict,
name,
shard_map.at(index),
dp_local_tp_rank_,
dp_local_tp_size_)
tmp_tensor = get_sharded_tensor(
state_dict, name, shard_map.at(index), tp_rank, tp_size)
.to(device_);
} else {
tmp_tensor = tensor.to(device_);
Expand Down
1 change: 1 addition & 0 deletions xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ class NpuQwen3MoeDecoderLayerImpl : public NpuBaseLayer {
int32_t start_expert_id_;
int32_t end_expert_id_;
int32_t ep_rank_;
int32_t n_kv_heads_;

int32_t dp_size_;
int32_t dp_local_tp_size_;
Expand Down