diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index b14bc06e913c..3052bdb4dc1e 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -288,7 +288,11 @@ def use_int4_w4a16(self) -> bool: @property def use_mxfp4_w4a4(self) -> bool: - return self.quant_dtype == "mxfp4" + return (self._a1.dtype == "mxfp4" and self._w1.dtype == "mxfp4") + + @property + def use_mxfp4_w4a16(self) -> bool: + return (self._a1.dtype is None and self._w1.dtype == "mxfp4") @property def use_nvfp4_w4a4(self) -> bool: @@ -453,6 +457,22 @@ def int8_w8a8_moe_quant_config( ) +def mxfp4_w4a16_moe_quant_config( + w1_scale: Union[torch.Tensor, "PrecisionConfig"], + w2_scale: Union[torch.Tensor, "PrecisionConfig"], + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None) -> FusedMoEQuantConfig: + """ + Construct a quant config for unquantized activations and mxfp4 weights. + """ + return FusedMoEQuantConfig( + _a1=FusedMoEQuantDesc(), + _a2=FusedMoEQuantDesc(), + _w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias), + _w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias), + ) + + def mxfp4_w4a4_moe_quant_config( w1_scale: Union[torch.Tensor, "PrecisionConfig"], w2_scale: Union[torch.Tensor, "PrecisionConfig"], diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py index f390f0a25875..a250a6218715 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -11,6 +11,7 @@ TopKWeightAndReduceContiguous, TopKWeightAndReduceDelegate) from vllm.model_executor.layers.fused_moe.utils import ( moe_kernel_quantize_input) +from vllm.utils import round_up class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): @@ -18,6 +19,23 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): Prepare/Finalize using DeepEP High-Throughput kernels. """ + @staticmethod + def maybe_roundup_layer_hidden_size(hidden_size: int, + dtype: torch.dtype) -> int: + # Round up hidden size so it is compatible with DeepEP High Throughput + # kernels. + # DeepEP intranode kernels make copies in units of, + # 32(warp-size) int4 elements. Round up hidden size to respect this. + # For example, an input hidden size of 2880 with dtype torch.bfloat16 + # will be rounded up to 3072. + hidden_size_bytes = hidden_size * dtype.itemsize + xfer_atom_size = 512 # 32 * 16 (size(int4)) + if hidden_size_bytes % xfer_atom_size == 0: + return hidden_size + + hidden_size_bytes = round_up(hidden_size_bytes, xfer_atom_size) + return hidden_size_bytes // dtype.itemsize + def __init__(self, buffer: deep_ep.Buffer, num_dispatchers: int, dp_size: int, rank_expert_offset: int): super().__init__() diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index f12d3807517f..0e84a9241e90 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -9,7 +9,8 @@ from vllm.model_executor.layers.fused_moe.config import ( FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceDelegate) + TopKWeightAndReduceNoOP) +from vllm.triton_utils import tl, triton from vllm.utils import has_triton_kernels logger = init_logger(__name__) @@ -19,13 +20,55 @@ import triton_kernels.swiglu from triton_kernels.matmul_ogs import (FnSpecs, FusedActivation, matmul_ogs) - from triton_kernels.routing import routing + from triton_kernels.routing import (RoutingData, routing, + routing_from_bitmatrix) + from triton_kernels.tensor import Bitmatrix except (ModuleNotFoundError, AttributeError) as e: logger.error( "Failed to import Triton kernels. Please make sure your triton " "version is compatible. Error: %s", e) +@triton.jit +def pack_bitmatrix( + bitmatrix, + topk_ids, + n_rows, # n_rows in bitmatrix / topk_ids + bm_cols: tl.constexpr, # n int32_t bitpacks in bitmatrix + n_expts_act, # num_topk + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + """ + Packs topk_ids into a bitmatrix. + code reference: + https://github.com/triton-lang/triton/blob/dd1bbc52b34d202dfe5ffea1e04fb16166c5c04e/python/triton_kernels/bench/distributed.py#L264 + """ + pid_m = tl.program_id(0) + offsets_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offsets_k = tl.arange(0, BLOCK_SIZE_K) + offsets = offsets_m[:, None] * n_expts_act + offsets_k[None, :] + mask = (offsets_m < n_rows)[:, None] & (offsets_k < n_expts_act)[None, :] + indices = tl.load(topk_ids + offsets, mask=mask, other=-1) + div = indices // 32 + rem = indices % 32 + one = tl.cast(1, tl.uint32) + + # Iterate through all the relevant bitmatrix columns. + for i in range(bm_cols): + # When BLOCK_SIZE_K=32, offs is just the column index. + offs = tl.arange(0, BLOCK_SIZE_K // 32) + i * (BLOCK_SIZE_K // 32) + # All topks that need to go into this column has the correct bit set. + # Other bits are 0. x is a 2D tensor. + x = tl.where(div[:, :, None] == offs[None, None, :], + (one << rem)[:, :, None], 0) + # Reduce x to get a single int32_t bitpack. + y = tl.reduce_or(x, axis=1) + bitmatrix_ptrs = bitmatrix + offsets_m[:, + None] * bm_cols + offs[None, :] + tl.store(bitmatrix_ptrs, y, mask=offsets_m[:, None] < n_rows) + + def triton_kernel_moe_forward( hidden_states: torch.Tensor, w1, # Tensor or triton_kernels.Tensor @@ -124,34 +167,88 @@ def triton_kernel_fused_experts( return intermediate_cache3 -class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): +def make_routing_data( + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + num_local_experts: int, +) -> tuple["RoutingData", torch.Tensor, torch.Tensor]: + + topk_ids = topk_ids.to(torch.int16) + topk_weights = topk_weights.to(torch.bfloat16) + + n_rows, num_topk = topk_ids.size() + + BLOCK_SIZE_M = 512 + BLOCK_SIZE_K = 32 + + bm_cols = triton.cdiv(num_local_experts, BLOCK_SIZE_K) # n_bitpacks + bitmatrix = torch.zeros((n_rows, bm_cols), + dtype=torch.uint32, + device=topk_ids.device) + + grid = (triton.cdiv(n_rows, BLOCK_SIZE_M), ) + pack_bitmatrix[grid]( + bitmatrix, + topk_ids, + n_rows, + bm_cols, + num_topk, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_K=BLOCK_SIZE_K, + ) + + bitmatrix_shape = [n_rows, bm_cols * 32] + bitmatrix_shape_max = [n_rows, None] + bitmatrix = Bitmatrix(bitmatrix, + shape=bitmatrix_shape, + shape_max=bitmatrix_shape_max, + scratchpad=None) + + # matmul_ogs expects invalid topk_weights to be -1s + topk_weights = torch.where(topk_ids == -1, -1.0, topk_weights) + routing_data, gather_indx, scatter_indx = routing_from_bitmatrix( + bitmatrix, topk_weights, topk_ids, num_local_experts, num_topk) + + return routing_data, gather_indx, scatter_indx + + +class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): + + def __init__(self, quant_config: FusedMoEQuantConfig): + super().__init__(quant_config) + + def supports_expert_map(self) -> bool: + return True + + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + # Weight application and reduction happens in the fused_experts kernel. + return TopKWeightAndReduceNoOP() - def __init__( + def _make_routing_data( self, - max_num_tokens: int, - num_dispatchers: int, - quant_config: FusedMoEQuantConfig, - ): + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + num_local_experts: int, + ) -> tuple["RoutingData", torch.Tensor, torch.Tensor]: + return make_routing_data(topk_ids, topk_weights, num_local_experts) + + +class OAITritonExperts(BaseOAITritonExperts): + + def __init__(self, quant_config: FusedMoEQuantConfig): + # TODO (varun) : Enable activation quantization + assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16" super().__init__(quant_config) - self.max_num_tokens = max_num_tokens - self.num_dispatchers = num_dispatchers @property def activation_formats( self ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return (mk.FusedMoEActivationFormat.BatchedExperts, - mk.FusedMoEActivationFormat.BatchedExperts) + return (mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard) def supports_chunking(self) -> bool: - return False - - def supports_expert_map(self) -> bool: - return False - - def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: - # Let PrepareAndFinalize::finalize() decide the impl. - return TopKWeightAndReduceDelegate() + return True def workspace_shapes( self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int, @@ -159,13 +256,10 @@ def workspace_shapes( expert_tokens_meta: Optional[mk.ExpertTokensMetadata] ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: # workspace are allocated inside the kernel - assert a.dim() == 2 - num_dp = self.num_dispatchers - num_experts = local_num_experts - max_num_tokens = self.max_num_tokens - workspace2 = (0, 0, 0) - output = (num_experts, max_num_tokens * num_dp, N) - return (output, workspace2, output, a.dtype) + workspace1 = (M, K) + workspace2 = (0, 0) + output = (M, K) + return (workspace1, workspace2, output, a.dtype) def apply( self, @@ -185,17 +279,29 @@ def apply( expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: bool, ): - return triton_kernel_fused_experts( - output, + if expert_map is not None: + topk_ids = expert_map[topk_ids] + + local_num_experts = w1.size(0) + if global_num_experts == -1: + global_num_experts = local_num_experts + + routing_data, gather_indx, scatter_indx = self._make_routing_data( + topk_ids, topk_weights, local_num_experts) + + experts_output = triton_kernel_fused_experts( + None, hidden_states, w1, w2, - routing_data=None, - gather_indx=None, - scatter_indx=None, + routing_data, + gather_indx, + scatter_indx, activation=activation, quant_config=self.quant_config, apply_router_weight_on_input=False, - global_num_experts=global_num_experts, - expert_map=expert_map, + global_num_experts=local_num_experts, + expert_map=None, # applied already a1q_scale=a1q_scale) + + output.copy_(experts_output, non_blocking=True) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 17ad75584a3f..1f80e972b7f0 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -800,6 +800,49 @@ def get_compressed_expert_map(expert_map: torch.Tensor) -> str: for local_index, global_index in zip(local_indices, global_indices)) +def maybe_roundup_hidden_size( + hidden_size: int, act_dtype: torch.dtype, + quant_config: Optional[QuantizationConfig], + moe_parallel_config: FusedMoEParallelConfig) -> int: + """ + Given layer hidden size and MoE configurations, round up hidden_size + if necessary. + + Args: + hidden_size(int): Layer hidden-size + act_dtype: Data type of the layer activations. + quant_config(FusedMoEQuantConfig): Fused MoE quantization configuration. + moe_parallel_config(FusedMoEParallelConfig): Fused MoE parallelization + strategy configuration. + + Return: + Rounded up hidden_size if rounding up is required based on the configs. + Original hidden size otherwise. + """ + + if (moe_parallel_config.use_deepep_ht_kernels): + hidden_size = ( + DeepEPHTPrepareAndFinalize.maybe_roundup_layer_hidden_size( + hidden_size, act_dtype)) + + # we are padding globally so EP buffer allocation works + if quant_config and quant_config.get_name() == "mxfp4": + + from vllm.model_executor.layers.quantization.mxfp4 import ( + Mxfp4Backend, get_mxfp4_backend) + current_mxfp4_backend = get_mxfp4_backend() + if (current_mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16 + or current_mxfp4_backend + == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS): + hidden_size = round_up(hidden_size, 128) + elif (current_platform.is_rocm() or current_mxfp4_backend + == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM + or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16): + hidden_size = round_up(hidden_size, 256) + + return hidden_size + + @CustomOp.register("fused_moe") class FusedMoE(CustomOp): """FusedMoE layer for MoE models. @@ -856,6 +899,18 @@ def __init__( params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype + vllm_config = get_current_vllm_config() + + # FIXME (varun): We should have a better way of inferring the activation + # datatype. This works for now as the tensor datatype entering the MoE + # operation is typically unquantized (i.e. float16/bfloat16). + if vllm_config.model_config is not None: + moe_in_dtype = vllm_config.model_config.dtype + else: + # TODO (bnell): This is a hack to get test_mixtral_moe to work + # since model_config is not set in the pytest test. + moe_in_dtype = params_dtype + tp_size_ = (tp_size if tp_size is not None else get_tensor_model_parallel_world_size()) dp_size_ = (dp_size @@ -865,7 +920,6 @@ def __init__( if self.is_sequence_parallel: self.sp_size = tp_size_ - vllm_config = get_current_vllm_config() self.moe_parallel_config: FusedMoEParallelConfig = ( FusedMoEParallelConfig.make( tp_size_=tp_size_, @@ -874,19 +928,10 @@ def __init__( self.global_num_experts = num_experts + num_redundant_experts - # we are padding globally so EP buffer allocation works - if quant_config and quant_config.get_name() == "mxfp4": - from vllm.model_executor.layers.quantization.mxfp4 import ( - Mxfp4Backend, get_mxfp4_backend) - current_mxfp4_backend = get_mxfp4_backend() - if (current_mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16 - or current_mxfp4_backend - == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS): - hidden_size = round_up(hidden_size, 128) - elif (current_platform.is_rocm() or current_mxfp4_backend - == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM or - current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16): - hidden_size = round_up(hidden_size, 256) + # Round up hidden size if needed. + hidden_size = maybe_roundup_hidden_size(hidden_size, moe_in_dtype, + quant_config, + self.moe_parallel_config) # For smuggling this layer into the fused moe custom op compilation_config = vllm_config.compilation_config @@ -967,20 +1012,13 @@ def __init__( raise ValueError("Only softmax scoring function is supported for " "non-grouped topk.") - if vllm_config.model_config is not None: - model_dtype = vllm_config.model_config.dtype - else: - # TODO (bnell): This is a hack to get test_mixtral_moe to work - # since model_config is not set in the pytest test. - model_dtype = params_dtype - moe = FusedMoEConfig( num_experts=self.global_num_experts, experts_per_token=top_k, hidden_dim=hidden_size, num_local_experts=self.local_num_experts, moe_parallel_config=self.moe_parallel_config, - in_dtype=model_dtype, + in_dtype=moe_in_dtype, max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE, has_bias=has_bias, ) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index a16c254fadf6..5fce24018e64 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -76,7 +76,7 @@ def _moe_problem_size( """ assert w1.dim() == 3 and w2.dim() == 3 E, N, _ = w1.size() - K = w2.size(1) + K = a1.size(-1) if a1.dim() == 2: # Make sure we are using the correct a1 (pre-permute). diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 5c3f8a891276..a71c8d32a22c 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -13,7 +13,10 @@ FusedMoEMethodBase) from vllm.model_executor.layers.fused_moe import modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEQuantConfig, mxfp4_w4a4_moe_quant_config) + FusedMoEQuantConfig, mxfp4_w4a4_moe_quant_config, + mxfp4_w4a16_moe_quant_config) +from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( + OAITritonExperts) from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) @@ -578,9 +581,14 @@ def _interleave_mxfp4_cutlass_sm90(w): layer.w13_bias = Parameter(w13_bias, requires_grad=False) layer.w2_bias = Parameter(w2_bias, requires_grad=False) - # FIXME warp need to be adjusted based on batch size - # only apply to batched mode - if self.moe.use_ep: + # Ideally we'd use FusedMoEModularKernel.prepare_finalize object + # (stored in self.fused_experts) to determine if the MoE has a + # batched activation format. As self.fused_experts is not + # initialized at this point, we resort to checking the MoE config + # directly. + is_batched_moe = (self.moe.use_pplx_kernels + or self.moe.use_deepep_ll_kernels) + if is_batched_moe: num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8 else: num_warps = 8 @@ -640,16 +648,21 @@ def get_fused_moe_quant_config( if self.mxfp4_backend == Mxfp4Backend.TRITON: w1_scale = self.w13_precision_config w2_scale = self.w2_precision_config + return mxfp4_w4a16_moe_quant_config( + w1_bias=layer.w13_bias, + w2_bias=layer.w2_bias, + w1_scale=w1_scale, + w2_scale=w2_scale, + ) else: w1_scale = layer.w13_weight_scale w2_scale = layer.w2_weight_scale - - return mxfp4_w4a4_moe_quant_config( - w1_bias=layer.w13_bias, - w2_bias=layer.w2_bias, - w1_scale=w1_scale, - w2_scale=w2_scale, - ) + return mxfp4_w4a4_moe_quant_config( + w1_bias=layer.w13_bias, + w2_bias=layer.w2_bias, + w1_scale=w1_scale, + w2_scale=w2_scale, + ) def select_gemm_impl( self, @@ -661,6 +674,7 @@ def select_gemm_impl( raise NotImplementedError( "Mxfp4 does not support batched experts format for EP") else: + assert self.moe_quant_config is not None if (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16): # B200 code-path @@ -671,13 +685,10 @@ def select_gemm_impl( # TODO(bnell): part of quant_config "max_capture_size": self.max_capture_size, } - assert self.moe_quant_config is not None return TrtLlmGenExperts(self.moe, self.moe_quant_config, **kwargs) else: - # Use matmul_ogs from triton_kernels here! - raise NotImplementedError( - "Mxfp4 does not support non-batched experts format for EP") + return OAITritonExperts(self.moe_quant_config) def _route_and_experts( self, @@ -722,10 +733,16 @@ def _route_and_experts( logical_to_physical_map=logical_to_physical_map, logical_replica_count=logical_replica_count) + w13_weight = (self.w13_weight_triton_tensor + if layer.w13_weight is None else layer.w13_weight) + w2_weight = (self.w2_weight_triton_tensor + if layer.w2_weight is None else layer.w2_weight) + assert all([w is not None for w in [w13_weight, w2_weight]]) + return self.fused_experts( hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, + w1=w13_weight, + w2=w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, inplace=True,