Skip to content

Commit 5c2e1f8

Browse files
jinzhen-linmgoin
authored andcommitted
[Kernel] moe wna16 cuda kernel (vllm-project#13321)
Signed-off-by: Jinzhen Lin <[email protected]> Co-authored-by: mgoin <[email protected]> Signed-off-by: Louis Ulmer <[email protected]>
1 parent 402f07b commit 5c2e1f8

File tree

7 files changed

+698
-1
lines changed

7 files changed

+698
-1
lines changed

CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,13 +558,21 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
558558
set(VLLM_MOE_EXT_SRC
559559
"csrc/moe/torch_bindings.cpp"
560560
"csrc/moe/moe_align_sum_kernels.cu"
561+
"csrc/moe/moe_wna16.cu"
561562
"csrc/moe/topk_softmax_kernels.cu")
562563

563564
set_gencode_flags_for_srcs(
564565
SRCS "${VLLM_MOE_EXT_SRC}"
565566
CUDA_ARCHS "${CUDA_ARCHS}")
566567

567568
if(VLLM_GPU_LANG STREQUAL "CUDA")
569+
set(VLLM_MOE_WNA16_SRC
570+
"csrc/moe/moe_wna16.cu")
571+
572+
set_gencode_flags_for_srcs(
573+
SRCS "${VLLM_MOE_WNA16_SRC}"
574+
CUDA_ARCHS "${CUDA_ARCHS}")
575+
568576
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
569577
if (MARLIN_MOE_ARCHS)
570578
set(MARLIN_MOE_SRC

csrc/moe/moe_ops.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,13 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
1818
torch::Tensor sorted_token_ids,
1919
torch::Tensor experts_ids,
2020
torch::Tensor num_tokens_post_pad);
21+
22+
torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
23+
torch::Tensor b_qweight, torch::Tensor b_scales,
24+
std::optional<torch::Tensor> b_qzeros,
25+
std::optional<torch::Tensor> topk_weights,
26+
torch::Tensor sorted_token_ids,
27+
torch::Tensor expert_ids,
28+
torch::Tensor num_tokens_post_pad, int64_t top_k,
29+
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
30+
int64_t BLOCK_SIZE_K, int64_t bit);

csrc/moe/moe_wna16.cu

Lines changed: 346 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,346 @@
1+
2+
#include <torch/all.h>
3+
#include <c10/cuda/CUDAGuard.h>
4+
#include <ATen/cuda/CUDAContext.h>
5+
#include <cuda_runtime.h>
6+
7+
#include <cuda_fp16.h>
8+
#include <cuda_bf16.h>
9+
#include "moe_wna16_utils.h"
10+
11+
#define DIVIDE(x, size) (((x) + (size) - 1) / (size))
12+
13+
template <typename scalar_t, int bit, int GROUPS>
14+
__global__ void moe_wna16_gemm_kernel(
15+
const scalar_t* __restrict__ input, scalar_t* __restrict__ output,
16+
17+
const uint32_t* __restrict__ qweight, const scalar_t* __restrict__ scales,
18+
const uint32_t* __restrict__ qzeros,
19+
20+
const float* __restrict__ topk_weights,
21+
const int32_t* __restrict__ sorted_token_ids,
22+
const int32_t* __restrict__ expert_ids,
23+
const int32_t* __restrict__ num_tokens_post_pad,
24+
25+
uint16_t num_experts, uint16_t group_size, uint16_t top_k, uint32_t size_m,
26+
uint32_t size_n, uint32_t size_k, uint16_t BLOCK_SIZE_M,
27+
uint16_t BLOCK_SIZE_N, uint16_t BLOCK_SIZE_K, bool has_zp,
28+
bool mul_topk_weight) {
29+
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800
30+
if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
31+
return;
32+
} else {
33+
#endif
34+
35+
using Dtype = ScalarType<scalar_t>;
36+
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
37+
38+
if (blockIdx.x * BLOCK_SIZE_M >= num_tokens_post_pad[0]) return;
39+
40+
const int32_t offset_n = blockIdx.y * BLOCK_SIZE_N + threadIdx.x;
41+
const int32_t offset_k = blockIdx.z * BLOCK_SIZE_K;
42+
43+
const int32_t expert_id = expert_ids[blockIdx.x];
44+
45+
int32_t num_valid_tokens = 0;
46+
extern __shared__ uint16_t block_input_tmp[];
47+
scalar_t* block_input = reinterpret_cast<scalar_t*>(block_input_tmp);
48+
scalar_t2* block_input_half2 = reinterpret_cast<scalar_t2*>(block_input);
49+
50+
// load BLOCK_SIZE_M * BLOCK_SIZE_K into shared memory
51+
for (int m = 0; m < BLOCK_SIZE_M; m++) {
52+
const int32_t offset_m = blockIdx.x * BLOCK_SIZE_M + m;
53+
const int32_t token_index = sorted_token_ids[offset_m];
54+
if (token_index / top_k >= size_m) break;
55+
56+
num_valid_tokens = m + 1;
57+
if (blockIdx.z == 0 && offset_n < size_n)
58+
output[token_index * size_n + offset_n] = Dtype::int2num(0);
59+
60+
if (expert_id != -1) {
61+
int k_per_thread = DIVIDE(BLOCK_SIZE_K, BLOCK_SIZE_N);
62+
for (int i = 0; i < k_per_thread; i++) {
63+
int k = BLOCK_SIZE_N * i + threadIdx.x;
64+
if (k >= BLOCK_SIZE_K) break;
65+
if (offset_k + k >= size_k) break;
66+
67+
// load input to shared memory
68+
// use a special layout to fit the layout of dequanted-weight
69+
int origin_k;
70+
if constexpr (bit == 4) {
71+
// [0, 4, 1, 5, 2, 6, 3, 7]
72+
int8_t order = (threadIdx.x % 2) * 4 + ((threadIdx.x % 8) / 2);
73+
origin_k = BLOCK_SIZE_N * i + threadIdx.x / 8 * 8 + order;
74+
} else {
75+
// [0, 2, 1, 3]
76+
int8_t order = (threadIdx.x % 2) * 2 + ((threadIdx.x % 4) / 2);
77+
origin_k = BLOCK_SIZE_N * i + threadIdx.x / 4 * 4 + order;
78+
}
79+
80+
origin_k += token_index / top_k * size_k + blockIdx.z * BLOCK_SIZE_K;
81+
block_input[m * BLOCK_SIZE_K + k] = input[origin_k];
82+
}
83+
}
84+
}
85+
86+
if (expert_id == -1) return;
87+
__syncthreads();
88+
if (threadIdx.x >= BLOCK_SIZE_N || offset_n >= size_n) return;
89+
90+
float res[64]; // assume BLOCK_SIZE_M <= 64
91+
scalar_t2 res2;
92+
scalar_t2 scale_f2;
93+
scalar_t2 qzero_f2;
94+
95+
// note that (size_n * size_k * expert_id) may greater than 2 ** 31
96+
constexpr int8_t pack_factor = 32 / bit;
97+
const uint64_t expert_offset = ((uint64_t)size_n) * size_k * expert_id;
98+
const uint32_t* expert_qweight = qweight + expert_offset / pack_factor;
99+
const scalar_t* expert_scales = scales + expert_offset / group_size;
100+
const uint32_t* expert_qzeros =
101+
qzeros + expert_offset / group_size / pack_factor;
102+
103+
// load 4*int32 one time: 4 int32 = 128 bit = 1 float4
104+
// weight would be loaded in loop
105+
uint32_t expert_qweight_tmp[4];
106+
float4* expert_qweight_tmp_float4 =
107+
reinterpret_cast<float4*>(expert_qweight_tmp);
108+
109+
// load all required scales one time
110+
scalar_t expert_scales_groups[GROUPS];
111+
int scales_offset_tmp =
112+
(offset_n * size_k + offset_k) / group_size / GROUPS;
113+
if constexpr (GROUPS == 1) {
114+
*expert_scales_groups = expert_scales[scales_offset_tmp];
115+
} else if constexpr (GROUPS == 2) {
116+
float* expert_scales_groups_tmp =
117+
reinterpret_cast<float*>(expert_scales_groups);
118+
*expert_scales_groups_tmp =
119+
reinterpret_cast<const float*>(expert_scales)[scales_offset_tmp];
120+
} else if constexpr (GROUPS == 4) {
121+
float2* expert_scales_groups_tmp =
122+
reinterpret_cast<float2*>(expert_scales_groups);
123+
*expert_scales_groups_tmp =
124+
reinterpret_cast<const float2*>(expert_scales)[scales_offset_tmp];
125+
} else if constexpr (GROUPS == 8) {
126+
float4* expert_scales_groups_tmp =
127+
reinterpret_cast<float4*>(expert_scales_groups);
128+
*expert_scales_groups_tmp =
129+
reinterpret_cast<const float4*>(expert_scales)[scales_offset_tmp];
130+
}
131+
132+
// load all required qzeros one time
133+
uint8_t expert_qzeros_groups[GROUPS];
134+
if (!has_zp) {
135+
if constexpr (bit == 4) {
136+
qzero_f2 = Dtype::num2num2(Dtype::int2num(8));
137+
} else {
138+
qzero_f2 = Dtype::num2num2(Dtype::int2num(128));
139+
}
140+
} else {
141+
int qzeros_offset_tmp =
142+
(offset_n / (8 / bit)) * (size_k / group_size / GROUPS) +
143+
offset_k / group_size / GROUPS;
144+
if constexpr (GROUPS == 1) {
145+
uint8_t* expert_qzeros_groups_tmp =
146+
reinterpret_cast<uint8_t*>(expert_qzeros_groups);
147+
*expert_qzeros_groups_tmp =
148+
reinterpret_cast<const uint8_t*>(expert_qzeros)[qzeros_offset_tmp];
149+
} else if constexpr (GROUPS == 2) {
150+
uint16_t* expert_qzeros_groups_tmp =
151+
reinterpret_cast<uint16_t*>(expert_qzeros_groups);
152+
*expert_qzeros_groups_tmp =
153+
reinterpret_cast<const uint16_t*>(expert_qzeros)[qzeros_offset_tmp];
154+
} else if constexpr (GROUPS == 4) {
155+
uint32_t* expert_qzeros_groups_tmp =
156+
reinterpret_cast<uint32_t*>(expert_qzeros_groups);
157+
*expert_qzeros_groups_tmp =
158+
reinterpret_cast<const uint32_t*>(expert_qzeros)[qzeros_offset_tmp];
159+
} else if constexpr (GROUPS == 8) {
160+
uint64_t* expert_qzeros_groups_tmp =
161+
reinterpret_cast<uint64_t*>(expert_qzeros_groups);
162+
*expert_qzeros_groups_tmp =
163+
reinterpret_cast<const uint64_t*>(expert_qzeros)[qzeros_offset_tmp];
164+
}
165+
}
166+
167+
for (int tmp_k = 0; tmp_k < BLOCK_SIZE_K / pack_factor; tmp_k++) {
168+
int k = offset_k + tmp_k * pack_factor;
169+
if (k >= size_k) break;
170+
const int32_t weight_offset = offset_n * size_k + k;
171+
172+
if (tmp_k % 4 == 0) {
173+
*expert_qweight_tmp_float4 = reinterpret_cast<const float4*>(
174+
expert_qweight)[weight_offset / pack_factor / 4];
175+
}
176+
177+
if (tmp_k % (group_size / pack_factor) == 0) {
178+
scalar_t scale_f =
179+
expert_scales_groups[tmp_k / (group_size / pack_factor)];
180+
scale_f2 = Dtype::num2num2(scale_f);
181+
182+
if (has_zp) {
183+
uint8_t qzero =
184+
expert_qzeros_groups[tmp_k / (group_size / pack_factor)];
185+
if constexpr (bit == 4) {
186+
qzero = (qzero >> ((threadIdx.x % 2) * 4)) & 0xF;
187+
}
188+
qzero_f2 = Dtype::num2num2(Dtype::int2num(qzero));
189+
}
190+
}
191+
192+
scalar_t2 weight_half2[16 / bit];
193+
dequant<scalar_t2, bit>(expert_qweight_tmp[tmp_k % 4], weight_half2);
194+
195+
for (int m = 0; m < num_valid_tokens; m++) {
196+
res2 = {};
197+
198+
#pragma unroll
199+
for (int i = 0; i < 16 / bit; i++) {
200+
int32_t offset_input = m * BLOCK_SIZE_K / 2 + tmp_k * (16 / bit) + i;
201+
res2 = __hfma2(__hmul2(__hsub2(weight_half2[i], qzero_f2), scale_f2),
202+
block_input_half2[offset_input], res2);
203+
}
204+
205+
if (tmp_k == 0) {
206+
res[m] = Dtype::num2float(res2.x) + Dtype::num2float(res2.y);
207+
} else {
208+
res[m] += Dtype::num2float(res2.x) + Dtype::num2float(res2.y);
209+
}
210+
}
211+
}
212+
213+
for (int m = 0; m < num_valid_tokens; ++m) {
214+
const int32_t token_index =
215+
sorted_token_ids[blockIdx.x * BLOCK_SIZE_M + m];
216+
if (mul_topk_weight) {
217+
res[m] *= topk_weights[token_index];
218+
}
219+
atomicAdd(&output[token_index * size_n + offset_n],
220+
Dtype::float2num(res[m]));
221+
}
222+
223+
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800
224+
}
225+
#endif
226+
}
227+
228+
template <typename scalar_t>
229+
void run_moe_wna16_gemm(const scalar_t* input, scalar_t* output,
230+
const uint32_t* b_qweight, const scalar_t* b_scales,
231+
const uint32_t* b_qzeros, const float* topk_weights,
232+
const int32_t* sorted_token_ids,
233+
const int32_t* expert_ids,
234+
const int32_t* num_tokens_post_pad, int num_experts,
235+
int group_size, int num_token_blocks, int top_k,
236+
int size_m, int size_n, int size_k, int BLOCK_SIZE_M,
237+
int BLOCK_SIZE_N, int BLOCK_SIZE_K, int bit,
238+
bool has_zp, bool mul_topk_weight) {
239+
dim3 blockDim, gridDim;
240+
blockDim.x = BLOCK_SIZE_N;
241+
blockDim.y = 1;
242+
blockDim.z = 1;
243+
gridDim.x = num_token_blocks;
244+
gridDim.y = DIVIDE(size_n, BLOCK_SIZE_N);
245+
gridDim.z = DIVIDE(size_k, BLOCK_SIZE_K);
246+
247+
auto kernel = moe_wna16_gemm_kernel<scalar_t, 4, 1>;
248+
if (bit == 4) {
249+
if (BLOCK_SIZE_K / group_size == 2) {
250+
kernel = moe_wna16_gemm_kernel<scalar_t, 4, 2>;
251+
} else if (BLOCK_SIZE_K / group_size == 4) {
252+
kernel = moe_wna16_gemm_kernel<scalar_t, 4, 4>;
253+
} else if (BLOCK_SIZE_K / group_size == 8) {
254+
kernel = moe_wna16_gemm_kernel<scalar_t, 4, 8>;
255+
}
256+
} else {
257+
if (BLOCK_SIZE_K / group_size == 1) {
258+
kernel = moe_wna16_gemm_kernel<scalar_t, 8, 1>;
259+
} else if (BLOCK_SIZE_K / group_size == 2) {
260+
kernel = moe_wna16_gemm_kernel<scalar_t, 8, 2>;
261+
} else if (BLOCK_SIZE_K / group_size == 4) {
262+
kernel = moe_wna16_gemm_kernel<scalar_t, 8, 4>;
263+
} else if (BLOCK_SIZE_K / group_size == 8) {
264+
kernel = moe_wna16_gemm_kernel<scalar_t, 8, 8>;
265+
}
266+
}
267+
268+
const int shared_mem_size = BLOCK_SIZE_M * BLOCK_SIZE_K * 2;
269+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
270+
kernel<<<gridDim, blockDim, shared_mem_size, stream>>>(
271+
input, output, b_qweight, b_scales, b_qzeros, topk_weights,
272+
sorted_token_ids, expert_ids, num_tokens_post_pad, num_experts,
273+
group_size, top_k, size_m, size_n, size_k, BLOCK_SIZE_M, BLOCK_SIZE_N,
274+
BLOCK_SIZE_K, has_zp, mul_topk_weight);
275+
}
276+
277+
torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
278+
torch::Tensor b_qweight, torch::Tensor b_scales,
279+
std::optional<torch::Tensor> b_qzeros,
280+
std::optional<torch::Tensor> topk_weights,
281+
torch::Tensor sorted_token_ids,
282+
torch::Tensor expert_ids,
283+
torch::Tensor num_tokens_post_pad, int64_t top_k,
284+
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
285+
int64_t BLOCK_SIZE_K, int64_t bit) {
286+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
287+
auto options =
288+
torch::TensorOptions().dtype(input.dtype()).device(input.device());
289+
290+
const int num_experts = b_qweight.size(0);
291+
const int size_m = input.size(0);
292+
const int size_n = b_qweight.size(1);
293+
const int size_k = input.size(1);
294+
const int group_size = size_k / b_scales.size(2);
295+
296+
int64_t EM = sorted_token_ids.size(0);
297+
if (size_m <= BLOCK_SIZE_M) {
298+
EM = min(EM, size_m * BLOCK_SIZE_M * top_k);
299+
}
300+
const int num_token_blocks = (EM + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M;
301+
302+
const uint32_t* b_qzeros_ptr;
303+
if (b_qzeros.has_value())
304+
b_qzeros_ptr = (const uint32_t*)b_qzeros.value().data_ptr<uint8_t>();
305+
const float* topk_weights_ptr;
306+
if (topk_weights.has_value())
307+
topk_weights_ptr = (const float*)topk_weights.value().data_ptr();
308+
309+
int groups_per_block_row = BLOCK_SIZE_K / group_size;
310+
TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8");
311+
TORCH_CHECK(size_k % BLOCK_SIZE_K == 0,
312+
"size_k must divisible by BLOCK_SIZE_K");
313+
TORCH_CHECK(BLOCK_SIZE_K % group_size == 0,
314+
"BLOCK_SIZE_K must divisible by group_size");
315+
TORCH_CHECK(BLOCK_SIZE_M <= 64, "BLOCK_SIZE_M must less or equal to 64");
316+
TORCH_CHECK(groups_per_block_row == 1 || groups_per_block_row == 2 ||
317+
groups_per_block_row == 4 || groups_per_block_row == 8,
318+
"BLOCK_SIZE_K // group_size must be one of [1, 2, 4, 8]");
319+
320+
if (input.scalar_type() == at::ScalarType::Half) {
321+
run_moe_wna16_gemm<half>(
322+
(const half*)input.data_ptr<at::Half>(),
323+
(half*)output.data_ptr<at::Half>(),
324+
(const uint32_t*)b_qweight.data_ptr<uint8_t>(),
325+
(const half*)b_scales.data_ptr<at::Half>(), b_qzeros_ptr,
326+
topk_weights_ptr, sorted_token_ids.data_ptr<int32_t>(),
327+
expert_ids.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(),
328+
num_experts, group_size, num_token_blocks, top_k, size_m, size_n,
329+
size_k, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, bit,
330+
b_qzeros.has_value(), topk_weights.has_value());
331+
} else if (input.scalar_type() == at::ScalarType::BFloat16) {
332+
run_moe_wna16_gemm<nv_bfloat16>(
333+
(const nv_bfloat16*)input.data_ptr<at::BFloat16>(),
334+
(nv_bfloat16*)output.data_ptr<at::BFloat16>(),
335+
(const uint32_t*)b_qweight.data_ptr<uint8_t>(),
336+
(const nv_bfloat16*)b_scales.data_ptr<at::BFloat16>(), b_qzeros_ptr,
337+
topk_weights_ptr, sorted_token_ids.data_ptr<int32_t>(),
338+
expert_ids.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(),
339+
num_experts, group_size, num_token_blocks, top_k, size_m, size_n,
340+
size_k, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, bit,
341+
b_qzeros.has_value(), topk_weights.has_value());
342+
} else {
343+
TORCH_CHECK(false, "moe_wna16_gemm only supports bfloat16 and float16");
344+
}
345+
return output;
346+
}

0 commit comments

Comments
 (0)