Skip to content

Commit 930e640

Browse files
authored
[CPU] MoE Kernel (#25958)
CPU MoE Kernel ``` name: SwigluMoEBlock, quant_bits: 0, dtype: FP32, batch: 1, seq_len: 16, max_diff: 2.682209014892578e-07 .name: SwigluMoEBlock, quant_bits: 0, dtype: FP32, batch: 1, seq_len: 32, max_diff: 2.980232238769531e-07 .name: SwigluMoEBlock, quant_bits: 0, dtype: FP32, batch: 2, seq_len: 16, max_diff: 2.980232238769531e-07 .name: SwigluMoEBlock, quant_bits: 0, dtype: FP32, batch: 2, seq_len: 32, max_diff: 4.172325134277344e-07 .MoE CPU kernel time: 15.721677541732786 ms . ---------------------------------------------------------------------- Ran 5 tests in 30.217s ```
1 parent 9ca0d69 commit 930e640

File tree

9 files changed

+1613
-124
lines changed

9 files changed

+1613
-124
lines changed

docs/OperatorKernels.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,7 @@ Do not modify directly.*
566566
|MatMulIntegerToFloat|*in* A:**T1**<br> *in* B:**T2**<br> *in* a_scale:**T3**<br> *in* b_scale:**T3**<br> *in* a_zero_point:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T3**<br> *out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float)|
567567
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T3**<br> *in* g_idx:**T4**<br> *in* bias:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(float), tensor(float16), tensor(uint8)<br/> **T4** = tensor(int32)|
568568
|MaxpoolWithMask|*in* X:**T**<br> *in* M:**tensor(int32)**<br> *out* Y:**T**|1+|**T** = tensor(float)|
569+
|MoE|*in* input:**T**<br> *in* router_probs:**T**<br> *in* fc1_experts_weights:**T**<br> *in* fc1_experts_bias:**T**<br> *in* fc2_experts_weights:**T**<br> *in* fc2_experts_bias:**T**<br> *in* fc3_experts_weights:**T**<br> *in* fc3_experts_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(float)|
569570
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* attention_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* past_sequence_length:**M**<br> *in* cache_indirection:**M**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**<br> *out* qk:**QK**|1+|**T** = tensor(float)|
570571
|MurmurHash3|*in* X:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string), tensor(uint32), tensor(uint64)<br/> **T2** = tensor(int32), tensor(uint32)|
571572
|NGramRepeatBlock|*in* input_ids:**Tid**<br> *in* scores:**T**<br> *out* scores_out:**T**|1+|**T** = tensor(float)<br/> **Tid** = tensor(int64)|

onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
108108
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QGemm);
109109
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, QMoE);
110110
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QMoE);
111+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MoE);
111112
// ******** End: Quantization ******************* //
112113

113114
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
@@ -275,6 +276,7 @@ Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) {
275276
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QGemm)>,
276277
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, QMoE)>,
277278
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QMoE)>,
279+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MoE)>,
278280
};
279281

280282
for (auto& function_table_entry : function_table) {

onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#include "core/common/common.h"
77
#include "core/framework/tensor_shape.h"
88
#include "core/framework/op_kernel.h"
9-
#include "contrib_ops/cpu/moe/moe_helper.h"
9+
#include "moe_helper.h"
1010
#include <limits>
1111

1212
namespace onnxruntime {

onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc

Lines changed: 605 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "core/common/common.h"
7+
#include "core/framework/op_kernel.h"
8+
#include "contrib_ops/cpu/moe/moe_base_cpu.h"
9+
10+
namespace onnxruntime {
11+
namespace contrib {
12+
13+
template <typename T>
14+
class MoE final : public OpKernel, public MoEBaseCPU {
15+
public:
16+
explicit MoE(const OpKernelInfo& op_kernel_info);
17+
Status Compute(OpKernelContext* context) const override;
18+
19+
private:
20+
Status ComputeMoE(const OpKernelContext* context,
21+
const Tensor* input,
22+
const Tensor* router_probs,
23+
const Tensor* fc1_experts_weights,
24+
const Tensor* fc1_experts_bias,
25+
const Tensor* fc2_experts_weights,
26+
const Tensor* fc2_experts_bias,
27+
Tensor* output) const;
28+
29+
Status ProcessExpertBatch(const T* input_tokens,
30+
const int64_t* token_expert_ids,
31+
const float* token_weights,
32+
int64_t num_tokens,
33+
int64_t expert_id,
34+
const T* fc1_weights,
35+
const T* fc1_bias,
36+
const T* fc2_weights,
37+
const T* fc2_bias,
38+
T* output_buffer,
39+
int64_t hidden_size,
40+
int64_t inter_size,
41+
T* fc1_output_buffer,
42+
T* activation_output_buffer) const;
43+
44+
Status ComputeGEMM(const T* A, const T* B, T* C,
45+
int64_t M, int64_t K, int64_t N,
46+
bool transpose_B = false) const;
47+
48+
void ApplyActivationVectorized(T* data, int64_t size) const;
49+
void ApplySwiGLUVectorized(const T* input, T* output, int64_t size) const;
50+
};
51+
52+
} // namespace contrib
53+
} // namespace onnxruntime

onnxruntime/contrib_ops/cpu/moe/moe_utils.cc

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,18 @@ void ApplySwiGLUActivation(const float* input_data, float* output_data, int64_t
3737
gate_val = std::min(gate_val, clamp_limit);
3838
linear_val = std::clamp(linear_val, -clamp_limit, clamp_limit);
3939

40+
// Use numerically stable sigmoid computation (matches CUDA kernel behavior)
4041
float sigmoid_arg = activation_alpha * gate_val;
41-
float sigmoid_out = 1.0f / (1.0f + std::exp(-sigmoid_arg));
42-
float swish_out = gate_val * sigmoid_out;
42+
float sigmoid_out;
43+
if (sigmoid_arg > 0) {
44+
float exp_neg = std::exp(-sigmoid_arg);
45+
sigmoid_out = 1.0f / (1.0f + exp_neg);
46+
} else {
47+
float exp_pos = std::exp(sigmoid_arg);
48+
sigmoid_out = exp_pos / (1.0f + exp_pos);
49+
}
4350

51+
float swish_out = gate_val * sigmoid_out;
4452
output_data[i] = swish_out * (linear_val + activation_beta);
4553
}
4654
} else {

onnxruntime/test/contrib_ops/moe_test.cc

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1690,6 +1690,97 @@ TEST(MoETest, QMoETest_CPU_SwiGLU_Int8) {
16901690
#endif
16911691
}
16921692

1693+
// Test for CPU MoE implementation
1694+
static void RunMoECpuTest(const std::vector<float>& input, const std::vector<float>& router_probs,
1695+
const std::vector<float>& fc1_experts_weights, const std::vector<float>& fc2_experts_weights,
1696+
const std::vector<float>& fc3_experts_weights, const std::vector<float>& fc1_experts_bias,
1697+
const std::vector<float>& fc2_experts_bias, const std::vector<float>& output_data, int num_rows,
1698+
int num_experts, int hidden_size, int inter_size, std::string activation_type,
1699+
int normalize_routing_weights = 1, int top_k = 1) {
1700+
OpTester tester("MoE", 1, onnxruntime::kMSDomain);
1701+
tester.AddAttribute<int64_t>("k", static_cast<int64_t>(top_k));
1702+
tester.AddAttribute<std::string>("activation_type", activation_type);
1703+
tester.AddAttribute<int64_t>("normalize_routing_weights", static_cast<int64_t>(normalize_routing_weights));
1704+
1705+
bool is_swiglu = (activation_type == "swiglu");
1706+
1707+
if (is_swiglu) {
1708+
tester.AddAttribute<int64_t>("swiglu_fusion", static_cast<int64_t>(1));
1709+
tester.AddAttribute<float>("activation_beta", 1.0f);
1710+
}
1711+
1712+
std::vector<int64_t> input_dims = {num_rows, hidden_size};
1713+
std::vector<int64_t> router_probs_dims = {num_rows, num_experts};
1714+
1715+
int64_t fc1_output_size = is_swiglu ? (2 * inter_size) : inter_size;
1716+
1717+
std::vector<int64_t> fc1_experts_weights_dims = {num_experts, hidden_size, fc1_output_size};
1718+
std::vector<int64_t> fc2_experts_weights_dims = {num_experts, inter_size, hidden_size};
1719+
std::vector<int64_t> fc3_experts_weights_dims = fc1_experts_weights_dims;
1720+
std::vector<int64_t> fc1_experts_bias_dims = {num_experts, fc1_output_size};
1721+
std::vector<int64_t> fc2_experts_bias_dims = {num_experts, hidden_size};
1722+
std::vector<int64_t> output_dims = {num_rows, hidden_size};
1723+
1724+
tester.AddInput<float>("input", input_dims, input);
1725+
tester.AddInput<float>("router_probs", router_probs_dims, router_probs);
1726+
tester.AddInput<float>("fc1_experts_weights", fc1_experts_weights_dims, fc1_experts_weights);
1727+
if (!fc1_experts_bias.empty()) {
1728+
tester.AddInput<float>("fc1_experts_bias", fc1_experts_bias_dims, fc1_experts_bias);
1729+
} else {
1730+
tester.AddOptionalInputEdge<float>();
1731+
}
1732+
tester.AddInput<float>("fc2_experts_weights", fc2_experts_weights_dims, fc2_experts_weights);
1733+
if (!fc2_experts_bias.empty()) {
1734+
tester.AddInput<float>("fc2_experts_bias", fc2_experts_bias_dims, fc2_experts_bias);
1735+
} else {
1736+
tester.AddOptionalInputEdge<float>();
1737+
}
1738+
if (!fc3_experts_weights.empty()) {
1739+
tester.AddInput<float>("fc3_experts_weights", fc3_experts_weights_dims, fc3_experts_weights);
1740+
} else {
1741+
tester.AddOptionalInputEdge<float>();
1742+
}
1743+
tester.AddOptionalInputEdge<float>(); // fc3_experts_bias
1744+
1745+
tester.AddOutput<float>("output", output_dims, output_data);
1746+
tester.SetOutputTolerance(0.05f);
1747+
1748+
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
1749+
execution_providers.push_back(DefaultCpuExecutionProvider());
1750+
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
1751+
}
1752+
1753+
TEST(MoETest, MoECpuTest_BasicSwiGLU) {
1754+
int num_rows = 2;
1755+
int num_experts = 2;
1756+
int hidden_size = 4;
1757+
int inter_size = 8;
1758+
1759+
// Simple test data
1760+
const std::vector<float> input = {
1761+
1.0f, 2.0f, 3.0f, 4.0f,
1762+
5.0f, 6.0f, 7.0f, 8.0f};
1763+
1764+
const std::vector<float> router_probs = {
1765+
0.8f, 0.2f,
1766+
0.3f, 0.7f};
1767+
1768+
const std::vector<float> fc1_experts_weights(num_experts * hidden_size * (2 * inter_size), 0.1f);
1769+
1770+
const std::vector<float> fc2_experts_weights(num_experts * inter_size * hidden_size, 0.1f);
1771+
1772+
const std::vector<float> fc3_experts_weights = {}; // No FC3
1773+
const std::vector<float> fc1_experts_bias = {}; // No bias
1774+
const std::vector<float> fc2_experts_bias = {}; // No bias
1775+
1776+
const std::vector<float> output_data = {
1777+
1.169694f, 1.169694f, 1.169694f, 1.169694f,
1778+
6.970291f, 6.970291f, 6.970291f, 6.970291f};
1779+
1780+
RunMoECpuTest(input, router_probs, fc1_experts_weights, fc2_experts_weights,
1781+
fc3_experts_weights, fc1_experts_bias, fc2_experts_bias, output_data,
1782+
num_rows, num_experts, hidden_size, inter_size, "swiglu");
1783+
}
16931784
#endif
16941785

16951786
} // namespace test

0 commit comments

Comments
 (0)