From 6a6fb2b8d8b0cd9bdcb74677a942537309dcc6d9 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Sat, 15 Feb 2025 13:35:33 +0800 Subject: [PATCH 01/15] add moe wna16 cuda kernel Signed-off-by: Jinzhen Lin --- CMakeLists.txt | 1 + csrc/moe/moe_ops.h | 15 + csrc/moe/moe_wna16.cu | 359 ++++++++++++++++++ csrc/moe/moe_wna16_utils.h | 213 +++++++++++ csrc/moe/torch_bindings.cpp | 10 + vllm/_custom_ops.py | 15 + .../layers/fused_moe/fused_moe.py | 100 ++++- 7 files changed, 712 insertions(+), 1 deletion(-) create mode 100644 csrc/moe/moe_wna16.cu create mode 100644 csrc/moe/moe_wna16_utils.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 244ceb721c98..6d729847fd45 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -493,6 +493,7 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1) set(VLLM_MOE_EXT_SRC "csrc/moe/torch_bindings.cpp" "csrc/moe/moe_align_sum_kernels.cu" + "csrc/moe/moe_wna16.cu" "csrc/moe/topk_softmax_kernels.cu") set_gencode_flags_for_srcs( diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index 66bb5f41b7f7..9ffc5dbafcd4 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -18,3 +18,18 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad); + +torch::Tensor moe_wna16_gemm(torch::Tensor input, + torch::Tensor output, + torch::Tensor b_qweight, + torch::Tensor b_scales, + std::optional b_qzeros, + std::optional topk_weights, + torch::Tensor sorted_token_ids, + torch::Tensor expert_ids, + torch::Tensor num_tokens_post_pad, + int64_t top_k, + int64_t BLOCK_SIZE_M, + int64_t BLOCK_SIZE_N, + int64_t BLOCK_SIZE_K, + int64_t bit); diff --git a/csrc/moe/moe_wna16.cu b/csrc/moe/moe_wna16.cu new file mode 100644 index 000000000000..3c705adac695 --- /dev/null +++ b/csrc/moe/moe_wna16.cu @@ -0,0 +1,359 @@ + +#include +#include +#include +#include + +#include +#include +#include "moe_wna16_utils.h" + +#define DIVIDE(x, size) (((x) + (size) - 1) / (size)) + + +template +__global__ void moe_wna16_gemm_kernel( + const scalar_t* __restrict__ input, + scalar_t* __restrict__ output, + + const uint32_t* __restrict__ qweight, + const scalar_t* __restrict__ scales, + const uint32_t* __restrict__ qzeros, + + const float* __restrict__ topk_weights, + const int32_t* __restrict__ sorted_token_ids, + const int32_t* __restrict__ expert_ids, + const int32_t* __restrict__ num_tokens_post_pad, + + uint16_t num_experts, uint16_t group_size, uint16_t top_k, + uint32_t size_m, uint32_t size_n, uint32_t size_k, + uint16_t BLOCK_SIZE_M, uint16_t BLOCK_SIZE_N, uint16_t BLOCK_SIZE_K, + bool has_zp, bool mul_topk_weight) { + +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800 + if constexpr (std::is_same::value) { + return; + } else { +#endif + + using Dtype = ScalarType; + using scalar_t2 = typename ScalarType::scalar_t2; + + if (blockIdx.x * BLOCK_SIZE_M >= num_tokens_post_pad[0]) return; + + const int32_t offset_n = blockIdx.y * BLOCK_SIZE_N + threadIdx.x; + const int32_t offset_k = blockIdx.z * BLOCK_SIZE_K; + + const int32_t expert_id = expert_ids[blockIdx.x]; + const int8_t pack_factor = 32 / bit; + + // note that (size_n * size_k * expert_id) may greater than 2 ** 31 + const uint64_t expert_offset = ((uint64_t) size_n) * size_k * expert_id; + const uint32_t* expert_qweight = qweight + expert_offset / pack_factor; + const scalar_t* expert_scales = scales + expert_offset / group_size; + const uint32_t* expert_qzeros = + qzeros + expert_offset / group_size / pack_factor; + + int32_t num_valid_tokens = 0; + extern __shared__ uint16_t block_input_tmp[]; + scalar_t *block_input = reinterpret_cast(&block_input_tmp); + scalar_t2 *block_input_half2 = reinterpret_cast(&block_input); + + // load BLOCK_SIZE_M * BLOCK_SIZE_K into shared memory + for (int m = 0; m < BLOCK_SIZE_M; m++) { + const int32_t offset_m = blockIdx.x * BLOCK_SIZE_M + m; + const int32_t token_index = sorted_token_ids[offset_m]; + if (token_index / top_k >= size_m) break; + + num_valid_tokens = m + 1; + if (blockIdx.z == 0 && offset_n < size_n) + output[token_index * size_n + offset_n] = Dtype::int2num(0); + + int k_per_thread = DIVIDE(BLOCK_SIZE_K, BLOCK_SIZE_N); + for (int i = 0; i < k_per_thread; i++) { + int k = BLOCK_SIZE_N * i + threadIdx.x; + if (k >= BLOCK_SIZE_K) break; + if (offset_k + k >= size_k) break; + + // load input to shared memory + // use a special layout to fit the layout of dequanted-weight + int origin_k; + if constexpr (bit == 4) { + // [0, 4, 1, 5, 2, 6, 3, 7] + int8_t order = (threadIdx.x % 2) * 4 + ((threadIdx.x % 8) / 2); + origin_k = BLOCK_SIZE_N * i + threadIdx.x / 8 * 8 + order; + } else { + // [0, 2, 1, 3] + int8_t order = (threadIdx.x % 2) * 4 + ((threadIdx.x % 8) / 2); + origin_k = BLOCK_SIZE_N * i + threadIdx.x / 4 * 4 + order; + } + + origin_k += token_index / top_k * size_k + blockIdx.z * BLOCK_SIZE_K; + block_input[m * BLOCK_SIZE_K + k] = input[origin_k]; + } + } + + __syncthreads(); + if (threadIdx.x >= BLOCK_SIZE_N || offset_n >= size_n) return; + + float res[64]; // assume BLOCK_SIZE_M <= 64 + scalar_t2 res2; + scalar_t2 scale_f2; + scalar_t2 qzero_f2; + + // load 4*int32 one time: 4 int32 = 128 bit = 1 float4 + // weight would be loaded in loop + uint32_t expert_qweight_tmp[4]; + float4 *expert_qweight_tmp_float4 = + reinterpret_cast(&expert_qweight_tmp); + + // load all required scales one time + scalar_t expert_scales_groups[GROUPS]; + int scales_offset_tmp = (offset_n * size_k + offset_k) / group_size / GROUPS; + if constexpr (GROUPS == 1) { + *expert_scales_groups = expert_scales[scales_offset_tmp]; + } else if constexpr (GROUPS == 2) { + float *expert_scales_groups_tmp = + reinterpret_cast(&expert_scales_groups); + *expert_scales_groups_tmp = + reinterpret_cast(&expert_scales)[scales_offset_tmp]; + } else if constexpr (GROUPS == 4) { + float2 *expert_scales_groups_tmp = + reinterpret_cast(&expert_scales_groups); + *expert_scales_groups_tmp = + reinterpret_cast(&expert_scales)[scales_offset_tmp]; + } else if constexpr (GROUPS == 8) { + float4 *expert_scales_groups_tmp = + reinterpret_cast(&expert_scales_groups); + *expert_scales_groups_tmp = + reinterpret_cast(&expert_scales)[scales_offset_tmp]; + } + + // load all required qzeros one time + uint8_t expert_qzeros_groups[GROUPS]; + if (!has_zp) { + qzero_f2 = Dtype::num2num2(Dtype::int2num(8)); + } else { + int qzeros_offset_tmp = (offset_n / 2) * (size_k / group_size / GROUPS) + + offset_k / group_size / GROUPS; + if constexpr (GROUPS == 1) { + uint8_t *expert_qzeros_groups_tmp = + reinterpret_cast(&expert_qzeros_groups); + *expert_qzeros_groups_tmp = + reinterpret_cast(&expert_qzeros)[qzeros_offset_tmp]; + } else if constexpr (GROUPS == 2) { + uint16_t *expert_qzeros_groups_tmp = + reinterpret_cast(&expert_qzeros_groups); + *expert_qzeros_groups_tmp = + reinterpret_cast(&expert_qzeros)[qzeros_offset_tmp]; + } else if constexpr (GROUPS == 4) { + uint32_t *expert_qzeros_groups_tmp = + reinterpret_cast(&expert_qzeros_groups); + *expert_qzeros_groups_tmp = + reinterpret_cast(&expert_qzeros)[qzeros_offset_tmp]; + } else if constexpr (GROUPS == 8) { + uint64_t *expert_qzeros_groups_tmp = + reinterpret_cast(&expert_qzeros_groups); + *expert_qzeros_groups_tmp = + reinterpret_cast(&expert_qzeros)[qzeros_offset_tmp]; + } + } + + for (int tmp_k = 0; tmp_k < BLOCK_SIZE_K / pack_factor; tmp_k++) { + int k = offset_k + tmp_k * pack_factor; + if (k >= size_k) break; + const int32_t weight_offset = offset_n * size_k + k; + + if (tmp_k % 4 == 0) { + *expert_qweight_tmp_float4 = reinterpret_cast(&expert_qweight)[ + weight_offset / pack_factor / 4]; + } + + if (tmp_k % (group_size / pack_factor) == 0) { + scalar_t scale_f = + expert_scales_groups[tmp_k / (group_size / pack_factor)]; + scale_f2 = Dtype::num2num2(scale_f); + + if (has_zp) { + uint8_t qzero = + expert_qzeros_groups[tmp_k / (group_size / pack_factor)]; + qzero = (qzero >> ((threadIdx.x % 2) * 4)) & 0xF; + qzero_f2 = Dtype::num2num2(Dtype::int2num(qzero)); + } + } + + scalar_t2 weight_half2[16 / bit]; + dequant(expert_qweight_tmp[tmp_k % 4], weight_half2); + + for (int m = 0; m < num_valid_tokens; m++) { + res2 = {}; + + #pragma unroll + for (int i = 0; i < 16 / bit; i++) { + int32_t offset_input = m * BLOCK_SIZE_K / 2 + tmp_k * (16 / bit) + i; + res2 = __hfma2( + __hmul2(__hsub2(weight_half2[i], qzero_f2), scale_f2), + block_input_half2[offset_input], res2); + } + + if (tmp_k == 0) { + res[m] = Dtype::num2float(res2.x) + Dtype::num2float(res2.y); + } else { + res[m] += Dtype::num2float(res2.x) + Dtype::num2float(res2.y); + } + } + } + + for (int m = 0; m < num_valid_tokens; ++m) { + const int32_t token_index = sorted_token_ids[blockIdx.x * BLOCK_SIZE_M + m]; + if (mul_topk_weight) { + res[m] *= topk_weights[token_index]; + } + atomicAdd(&output[token_index * size_n + offset_n], Dtype::float2num(res[m])); + } + +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800 + } +#endif +} + + +template +void run_moe_wna16_gemm(const scalar_t* input, scalar_t* output, + const uint32_t* b_qweight, + const scalar_t* b_scales, + const uint32_t* b_qzeros, + const float* topk_weights, + const int32_t* sorted_token_ids, + const int32_t* expert_ids, + const int32_t* num_tokens_post_pad, + int num_experts, int group_size, + int num_token_blocks, int top_k, + int size_m, int size_n, int size_k, + int BLOCK_SIZE_M, int BLOCK_SIZE_N, int BLOCK_SIZE_K, + int bit, bool has_zp, bool mul_topk_weight) { + + dim3 blockDim, gridDim; + blockDim.x = BLOCK_SIZE_N; + blockDim.y = 1; + blockDim.z = 1; + gridDim.x = num_token_blocks; + gridDim.y = DIVIDE(size_n, BLOCK_SIZE_N); + gridDim.z = DIVIDE(size_k, BLOCK_SIZE_K); + + auto kernel = moe_wna16_gemm_kernel; + if (bit == 4) { + if (BLOCK_SIZE_K / group_size == 2) { + kernel = moe_wna16_gemm_kernel; + } else if (BLOCK_SIZE_K / group_size == 4) { + kernel = moe_wna16_gemm_kernel; + } else if (BLOCK_SIZE_K / group_size == 8) { + kernel = moe_wna16_gemm_kernel; + } + } else { + if (BLOCK_SIZE_K / group_size == 1) { + kernel = moe_wna16_gemm_kernel; + } else if (BLOCK_SIZE_K / group_size == 2) { + kernel = moe_wna16_gemm_kernel; + } else if (BLOCK_SIZE_K / group_size == 4) { + kernel = moe_wna16_gemm_kernel; + } else if (BLOCK_SIZE_K / group_size == 8) { + kernel = moe_wna16_gemm_kernel; + } + } + + const int shared_mem_size = BLOCK_SIZE_M * BLOCK_SIZE_K * 2; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + kernel<<>>( + input, output, b_qweight, b_scales, b_qzeros, topk_weights, + sorted_token_ids, expert_ids, num_tokens_post_pad, + num_experts, group_size, top_k, size_m, size_n, size_k, + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, + has_zp, mul_topk_weight); +} + + +torch::Tensor moe_wna16_gemm(torch::Tensor input, + torch::Tensor output, + torch::Tensor b_qweight, + torch::Tensor b_scales, + std::optional b_qzeros, + std::optional topk_weights, + torch::Tensor sorted_token_ids, + torch::Tensor expert_ids, + torch::Tensor num_tokens_post_pad, + int64_t top_k, + int64_t BLOCK_SIZE_M, + int64_t BLOCK_SIZE_N, + int64_t BLOCK_SIZE_K, + int64_t bit) { + + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + auto options = + torch::TensorOptions().dtype(input.dtype()).device(input.device()); + + const int num_experts = b_qweight.size(0); + const int size_m = input.size(0); + const int size_n = b_qweight.size(1); + const int size_k = input.size(1); + const int group_size = size_k / b_scales.size(2); + + int64_t EM = sorted_token_ids.size(0); + if (size_m <= BLOCK_SIZE_M) { + EM = min(EM, size_m * BLOCK_SIZE_M * top_k); + } + const int num_token_blocks = (EM + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M; + + const uint32_t* b_qzeros_ptr; + if (b_qzeros.has_value()) + b_qzeros_ptr = (const uint32_t*) b_qzeros.value().data_ptr(); + const float* topk_weights_ptr; + if (topk_weights.has_value()) + topk_weights_ptr = (const float*) topk_weights.value().data_ptr(); + + int groups_per_block_row = BLOCK_SIZE_K / group_size; + TORCH_CHECK(bit == 4 || bit == 8, + "bit must be 4 or 8"); + TORCH_CHECK(size_k % BLOCK_SIZE_K == 0, + "size_k must divisable by BLOCK_SIZE_K"); + TORCH_CHECK(BLOCK_SIZE_K % group_size == 0, + "BLOCK_SIZE_K must divisable by group_size"); + TORCH_CHECK(BLOCK_SIZE_M <= 64, + "BLOCK_SIZE_M must less or equal to 64"); + TORCH_CHECK(groups_per_block_row == 1 || groups_per_block_row == 2 || + groups_per_block_row == 4 || groups_per_block_row == 8, + "BLOCK_SIZE_K // group_size must be one of [1, 2, 4, 8]"); + + if (input.scalar_type() == at::ScalarType::Half) { + run_moe_wna16_gemm( + (const half*)input.data_ptr(), + (half*)output.data_ptr(), + (const uint32_t*)b_qweight.data_ptr(), + (const half*)b_scales.data_ptr(), + b_qzeros_ptr, topk_weights_ptr, + sorted_token_ids.data_ptr(), + expert_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), + num_experts, group_size, num_token_blocks, + top_k, size_m, size_n, size_k, + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, + bit, b_qzeros.has_value(), topk_weights.has_value()); + } else if (input.scalar_type() == at::ScalarType::BFloat16) { + run_moe_wna16_gemm( + (const nv_bfloat16*)input.data_ptr(), + (nv_bfloat16*)output.data_ptr(), + (const uint32_t*)b_qweight.data_ptr(), + (const nv_bfloat16*)b_scales.data_ptr(), + b_qzeros_ptr, topk_weights_ptr, + sorted_token_ids.data_ptr(), + expert_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), + num_experts, group_size, num_token_blocks, + top_k, size_m, size_n, size_k, + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, + bit, b_qzeros.has_value(), topk_weights.has_value()); + } else { + TORCH_CHECK(false, "moe_wna16_gemm only supports bfloat16 and float16"); + } + return output; +} diff --git a/csrc/moe/moe_wna16_utils.h b/csrc/moe/moe_wna16_utils.h new file mode 100644 index 000000000000..5f6c81bc12eb --- /dev/null +++ b/csrc/moe/moe_wna16_utils.h @@ -0,0 +1,213 @@ + +#include +#include + +template +class ScalarType {}; + +template <> +class ScalarType { + public: + using scalar_t = half; + using scalar_t2 = half2; + + static __device__ float inline num2float(const half x) { + return __half2float(x); + } + + static __device__ half2 inline num2num2(const half x) { + return __half2half2(x); + } + + static __device__ half2 inline nums2num2(const half x1, const half x2) { + return __halves2half2(x1, x2); + } + + static __host__ __device__ half inline float2num(const float x) { + return __float2half(x); + } + + static __host__ __device__ half inline int2num(const float x) { + return __int2half_rn(x); + } + + static __host__ __device__ float2 inline num22float2(const half2 x) { + return __half22float2(x); + } + + static __host__ __device__ half2 inline float22num2(const float2 x) { + return __float22half2_rn(x); + } +}; + +template <> +class ScalarType { + public: + using scalar_t = nv_bfloat16; + using scalar_t2 = nv_bfloat162; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + static __device__ float inline num2float(const nv_bfloat16 x) { + return __bfloat162float(x); + } + + static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { + return __bfloat162bfloat162(x); + } + + static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, + const nv_bfloat16 x2) { + return __halves2bfloat162(x1, x2); + } + + static __host__ __device__ nv_bfloat16 inline float2num(const float x) { + return __float2bfloat16(x); + } + + static __host__ __device__ nv_bfloat16 inline int2num(const float x) { + return __int2bfloat16_rn(x); + } + + static __host__ __device__ float2 inline num22float2(const nv_bfloat162 x) { + return __bfloat1622float2(x); + } + + static __host__ __device__ nv_bfloat162 inline float22num2(const float2 x) { + return __float22bfloat162_rn(x); + } +#endif +}; + + +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + + +template +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" + : "=r"(res) + : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + + +template +__device__ inline void dequant(int q, scalar_t2* res) { + +} + + + +template <> +__device__ inline void dequant(int q, half2* res) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + const int SUB = 0x64006400; + const int MUL = 0x2c002c00; + const int ADD = 0xd400d400; + + int lo0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX); + int hi0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX); + q >>= 8; + int lo1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX); + int hi1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX); + + res[0] = __hsub2(*reinterpret_cast(&lo0), + *reinterpret_cast(&SUB)); + res[1] = __hfma2(*reinterpret_cast(&hi0), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + res[2] = __hsub2(*reinterpret_cast(&lo1), + *reinterpret_cast(&SUB)); + res[3] = __hfma2(*reinterpret_cast(&hi1), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); +} + + +template <> +__device__ inline void dequant(int q, half2* res) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; + + res[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + res[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); +} + + + + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template <> +__device__ inline void dequant(int q, nv_bfloat162* res) { + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t EX = 0x43004300; + + int lo0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); + q >>= 4; + int hi0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); + q >>= 4; + int lo1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); + q >>= 4; + int hi1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); + + static constexpr uint32_t MUL = 0x3F803F80; + static constexpr uint32_t ADD = 0xC300C300; + + res[0] = __hfma2(*reinterpret_cast(&lo0), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + res[1] = __hfma2(*reinterpret_cast(&hi0), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + res[2] = __hfma2(*reinterpret_cast(&lo1), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + res[3] = __hfma2(*reinterpret_cast(&hi1), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); +} + + +template <> +__device__ inline void dequant(int q, nv_bfloat162* res) { + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); + + static constexpr uint32_t fp32_base = 0x4B000000; + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388608.f; + fp32_intermediates[1] -= 8388608.f; + fp32_intermediates[2] -= 8388608.f; + fp32_intermediates[3] -= 8388608.f; + + uint32_t* bf16_result_ptr = reinterpret_cast(&res); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], + fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], + fp32_intermediates_casted[3], 0x7632); + +} +#endif diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 8540633dcc8b..d2c03c4d4bef 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -31,6 +31,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { " Tensor! num_tokens_post_pad) -> ()"); m.impl("sgl_moe_align_block_size", torch::kCUDA, &sgl_moe_align_block_size); + m.def( + "moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, " + "Tensor b_scales, Tensor? b_qzeros, " + "Tensor? topk_weights, Tensor sorted_token_ids, " + "Tensor expert_ids, Tensor num_tokens_post_pad, " + "int top_k, int BLOCK_SIZE_M, int BLOCK_SIZE_N, int BLOCK_SIZE_K, " + "int bit) -> Tensor"); + + m.impl("moe_wna16_gemm", torch::kCUDA, &moe_wna16_gemm); + #ifndef USE_ROCM m.def( "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 67843c177403..bfe83aa59a48 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1018,6 +1018,21 @@ def sgl_moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, experts_ids, num_tokens_post_pad) +def moe_wna16_gemm(input: torch.Tensor, output: torch.Tensor, + b_qweight: torch.Tensor, b_scales: torch.Tensor, + b_qzeros: Optional[torch.Tensor], + topk_weights: Optional[torch.Tensor], + sorted_token_ids: torch.Tensor, experts_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor, top_k: int, + BLOCK_SIZE_M: int, BLOCK_SIZE_N: int, BLOCK_SIZE_K: int, + bit: int) -> torch.Tensor: + torch.ops._moe_C.moe_wna16_gemm(input, output, b_qweight, b_scales, + b_qzeros, topk_weights, sorted_token_ids, + experts_ids, num_tokens_post_pad, top_k, + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, + bit) + + def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, token_expert_indicies: torch.Tensor, gating_output: float) -> None: diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index f14200e0288e..199a772c2cf0 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -676,6 +676,32 @@ def invoke_fused_moe_kernel(A: torch.Tensor, assert B_scale is not None and B_scale.ndim == 3 assert B_zp is None or B_zp.ndim == 3 + use_moe_wna16_cuda = should_moe_wna16_use_cuda( + num_valid_tokens=topk_ids.numel(), + group_size=block_shape[1], + num_experts=B.shape[0], + bit=4 if use_int4_w4a16 else 8) + config = config.copy() + config.update( + get_moe_wna16_block_config(config=config, + use_moe_wna16_cuda=use_moe_wna16_cuda, + num_valid_tokens=topk_ids.numel(), + size_k=A.shape[1], + size_n=B.shape[1], + num_experts=B.shape[1], + group_size=block_shape[1], + real_top_k=topk_ids.shape[1], + block_size_m=config["BLOCK_SIZE_M"])) + + if use_moe_wna16_cuda: + ops.moe_wna16_gemm( + A, C, B, B_scale, B_zp, + topk_weights if mul_routed_weight else None, sorted_token_ids, + expert_ids, num_tokens_post_padded, top_k, + config["BLOCK_SIZE_M"], config["BLOCK_SIZE_N"], + config["BLOCK_SIZE_K"], 4) + return + fused_moe_kernel_gptq_awq[grid]( A, B, @@ -809,6 +835,60 @@ def get_moe_configs( return None +def get_moe_wna16_block_config(config: Dict[str, int], + use_moe_wna16_cuda: bool, num_valid_tokens: int, + size_k: int, size_n: int, num_experts: int, + group_size: int, real_top_k: int, + block_size_m: int): + if "BLOCK_SIZE_N" in config and "BLOCK_SIZE_K" in config: + return {} + if not use_moe_wna16_cuda: + if num_valid_tokens // real_top_k == 1: + return {"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64} + else: + return {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32} + else: + block_size_n = 128 + block_size_k = 128 + if block_size_k <= group_size: + block_size_k = group_size + + num_n_blocks = size_k // block_size_k + num_k_blocks = size_n // block_size_k + num_m_blocks = (num_valid_tokens + block_size_m - 1) / block_size_m + \ + num_experts + if num_valid_tokens // real_top_k <= block_size_m: + num_m_blocks = min(num_m_blocks, num_valid_tokens) + num_blocks = num_m_blocks * num_n_blocks * num_k_blocks + + if size_k % 256 == 0 and num_blocks >= 256 and \ + block_size_k < 256: + block_size_k = 256 + num_blocks = num_blocks / (256 / block_size_k) + + if num_m_blocks <= 16 and size_k % (block_size_k * 2) == 0 and \ + size_k % (block_size_k * 2) == 0 and block_size_k <= 512 and \ + num_blocks >= 512: + block_size_k = block_size_k * 2 + num_blocks /= 2 + + if num_blocks > 1024: + block_size_n = 256 + num_n_blocks = num_n_blocks / 2 + num_blocks /= 2 + + if size_n <= 1024 and num_blocks >= 1024: + block_size_n = 1024 + + return {"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k} + + +def should_moe_wna16_use_cuda(num_valid_tokens: int, group_size: int, + num_experts: int, bit: int): + return bit == 4 and group_size in [32, 64, 128] and \ + num_valid_tokens / num_experts <= 8 + + def get_default_config( M: int, E: int, @@ -830,6 +910,21 @@ def get_default_config( "num_warps": 4, "num_stages": 3, } + elif dtype in ["int4_w8a16", "int8_w8a16"] and block_shape is not None: + # moe wna16 kernels + # only set BLOCK_SIZE_M + # BLOCK_SIZE_N and BLOCK_SIZE_K would be set later + bit = 4 if dtype == "int4_w8a16" else 8 + use_moe_wna16_cuda = should_moe_wna16_use_cuda(M * topk, + block_shape[1], E, bit) + if use_moe_wna16_cuda: + config = {"BLOCK_SIZE_M": min(16, M)} + elif M <= 20: + config = {"BLOCK_SIZE_M": 16, "GROUP_SIZE_M": 1} + elif M <= 40: + config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1} + else: + config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1} else: config = { "BLOCK_SIZE_M": 64, @@ -864,6 +959,8 @@ def try_get_optimal_moe_config( else: # First try to load optimal config from the file E, _, N = w2_shape + if dtype == "int4_w8a8": + N = N * 2 block_n = block_shape[0] if block_shape else 0 block_k = block_shape[1] if block_shape else 0 configs = get_moe_configs(E, N, dtype, block_n, block_k) @@ -1209,7 +1306,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, # so the cache size and config are already set correctly and # do not need to be adjusted. intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] - intermediate_cache2 = intermediate_cache2[:tokens_in_chunk] + intermediate_cache2 = intermediate_cache2[:tokens_in_chunk * + topk_ids.shape[1]] intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] config = get_config_func(tokens_in_chunk) From b242bae9ac12c6ef591e6f8818c0d516a743a2f9 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Sat, 15 Feb 2025 14:12:24 +0800 Subject: [PATCH 02/15] fix format error Signed-off-by: Jinzhen Lin --- csrc/moe/moe_ops.h | 15 +- csrc/moe/moe_wna16.cu | 425 +++++++++--------- csrc/moe/moe_wna16_utils.h | 45 +- .../layers/fused_moe/fused_moe.py | 30 +- 4 files changed, 237 insertions(+), 278 deletions(-) diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index 9ffc5dbafcd4..371edb6495b1 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -19,17 +19,12 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad); -torch::Tensor moe_wna16_gemm(torch::Tensor input, - torch::Tensor output, - torch::Tensor b_qweight, - torch::Tensor b_scales, +torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, + torch::Tensor b_qweight, torch::Tensor b_scales, std::optional b_qzeros, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, - torch::Tensor num_tokens_post_pad, - int64_t top_k, - int64_t BLOCK_SIZE_M, - int64_t BLOCK_SIZE_N, - int64_t BLOCK_SIZE_K, - int64_t bit); + torch::Tensor num_tokens_post_pad, int64_t top_k, + int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, + int64_t BLOCK_SIZE_K, int64_t bit); diff --git a/csrc/moe/moe_wna16.cu b/csrc/moe/moe_wna16.cu index 3c705adac695..02972e667570 100644 --- a/csrc/moe/moe_wna16.cu +++ b/csrc/moe/moe_wna16.cu @@ -8,16 +8,13 @@ #include #include "moe_wna16_utils.h" -#define DIVIDE(x, size) (((x) + (size) - 1) / (size)) - +#define DIVIDE(x, size) (((x) + (size)-1) / (size)) template __global__ void moe_wna16_gemm_kernel( - const scalar_t* __restrict__ input, - scalar_t* __restrict__ output, + const scalar_t* __restrict__ input, scalar_t* __restrict__ output, - const uint32_t* __restrict__ qweight, - const scalar_t* __restrict__ scales, + const uint32_t* __restrict__ qweight, const scalar_t* __restrict__ scales, const uint32_t* __restrict__ qzeros, const float* __restrict__ topk_weights, @@ -25,214 +22,210 @@ __global__ void moe_wna16_gemm_kernel( const int32_t* __restrict__ expert_ids, const int32_t* __restrict__ num_tokens_post_pad, - uint16_t num_experts, uint16_t group_size, uint16_t top_k, - uint32_t size_m, uint32_t size_n, uint32_t size_k, - uint16_t BLOCK_SIZE_M, uint16_t BLOCK_SIZE_N, uint16_t BLOCK_SIZE_K, - bool has_zp, bool mul_topk_weight) { - + uint16_t num_experts, uint16_t group_size, uint16_t top_k, uint32_t size_m, + uint32_t size_n, uint32_t size_k, uint16_t BLOCK_SIZE_M, + uint16_t BLOCK_SIZE_N, uint16_t BLOCK_SIZE_K, bool has_zp, + bool mul_topk_weight) { #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800 - if constexpr (std::is_same::value) { - return; - } else { + if constexpr (std::is_same::value) { + return; + } else { #endif - using Dtype = ScalarType; - using scalar_t2 = typename ScalarType::scalar_t2; - - if (blockIdx.x * BLOCK_SIZE_M >= num_tokens_post_pad[0]) return; - - const int32_t offset_n = blockIdx.y * BLOCK_SIZE_N + threadIdx.x; - const int32_t offset_k = blockIdx.z * BLOCK_SIZE_K; - - const int32_t expert_id = expert_ids[blockIdx.x]; - const int8_t pack_factor = 32 / bit; - - // note that (size_n * size_k * expert_id) may greater than 2 ** 31 - const uint64_t expert_offset = ((uint64_t) size_n) * size_k * expert_id; - const uint32_t* expert_qweight = qweight + expert_offset / pack_factor; - const scalar_t* expert_scales = scales + expert_offset / group_size; - const uint32_t* expert_qzeros = - qzeros + expert_offset / group_size / pack_factor; - - int32_t num_valid_tokens = 0; - extern __shared__ uint16_t block_input_tmp[]; - scalar_t *block_input = reinterpret_cast(&block_input_tmp); - scalar_t2 *block_input_half2 = reinterpret_cast(&block_input); - - // load BLOCK_SIZE_M * BLOCK_SIZE_K into shared memory - for (int m = 0; m < BLOCK_SIZE_M; m++) { - const int32_t offset_m = blockIdx.x * BLOCK_SIZE_M + m; - const int32_t token_index = sorted_token_ids[offset_m]; - if (token_index / top_k >= size_m) break; - - num_valid_tokens = m + 1; - if (blockIdx.z == 0 && offset_n < size_n) - output[token_index * size_n + offset_n] = Dtype::int2num(0); - - int k_per_thread = DIVIDE(BLOCK_SIZE_K, BLOCK_SIZE_N); - for (int i = 0; i < k_per_thread; i++) { - int k = BLOCK_SIZE_N * i + threadIdx.x; - if (k >= BLOCK_SIZE_K) break; - if (offset_k + k >= size_k) break; - - // load input to shared memory - // use a special layout to fit the layout of dequanted-weight - int origin_k; - if constexpr (bit == 4) { - // [0, 4, 1, 5, 2, 6, 3, 7] - int8_t order = (threadIdx.x % 2) * 4 + ((threadIdx.x % 8) / 2); - origin_k = BLOCK_SIZE_N * i + threadIdx.x / 8 * 8 + order; - } else { - // [0, 2, 1, 3] - int8_t order = (threadIdx.x % 2) * 4 + ((threadIdx.x % 8) / 2); - origin_k = BLOCK_SIZE_N * i + threadIdx.x / 4 * 4 + order; + using Dtype = ScalarType; + using scalar_t2 = typename ScalarType::scalar_t2; + + if (blockIdx.x * BLOCK_SIZE_M >= num_tokens_post_pad[0]) return; + + const int32_t offset_n = blockIdx.y * BLOCK_SIZE_N + threadIdx.x; + const int32_t offset_k = blockIdx.z * BLOCK_SIZE_K; + + const int32_t expert_id = expert_ids[blockIdx.x]; + const int8_t pack_factor = 32 / bit; + + // note that (size_n * size_k * expert_id) may greater than 2 ** 31 + const uint64_t expert_offset = ((uint64_t)size_n) * size_k * expert_id; + const uint32_t* expert_qweight = qweight + expert_offset / pack_factor; + const scalar_t* expert_scales = scales + expert_offset / group_size; + const uint32_t* expert_qzeros = + qzeros + expert_offset / group_size / pack_factor; + + int32_t num_valid_tokens = 0; + extern __shared__ uint16_t block_input_tmp[]; + scalar_t* block_input = reinterpret_cast(&block_input_tmp); + scalar_t2* block_input_half2 = reinterpret_cast(&block_input); + + // load BLOCK_SIZE_M * BLOCK_SIZE_K into shared memory + for (int m = 0; m < BLOCK_SIZE_M; m++) { + const int32_t offset_m = blockIdx.x * BLOCK_SIZE_M + m; + const int32_t token_index = sorted_token_ids[offset_m]; + if (token_index / top_k >= size_m) break; + + num_valid_tokens = m + 1; + if (blockIdx.z == 0 && offset_n < size_n) + output[token_index * size_n + offset_n] = Dtype::int2num(0); + + int k_per_thread = DIVIDE(BLOCK_SIZE_K, BLOCK_SIZE_N); + for (int i = 0; i < k_per_thread; i++) { + int k = BLOCK_SIZE_N * i + threadIdx.x; + if (k >= BLOCK_SIZE_K) break; + if (offset_k + k >= size_k) break; + + // load input to shared memory + // use a special layout to fit the layout of dequanted-weight + int origin_k; + if constexpr (bit == 4) { + // [0, 4, 1, 5, 2, 6, 3, 7] + int8_t order = (threadIdx.x % 2) * 4 + ((threadIdx.x % 8) / 2); + origin_k = BLOCK_SIZE_N * i + threadIdx.x / 8 * 8 + order; + } else { + // [0, 2, 1, 3] + int8_t order = (threadIdx.x % 2) * 4 + ((threadIdx.x % 8) / 2); + origin_k = BLOCK_SIZE_N * i + threadIdx.x / 4 * 4 + order; + } + + origin_k += token_index / top_k * size_k + blockIdx.z * BLOCK_SIZE_K; + block_input[m * BLOCK_SIZE_K + k] = input[origin_k]; } - - origin_k += token_index / top_k * size_k + blockIdx.z * BLOCK_SIZE_K; - block_input[m * BLOCK_SIZE_K + k] = input[origin_k]; } - } - __syncthreads(); - if (threadIdx.x >= BLOCK_SIZE_N || offset_n >= size_n) return; - - float res[64]; // assume BLOCK_SIZE_M <= 64 - scalar_t2 res2; - scalar_t2 scale_f2; - scalar_t2 qzero_f2; - - // load 4*int32 one time: 4 int32 = 128 bit = 1 float4 - // weight would be loaded in loop - uint32_t expert_qweight_tmp[4]; - float4 *expert_qweight_tmp_float4 = - reinterpret_cast(&expert_qweight_tmp); - - // load all required scales one time - scalar_t expert_scales_groups[GROUPS]; - int scales_offset_tmp = (offset_n * size_k + offset_k) / group_size / GROUPS; - if constexpr (GROUPS == 1) { - *expert_scales_groups = expert_scales[scales_offset_tmp]; - } else if constexpr (GROUPS == 2) { - float *expert_scales_groups_tmp = - reinterpret_cast(&expert_scales_groups); - *expert_scales_groups_tmp = - reinterpret_cast(&expert_scales)[scales_offset_tmp]; - } else if constexpr (GROUPS == 4) { - float2 *expert_scales_groups_tmp = - reinterpret_cast(&expert_scales_groups); - *expert_scales_groups_tmp = - reinterpret_cast(&expert_scales)[scales_offset_tmp]; - } else if constexpr (GROUPS == 8) { - float4 *expert_scales_groups_tmp = - reinterpret_cast(&expert_scales_groups); - *expert_scales_groups_tmp = - reinterpret_cast(&expert_scales)[scales_offset_tmp]; - } + __syncthreads(); + if (threadIdx.x >= BLOCK_SIZE_N || offset_n >= size_n) return; - // load all required qzeros one time - uint8_t expert_qzeros_groups[GROUPS]; - if (!has_zp) { - qzero_f2 = Dtype::num2num2(Dtype::int2num(8)); - } else { - int qzeros_offset_tmp = (offset_n / 2) * (size_k / group_size / GROUPS) + - offset_k / group_size / GROUPS; + float res[64]; // assume BLOCK_SIZE_M <= 64 + scalar_t2 res2; + scalar_t2 scale_f2; + scalar_t2 qzero_f2; + + // load 4*int32 one time: 4 int32 = 128 bit = 1 float4 + // weight would be loaded in loop + uint32_t expert_qweight_tmp[4]; + float4* expert_qweight_tmp_float4 = + reinterpret_cast(&expert_qweight_tmp); + + // load all required scales one time + scalar_t expert_scales_groups[GROUPS]; + int scales_offset_tmp = + (offset_n * size_k + offset_k) / group_size / GROUPS; if constexpr (GROUPS == 1) { - uint8_t *expert_qzeros_groups_tmp = - reinterpret_cast(&expert_qzeros_groups); - *expert_qzeros_groups_tmp = - reinterpret_cast(&expert_qzeros)[qzeros_offset_tmp]; + *expert_scales_groups = expert_scales[scales_offset_tmp]; } else if constexpr (GROUPS == 2) { - uint16_t *expert_qzeros_groups_tmp = - reinterpret_cast(&expert_qzeros_groups); - *expert_qzeros_groups_tmp = - reinterpret_cast(&expert_qzeros)[qzeros_offset_tmp]; + float* expert_scales_groups_tmp = + reinterpret_cast(&expert_scales_groups); + *expert_scales_groups_tmp = + reinterpret_cast(&expert_scales)[scales_offset_tmp]; } else if constexpr (GROUPS == 4) { - uint32_t *expert_qzeros_groups_tmp = - reinterpret_cast(&expert_qzeros_groups); - *expert_qzeros_groups_tmp = - reinterpret_cast(&expert_qzeros)[qzeros_offset_tmp]; + float2* expert_scales_groups_tmp = + reinterpret_cast(&expert_scales_groups); + *expert_scales_groups_tmp = + reinterpret_cast(&expert_scales)[scales_offset_tmp]; } else if constexpr (GROUPS == 8) { - uint64_t *expert_qzeros_groups_tmp = - reinterpret_cast(&expert_qzeros_groups); - *expert_qzeros_groups_tmp = - reinterpret_cast(&expert_qzeros)[qzeros_offset_tmp]; + float4* expert_scales_groups_tmp = + reinterpret_cast(&expert_scales_groups); + *expert_scales_groups_tmp = + reinterpret_cast(&expert_scales)[scales_offset_tmp]; } - } - for (int tmp_k = 0; tmp_k < BLOCK_SIZE_K / pack_factor; tmp_k++) { - int k = offset_k + tmp_k * pack_factor; - if (k >= size_k) break; - const int32_t weight_offset = offset_n * size_k + k; - - if (tmp_k % 4 == 0) { - *expert_qweight_tmp_float4 = reinterpret_cast(&expert_qweight)[ - weight_offset / pack_factor / 4]; + // load all required qzeros one time + uint8_t expert_qzeros_groups[GROUPS]; + if (!has_zp) { + qzero_f2 = Dtype::num2num2(Dtype::int2num(8)); + } else { + int qzeros_offset_tmp = (offset_n / 2) * (size_k / group_size / GROUPS) + + offset_k / group_size / GROUPS; + if constexpr (GROUPS == 1) { + uint8_t* expert_qzeros_groups_tmp = + reinterpret_cast(&expert_qzeros_groups); + *expert_qzeros_groups_tmp = + reinterpret_cast(&expert_qzeros)[qzeros_offset_tmp]; + } else if constexpr (GROUPS == 2) { + uint16_t* expert_qzeros_groups_tmp = + reinterpret_cast(&expert_qzeros_groups); + *expert_qzeros_groups_tmp = + reinterpret_cast(&expert_qzeros)[qzeros_offset_tmp]; + } else if constexpr (GROUPS == 4) { + uint32_t* expert_qzeros_groups_tmp = + reinterpret_cast(&expert_qzeros_groups); + *expert_qzeros_groups_tmp = + reinterpret_cast(&expert_qzeros)[qzeros_offset_tmp]; + } else if constexpr (GROUPS == 8) { + uint64_t* expert_qzeros_groups_tmp = + reinterpret_cast(&expert_qzeros_groups); + *expert_qzeros_groups_tmp = + reinterpret_cast(&expert_qzeros)[qzeros_offset_tmp]; + } } - if (tmp_k % (group_size / pack_factor) == 0) { - scalar_t scale_f = - expert_scales_groups[tmp_k / (group_size / pack_factor)]; - scale_f2 = Dtype::num2num2(scale_f); + for (int tmp_k = 0; tmp_k < BLOCK_SIZE_K / pack_factor; tmp_k++) { + int k = offset_k + tmp_k * pack_factor; + if (k >= size_k) break; + const int32_t weight_offset = offset_n * size_k + k; - if (has_zp) { - uint8_t qzero = - expert_qzeros_groups[tmp_k / (group_size / pack_factor)]; - qzero = (qzero >> ((threadIdx.x % 2) * 4)) & 0xF; - qzero_f2 = Dtype::num2num2(Dtype::int2num(qzero)); + if (tmp_k % 4 == 0) { + *expert_qweight_tmp_float4 = reinterpret_cast( + &expert_qweight)[weight_offset / pack_factor / 4]; } - } - scalar_t2 weight_half2[16 / bit]; - dequant(expert_qweight_tmp[tmp_k % 4], weight_half2); + if (tmp_k % (group_size / pack_factor) == 0) { + scalar_t scale_f = + expert_scales_groups[tmp_k / (group_size / pack_factor)]; + scale_f2 = Dtype::num2num2(scale_f); + + if (has_zp) { + uint8_t qzero = + expert_qzeros_groups[tmp_k / (group_size / pack_factor)]; + qzero = (qzero >> ((threadIdx.x % 2) * 4)) & 0xF; + qzero_f2 = Dtype::num2num2(Dtype::int2num(qzero)); + } + } - for (int m = 0; m < num_valid_tokens; m++) { - res2 = {}; + scalar_t2 weight_half2[16 / bit]; + dequant(expert_qweight_tmp[tmp_k % 4], weight_half2); - #pragma unroll - for (int i = 0; i < 16 / bit; i++) { - int32_t offset_input = m * BLOCK_SIZE_K / 2 + tmp_k * (16 / bit) + i; - res2 = __hfma2( - __hmul2(__hsub2(weight_half2[i], qzero_f2), scale_f2), - block_input_half2[offset_input], res2); - } + for (int m = 0; m < num_valid_tokens; m++) { + res2 = {}; - if (tmp_k == 0) { - res[m] = Dtype::num2float(res2.x) + Dtype::num2float(res2.y); - } else { - res[m] += Dtype::num2float(res2.x) + Dtype::num2float(res2.y); +#pragma unroll + for (int i = 0; i < 16 / bit; i++) { + int32_t offset_input = m * BLOCK_SIZE_K / 2 + tmp_k * (16 / bit) + i; + res2 = __hfma2(__hmul2(__hsub2(weight_half2[i], qzero_f2), scale_f2), + block_input_half2[offset_input], res2); + } + + if (tmp_k == 0) { + res[m] = Dtype::num2float(res2.x) + Dtype::num2float(res2.y); + } else { + res[m] += Dtype::num2float(res2.x) + Dtype::num2float(res2.y); + } } } - } - for (int m = 0; m < num_valid_tokens; ++m) { - const int32_t token_index = sorted_token_ids[blockIdx.x * BLOCK_SIZE_M + m]; - if (mul_topk_weight) { - res[m] *= topk_weights[token_index]; + for (int m = 0; m < num_valid_tokens; ++m) { + const int32_t token_index = + sorted_token_ids[blockIdx.x * BLOCK_SIZE_M + m]; + if (mul_topk_weight) { + res[m] *= topk_weights[token_index]; + } + atomicAdd(&output[token_index * size_n + offset_n], + Dtype::float2num(res[m])); } - atomicAdd(&output[token_index * size_n + offset_n], Dtype::float2num(res[m])); - } #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800 - } + } #endif } - template void run_moe_wna16_gemm(const scalar_t* input, scalar_t* output, - const uint32_t* b_qweight, - const scalar_t* b_scales, - const uint32_t* b_qzeros, - const float* topk_weights, + const uint32_t* b_qweight, const scalar_t* b_scales, + const uint32_t* b_qzeros, const float* topk_weights, const int32_t* sorted_token_ids, const int32_t* expert_ids, - const int32_t* num_tokens_post_pad, - int num_experts, int group_size, - int num_token_blocks, int top_k, - int size_m, int size_n, int size_k, - int BLOCK_SIZE_M, int BLOCK_SIZE_N, int BLOCK_SIZE_K, - int bit, bool has_zp, bool mul_topk_weight) { - + const int32_t* num_tokens_post_pad, int num_experts, + int group_size, int num_token_blocks, int top_k, + int size_m, int size_n, int size_k, int BLOCK_SIZE_M, + int BLOCK_SIZE_N, int BLOCK_SIZE_K, int bit, + bool has_zp, bool mul_topk_weight) { dim3 blockDim, gridDim; blockDim.x = BLOCK_SIZE_N; blockDim.y = 1; @@ -266,31 +259,23 @@ void run_moe_wna16_gemm(const scalar_t* input, scalar_t* output, const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); kernel<<>>( input, output, b_qweight, b_scales, b_qzeros, topk_weights, - sorted_token_ids, expert_ids, num_tokens_post_pad, - num_experts, group_size, top_k, size_m, size_n, size_k, - BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, - has_zp, mul_topk_weight); + sorted_token_ids, expert_ids, num_tokens_post_pad, num_experts, + group_size, top_k, size_m, size_n, size_k, BLOCK_SIZE_M, BLOCK_SIZE_N, + BLOCK_SIZE_K, has_zp, mul_topk_weight); } - -torch::Tensor moe_wna16_gemm(torch::Tensor input, - torch::Tensor output, - torch::Tensor b_qweight, - torch::Tensor b_scales, +torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, + torch::Tensor b_qweight, torch::Tensor b_scales, std::optional b_qzeros, std::optional topk_weights, torch::Tensor sorted_token_ids, torch::Tensor expert_ids, - torch::Tensor num_tokens_post_pad, - int64_t top_k, - int64_t BLOCK_SIZE_M, - int64_t BLOCK_SIZE_N, - int64_t BLOCK_SIZE_K, - int64_t bit) { - + torch::Tensor num_tokens_post_pad, int64_t top_k, + int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, + int64_t BLOCK_SIZE_K, int64_t bit) { const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); auto options = - torch::TensorOptions().dtype(input.dtype()).device(input.device()); + torch::TensorOptions().dtype(input.dtype()).device(input.device()); const int num_experts = b_qweight.size(0); const int size_m = input.size(0); @@ -306,22 +291,20 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, const uint32_t* b_qzeros_ptr; if (b_qzeros.has_value()) - b_qzeros_ptr = (const uint32_t*) b_qzeros.value().data_ptr(); + b_qzeros_ptr = (const uint32_t*)b_qzeros.value().data_ptr(); const float* topk_weights_ptr; if (topk_weights.has_value()) - topk_weights_ptr = (const float*) topk_weights.value().data_ptr(); + topk_weights_ptr = (const float*)topk_weights.value().data_ptr(); int groups_per_block_row = BLOCK_SIZE_K / group_size; - TORCH_CHECK(bit == 4 || bit == 8, - "bit must be 4 or 8"); + TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8"); TORCH_CHECK(size_k % BLOCK_SIZE_K == 0, - "size_k must divisable by BLOCK_SIZE_K"); + "size_k must divisible by BLOCK_SIZE_K"); TORCH_CHECK(BLOCK_SIZE_K % group_size == 0, - "BLOCK_SIZE_K must divisable by group_size"); - TORCH_CHECK(BLOCK_SIZE_M <= 64, - "BLOCK_SIZE_M must less or equal to 64"); + "BLOCK_SIZE_K must divisible by group_size"); + TORCH_CHECK(BLOCK_SIZE_M <= 64, "BLOCK_SIZE_M must less or equal to 64"); TORCH_CHECK(groups_per_block_row == 1 || groups_per_block_row == 2 || - groups_per_block_row == 4 || groups_per_block_row == 8, + groups_per_block_row == 4 || groups_per_block_row == 8, "BLOCK_SIZE_K // group_size must be one of [1, 2, 4, 8]"); if (input.scalar_type() == at::ScalarType::Half) { @@ -329,29 +312,23 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, (const half*)input.data_ptr(), (half*)output.data_ptr(), (const uint32_t*)b_qweight.data_ptr(), - (const half*)b_scales.data_ptr(), - b_qzeros_ptr, topk_weights_ptr, - sorted_token_ids.data_ptr(), - expert_ids.data_ptr(), - num_tokens_post_pad.data_ptr(), - num_experts, group_size, num_token_blocks, - top_k, size_m, size_n, size_k, - BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, - bit, b_qzeros.has_value(), topk_weights.has_value()); + (const half*)b_scales.data_ptr(), b_qzeros_ptr, + topk_weights_ptr, sorted_token_ids.data_ptr(), + expert_ids.data_ptr(), num_tokens_post_pad.data_ptr(), + num_experts, group_size, num_token_blocks, top_k, size_m, size_n, + size_k, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, bit, + b_qzeros.has_value(), topk_weights.has_value()); } else if (input.scalar_type() == at::ScalarType::BFloat16) { run_moe_wna16_gemm( (const nv_bfloat16*)input.data_ptr(), (nv_bfloat16*)output.data_ptr(), (const uint32_t*)b_qweight.data_ptr(), - (const nv_bfloat16*)b_scales.data_ptr(), - b_qzeros_ptr, topk_weights_ptr, - sorted_token_ids.data_ptr(), - expert_ids.data_ptr(), - num_tokens_post_pad.data_ptr(), - num_experts, group_size, num_token_blocks, - top_k, size_m, size_n, size_k, - BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, - bit, b_qzeros.has_value(), topk_weights.has_value()); + (const nv_bfloat16*)b_scales.data_ptr(), b_qzeros_ptr, + topk_weights_ptr, sorted_token_ids.data_ptr(), + expert_ids.data_ptr(), num_tokens_post_pad.data_ptr(), + num_experts, group_size, num_token_blocks, top_k, size_m, size_n, + size_k, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, bit, + b_qzeros.has_value(), topk_weights.has_value()); } else { TORCH_CHECK(false, "moe_wna16_gemm only supports bfloat16 and float16"); } diff --git a/csrc/moe/moe_wna16_utils.h b/csrc/moe/moe_wna16_utils.h index 5f6c81bc12eb..e969577a8aea 100644 --- a/csrc/moe/moe_wna16_utils.h +++ b/csrc/moe/moe_wna16_utils.h @@ -78,7 +78,6 @@ class ScalarType { #endif }; - template __device__ inline int lop3(int a, int b, int c) { int res; @@ -88,7 +87,6 @@ __device__ inline int lop3(int a, int b, int c) { return res; } - template __device__ inline uint32_t prmt(uint32_t a) { uint32_t res; @@ -98,16 +96,11 @@ __device__ inline uint32_t prmt(uint32_t a) { return res; } - template -__device__ inline void dequant(int q, scalar_t2* res) { - -} - - +__device__ inline void dequant(int q, scalar_t2* res) {} template <> -__device__ inline void dequant(int q, half2* res) { +__device__ inline void dequant(int q, half2* res) { const int LO = 0x000f000f; const int HI = 0x00f000f0; const int EX = 0x64006400; @@ -115,17 +108,17 @@ __device__ inline void dequant(int q, half2* res) { const int MUL = 0x2c002c00; const int ADD = 0xd400d400; - int lo0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX); - int hi0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX); + int lo0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); q >>= 8; - int lo1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX); - int hi1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX); + int lo1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); res[0] = __hsub2(*reinterpret_cast(&lo0), *reinterpret_cast(&SUB)); res[1] = __hfma2(*reinterpret_cast(&hi0), *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); + *reinterpret_cast(&ADD)); res[2] = __hsub2(*reinterpret_cast(&lo1), *reinterpret_cast(&SUB)); res[3] = __hfma2(*reinterpret_cast(&hi1), @@ -133,9 +126,8 @@ __device__ inline void dequant(int q, half2* res) { *reinterpret_cast(&ADD)); } - template <> -__device__ inline void dequant(int q, half2* res) { +__device__ inline void dequant(int q, half2* res) { static constexpr uint32_t mask_for_elt_01 = 0x5250; static constexpr uint32_t mask_for_elt_23 = 0x5351; static constexpr uint32_t start_byte_for_fp16 = 0x64646464; @@ -151,22 +143,19 @@ __device__ inline void dequant(int q, half2* res) { *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); } - - - #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 template <> -__device__ inline void dequant(int q, nv_bfloat162* res) { +__device__ inline void dequant(int q, nv_bfloat162* res) { static constexpr uint32_t MASK = 0x000f000f; static constexpr uint32_t EX = 0x43004300; - int lo0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); + int lo0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); q >>= 4; - int hi0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); + int hi0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); q >>= 4; - int lo1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); + int lo1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); q >>= 4; - int hi1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); + int hi1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); static constexpr uint32_t MUL = 0x3F803F80; static constexpr uint32_t ADD = 0xC300C300; @@ -178,16 +167,15 @@ __device__ inline void dequant(int q, nv_bfloat162* res) { *reinterpret_cast(&MUL), *reinterpret_cast(&ADD)); res[2] = __hfma2(*reinterpret_cast(&lo1), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); res[3] = __hfma2(*reinterpret_cast(&hi1), *reinterpret_cast(&MUL), *reinterpret_cast(&ADD)); } - template <> -__device__ inline void dequant(int q, nv_bfloat162* res) { +__device__ inline void dequant(int q, nv_bfloat162* res) { float fp32_intermediates[4]; uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); @@ -208,6 +196,5 @@ __device__ inline void dequant(int q, nv_bfloat162* res) { fp32_intermediates_casted[1], 0x7632); bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632); - } #endif diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 199a772c2cf0..e3b0ced1b7d5 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -694,12 +694,12 @@ def invoke_fused_moe_kernel(A: torch.Tensor, block_size_m=config["BLOCK_SIZE_M"])) if use_moe_wna16_cuda: - ops.moe_wna16_gemm( - A, C, B, B_scale, B_zp, - topk_weights if mul_routed_weight else None, sorted_token_ids, - expert_ids, num_tokens_post_padded, top_k, - config["BLOCK_SIZE_M"], config["BLOCK_SIZE_N"], - config["BLOCK_SIZE_K"], 4) + ops.moe_wna16_gemm(A, C, B, B_scale, B_zp, + topk_weights if mul_routed_weight else None, + sorted_token_ids, expert_ids, + num_tokens_post_padded, top_k, + config["BLOCK_SIZE_M"], config["BLOCK_SIZE_N"], + config["BLOCK_SIZE_K"], 4) return fused_moe_kernel_gptq_awq[grid]( @@ -835,11 +835,11 @@ def get_moe_configs( return None -def get_moe_wna16_block_config(config: Dict[str, int], - use_moe_wna16_cuda: bool, num_valid_tokens: int, - size_k: int, size_n: int, num_experts: int, - group_size: int, real_top_k: int, - block_size_m: int): +def get_moe_wna16_block_config(config: Dict[str, + int], use_moe_wna16_cuda: bool, + num_valid_tokens: int, size_k: int, size_n: int, + num_experts: int, group_size: int, + real_top_k: int, block_size_m: int): if "BLOCK_SIZE_N" in config and "BLOCK_SIZE_K" in config: return {} if not use_moe_wna16_cuda: @@ -864,18 +864,18 @@ def get_moe_wna16_block_config(config: Dict[str, int], if size_k % 256 == 0 and num_blocks >= 256 and \ block_size_k < 256: block_size_k = 256 - num_blocks = num_blocks / (256 / block_size_k) + num_blocks = num_blocks // (256 // block_size_k) if num_m_blocks <= 16 and size_k % (block_size_k * 2) == 0 and \ size_k % (block_size_k * 2) == 0 and block_size_k <= 512 and \ num_blocks >= 512: block_size_k = block_size_k * 2 - num_blocks /= 2 + num_blocks = num_blocks // 2 if num_blocks > 1024: block_size_n = 256 - num_n_blocks = num_n_blocks / 2 - num_blocks /= 2 + num_n_blocks = num_n_blocks // 2 + num_blocks = num_blocks // 2 if size_n <= 1024 and num_blocks >= 1024: block_size_n = 1024 From 355b09617fd0d1fc10f3b8af0fd238fdff16ccf2 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Sat, 15 Feb 2025 14:19:40 +0800 Subject: [PATCH 03/15] fix format error Signed-off-by: Jinzhen Lin --- csrc/moe/moe_wna16.cu | 2 +- csrc/moe/moe_wna16_utils.h | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/csrc/moe/moe_wna16.cu b/csrc/moe/moe_wna16.cu index 02972e667570..8b837d617309 100644 --- a/csrc/moe/moe_wna16.cu +++ b/csrc/moe/moe_wna16.cu @@ -8,7 +8,7 @@ #include #include "moe_wna16_utils.h" -#define DIVIDE(x, size) (((x) + (size)-1) / (size)) +#define DIVIDE(x, size) (((x) + (size) - 1) / (size)) template __global__ void moe_wna16_gemm_kernel( diff --git a/csrc/moe/moe_wna16_utils.h b/csrc/moe/moe_wna16_utils.h index e969577a8aea..e7226bd71ada 100644 --- a/csrc/moe/moe_wna16_utils.h +++ b/csrc/moe/moe_wna16_utils.h @@ -108,11 +108,11 @@ __device__ inline void dequant(int q, half2* res) { const int MUL = 0x2c002c00; const int ADD = 0xd400d400; - int lo0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); - int hi0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + int lo0 = lop3 < (0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi0 = lop3 < (0xf0 & 0xcc) | 0xaa>(q, HI, EX); q >>= 8; - int lo1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); - int hi1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + int lo1 = lop3 < (0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi1 = lop3 < (0xf0 & 0xcc) | 0xaa>(q, HI, EX); res[0] = __hsub2(*reinterpret_cast(&lo0), *reinterpret_cast(&SUB)); @@ -149,13 +149,13 @@ __device__ inline void dequant(int q, nv_bfloat162* res) { static constexpr uint32_t MASK = 0x000f000f; static constexpr uint32_t EX = 0x43004300; - int lo0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + int lo0 = lop3 < (0xf0 & 0xcc) | 0xaa>(q, MASK, EX); q >>= 4; - int hi0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + int hi0 = lop3 < (0xf0 & 0xcc) | 0xaa>(q, MASK, EX); q >>= 4; - int lo1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + int lo1 = lop3 < (0xf0 & 0xcc) | 0xaa>(q, MASK, EX); q >>= 4; - int hi1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + int hi1 = lop3 < (0xf0 & 0xcc) | 0xaa>(q, MASK, EX); static constexpr uint32_t MUL = 0x3F803F80; static constexpr uint32_t ADD = 0xC300C300; From 31dc339b1be11c4712e7768acb9b9fbb707501c6 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Sat, 15 Feb 2025 14:25:31 +0800 Subject: [PATCH 04/15] fix format error Signed-off-by: Jinzhen Lin --- csrc/moe/moe_wna16_utils.h | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/csrc/moe/moe_wna16_utils.h b/csrc/moe/moe_wna16_utils.h index e7226bd71ada..265fbe028dee 100644 --- a/csrc/moe/moe_wna16_utils.h +++ b/csrc/moe/moe_wna16_utils.h @@ -108,11 +108,11 @@ __device__ inline void dequant(int q, half2* res) { const int MUL = 0x2c002c00; const int ADD = 0xd400d400; - int lo0 = lop3 < (0xf0 & 0xcc) | 0xaa>(q, LO, EX); - int hi0 = lop3 < (0xf0 & 0xcc) | 0xaa>(q, HI, EX); + int lo0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX); + int hi0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX); q >>= 8; - int lo1 = lop3 < (0xf0 & 0xcc) | 0xaa>(q, LO, EX); - int hi1 = lop3 < (0xf0 & 0xcc) | 0xaa>(q, HI, EX); + int lo1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX); + int hi1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX); res[0] = __hsub2(*reinterpret_cast(&lo0), *reinterpret_cast(&SUB)); @@ -149,13 +149,13 @@ __device__ inline void dequant(int q, nv_bfloat162* res) { static constexpr uint32_t MASK = 0x000f000f; static constexpr uint32_t EX = 0x43004300; - int lo0 = lop3 < (0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + int lo0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); q >>= 4; - int hi0 = lop3 < (0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + int hi0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); q >>= 4; - int lo1 = lop3 < (0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + int lo1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); q >>= 4; - int hi1 = lop3 < (0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + int hi1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); static constexpr uint32_t MUL = 0x3F803F80; static constexpr uint32_t ADD = 0xC300C300; From 2c1497bbae1172100d43e54ccf58ac0b8f2b8b78 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Sun, 16 Feb 2025 00:21:26 +0800 Subject: [PATCH 05/15] fix moe wna16 kernel Signed-off-by: Jinzhen Lin --- csrc/moe/moe_wna16.cu | 40 ++++++++++--------- csrc/moe/moe_wna16_utils.h | 2 +- .../layers/fused_moe/fused_moe.py | 3 +- 3 files changed, 25 insertions(+), 20 deletions(-) diff --git a/csrc/moe/moe_wna16.cu b/csrc/moe/moe_wna16.cu index 8b837d617309..41d6ad055ba3 100644 --- a/csrc/moe/moe_wna16.cu +++ b/csrc/moe/moe_wna16.cu @@ -52,8 +52,8 @@ __global__ void moe_wna16_gemm_kernel( int32_t num_valid_tokens = 0; extern __shared__ uint16_t block_input_tmp[]; - scalar_t* block_input = reinterpret_cast(&block_input_tmp); - scalar_t2* block_input_half2 = reinterpret_cast(&block_input); + scalar_t* block_input = reinterpret_cast(block_input_tmp); + scalar_t2* block_input_half2 = reinterpret_cast(block_input); // load BLOCK_SIZE_M * BLOCK_SIZE_K into shared memory for (int m = 0; m < BLOCK_SIZE_M; m++) { @@ -101,7 +101,7 @@ __global__ void moe_wna16_gemm_kernel( // weight would be loaded in loop uint32_t expert_qweight_tmp[4]; float4* expert_qweight_tmp_float4 = - reinterpret_cast(&expert_qweight_tmp); + reinterpret_cast(expert_qweight_tmp); // load all required scales one time scalar_t expert_scales_groups[GROUPS]; @@ -111,48 +111,52 @@ __global__ void moe_wna16_gemm_kernel( *expert_scales_groups = expert_scales[scales_offset_tmp]; } else if constexpr (GROUPS == 2) { float* expert_scales_groups_tmp = - reinterpret_cast(&expert_scales_groups); + reinterpret_cast(expert_scales_groups); *expert_scales_groups_tmp = - reinterpret_cast(&expert_scales)[scales_offset_tmp]; + reinterpret_cast(expert_scales)[scales_offset_tmp]; } else if constexpr (GROUPS == 4) { float2* expert_scales_groups_tmp = - reinterpret_cast(&expert_scales_groups); + reinterpret_cast(expert_scales_groups); *expert_scales_groups_tmp = - reinterpret_cast(&expert_scales)[scales_offset_tmp]; + reinterpret_cast(expert_scales)[scales_offset_tmp]; } else if constexpr (GROUPS == 8) { float4* expert_scales_groups_tmp = - reinterpret_cast(&expert_scales_groups); + reinterpret_cast(expert_scales_groups); *expert_scales_groups_tmp = - reinterpret_cast(&expert_scales)[scales_offset_tmp]; + reinterpret_cast(expert_scales)[scales_offset_tmp]; } // load all required qzeros one time uint8_t expert_qzeros_groups[GROUPS]; if (!has_zp) { - qzero_f2 = Dtype::num2num2(Dtype::int2num(8)); + if constexpr (bit == 4) { + qzero_f2 = Dtype::num2num2(Dtype::int2num(8)); + } else { + qzero_f2 = Dtype::num2num2(Dtype::int2num(128)); + } } else { int qzeros_offset_tmp = (offset_n / 2) * (size_k / group_size / GROUPS) + offset_k / group_size / GROUPS; if constexpr (GROUPS == 1) { uint8_t* expert_qzeros_groups_tmp = - reinterpret_cast(&expert_qzeros_groups); + reinterpret_cast(expert_qzeros_groups); *expert_qzeros_groups_tmp = - reinterpret_cast(&expert_qzeros)[qzeros_offset_tmp]; + reinterpret_cast(expert_qzeros)[qzeros_offset_tmp]; } else if constexpr (GROUPS == 2) { uint16_t* expert_qzeros_groups_tmp = - reinterpret_cast(&expert_qzeros_groups); + reinterpret_cast(expert_qzeros_groups); *expert_qzeros_groups_tmp = - reinterpret_cast(&expert_qzeros)[qzeros_offset_tmp]; + reinterpret_cast(expert_qzeros)[qzeros_offset_tmp]; } else if constexpr (GROUPS == 4) { uint32_t* expert_qzeros_groups_tmp = - reinterpret_cast(&expert_qzeros_groups); + reinterpret_cast(expert_qzeros_groups); *expert_qzeros_groups_tmp = - reinterpret_cast(&expert_qzeros)[qzeros_offset_tmp]; + reinterpret_cast(expert_qzeros)[qzeros_offset_tmp]; } else if constexpr (GROUPS == 8) { uint64_t* expert_qzeros_groups_tmp = - reinterpret_cast(&expert_qzeros_groups); + reinterpret_cast(expert_qzeros_groups); *expert_qzeros_groups_tmp = - reinterpret_cast(&expert_qzeros)[qzeros_offset_tmp]; + reinterpret_cast(expert_qzeros)[qzeros_offset_tmp]; } } diff --git a/csrc/moe/moe_wna16_utils.h b/csrc/moe/moe_wna16_utils.h index 265fbe028dee..4396b80240ef 100644 --- a/csrc/moe/moe_wna16_utils.h +++ b/csrc/moe/moe_wna16_utils.h @@ -191,7 +191,7 @@ __device__ inline void dequant(int q, nv_bfloat162* res) { fp32_intermediates[2] -= 8388608.f; fp32_intermediates[3] -= 8388608.f; - uint32_t* bf16_result_ptr = reinterpret_cast(&res); + uint32_t* bf16_result_ptr = reinterpret_cast(res); bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632); bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index e3b0ced1b7d5..88a2e0d499d8 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -694,12 +694,13 @@ def invoke_fused_moe_kernel(A: torch.Tensor, block_size_m=config["BLOCK_SIZE_M"])) if use_moe_wna16_cuda: + bit = 4 if use_int4_w4a16 else 8 ops.moe_wna16_gemm(A, C, B, B_scale, B_zp, topk_weights if mul_routed_weight else None, sorted_token_ids, expert_ids, num_tokens_post_padded, top_k, config["BLOCK_SIZE_M"], config["BLOCK_SIZE_N"], - config["BLOCK_SIZE_K"], 4) + config["BLOCK_SIZE_K"], bit) return fused_moe_kernel_gptq_awq[grid]( From 4202c4b12b2e9268ea574edc67d4e8d72742bc49 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Sun, 16 Feb 2025 00:48:16 +0800 Subject: [PATCH 06/15] fix error Signed-off-by: Jinzhen Lin --- csrc/moe/moe_wna16.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/moe/moe_wna16.cu b/csrc/moe/moe_wna16.cu index 41d6ad055ba3..942cfdffc177 100644 --- a/csrc/moe/moe_wna16.cu +++ b/csrc/moe/moe_wna16.cu @@ -167,7 +167,7 @@ __global__ void moe_wna16_gemm_kernel( if (tmp_k % 4 == 0) { *expert_qweight_tmp_float4 = reinterpret_cast( - &expert_qweight)[weight_offset / pack_factor / 4]; + expert_qweight)[weight_offset / pack_factor / 4]; } if (tmp_k % (group_size / pack_factor) == 0) { From e20d529493f58a80896882b5d172083294c05369 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Sun, 16 Feb 2025 00:51:07 +0800 Subject: [PATCH 07/15] fix error Signed-off-by: Jinzhen Lin --- csrc/moe/moe_wna16.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/moe/moe_wna16.cu b/csrc/moe/moe_wna16.cu index 942cfdffc177..034cb63ce6ed 100644 --- a/csrc/moe/moe_wna16.cu +++ b/csrc/moe/moe_wna16.cu @@ -166,7 +166,7 @@ __global__ void moe_wna16_gemm_kernel( const int32_t weight_offset = offset_n * size_k + k; if (tmp_k % 4 == 0) { - *expert_qweight_tmp_float4 = reinterpret_cast( + *expert_qweight_tmp_float4 = reinterpret_cast( expert_qweight)[weight_offset / pack_factor / 4]; } From c37abcbae0807f6a4dac43c1e200e9b143f68d62 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Sun, 16 Feb 2025 01:09:30 +0800 Subject: [PATCH 08/15] fix typo Signed-off-by: Jinzhen Lin --- vllm/model_executor/layers/fused_moe/fused_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 88a2e0d499d8..11731f26b5fb 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -960,7 +960,7 @@ def try_get_optimal_moe_config( else: # First try to load optimal config from the file E, _, N = w2_shape - if dtype == "int4_w8a8": + if dtype == "int4_w8a16": N = N * 2 block_n = block_shape[0] if block_shape else 0 block_k = block_shape[1] if block_shape else 0 From 28d1c497db92e0e8a66009df748f8199f1504bdb Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Fri, 28 Feb 2025 13:21:14 +0800 Subject: [PATCH 09/15] fix typo and cmake config Signed-off-by: Jinzhen Lin --- CMakeLists.txt | 7 +++++++ vllm/model_executor/layers/fused_moe/fused_moe.py | 8 ++++---- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c62a1ce96448..647ffeea20a1 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -522,6 +522,13 @@ set_gencode_flags_for_srcs( CUDA_ARCHS "${CUDA_ARCHS}") if(VLLM_GPU_LANG STREQUAL "CUDA") + set(VLLM_MOE_WNA16_SRC + "csrc/moe/moe_wna16.cu") + + set_gencode_flags_for_srcs( + SRCS "${VLLM_MOE_WNA16_SRC}" + CUDA_ARCHS "${CUDA_ARCHS}") + cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}") if (MARLIN_MOE_ARCHS) set(MARLIN_MOE_SRC diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 4f33a17354ca..355beadf4d46 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -954,11 +954,11 @@ def get_default_config( "num_warps": 4, "num_stages": 3, } - elif dtype in ["int4_w8a16", "int8_w8a16"] and block_shape is not None: + elif dtype in ["int4_w4a16", "int8_w8a16"] and block_shape is not None: # moe wna16 kernels # only set BLOCK_SIZE_M # BLOCK_SIZE_N and BLOCK_SIZE_K would be set later - bit = 4 if dtype == "int4_w8a16" else 8 + bit = 4 if dtype == "int4_w4a16" else 8 use_moe_wna16_cuda = should_moe_wna16_use_cuda(M * topk, block_shape[1], E, bit) if use_moe_wna16_cuda: @@ -1003,7 +1003,7 @@ def try_get_optimal_moe_config( else: # First try to load optimal config from the file E, _, N = w2_shape - if dtype == "int4_w8a16": + if dtype == "int4_w4a16": N = N * 2 block_n = block_shape[0] if block_shape else 0 block_k = block_shape[1] if block_shape else 0 @@ -1125,7 +1125,7 @@ def get_config_dtype_str(dtype: torch.dtype, elif use_int8_w8a16: return "int8_w8a16" elif use_int4_w4a16: - return "int4_w8a16" + return "int4_w4a16" elif dtype == torch.float: # avoiding cases where kernel fails when float32 MoE # use fp16/bfloat16 configs From 7287aa02cff76e8a6c6dfdc991d33e519a95fa73 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Fri, 28 Feb 2025 13:41:28 +0800 Subject: [PATCH 10/15] update comment; enable int8 moe wna16 cuda kernel Signed-off-by: Jinzhen Lin --- .../layers/fused_moe/fused_moe.py | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 355beadf4d46..64665ab1dbcc 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -722,8 +722,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, use_moe_wna16_cuda = should_moe_wna16_use_cuda( num_valid_tokens=topk_ids.numel(), group_size=block_shape[1], - num_experts=B.shape[0], - bit=4 if use_int4_w4a16 else 8) + num_experts=B.shape[0]) config = config.copy() config.update( get_moe_wna16_block_config(config=config, @@ -885,13 +884,19 @@ def get_moe_wna16_block_config(config: Dict[str, num_experts: int, group_size: int, real_top_k: int, block_size_m: int): if "BLOCK_SIZE_N" in config and "BLOCK_SIZE_K" in config: + # optimal block config is set return {} if not use_moe_wna16_cuda: + # triton moe wna16 kernel if num_valid_tokens // real_top_k == 1: + # if bs=1, use a smaller BLOCK_SIZE_N return {"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64} else: return {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32} else: + # cuda moe wna16 kernel + # set default block_size 128, and increase them when num_blocks + # is too large. block_size_n = 128 block_size_k = 128 if block_size_k <= group_size: @@ -922,15 +927,18 @@ def get_moe_wna16_block_config(config: Dict[str, num_blocks = num_blocks // 2 if size_n <= 1024 and num_blocks >= 1024: + # The kernel performance got much better with BLOCK_SIZE_N=1024 + # when num_blocks is large, event when N is small. + # Not sure why, maybe it force the CUDA SM process only one block + # at the same time. block_size_n = 1024 return {"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k} def should_moe_wna16_use_cuda(num_valid_tokens: int, group_size: int, - num_experts: int, bit: int): - return bit == 4 and group_size in [32, 64, 128] and \ - num_valid_tokens / num_experts <= 8 + num_experts: int): + return group_size in [32, 64, 128] and num_valid_tokens / num_experts <= 8 def get_default_config( @@ -958,9 +966,8 @@ def get_default_config( # moe wna16 kernels # only set BLOCK_SIZE_M # BLOCK_SIZE_N and BLOCK_SIZE_K would be set later - bit = 4 if dtype == "int4_w4a16" else 8 use_moe_wna16_cuda = should_moe_wna16_use_cuda(M * topk, - block_shape[1], E, bit) + block_shape[1], E) if use_moe_wna16_cuda: config = {"BLOCK_SIZE_M": min(16, M)} elif M <= 20: From 14ac3289ea81ffb2cc237e675488a11911059077 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Fri, 28 Feb 2025 23:58:41 +0800 Subject: [PATCH 11/15] fix int8 error Signed-off-by: Jinzhen Lin --- csrc/moe/moe_wna16.cu | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/csrc/moe/moe_wna16.cu b/csrc/moe/moe_wna16.cu index 034cb63ce6ed..ead9c8c415d6 100644 --- a/csrc/moe/moe_wna16.cu +++ b/csrc/moe/moe_wna16.cu @@ -80,7 +80,7 @@ __global__ void moe_wna16_gemm_kernel( origin_k = BLOCK_SIZE_N * i + threadIdx.x / 8 * 8 + order; } else { // [0, 2, 1, 3] - int8_t order = (threadIdx.x % 2) * 4 + ((threadIdx.x % 8) / 2); + int8_t order = (threadIdx.x % 2) * 2 + ((threadIdx.x % 4) / 2); origin_k = BLOCK_SIZE_N * i + threadIdx.x / 4 * 4 + order; } @@ -135,7 +135,7 @@ __global__ void moe_wna16_gemm_kernel( qzero_f2 = Dtype::num2num2(Dtype::int2num(128)); } } else { - int qzeros_offset_tmp = (offset_n / 2) * (size_k / group_size / GROUPS) + + int qzeros_offset_tmp = (offset_n / (8 / bit)) * (size_k / group_size / GROUPS) + offset_k / group_size / GROUPS; if constexpr (GROUPS == 1) { uint8_t* expert_qzeros_groups_tmp = @@ -178,7 +178,9 @@ __global__ void moe_wna16_gemm_kernel( if (has_zp) { uint8_t qzero = expert_qzeros_groups[tmp_k / (group_size / pack_factor)]; - qzero = (qzero >> ((threadIdx.x % 2) * 4)) & 0xF; + if constexpr (bit == 4) { + qzero = (qzero >> ((threadIdx.x % 2) * 4)) & 0xF; + } qzero_f2 = Dtype::num2num2(Dtype::int2num(qzero)); } } From 38a0eed21e9380cdeae374fb01f812b23f138b43 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Sat, 1 Mar 2025 00:10:04 +0800 Subject: [PATCH 12/15] disable moe wna16 cuda kernel for int8 Signed-off-by: Jinzhen Lin --- vllm/model_executor/layers/fused_moe/fused_moe.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 64665ab1dbcc..f3136576e3f9 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -722,7 +722,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, use_moe_wna16_cuda = should_moe_wna16_use_cuda( num_valid_tokens=topk_ids.numel(), group_size=block_shape[1], - num_experts=B.shape[0]) + num_experts=B.shape[0], + bit=4 if use_int4_w4a16 else 8) config = config.copy() config.update( get_moe_wna16_block_config(config=config, @@ -937,8 +938,9 @@ def get_moe_wna16_block_config(config: Dict[str, def should_moe_wna16_use_cuda(num_valid_tokens: int, group_size: int, - num_experts: int): - return group_size in [32, 64, 128] and num_valid_tokens / num_experts <= 8 + num_experts: int, bit: int): + return bit == 4 and group_size in [32, 64, 128] and \ + num_valid_tokens / num_experts <= 8 def get_default_config( @@ -966,8 +968,9 @@ def get_default_config( # moe wna16 kernels # only set BLOCK_SIZE_M # BLOCK_SIZE_N and BLOCK_SIZE_K would be set later + bit = 4 if dtype == "int4_w4a16" else 8 use_moe_wna16_cuda = should_moe_wna16_use_cuda(M * topk, - block_shape[1], E) + block_shape[1], E, bit) if use_moe_wna16_cuda: config = {"BLOCK_SIZE_M": min(16, M)} elif M <= 20: From 22ccc61b83d38d5ac610370edb8e49fff85b979c Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Sat, 1 Mar 2025 00:18:05 +0800 Subject: [PATCH 13/15] fix ci error Signed-off-by: Jinzhen Lin --- csrc/moe/moe_wna16.cu | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/csrc/moe/moe_wna16.cu b/csrc/moe/moe_wna16.cu index ead9c8c415d6..a897693c4f62 100644 --- a/csrc/moe/moe_wna16.cu +++ b/csrc/moe/moe_wna16.cu @@ -135,8 +135,9 @@ __global__ void moe_wna16_gemm_kernel( qzero_f2 = Dtype::num2num2(Dtype::int2num(128)); } } else { - int qzeros_offset_tmp = (offset_n / (8 / bit)) * (size_k / group_size / GROUPS) + - offset_k / group_size / GROUPS; + int qzeros_offset_tmp = + (offset_n / (8 / bit)) * (size_k / group_size / GROUPS) + + offset_k / group_size / GROUPS; if constexpr (GROUPS == 1) { uint8_t* expert_qzeros_groups_tmp = reinterpret_cast(expert_qzeros_groups); @@ -179,7 +180,7 @@ __global__ void moe_wna16_gemm_kernel( uint8_t qzero = expert_qzeros_groups[tmp_k / (group_size / pack_factor)]; if constexpr (bit == 4) { - qzero = (qzero >> ((threadIdx.x % 2) * 4)) & 0xF; + qzero = (qzero >> ((threadIdx.x % 2) * 4)) & 0xF; } qzero_f2 = Dtype::num2num2(Dtype::int2num(qzero)); } From aff319fa420083310f843192320d9b76f84a06f5 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Sun, 9 Mar 2025 11:49:06 +0800 Subject: [PATCH 14/15] support expert parallelism Signed-off-by: Jinzhen Lin --- csrc/moe/moe_wna16.cu | 59 +++++++++++++++++++++++-------------------- 1 file changed, 31 insertions(+), 28 deletions(-) diff --git a/csrc/moe/moe_wna16.cu b/csrc/moe/moe_wna16.cu index a897693c4f62..51ae76c1ec88 100644 --- a/csrc/moe/moe_wna16.cu +++ b/csrc/moe/moe_wna16.cu @@ -41,14 +41,6 @@ __global__ void moe_wna16_gemm_kernel( const int32_t offset_k = blockIdx.z * BLOCK_SIZE_K; const int32_t expert_id = expert_ids[blockIdx.x]; - const int8_t pack_factor = 32 / bit; - - // note that (size_n * size_k * expert_id) may greater than 2 ** 31 - const uint64_t expert_offset = ((uint64_t)size_n) * size_k * expert_id; - const uint32_t* expert_qweight = qweight + expert_offset / pack_factor; - const scalar_t* expert_scales = scales + expert_offset / group_size; - const uint32_t* expert_qzeros = - qzeros + expert_offset / group_size / pack_factor; int32_t num_valid_tokens = 0; extern __shared__ uint16_t block_input_tmp[]; @@ -65,30 +57,33 @@ __global__ void moe_wna16_gemm_kernel( if (blockIdx.z == 0 && offset_n < size_n) output[token_index * size_n + offset_n] = Dtype::int2num(0); - int k_per_thread = DIVIDE(BLOCK_SIZE_K, BLOCK_SIZE_N); - for (int i = 0; i < k_per_thread; i++) { - int k = BLOCK_SIZE_N * i + threadIdx.x; - if (k >= BLOCK_SIZE_K) break; - if (offset_k + k >= size_k) break; - - // load input to shared memory - // use a special layout to fit the layout of dequanted-weight - int origin_k; - if constexpr (bit == 4) { - // [0, 4, 1, 5, 2, 6, 3, 7] - int8_t order = (threadIdx.x % 2) * 4 + ((threadIdx.x % 8) / 2); - origin_k = BLOCK_SIZE_N * i + threadIdx.x / 8 * 8 + order; - } else { - // [0, 2, 1, 3] - int8_t order = (threadIdx.x % 2) * 2 + ((threadIdx.x % 4) / 2); - origin_k = BLOCK_SIZE_N * i + threadIdx.x / 4 * 4 + order; - } + if (expert_id != -1) { + int k_per_thread = DIVIDE(BLOCK_SIZE_K, BLOCK_SIZE_N); + for (int i = 0; i < k_per_thread; i++) { + int k = BLOCK_SIZE_N * i + threadIdx.x; + if (k >= BLOCK_SIZE_K) break; + if (offset_k + k >= size_k) break; + + // load input to shared memory + // use a special layout to fit the layout of dequanted-weight + int origin_k; + if constexpr (bit == 4) { + // [0, 4, 1, 5, 2, 6, 3, 7] + int8_t order = (threadIdx.x % 2) * 4 + ((threadIdx.x % 8) / 2); + origin_k = BLOCK_SIZE_N * i + threadIdx.x / 8 * 8 + order; + } else { + // [0, 2, 1, 3] + int8_t order = (threadIdx.x % 2) * 2 + ((threadIdx.x % 4) / 2); + origin_k = BLOCK_SIZE_N * i + threadIdx.x / 4 * 4 + order; + } - origin_k += token_index / top_k * size_k + blockIdx.z * BLOCK_SIZE_K; - block_input[m * BLOCK_SIZE_K + k] = input[origin_k]; + origin_k += token_index / top_k * size_k + blockIdx.z * BLOCK_SIZE_K; + block_input[m * BLOCK_SIZE_K + k] = input[origin_k]; + } } } + if (expert_id == -1) return; __syncthreads(); if (threadIdx.x >= BLOCK_SIZE_N || offset_n >= size_n) return; @@ -97,6 +92,14 @@ __global__ void moe_wna16_gemm_kernel( scalar_t2 scale_f2; scalar_t2 qzero_f2; + // note that (size_n * size_k * expert_id) may greater than 2 ** 31 + constexpr int8_t pack_factor = 32 / bit; + const uint64_t expert_offset = ((uint64_t)size_n) * size_k * expert_id; + const uint32_t* expert_qweight = qweight + expert_offset / pack_factor; + const scalar_t* expert_scales = scales + expert_offset / group_size; + const uint32_t* expert_qzeros = + qzeros + expert_offset / group_size / pack_factor; + // load 4*int32 one time: 4 int32 = 128 bit = 1 float4 // weight would be loaded in loop uint32_t expert_qweight_tmp[4]; From 4b740b0d28bd623f58dfe9e4e45ce2ae638fe3e7 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Tue, 11 Mar 2025 01:29:45 +0800 Subject: [PATCH 15/15] update condition Signed-off-by: Jinzhen Lin --- vllm/model_executor/layers/fused_moe/fused_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 9134a0b46d5f..89ceba122748 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -940,7 +940,7 @@ def get_moe_wna16_block_config(config: Dict[str, def should_moe_wna16_use_cuda(num_valid_tokens: int, group_size: int, num_experts: int, bit: int): return bit == 4 and group_size in [32, 64, 128] and \ - num_valid_tokens / num_experts <= 8 + num_valid_tokens / num_experts <= 6 def get_default_config(