Skip to content

Commit 28d1c49

Browse files
committed
fix typo and cmake config
Signed-off-by: Jinzhen Lin <[email protected]>
1 parent 8c9c914 commit 28d1c49

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,13 @@ set_gencode_flags_for_srcs(
522522
CUDA_ARCHS "${CUDA_ARCHS}")
523523

524524
if(VLLM_GPU_LANG STREQUAL "CUDA")
525+
set(VLLM_MOE_WNA16_SRC
526+
"csrc/moe/moe_wna16.cu")
527+
528+
set_gencode_flags_for_srcs(
529+
SRCS "${VLLM_MOE_WNA16_SRC}"
530+
CUDA_ARCHS "${CUDA_ARCHS}")
531+
525532
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}")
526533
if (MARLIN_MOE_ARCHS)
527534
set(MARLIN_MOE_SRC

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -954,11 +954,11 @@ def get_default_config(
954954
"num_warps": 4,
955955
"num_stages": 3,
956956
}
957-
elif dtype in ["int4_w8a16", "int8_w8a16"] and block_shape is not None:
957+
elif dtype in ["int4_w4a16", "int8_w8a16"] and block_shape is not None:
958958
# moe wna16 kernels
959959
# only set BLOCK_SIZE_M
960960
# BLOCK_SIZE_N and BLOCK_SIZE_K would be set later
961-
bit = 4 if dtype == "int4_w8a16" else 8
961+
bit = 4 if dtype == "int4_w4a16" else 8
962962
use_moe_wna16_cuda = should_moe_wna16_use_cuda(M * topk,
963963
block_shape[1], E, bit)
964964
if use_moe_wna16_cuda:
@@ -1003,7 +1003,7 @@ def try_get_optimal_moe_config(
10031003
else:
10041004
# First try to load optimal config from the file
10051005
E, _, N = w2_shape
1006-
if dtype == "int4_w8a16":
1006+
if dtype == "int4_w4a16":
10071007
N = N * 2
10081008
block_n = block_shape[0] if block_shape else 0
10091009
block_k = block_shape[1] if block_shape else 0
@@ -1125,7 +1125,7 @@ def get_config_dtype_str(dtype: torch.dtype,
11251125
elif use_int8_w8a16:
11261126
return "int8_w8a16"
11271127
elif use_int4_w4a16:
1128-
return "int4_w8a16"
1128+
return "int4_w4a16"
11291129
elif dtype == torch.float:
11301130
# avoiding cases where kernel fails when float32 MoE
11311131
# use fp16/bfloat16 configs

0 commit comments

Comments
 (0)