@@ -722,8 +722,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
722
722
use_moe_wna16_cuda = should_moe_wna16_use_cuda (
723
723
num_valid_tokens = topk_ids .numel (),
724
724
group_size = block_shape [1 ],
725
- num_experts = B .shape [0 ],
726
- bit = 4 if use_int4_w4a16 else 8 )
725
+ num_experts = B .shape [0 ])
727
726
config = config .copy ()
728
727
config .update (
729
728
get_moe_wna16_block_config (config = config ,
@@ -885,13 +884,19 @@ def get_moe_wna16_block_config(config: Dict[str,
885
884
num_experts : int , group_size : int ,
886
885
real_top_k : int , block_size_m : int ):
887
886
if "BLOCK_SIZE_N" in config and "BLOCK_SIZE_K" in config :
887
+ # optimal block config is set
888
888
return {}
889
889
if not use_moe_wna16_cuda :
890
+ # triton moe wna16 kernel
890
891
if num_valid_tokens // real_top_k == 1 :
892
+ # if bs=1, use a smaller BLOCK_SIZE_N
891
893
return {"BLOCK_SIZE_N" : 32 , "BLOCK_SIZE_K" : 64 }
892
894
else :
893
895
return {"BLOCK_SIZE_N" : 64 , "BLOCK_SIZE_K" : 32 }
894
896
else :
897
+ # cuda moe wna16 kernel
898
+ # set default block_size 128, and increase them when num_blocks
899
+ # is too large.
895
900
block_size_n = 128
896
901
block_size_k = 128
897
902
if block_size_k <= group_size :
@@ -922,15 +927,18 @@ def get_moe_wna16_block_config(config: Dict[str,
922
927
num_blocks = num_blocks // 2
923
928
924
929
if size_n <= 1024 and num_blocks >= 1024 :
930
+ # The kernel performance got much better with BLOCK_SIZE_N=1024
931
+ # when num_blocks is large, event when N is small.
932
+ # Not sure why, maybe it force the CUDA SM process only one block
933
+ # at the same time.
925
934
block_size_n = 1024
926
935
927
936
return {"BLOCK_SIZE_N" : block_size_n , "BLOCK_SIZE_K" : block_size_k }
928
937
929
938
930
939
def should_moe_wna16_use_cuda (num_valid_tokens : int , group_size : int ,
931
- num_experts : int , bit : int ):
932
- return bit == 4 and group_size in [32 , 64 , 128 ] and \
933
- num_valid_tokens / num_experts <= 8
940
+ num_experts : int ):
941
+ return group_size in [32 , 64 , 128 ] and num_valid_tokens / num_experts <= 8
934
942
935
943
936
944
def get_default_config (
@@ -958,9 +966,8 @@ def get_default_config(
958
966
# moe wna16 kernels
959
967
# only set BLOCK_SIZE_M
960
968
# BLOCK_SIZE_N and BLOCK_SIZE_K would be set later
961
- bit = 4 if dtype == "int4_w4a16" else 8
962
969
use_moe_wna16_cuda = should_moe_wna16_use_cuda (M * topk ,
963
- block_shape [1 ], E , bit )
970
+ block_shape [1 ], E )
964
971
if use_moe_wna16_cuda :
965
972
config = {"BLOCK_SIZE_M" : min (16 , M )}
966
973
elif M <= 20 :
0 commit comments