File tree Expand file tree Collapse file tree 2 files changed +11
-4
lines changed
vllm/model_executor/layers/fused_moe Expand file tree Collapse file tree 2 files changed +11
-4
lines changed Original file line number Diff line number Diff line change @@ -522,6 +522,13 @@ set_gencode_flags_for_srcs(
522
522
CUDA_ARCHS "${CUDA_ARCHS} " )
523
523
524
524
if (VLLM_GPU_LANG STREQUAL "CUDA" )
525
+ set (VLLM_MOE_WNA16_SRC
526
+ "csrc/moe/moe_wna16.cu" )
527
+
528
+ set_gencode_flags_for_srcs(
529
+ SRCS "${VLLM_MOE_WNA16_SRC} "
530
+ CUDA_ARCHS "${CUDA_ARCHS} " )
531
+
525
532
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS} " )
526
533
if (MARLIN_MOE_ARCHS)
527
534
set (MARLIN_MOE_SRC
Original file line number Diff line number Diff line change @@ -954,11 +954,11 @@ def get_default_config(
954
954
"num_warps" : 4 ,
955
955
"num_stages" : 3 ,
956
956
}
957
- elif dtype in ["int4_w8a16 " , "int8_w8a16" ] and block_shape is not None :
957
+ elif dtype in ["int4_w4a16 " , "int8_w8a16" ] and block_shape is not None :
958
958
# moe wna16 kernels
959
959
# only set BLOCK_SIZE_M
960
960
# BLOCK_SIZE_N and BLOCK_SIZE_K would be set later
961
- bit = 4 if dtype == "int4_w8a16 " else 8
961
+ bit = 4 if dtype == "int4_w4a16 " else 8
962
962
use_moe_wna16_cuda = should_moe_wna16_use_cuda (M * topk ,
963
963
block_shape [1 ], E , bit )
964
964
if use_moe_wna16_cuda :
@@ -1003,7 +1003,7 @@ def try_get_optimal_moe_config(
1003
1003
else :
1004
1004
# First try to load optimal config from the file
1005
1005
E , _ , N = w2_shape
1006
- if dtype == "int4_w8a16 " :
1006
+ if dtype == "int4_w4a16 " :
1007
1007
N = N * 2
1008
1008
block_n = block_shape [0 ] if block_shape else 0
1009
1009
block_k = block_shape [1 ] if block_shape else 0
@@ -1125,7 +1125,7 @@ def get_config_dtype_str(dtype: torch.dtype,
1125
1125
elif use_int8_w8a16 :
1126
1126
return "int8_w8a16"
1127
1127
elif use_int4_w4a16 :
1128
- return "int4_w8a16 "
1128
+ return "int4_w4a16 "
1129
1129
elif dtype == torch .float :
1130
1130
# avoiding cases where kernel fails when float32 MoE
1131
1131
# use fp16/bfloat16 configs
You can’t perform that action at this time.
0 commit comments