From 3b3356be05fb42101552cf5f0a40e5c746c6d3ce Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 11 Aug 2025 15:24:36 +0200 Subject: [PATCH 01/64] Add MoE base class to Transformers backend Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config/__init__.py | 27 ++++++---- vllm/model_executor/models/transformers.py | 60 ++++++++++++++++++++++ 2 files changed, 76 insertions(+), 11 deletions(-) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 700d29f956a8..235360ba6b54 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -1205,17 +1205,7 @@ def _verify_bnb_config(self) -> None: self.enforce_eager = True def _verify_with_expert_parallelism(self) -> None: - num_expert_names = [ - "moe_num_experts", # Dbrx - "num_experts", # Jamba - "n_routed_experts", # DeepSeek - "num_local_experts", # Mixtral - ] - num_experts = 0 - for name in num_expert_names: - num_experts = getattr(self.hf_text_config, name, 0) - if num_experts > 0: - break + num_experts = self.get_num_experts() if num_experts < 1: raise ValueError( "Number of experts in the model must be greater than 0 " @@ -1429,6 +1419,21 @@ def get_num_attention_heads(self, num_heads = getattr(self.hf_text_config, "num_attention_heads", 0) return num_heads // parallel_config.tensor_parallel_size + def get_num_experts(self) -> int: + """Returns the number of experts in the model.""" + num_expert_names = [ + "moe_num_experts", # Dbrx + "num_experts", # Jamba + "n_routed_experts", # DeepSeek + "num_local_experts", # Mixtral + ] + num_experts = 0 + for name in num_expert_names: + num_experts = getattr(self.hf_text_config, name, 0) + if num_experts > 0: + break + return num_experts + def get_layers_start_end_indices( self, parallel_config: "ParallelConfig") -> tuple[int, int]: from vllm.distributed.utils import get_pp_indices diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index fc4585618b04..af7a54347620 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -33,6 +33,7 @@ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed.utils import get_pp_indices from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import (ColumnParallelLinear, ReplicatedLinear, RowParallelLinear) @@ -419,6 +420,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.pipeline_parallel() + self.fused_moe() self.tensor_parallel() # Input embeddings @@ -492,6 +494,13 @@ def pipeline_parallel(self): if not self.pp_group.is_last_rank: setattr(self.model, name, PPMissingLayer()) + def fused_moe(self): + """ + Substitute the model's MoE layers with vLLM's FusedMoE. + To be overridden by child classes if they support MoE. + """ + pass + def tensor_parallel(self): """ Apply the model's tensor parallelization plan. @@ -617,6 +626,57 @@ def load_weights(self, weights: Iterable[tuple[str, return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) +class TransformersMoEBase(TransformersBase): + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.model_config.get_num_experts(), + # num_redundant_experts=self.num_redundant_experts, + ) + + def fused_moe(self): + + def _fused_moe(module: nn.Module, prefix: str = ""): + for child_name, child_module in module.named_children(): + qual_name = maybe_prefix(prefix, child_name) + if (child_name == "experts" + and isinstance(child_module, nn.ModuleList)): + new_module = FusedMoE( + # num_experts=self.text_config.num_experts, + num_experts=self.model_config.get_num_experts(), + top_k=8, # TODO: set this properly + hidden_size=self.text_config.hidden_size, + intermediate_size=768, # TODO: set this properly + # params_dtype + # reduce_results + # renormalize + # use_grouped_topk + # num_expert_group + # topk_group + quant_config=self.quant_config, + prefix=qual_name, + # custom_routing_function + # scoring_func + # e_score_correction_bias + # apply_router_weight_on_input + # activation + # enable_eplb + # num_redundant_experts + # has_bias + ) + setattr(module, child_name, new_module) + log_replacement(qual_name, child_module, new_module) + else: + _fused_moe(child_module, prefix=qual_name) + + _fused_moe(self.model) + + @support_torch_compile class TransformersModel(TransformersBase): hf_to_vllm_mapper = WeightsMapper( From ee32a6eda8cb90dec4270c1b6699a1fbbf0b341a Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 11 Aug 2025 15:36:25 +0200 Subject: [PATCH 02/64] Add MoE model classes to Transformers backend Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config/__init__.py | 20 +++++++++++++++++--- vllm/model_executor/models/registry.py | 5 ++++- vllm/model_executor/models/transformers.py | 19 +++++++++++++++++++ 3 files changed, 40 insertions(+), 4 deletions(-) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 235360ba6b54..ee59a3e92e7a 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -774,13 +774,27 @@ def validate_model_config_after(self: "ModelConfig") -> "ModelConfig": def _get_transformers_backend_cls(self) -> str: """Determine which Transformers backend class will be used if `model_impl` is set to `transformers` or `auto`.""" + architecture_family = "moe" if self.get_num_experts() > 1 else "dense" + transformers_backend_cls_map = { + "dense": { + "model": "TransformersModel", + "for_causal_lm": "TransformersForCausalLM", + "for_multimodal_lm": "TransformersForMultimodalLM", + }, + "moe": { + "model": "TransformersMoEModel", + "for_causal_lm": "TransformersMoEForCausalLM", + "for_multimodal_lm": "TransformersMoEForMultimodalLM", + }, + }.get(architecture_family) + if getattr(self, "runner_type", self.runner) == "pooling": - return "TransformersModel" + return transformers_backend_cls_map["model"] if self.hf_config != self.hf_text_config: # If 'hf_text_config' is the same as 'hf_config'. If not, it is # probably a composite config, i.e. multimodal - return "TransformersForMultimodalLM" - return "TransformersForCausalLM" + return transformers_backend_cls_map["for_multimodal_lm"] + return transformers_backend_cls_map["for_causal_lm"] def using_transformers_backend(self) -> bool: """Check if the model is using the Transformers backend class.""" diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index aca3d84f0071..b50d9d4d8c10 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -274,8 +274,11 @@ _TRANSFORMERS_BACKEND_MODELS = { "TransformersModel": ("transformers", "TransformersModel"), + "TransformersMoEModel": ("transformers", "TransformersMoEModel"), "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"), - "TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501 + "TransformersMoEForCausalLM": ("transformers", "TransformersMoEForCausalLM"), # noqa: E501 + "TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501 + "TransformersMoEForMultimodalLM": ("transformers", "TransformersMoEForMultimodalLM"), # noqa: E501 } # yapf: enable diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index af7a54347620..446819f8e4a0 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -689,6 +689,11 @@ class TransformersModel(TransformersBase): }) +@support_torch_compile +class TransformersMoEModel(TransformersMoEBase, TransformersModel): + pass + + @support_torch_compile class TransformersForCausalLM(TransformersBase): @@ -729,6 +734,11 @@ def compute_logits( return logits +@support_torch_compile +class TransformersMoEForCausalLM(TransformersMoEBase, TransformersForCausalLM): + pass + + @MULTIMODAL_REGISTRY.register_processor( MultiModalProcessor, info=MultiModalProcessingInfo, @@ -851,3 +861,12 @@ def get_input_embeddings( inputs_embeds = inputs_embeds.masked_scatter( mask, multimodal_embeddings) return inputs_embeds + + +@MULTIMODAL_REGISTRY.register_processor( + MultiModalProcessor, + info=MultiModalProcessingInfo, + dummy_inputs=MultiModalDummyInputsBuilder) +class TransformersMoEForMultimodalLM(TransformersMoEForCausalLM, + TransformersForMultimodalLM): + pass From 64e827eb2fb461d09690bd17ef602285e17eef7c Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Wed, 13 Aug 2025 09:29:25 +0200 Subject: [PATCH 03/64] Make expert mapping comprehensive Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 34 ++++++++++++++++++---- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index ef2e63a2e935..c94a83142f34 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -654,15 +654,39 @@ def load_weights(self, weights: Iterable[tuple[str, class TransformersMoEBase(TransformersBase): def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - return FusedMoE.make_expert_params_mapping( + """ + Params for weights, fp8 weight scales, fp8 activation scales + (param_name, weight_name, expert_id, shard_id) + """ + num_experts = self.model_config.get_num_experts() + num_redundant_experts = 0 # TODO: enable EPLB + # Most common MoE style + expert_mapping = FusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.model_config.get_num_experts(), - # num_redundant_experts=self.num_redundant_experts, + num_experts=num_experts, + num_redundant_experts=num_redundant_experts, ) + # Granite, Mixtral, Phi MoE style + expert_mapping.extend( + FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="w1", + ckpt_down_proj_name="w2", + ckpt_up_proj_name="w3", + num_experts=num_experts, + num_redundant_experts=num_redundant_experts, + )) + # Grok1 style + expert_mapping.extend( + FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="linear", + ckpt_down_proj_name="linear_1", + ckpt_up_proj_name="linear_v", + num_experts=num_experts, + num_redundant_experts=num_redundant_experts, + )) + return expert_mapping def fused_moe(self): From 48ab51a9542cd001dcab486ccc7848fe6a150020 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Wed, 13 Aug 2025 14:19:37 +0200 Subject: [PATCH 04/64] Add guard against EPLB until it's tested Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index c94a83142f34..a5550c212589 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -690,6 +690,10 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def fused_moe(self): + if self.parallel_config.enable_eplb: + raise NotImplementedError( + "Transformers backend does not support EPLB yet!") + def _fused_moe(module: nn.Module, prefix: str = ""): for child_name, child_module in module.named_children(): qual_name = maybe_prefix(prefix, child_name) From 80f15e2a2cf8692dc41d50bf0cb86f21a2beb1c1 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Wed, 13 Aug 2025 14:27:54 +0200 Subject: [PATCH 05/64] Add FusedMoE weight loading to `AutoWeightLoader` Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 9 +++++++ vllm/model_executor/models/utils.py | 29 ++++++++++++++++++++-- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index a5550c212589..b2529d42d9ec 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -729,6 +729,15 @@ def _fused_moe(module: nn.Module, prefix: str = ""): _fused_moe(self.model) + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self, skip_prefixes=self.skip_prefixes) + return loader.load_weights( + weights, + mapper=self.hf_to_vllm_mapper, + expert_mapping=self.get_expert_mapping(), + ) + @support_torch_compile class TransformersModel(TransformersBase): diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 6c27fedc61b1..04a2a8eac61f 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -210,6 +210,7 @@ def _load_module( base_prefix: str, module: nn.Module, weights: Iterable[tuple[str, torch.Tensor]], + expert_mapping: Optional[list[tuple[str, str, int, str]]], ) -> Iterable[str]: if isinstance(module, PPMissingLayer): return @@ -246,9 +247,31 @@ def _load_module( continue + if expert_mapping is not None and child_prefix == "experts": + for expert_name, loaded_weight in child_weights: + for (param_name, weight_name, expert_id, + shard_id) in expert_mapping: + if weight_name not in f"experts.{expert_name}": + continue + fused_moe = child_modules[child_prefix] + param_name = ( + f"{param_name.removeprefix('experts.')}weight") + param = getattr(fused_moe, param_name) + weight_name = maybe_prefix(prefix, param_name) + fused_moe.weight_loader( + param=param, + loaded_weight=loaded_weight, + weight_name=weight_name, + shard_id=shard_id, + expert_id=expert_id, + ) + logger.debug("Loaded expert %d into %s", shard_id, + prefix) + yield weight_name + yield from self._load_module(prefix, child_modules[child_prefix], - child_weights) + child_weights, expert_mapping) elif child_prefix in child_params: if self._can_skip(prefix): logger.debug("Skipping param %s", prefix) @@ -281,6 +304,7 @@ def load_weights( weights: Iterable[tuple[str, torch.Tensor]], *, mapper: Optional[WeightsMapper] = None, + expert_mapping: Optional[list[tuple[str, str, int, str]]] = None, ) -> set[str]: if mapper is not None: weights = mapper.apply(weights) @@ -288,7 +312,8 @@ def load_weights( weights = ((name, weight) for name, weight in weights if not self._can_skip(name)) - autoloaded_weights = set(self._load_module("", self.module, weights)) + autoloaded_weights = set( + self._load_module("", self.module, weights, expert_mapping)) return autoloaded_weights From 68a132fb2802d44ea13495f242a43d4b1524b351 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Wed, 13 Aug 2025 17:08:55 +0200 Subject: [PATCH 06/64] More compact expert mapping generation Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 44 ++++++++-------------- 1 file changed, 16 insertions(+), 28 deletions(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index b2529d42d9ec..9d8f170675ee 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -658,34 +658,22 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: Params for weights, fp8 weight scales, fp8 activation scales (param_name, weight_name, expert_id, shard_id) """ - num_experts = self.model_config.get_num_experts() - num_redundant_experts = 0 # TODO: enable EPLB - # Most common MoE style - expert_mapping = FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=num_experts, - num_redundant_experts=num_redundant_experts, - ) - # Granite, Mixtral, Phi MoE style - expert_mapping.extend( - FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="w1", - ckpt_down_proj_name="w2", - ckpt_up_proj_name="w3", - num_experts=num_experts, - num_redundant_experts=num_redundant_experts, - )) - # Grok1 style - expert_mapping.extend( - FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="linear", - ckpt_down_proj_name="linear_1", - ckpt_up_proj_name="linear_v", - num_experts=num_experts, - num_redundant_experts=num_redundant_experts, - )) + ckpt_names = [ + # (ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name) + ("gate_proj", "down_proj", "up_proj"), # Most common MoE style + ("w1", "w2", "w3"), # Granite, Mixtral, Phi MoE style + ("linear", "linear_1", "linear_v"), # Grok1 style + ] + expert_mapping = [] + for gate_proj, down_proj, up_proj in ckpt_names: + expert_mapping.extend( + FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name=gate_proj, + ckpt_down_proj_name=down_proj, + ckpt_up_proj_name=up_proj, + num_experts=self.model_config.get_num_experts(), + num_redundant_experts=0, # TODO: enable EPLB + )) return expert_mapping def fused_moe(self): From 781d15d75a3b081669363eea49313ea162e80ca6 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Wed, 13 Aug 2025 18:49:19 +0200 Subject: [PATCH 07/64] Use forward hook to perform all reduce after experts in TP/EP Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 9d8f170675ee..88f92fe57995 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -678,6 +678,10 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def fused_moe(self): + def reduce_results(module, _, output): + if (experts := module.experts).tp_size > 1 or experts.ep_size > 1: + return experts.maybe_all_reduce_tensor_model_parallel(output) + if self.parallel_config.enable_eplb: raise NotImplementedError( "Transformers backend does not support EPLB yet!") @@ -687,6 +691,7 @@ def _fused_moe(module: nn.Module, prefix: str = ""): qual_name = maybe_prefix(prefix, child_name) if (child_name == "experts" and isinstance(child_module, nn.ModuleList)): + # Replace experts module with FusedMoE new_module = FusedMoE( # num_experts=self.text_config.num_experts, num_experts=self.model_config.get_num_experts(), @@ -694,7 +699,7 @@ def _fused_moe(module: nn.Module, prefix: str = ""): hidden_size=self.text_config.hidden_size, intermediate_size=768, # TODO: set this properly # params_dtype - # reduce_results + reduce_results=False, # renormalize # use_grouped_topk # num_expert_group @@ -712,6 +717,12 @@ def _fused_moe(module: nn.Module, prefix: str = ""): ) setattr(module, child_name, new_module) log_replacement(qual_name, child_module, new_module) + # Register all-reduce hook to the parent of the experts + # if tensor parallel or expert parallel is enabled. We do + # this instead of setting reduce_results=True to guarantee + # that the all-reduce happens after any shared experts have + # been added to the hidden state + module.register_forward_hook(reduce_results) else: _fused_moe(child_module, prefix=qual_name) From 8b1f3e881f7a2c49d16c2fd96b6dfbd0644475af Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Wed, 13 Aug 2025 18:51:43 +0200 Subject: [PATCH 08/64] Remove `params_dtype` because it's not used for most MoEs Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 88f92fe57995..4c6657540c71 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -698,7 +698,6 @@ def _fused_moe(module: nn.Module, prefix: str = ""): top_k=8, # TODO: set this properly hidden_size=self.text_config.hidden_size, intermediate_size=768, # TODO: set this properly - # params_dtype reduce_results=False, # renormalize # use_grouped_topk From 5b6f5fdbe5b619fa61e09d25fa80f8795b49e1db Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Wed, 13 Aug 2025 19:25:54 +0200 Subject: [PATCH 09/64] Better debug log for expert loading Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 04a2a8eac61f..5d5d8c4141d1 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -265,8 +265,8 @@ def _load_module( shard_id=shard_id, expert_id=expert_id, ) - logger.debug("Loaded expert %d into %s", shard_id, - prefix) + logger.debug("Loaded %s for expert %d into %s", + param_name, expert_id, prefix) yield weight_name yield from self._load_module(prefix, From c19ae9b437a1eaae5e919d1acea5115009c6de80 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 14 Aug 2025 09:58:54 +0200 Subject: [PATCH 10/64] Update transformers backend doc Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- docs/models/supported_models.md | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index dbbbc5122b80..2977902e4017 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -17,9 +17,28 @@ These models are what we list in [supported-text-models][supported-text-models] ### Transformers -vLLM also supports model implementations that are available in Transformers. This does not currently work for all models, but most decoder language models and common vision language models are supported! Vision-language models currently accept only image inputs. Support for video inputs will be added in future releases. +vLLM also supports model implementations that are available in Transformers. You should expect the performance of a Transformers model implementation used in vLLM to be within <1% of the performance of a dedicated vLLM model implementation. We call this feature the "Transformers backend". -To check if the modeling backend is Transformers, you can simply do this: +Currently, the Transformers backend works for the following: + +- Modalities: embedding models, language models and vision-language models* +- Attention types: full attention and/or sliding attention +- MLP types: dense and/or mixture-of-experts + +_*Vision-language models currently accept only image inputs. Support for video inputs will be added in a future release._ + +If the Transformers model implementation follows all the steps in [writing a custom model](#writing-custom-models) then, when used with the Transformers backend, it will be compatible with the following features of vLLM: + +- All the features listed in the [compatibility matrix](../features/compatibility_matrix.md#feature-x-feature)* +- Any combination of vLLM's parallelisation schemes: + - Data parallel + - Expert parallel + - Pipeline parallel + - Tensor parallel + +_*except encoder-decoder models. Support for encoder-decoder models will be added in a future release._ + +Checking if the modeling backend is Transformers is as simple as: ```python from vllm import LLM @@ -27,13 +46,9 @@ llm = LLM(model=...) # Name or path of your model llm.apply_model(lambda model: print(type(model))) ``` -If it is `TransformersForCausalLM` or `TransformersForMultimodalLM` then it means it's based on Transformers! - -!!! tip - You can force the use of `TransformersForCausalLM` by setting `model_impl="transformers"` for [offline-inference](../serving/offline_inference.md) or `--model-impl transformers` for the [openai-compatible-server](../serving/openai_compatible_server.md). +If the printed type starts with `Transformers...` then it's using the Transformers model implementation! -!!! note - vLLM may not fully optimise the Transformers implementation so you may see degraded performance if comparing a native model to a Transformers model in vLLM. +If your model has a vLLM implementation but you would prefer to use the Transformers implementation via the Transformers backend, set `model_impl="transformers"` for [offline inference](../serving/offline_inference.md) or `--model-impl transformers` for the [online serving](../serving/openai_compatible_server.md). !!! note In case of vision language models if you are loading with `dtype="auto"`, vLLM loads the whole model with config's `dtype` if it exists. In contrast the native Transformers will respect the `dtype` attribute of each backbone in the model. That might cause a slight difference in performance. From a9272bb54906564c4545035a016707deeffe1598 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 14 Aug 2025 09:59:20 +0200 Subject: [PATCH 11/64] Small doc tweak Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- docs/models/supported_models.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 2977902e4017..5619f1366dc3 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -51,7 +51,7 @@ If the printed type starts with `Transformers...` then it's using the Transforme If your model has a vLLM implementation but you would prefer to use the Transformers implementation via the Transformers backend, set `model_impl="transformers"` for [offline inference](../serving/offline_inference.md) or `--model-impl transformers` for the [online serving](../serving/openai_compatible_server.md). !!! note - In case of vision language models if you are loading with `dtype="auto"`, vLLM loads the whole model with config's `dtype` if it exists. In contrast the native Transformers will respect the `dtype` attribute of each backbone in the model. That might cause a slight difference in performance. + For vision-language models, if you are loading with `dtype="auto"`, vLLM loads the whole model with config's `dtype` if it exists. In contrast the native Transformers will respect the `dtype` attribute of each backbone in the model. That might cause a slight difference in performance. #### Custom models From 690df42da60fd26d36e65cd437b1680e24adf6aa Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 14 Aug 2025 10:20:34 +0200 Subject: [PATCH 12/64] Add docstring to `reduce_results` Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 4c6657540c71..0afe7b26c7ff 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -679,6 +679,8 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def fused_moe(self): def reduce_results(module, _, output): + """Forward hook that performs all-reduce on a nn.Module's + output if tensor parallel or expert parallel is enabled.""" if (experts := module.experts).tp_size > 1 or experts.ep_size > 1: return experts.maybe_all_reduce_tensor_model_parallel(output) From fd8bddb37e6b96a5e47ba534fdf82e1d19a6287c Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 14 Aug 2025 10:21:52 +0200 Subject: [PATCH 13/64] Set `renormalize` correctly Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 0afe7b26c7ff..aaddd909212b 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -678,6 +678,14 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def fused_moe(self): + # TODO: replace with self.text_config.num_experts, + num_experts = self.model_config.get_num_experts() + top_k = 8 # TODO: set this properly + hidden_size = self.text_config.hidden_size + intermediate_size = 768 # TODO: set this properly + renormalize: bool = getattr(self.hf_text_config, "norm_topk_prob", + top_k > 1) + def reduce_results(module, _, output): """Forward hook that performs all-reduce on a nn.Module's output if tensor parallel or expert parallel is enabled.""" @@ -695,13 +703,12 @@ def _fused_moe(module: nn.Module, prefix: str = ""): and isinstance(child_module, nn.ModuleList)): # Replace experts module with FusedMoE new_module = FusedMoE( - # num_experts=self.text_config.num_experts, - num_experts=self.model_config.get_num_experts(), - top_k=8, # TODO: set this properly - hidden_size=self.text_config.hidden_size, - intermediate_size=768, # TODO: set this properly + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, reduce_results=False, - # renormalize + renormalize=renormalize, # use_grouped_topk # num_expert_group # topk_group From 87473f762889069e0155ae4e7fd198c259913814 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 14 Aug 2025 12:05:03 +0200 Subject: [PATCH 14/64] Set `top_k` correctly Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index aaddd909212b..f145d61aa7d1 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -678,10 +678,12 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def fused_moe(self): + text_config = self.text_config + # TODO: replace with self.text_config.num_experts, num_experts = self.model_config.get_num_experts() - top_k = 8 # TODO: set this properly - hidden_size = self.text_config.hidden_size + top_k = text_config.num_experts_per_token + hidden_size = text_config.hidden_size intermediate_size = 768 # TODO: set this properly renormalize: bool = getattr(self.hf_text_config, "norm_topk_prob", top_k > 1) From 4d7e41c1dfd0b24d7072fa46c1d6bac2cf242eb9 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 14 Aug 2025 12:56:27 +0200 Subject: [PATCH 15/64] Add support for grouped topk expert selection Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index f145d61aa7d1..9cbcd6d29dba 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -688,6 +688,12 @@ def fused_moe(self): renormalize: bool = getattr(self.hf_text_config, "norm_topk_prob", top_k > 1) + # Grouped topk kwargs. If either config is set, enable grouped topk + # and let FusedMoE handle any errors from misconfiguration + num_expert_group = getattr(text_config, "n_group", None) + topk_group = getattr(text_config, "topk_group", None) + use_grouped_topk = num_expert_group or topk_group + def reduce_results(module, _, output): """Forward hook that performs all-reduce on a nn.Module's output if tensor parallel or expert parallel is enabled.""" @@ -711,9 +717,9 @@ def _fused_moe(module: nn.Module, prefix: str = ""): intermediate_size=intermediate_size, reduce_results=False, renormalize=renormalize, - # use_grouped_topk - # num_expert_group - # topk_group + use_grouped_topk=use_grouped_topk, + num_expert_group=num_expert_group, + topk_group=topk_group, quant_config=self.quant_config, prefix=qual_name, # custom_routing_function From a6c0483d4b3783fda0402d91532f6f234ca2ad4e Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 14 Aug 2025 12:57:16 +0200 Subject: [PATCH 16/64] Make `use_grouped_topk` a bool Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 9cbcd6d29dba..24d92ba3dba0 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -692,7 +692,7 @@ def fused_moe(self): # and let FusedMoE handle any errors from misconfiguration num_expert_group = getattr(text_config, "n_group", None) topk_group = getattr(text_config, "topk_group", None) - use_grouped_topk = num_expert_group or topk_group + use_grouped_topk = bool(num_expert_group or topk_group) def reduce_results(module, _, output): """Forward hook that performs all-reduce on a nn.Module's From cc45642391d5f148fa3b606f985b51098c41ad82 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 25 Aug 2025 15:56:23 +0200 Subject: [PATCH 17/64] Add note for removal of `get_num_experts` Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config/__init__.py | 3 +++ vllm/model_executor/models/transformers.py | 4 +++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index aada12617943..a3aa72d91bba 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -1496,6 +1496,9 @@ def get_num_attention_heads(self, num_heads = getattr(self.hf_text_config, "num_attention_heads", 0) return num_heads // parallel_config.tensor_parallel_size + # TODO: Remove once https://github.com/huggingface/transformers/pull/40156 + # is released. Attribute mapping will allow us to simply read + # self.hf_text_config.num_experts def get_num_experts(self) -> int: """Returns the number of experts in the model.""" num_expert_names = [ diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 63cfeb74da4e..b26f128fae64 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -679,7 +679,9 @@ def fused_moe(self): text_config = self.text_config - # TODO: replace with self.text_config.num_experts, + # TODO: Remove once https://github.com/huggingface/transformers/pull/40156 + # is released. Attribute mapping will allow us to simply read + # text_config.num_experts num_experts = self.model_config.get_num_experts() top_k = text_config.num_experts_per_token hidden_size = text_config.hidden_size From d711ab52a3797967e3fe318943f022634877a489 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 25 Aug 2025 15:57:25 +0200 Subject: [PATCH 18/64] Better handling of shared experts and renosmalisation Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 24 +++++++++++----------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index b26f128fae64..b4248e84a207 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -686,8 +686,8 @@ def fused_moe(self): top_k = text_config.num_experts_per_token hidden_size = text_config.hidden_size intermediate_size = 768 # TODO: set this properly - renormalize: bool = getattr(self.hf_text_config, "norm_topk_prob", - top_k > 1) + reduce_results = getattr(text_config, "num_experts_shared", 0) == 0 + renormalize = getattr(text_config, "norm_topk_prob", top_k > 1) # Grouped topk kwargs. If either config is set, enable grouped topk # and let FusedMoE handle any errors from misconfiguration @@ -695,9 +695,11 @@ def fused_moe(self): topk_group = getattr(text_config, "topk_group", None) use_grouped_topk = bool(num_expert_group or topk_group) - def reduce_results(module, _, output): - """Forward hook that performs all-reduce on a nn.Module's - output if tensor parallel or expert parallel is enabled.""" + def reduce_results_hook(module, _, output): + """Forward hook that performs all-reduce on a nn.Module's output if + tensor parallel or expert parallel is enabled. This is used for + models with shared experts where the all reduce happens after any + shared experts have been added to the hidden state.""" if (experts := module.experts).tp_size > 1 or experts.ep_size > 1: return experts.maybe_all_reduce_tensor_model_parallel(output) @@ -716,7 +718,7 @@ def _fused_moe(module: nn.Module, prefix: str = ""): top_k=top_k, hidden_size=hidden_size, intermediate_size=intermediate_size, - reduce_results=False, + reduce_results=reduce_results, renormalize=renormalize, use_grouped_topk=use_grouped_topk, num_expert_group=num_expert_group, @@ -734,12 +736,10 @@ def _fused_moe(module: nn.Module, prefix: str = ""): ) setattr(module, child_name, new_module) log_replacement(qual_name, child_module, new_module) - # Register all-reduce hook to the parent of the experts - # if tensor parallel or expert parallel is enabled. We do - # this instead of setting reduce_results=True to guarantee - # that the all-reduce happens after any shared experts have - # been added to the hidden state - module.register_forward_hook(reduce_results) + # If results are not all-reduced in FusedMoE, ensure they + # are all-reduced at the end of module.forward() + if not reduce_results: + module.register_forward_hook(reduce_results_hook) else: _fused_moe(child_module, prefix=qual_name) From 88154b61a369698b70c92fb56819d7b8c0388524 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 25 Aug 2025 17:48:17 +0200 Subject: [PATCH 19/64] Add util which does `getattr` for a list of possible names Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config/utils.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/vllm/config/utils.py b/vllm/config/utils.py index 98fbeb1fa86a..b91633c3c514 100644 --- a/vllm/config/utils.py +++ b/vllm/config/utils.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, TypeVar +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any, TypeVar if TYPE_CHECKING: from _typeshed import DataclassInstance @@ -27,3 +28,15 @@ def config(cls: ConfigT) -> ConfigT: script, which is invoked during the pre-commit checks. """ return cls + + +def getattr_iter(object: object, names: Iterable[str], default: Any) -> Any: + """ + A helper function that retrieves an attribute from an object which may + have multiple possible names. This is useful when fetching attributes from + arbitrary `transformers.PretrainedConfig` instances. + """ + for name in names: + if hasattr(object, name): + return getattr(object, name) + return default From bf26f607fa63fe0e5372c7c2306e6db3a7f189a6 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 25 Aug 2025 17:50:54 +0200 Subject: [PATCH 20/64] Use new util in ModelConfig Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config/__init__.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index a3aa72d91bba..a5ec3f713d6d 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -36,7 +36,7 @@ from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig, ParallelConfig) from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy -from vllm.config.utils import ConfigType, config +from vllm.config.utils import ConfigType, config, getattr_iter from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.platforms import current_platform @@ -1502,17 +1502,12 @@ def get_num_attention_heads(self, def get_num_experts(self) -> int: """Returns the number of experts in the model.""" num_expert_names = [ - "moe_num_experts", # Dbrx "num_experts", # Jamba + "moe_num_experts", # Dbrx "n_routed_experts", # DeepSeek "num_local_experts", # Mixtral ] - num_experts = 0 - for name in num_expert_names: - num_experts = getattr(self.hf_text_config, name, 0) - if num_experts > 0: - break - return num_experts + return getattr_iter(self.hf_text_config, num_expert_names, 0) def get_layers_start_end_indices( self, parallel_config: "ParallelConfig") -> tuple[int, int]: From 982643ca5fe92ab8b1f7d8bcbb84f9e472df1ee8 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 25 Aug 2025 17:52:13 +0200 Subject: [PATCH 21/64] Set `intermediate_size` properly Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index b4248e84a207..9dcc1d2d5aec 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -30,6 +30,7 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, ParallelConfig, VllmConfig) +from vllm.config.utils import getattr_iter from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed.utils import get_pp_indices from vllm.logger import init_logger @@ -679,13 +680,14 @@ def fused_moe(self): text_config = self.text_config - # TODO: Remove once https://github.com/huggingface/transformers/pull/40156 - # is released. Attribute mapping will allow us to simply read - # text_config.num_experts + # Positional arguments num_experts = self.model_config.get_num_experts() top_k = text_config.num_experts_per_token hidden_size = text_config.hidden_size - intermediate_size = 768 # TODO: set this properly + names = ["moe_intermediate_size", "intermediate_size"] + intermediate_size = getattr_iter(text_config, names, None) + + # Reduction kwargs reduce_results = getattr(text_config, "num_experts_shared", 0) == 0 renormalize = getattr(text_config, "norm_topk_prob", top_k > 1) From 583d5c60777397cc5c68c34ff0da226f41853b19 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 25 Aug 2025 17:53:37 +0200 Subject: [PATCH 22/64] Move reduction kwargs to be together Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 9dcc1d2d5aec..8ff9bef73a7d 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -691,12 +691,6 @@ def fused_moe(self): reduce_results = getattr(text_config, "num_experts_shared", 0) == 0 renormalize = getattr(text_config, "norm_topk_prob", top_k > 1) - # Grouped topk kwargs. If either config is set, enable grouped topk - # and let FusedMoE handle any errors from misconfiguration - num_expert_group = getattr(text_config, "n_group", None) - topk_group = getattr(text_config, "topk_group", None) - use_grouped_topk = bool(num_expert_group or topk_group) - def reduce_results_hook(module, _, output): """Forward hook that performs all-reduce on a nn.Module's output if tensor parallel or expert parallel is enabled. This is used for @@ -705,6 +699,12 @@ def reduce_results_hook(module, _, output): if (experts := module.experts).tp_size > 1 or experts.ep_size > 1: return experts.maybe_all_reduce_tensor_model_parallel(output) + # Grouped topk kwargs. If either config is set, enable grouped topk + # and let FusedMoE handle any errors from misconfiguration + num_expert_group = getattr(text_config, "n_group", None) + topk_group = getattr(text_config, "topk_group", None) + use_grouped_topk = bool(num_expert_group or topk_group) + if self.parallel_config.enable_eplb: raise NotImplementedError( "Transformers backend does not support EPLB yet!") From e4fad6f56c119f479cb04b47056c638ddab2ed87 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 25 Aug 2025 17:55:14 +0200 Subject: [PATCH 23/64] Set `e_score_correction_bias` correctly Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 8ff9bef73a7d..f1eee1f39f56 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -714,6 +714,10 @@ def _fused_moe(module: nn.Module, prefix: str = ""): qual_name = maybe_prefix(prefix, child_name) if (child_name == "experts" and isinstance(child_module, nn.ModuleList)): + # Get e_score_correction_bias if present + gate = getattr(module, "gate", None) + e_score_correction_bias = getattr( + gate, "e_score_correction_bias", None) # Replace experts module with FusedMoE new_module = FusedMoE( num_experts=num_experts, @@ -729,7 +733,7 @@ def _fused_moe(module: nn.Module, prefix: str = ""): prefix=qual_name, # custom_routing_function # scoring_func - # e_score_correction_bias + e_score_correction_bias=e_score_correction_bias, # apply_router_weight_on_input # activation # enable_eplb From b5a5cbc838f33c3fa46b54c8baeecbf27a068e3a Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 25 Aug 2025 17:59:40 +0200 Subject: [PATCH 24/64] Add EPLB support Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index f1eee1f39f56..4ff7452d3e04 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -705,10 +705,14 @@ def reduce_results_hook(module, _, output): topk_group = getattr(text_config, "topk_group", None) use_grouped_topk = bool(num_expert_group or topk_group) - if self.parallel_config.enable_eplb: - raise NotImplementedError( - "Transformers backend does not support EPLB yet!") + # Expert parallel load balancing kwargs + parallel_config = self.parallel_config + eplb_config = parallel_config.eplb_config + enable_eplb = parallel_config.enable_eplb + num_redundant_experts = eplb_config.num_redundant_experts + + # Recursively fuse MoE layers def _fused_moe(module: nn.Module, prefix: str = ""): for child_name, child_module in module.named_children(): qual_name = maybe_prefix(prefix, child_name) @@ -736,8 +740,8 @@ def _fused_moe(module: nn.Module, prefix: str = ""): e_score_correction_bias=e_score_correction_bias, # apply_router_weight_on_input # activation - # enable_eplb - # num_redundant_experts + enable_eplb=enable_eplb, + num_redundant_experts=num_redundant_experts, # has_bias ) setattr(module, child_name, new_module) From dbd0c4ea082b5686535515959e8f5c1e242f45d9 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 25 Aug 2025 18:00:14 +0200 Subject: [PATCH 25/64] Label remaining missing features with the models that require them Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 4ff7452d3e04..7691cc3c075f 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -735,14 +735,14 @@ def _fused_moe(module: nn.Module, prefix: str = ""): topk_group=topk_group, quant_config=self.quant_config, prefix=qual_name, - # custom_routing_function - # scoring_func + # TODO: custom_routing_function - llama4, phimoe + # TODO: scoring_func - deepseek_v2, dots1, glm4_moe e_score_correction_bias=e_score_correction_bias, - # apply_router_weight_on_input - # activation + # TODO: apply_router_weight_on_input - llama4 + # TODO: activation - grok1, gpt-oss enable_eplb=enable_eplb, num_redundant_experts=num_redundant_experts, - # has_bias + # TODO: has_bias - gpt-oss ) setattr(module, child_name, new_module) log_replacement(qual_name, child_module, new_module) From 40c097b8302bd8e2b3ba360f2db9e1c2955d0c78 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 28 Aug 2025 18:30:37 +0200 Subject: [PATCH 26/64] Fix `support_torch_compile` for MoE classes Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index c5bd89e2094e..d24e3ee048a0 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -848,7 +848,7 @@ def compute_logits( return logits -@support_torch_compile +@support_torch_compile(enable_if=can_enable_torch_compile) class TransformersMoEForCausalLM(TransformersMoEBase, TransformersForCausalLM): pass @@ -996,10 +996,15 @@ def get_input_embeddings( return inputs_embeds -@MULTIMODAL_REGISTRY.register_processor( - MultiModalProcessor, - info=MultiModalProcessingInfo, - dummy_inputs=MultiModalDummyInputsBuilder) +@support_torch_compile( + # set `positions` to last dim to support Qwen-mrope + dynamic_arg_dims={ + "input_ids": 0, + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + }, + enable_if=can_enable_torch_compile) class TransformersMoEForMultimodalLM(TransformersMoEForCausalLM, TransformersForMultimodalLM): pass From 55fb46feb3f3191d0d5f87b5e519ccbd8d806281 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 28 Aug 2025 18:31:34 +0200 Subject: [PATCH 27/64] Fix typo Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index d24e3ee048a0..27ecb9c3af95 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -708,7 +708,7 @@ def fused_moe(self): # Positional arguments num_experts = self.model_config.get_num_experts() - top_k = text_config.num_experts_per_token + top_k = text_config.num_experts_per_tok hidden_size = text_config.hidden_size names = ["moe_intermediate_size", "intermediate_size"] intermediate_size = getattr_iter(text_config, names, None) From 47b9d5357eb15b9cdecbdf2d62370c238828cb07 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Tue, 23 Sep 2025 12:46:16 +0200 Subject: [PATCH 28/64] CustomFusedMoE Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 42 +++++++++++++++++++--- 1 file changed, 37 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 91594401f8d5..7cd69325b4f4 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -755,11 +755,44 @@ def reduce_results_hook(module, _, output): if (experts := module.experts).tp_size > 1 or experts.ep_size > 1: return experts.maybe_all_reduce_tensor_model_parallel(output) - # Grouped topk kwargs. If either config is set, enable grouped topk - # and let FusedMoE handle any errors from misconfiguration + # Grouped topk kwargs. We hardcode use_grouped_topk = False so that + # FusedMoE will use the custom_routing_function, which is necessary + # because the routing happens on the Transformers side num_expert_group = getattr(text_config, "n_group", None) topk_group = getattr(text_config, "topk_group", None) - use_grouped_topk = bool(num_expert_group or topk_group) + use_grouped_topk = False + + class CustomFusedMoE(FusedMoE): + + def forward(self, hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor): + """Custom forward which allows us to bypass the routing inside + FusedMoE. Instead, we use the top_k_index and top_k_weights + computed on the Transformers side. + + On the Transformers side the schema for `args` is + `[hidden_states, top_k_index, top_k_weights]`. + + On the vLLM side these correspond to + `[hidden_states, topk_ids, topk_weights]` respectively.""" + + def custom_routing_function( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + ) -> tuple[torch.Tensor, torch.Tensor]: + # gating_output is actually topk_weights + topk_weights = gating_output + # Use the saved topk_ids + return topk_weights, top_k_index + + self.custom_routing_function = custom_routing_function + + # Call the parent forward method with hidden_states and + # topk_weights as gating_output + return super().forward(hidden_states, top_k_weights) # Expert parallel load balancing kwargs parallel_config = self.parallel_config @@ -779,7 +812,7 @@ def _fused_moe(module: nn.Module, prefix: str = ""): e_score_correction_bias = getattr( gate, "e_score_correction_bias", None) # Replace experts module with FusedMoE - new_module = FusedMoE( + new_module = CustomFusedMoE( num_experts=num_experts, top_k=top_k, hidden_size=hidden_size, @@ -791,7 +824,6 @@ def _fused_moe(module: nn.Module, prefix: str = ""): topk_group=topk_group, quant_config=self.quant_config, prefix=qual_name, - # TODO: custom_routing_function - llama4, phimoe # TODO: scoring_func - deepseek_v2, dots1, glm4_moe e_score_correction_bias=e_score_correction_bias, # TODO: apply_router_weight_on_input - llama4 From 8b351172b0563990f98643c6cda658b955d2c581 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 25 Sep 2025 12:54:13 +0000 Subject: [PATCH 29/64] Use custom op for top_k_index bypass Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 87 +++++++++++++--------- 1 file changed, 53 insertions(+), 34 deletions(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 7cd69325b4f4..216bda0d1e91 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -34,7 +34,9 @@ from vllm.config.utils import getattr_iter from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed.utils import get_pp_indices +from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger +from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import (ColumnParallelLinear, ReplicatedLinear, @@ -51,8 +53,9 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo) from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from vllm.utils import is_list_of +from vllm.utils import direct_register_custom_op, is_list_of from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP, SupportsQuant) @@ -707,6 +710,54 @@ def load_weights(self, weights: Iterable[tuple[str, return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) +@CustomOp.register("transformers_fused_moe") +class TransformersFusedMoE(FusedMoE): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._top_k_index: torch.Tensor = None + + def custom_routing_function(hidden_states, gating_output, topk, + renormalize, *extra): + return gating_output, self._top_k_index + + self.custom_routing_function = custom_routing_function + + def forward(self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, + top_k_weights: torch.Tensor): + return torch.ops.vllm.transformers_moe_forward(hidden_states, + top_k_index, + top_k_weights, + self.layer_name) + + +def transformers_moe_forward(hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + layer_name: str) -> torch.Tensor: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + self._top_k_index = top_k_index + return self.forward_impl(hidden_states.clone(), top_k_weights) + + +def transformers_moe_forward_fake(hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + layer_name: str) -> torch.Tensor: + return torch.empty_like(hidden_states) + + +direct_register_custom_op( + op_name="transformers_moe_forward", + op_func=transformers_moe_forward, + mutates_args=["hidden_states"], + fake_impl=transformers_moe_forward_fake, + dispatch_key=current_platform.dispatch_key, + tags=(torch.Tag.needs_fixed_stride_order, ), +) + + class TransformersMoEBase(TransformersBase): def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: @@ -762,38 +813,6 @@ def reduce_results_hook(module, _, output): topk_group = getattr(text_config, "topk_group", None) use_grouped_topk = False - class CustomFusedMoE(FusedMoE): - - def forward(self, hidden_states: torch.Tensor, - top_k_index: torch.Tensor, - top_k_weights: torch.Tensor): - """Custom forward which allows us to bypass the routing inside - FusedMoE. Instead, we use the top_k_index and top_k_weights - computed on the Transformers side. - - On the Transformers side the schema for `args` is - `[hidden_states, top_k_index, top_k_weights]`. - - On the vLLM side these correspond to - `[hidden_states, topk_ids, topk_weights]` respectively.""" - - def custom_routing_function( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - ) -> tuple[torch.Tensor, torch.Tensor]: - # gating_output is actually topk_weights - topk_weights = gating_output - # Use the saved topk_ids - return topk_weights, top_k_index - - self.custom_routing_function = custom_routing_function - - # Call the parent forward method with hidden_states and - # topk_weights as gating_output - return super().forward(hidden_states, top_k_weights) - # Expert parallel load balancing kwargs parallel_config = self.parallel_config eplb_config = parallel_config.eplb_config @@ -812,7 +831,7 @@ def _fused_moe(module: nn.Module, prefix: str = ""): e_score_correction_bias = getattr( gate, "e_score_correction_bias", None) # Replace experts module with FusedMoE - new_module = CustomFusedMoE( + new_module = TransformersFusedMoE( num_experts=num_experts, top_k=top_k, hidden_size=hidden_size, From 524b4c401e0f53c973177c1817e0bb7d8d4635c7 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 25 Sep 2025 12:54:53 +0000 Subject: [PATCH 30/64] Fix reduce_results handling Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 216bda0d1e91..dd5409d36ac5 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -795,7 +795,8 @@ def fused_moe(self): intermediate_size = getattr_iter(text_config, names, None) # Reduction kwargs - reduce_results = getattr(text_config, "num_experts_shared", 0) == 0 + names = ["num_experts_shared", "shared_expert_intermediate_size"] + reduce_results = getattr_iter(text_config, names, 0) == 0 renormalize = getattr(text_config, "norm_topk_prob", top_k > 1) def reduce_results_hook(module, _, output): @@ -803,8 +804,10 @@ def reduce_results_hook(module, _, output): tensor parallel or expert parallel is enabled. This is used for models with shared experts where the all reduce happens after any shared experts have been added to the hidden state.""" - if (experts := module.experts).tp_size > 1 or experts.ep_size > 1: - return experts.maybe_all_reduce_tensor_model_parallel(output) + if isinstance(output, tuple): + output = output[0] + return module.experts.maybe_all_reduce_tensor_model_parallel( + output) # Grouped topk kwargs. We hardcode use_grouped_topk = False so that # FusedMoE will use the custom_routing_function, which is necessary @@ -855,7 +858,8 @@ def _fused_moe(module: nn.Module, prefix: str = ""): log_replacement(qual_name, child_module, new_module) # If results are not all-reduced in FusedMoE, ensure they # are all-reduced at the end of module.forward() - if not reduce_results: + if not reduce_results and (new_module.tp_size > 1 + or new_module.ep_size > 1): module.register_forward_hook(reduce_results_hook) else: _fused_moe(child_module, prefix=qual_name) From 8b2bd5983cdd128dcea2ad3cbaf291c676098a4f Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 25 Sep 2025 13:44:17 +0000 Subject: [PATCH 31/64] Reorganise the kwargs a little Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 24 ++++++++++------------ 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index dd5409d36ac5..d5a577f59025 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -797,7 +797,6 @@ def fused_moe(self): # Reduction kwargs names = ["num_experts_shared", "shared_expert_intermediate_size"] reduce_results = getattr_iter(text_config, names, 0) == 0 - renormalize = getattr(text_config, "norm_topk_prob", top_k > 1) def reduce_results_hook(module, _, output): """Forward hook that performs all-reduce on a nn.Module's output if @@ -809,12 +808,17 @@ def reduce_results_hook(module, _, output): return module.experts.maybe_all_reduce_tensor_model_parallel( output) - # Grouped topk kwargs. We hardcode use_grouped_topk = False so that - # FusedMoE will use the custom_routing_function, which is necessary - # because the routing happens on the Transformers side + # Unused kwargs since we use custom_routing_function: + # - `scoring_func` and `e_score_correction_bias` only used for grouped + # topk routing inside vLLM and are non-trivial to infer + # and hard code `use_grouped_topk=False` + # - `renormalize` passed anyway because it's easy to infer + # - `num_expert_group` and `topk_group` used for inferring expert + # placement strategy in FusedMoE + # - `apply_router_weight_on_input` is already applied in Transformers + renormalize = getattr(text_config, "norm_topk_prob", top_k > 1) num_expert_group = getattr(text_config, "n_group", None) topk_group = getattr(text_config, "topk_group", None) - use_grouped_topk = False # Expert parallel load balancing kwargs parallel_config = self.parallel_config @@ -829,10 +833,6 @@ def _fused_moe(module: nn.Module, prefix: str = ""): qual_name = maybe_prefix(prefix, child_name) if (child_name == "experts" and isinstance(child_module, nn.ModuleList)): - # Get e_score_correction_bias if present - gate = getattr(module, "gate", None) - e_score_correction_bias = getattr( - gate, "e_score_correction_bias", None) # Replace experts module with FusedMoE new_module = TransformersFusedMoE( num_experts=num_experts, @@ -841,14 +841,12 @@ def _fused_moe(module: nn.Module, prefix: str = ""): intermediate_size=intermediate_size, reduce_results=reduce_results, renormalize=renormalize, - use_grouped_topk=use_grouped_topk, + # Hard coded because topk happens in Transformers + use_grouped_topk=False, num_expert_group=num_expert_group, topk_group=topk_group, quant_config=self.quant_config, prefix=qual_name, - # TODO: scoring_func - deepseek_v2, dots1, glm4_moe - e_score_correction_bias=e_score_correction_bias, - # TODO: apply_router_weight_on_input - llama4 # TODO: activation - grok1, gpt-oss enable_eplb=enable_eplb, num_redundant_experts=num_redundant_experts, From c4120fb42095beef214d9959c29d1f64a5708804 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 25 Sep 2025 13:44:52 +0000 Subject: [PATCH 32/64] Make docs claim more conservative Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- docs/models/supported_models.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index b5f5ccfa213b..7d8581f73aa1 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -17,7 +17,7 @@ These models are what we list in [supported-text-models][supported-text-models] ### Transformers -vLLM also supports model implementations that are available in Transformers. You should expect the performance of a Transformers model implementation used in vLLM to be within <1% of the performance of a dedicated vLLM model implementation. We call this feature the "Transformers backend". +vLLM also supports model implementations that are available in Transformers. You should expect the performance of a Transformers model implementation used in vLLM to be within <5% of the performance of a dedicated vLLM model implementation. We call this feature the "Transformers backend". Currently, the Transformers backend works for the following: From 0134d623233ea10501d45e9204a4aca1bc013b9c Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 25 Sep 2025 13:56:25 +0000 Subject: [PATCH 33/64] Better transformers backend class resolution Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config/model.py | 24 +++++++----------------- 1 file changed, 7 insertions(+), 17 deletions(-) diff --git a/vllm/config/model.py b/vllm/config/model.py index aa9e88210ba0..88cfe8556b40 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -665,27 +665,17 @@ def validate_model_config_after(self: "ModelConfig") -> "ModelConfig": def _get_transformers_backend_cls(self) -> str: """Determine which Transformers backend class will be used if `model_impl` is set to `transformers` or `auto`.""" - architecture_family = "moe" if self.get_num_experts() > 1 else "dense" - transformers_backend_cls_map = { - "dense": { - "model": "TransformersModel", - "for_causal_lm": "TransformersForCausalLM", - "for_multimodal_lm": "TransformersForMultimodalLM", - }, - "moe": { - "model": "TransformersMoEModel", - "for_causal_lm": "TransformersMoEForCausalLM", - "for_multimodal_lm": "TransformersMoEForMultimodalLM", - }, - }.get(architecture_family) - + prefix = "Transformers" + # Resolve Transformers backend pooling class if getattr(self, "runner_type", self.runner) == "pooling": - return transformers_backend_cls_map["model"] + return prefix + "Model" + # Resolve Transformers backend generate classes + prefix += "MoE" if self.get_num_experts() > 1 else "" if self.hf_config != self.hf_text_config: # If 'hf_text_config' is the same as 'hf_config'. If not, it is # probably a composite config, i.e. multimodal - return transformers_backend_cls_map["for_multimodal_lm"] - return transformers_backend_cls_map["for_causal_lm"] + return prefix + "ForMultimodalLM" + return prefix + "ForCausalLM" def using_transformers_backend(self) -> bool: """Check if the model is using the Transformers backend class.""" From dbe1352252a588c5d51c4cda948c917532a783ea Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 25 Sep 2025 14:00:32 +0000 Subject: [PATCH 34/64] Allow MoE pooling Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config/model.py b/vllm/config/model.py index 88cfe8556b40..c5ef5083d504 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -666,11 +666,11 @@ def _get_transformers_backend_cls(self) -> str: """Determine which Transformers backend class will be used if `model_impl` is set to `transformers` or `auto`.""" prefix = "Transformers" + prefix += "MoE" if self.get_num_experts() > 1 else "" # Resolve Transformers backend pooling class if getattr(self, "runner_type", self.runner) == "pooling": return prefix + "Model" # Resolve Transformers backend generate classes - prefix += "MoE" if self.get_num_experts() > 1 else "" if self.hf_config != self.hf_text_config: # If 'hf_text_config' is the same as 'hf_config'. If not, it is # probably a composite config, i.e. multimodal From 152d7ff7126e71739c8da7b67bc03bf7c93bca77 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 25 Sep 2025 14:00:51 +0000 Subject: [PATCH 35/64] Remove TODO Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config/model.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/config/model.py b/vllm/config/model.py index c5ef5083d504..dc27552bba7f 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -1257,9 +1257,6 @@ def get_num_attention_heads(self, parallel_config: ParallelConfig) -> int: num_heads = getattr(self.hf_text_config, "num_attention_heads", 0) return num_heads // parallel_config.tensor_parallel_size - # TODO: Remove once https://github.com/huggingface/transformers/pull/40156 - # is released. Attribute mapping will allow us to simply read - # self.hf_text_config.num_experts def get_num_experts(self) -> int: """Returns the number of experts in the model.""" num_expert_names = [ From 199d1db50a0a1fa64091efca3a3a5dca57e032af Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 25 Sep 2025 15:34:02 +0000 Subject: [PATCH 36/64] Extract `TransformersMoE` classes to `transformers_moe` Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/registry.py | 6 +- vllm/model_executor/models/transformers.py | 217 +-------------- .../model_executor/models/transformers_moe.py | 255 ++++++++++++++++++ 3 files changed, 272 insertions(+), 206 deletions(-) create mode 100644 vllm/model_executor/models/transformers_moe.py diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 1d614b059af3..8ef6f85c529c 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -305,11 +305,11 @@ _TRANSFORMERS_BACKEND_MODELS = { "TransformersModel": ("transformers", "TransformersModel"), - "TransformersMoEModel": ("transformers", "TransformersMoEModel"), "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"), - "TransformersMoEForCausalLM": ("transformers", "TransformersMoEForCausalLM"), # noqa: E501 "TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501 - "TransformersMoEForMultimodalLM": ("transformers", "TransformersMoEForMultimodalLM"), # noqa: E501 + "TransformersMoEModel": ("transformers_moe", "TransformersMoEModel"), + "TransformersMoEForCausalLM": ("transformers_moe", "TransformersMoEForCausalLM"), # noqa: E501 + "TransformersMoEForMultimodalLM": ("transformers_moe", "TransformersMoEForMultimodalLM"), # noqa: E501 } # yapf: enable diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index d5a577f59025..a6b2b8ecf06d 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -31,13 +31,9 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, ParallelConfig, VllmConfig) -from vllm.config.utils import getattr_iter from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed.utils import get_pp_indices -from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger -from vllm.model_executor.custom_op import CustomOp -from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import (ColumnParallelLinear, ReplicatedLinear, RowParallelLinear) @@ -53,9 +49,8 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from vllm.utils import direct_register_custom_op, is_list_of +from vllm.utils import is_list_of from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP, SupportsQuant) @@ -449,7 +444,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.device_config: DeviceConfig = vllm_config.device_config self.model_config: ModelConfig = vllm_config.model_config self.parallel_config: ParallelConfig = vllm_config.parallel_config - self.quant_config: QuantizationConfig = vllm_config.quant_config + self.quant_config: Optional[ + QuantizationConfig] = vllm_config.quant_config self.pp_group = get_pp_group() self.pp_size = self.pp_group.world_size @@ -458,7 +454,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Weights to skip in `self.load_weights` self.skip_prefixes: list[str] = [] + """Skip loading weights whose qualname starts with these prefixes.""" self.skip_substrs: list[str] = [] + """Skip loading weights whose qualname contains these substrings.""" # Set correct attn and init on "meta" to delay allocating GPU tensors # TODO: @raushan, use the public `model.set_attn_implementation()` @@ -472,7 +470,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.pipeline_parallel() - self.fused_moe() + self.init_hook() self.tensor_parallel() # Input embeddings @@ -549,11 +547,10 @@ def pipeline_parallel(self): if not self.pp_group.is_last_rank: setattr(self.model, name, PPMissingLayer()) - def fused_moe(self): - """ - Substitute the model's MoE layers with vLLM's FusedMoE. - To be overridden by child classes if they support MoE. - """ + def init_hook(self): + """Method to be overridden by child classes if necessary. + + Called after `pipeline_parallel()` but before `tensor_parallel()`.""" pass def tensor_parallel(self): @@ -700,8 +697,10 @@ def forward( return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights( + self, + weights: Iterable[tuple[str, torch.Tensor]], + ) -> set[str]: loader = AutoWeightsLoader( self, skip_prefixes=self.skip_prefixes, @@ -710,170 +709,6 @@ def load_weights(self, weights: Iterable[tuple[str, return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) -@CustomOp.register("transformers_fused_moe") -class TransformersFusedMoE(FusedMoE): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._top_k_index: torch.Tensor = None - - def custom_routing_function(hidden_states, gating_output, topk, - renormalize, *extra): - return gating_output, self._top_k_index - - self.custom_routing_function = custom_routing_function - - def forward(self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, - top_k_weights: torch.Tensor): - return torch.ops.vllm.transformers_moe_forward(hidden_states, - top_k_index, - top_k_weights, - self.layer_name) - - -def transformers_moe_forward(hidden_states: torch.Tensor, - top_k_index: torch.Tensor, - top_k_weights: torch.Tensor, - layer_name: str) -> torch.Tensor: - forward_context: ForwardContext = get_forward_context() - self = forward_context.no_compile_layers[layer_name] - self._top_k_index = top_k_index - return self.forward_impl(hidden_states.clone(), top_k_weights) - - -def transformers_moe_forward_fake(hidden_states: torch.Tensor, - top_k_index: torch.Tensor, - top_k_weights: torch.Tensor, - layer_name: str) -> torch.Tensor: - return torch.empty_like(hidden_states) - - -direct_register_custom_op( - op_name="transformers_moe_forward", - op_func=transformers_moe_forward, - mutates_args=["hidden_states"], - fake_impl=transformers_moe_forward_fake, - dispatch_key=current_platform.dispatch_key, - tags=(torch.Tag.needs_fixed_stride_order, ), -) - - -class TransformersMoEBase(TransformersBase): - - def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - """ - Params for weights, fp8 weight scales, fp8 activation scales - (param_name, weight_name, expert_id, shard_id) - """ - ckpt_names = [ - # (ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name) - ("gate_proj", "down_proj", "up_proj"), # Most common MoE style - ("w1", "w2", "w3"), # Granite, Mixtral, Phi MoE style - ("linear", "linear_1", "linear_v"), # Grok1 style - ] - expert_mapping = [] - for gate_proj, down_proj, up_proj in ckpt_names: - expert_mapping.extend( - FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name=gate_proj, - ckpt_down_proj_name=down_proj, - ckpt_up_proj_name=up_proj, - num_experts=self.model_config.get_num_experts(), - num_redundant_experts=0, # TODO: enable EPLB - )) - return expert_mapping - - def fused_moe(self): - - text_config = self.text_config - - # Positional arguments - num_experts = self.model_config.get_num_experts() - top_k = text_config.num_experts_per_tok - hidden_size = text_config.hidden_size - names = ["moe_intermediate_size", "intermediate_size"] - intermediate_size = getattr_iter(text_config, names, None) - - # Reduction kwargs - names = ["num_experts_shared", "shared_expert_intermediate_size"] - reduce_results = getattr_iter(text_config, names, 0) == 0 - - def reduce_results_hook(module, _, output): - """Forward hook that performs all-reduce on a nn.Module's output if - tensor parallel or expert parallel is enabled. This is used for - models with shared experts where the all reduce happens after any - shared experts have been added to the hidden state.""" - if isinstance(output, tuple): - output = output[0] - return module.experts.maybe_all_reduce_tensor_model_parallel( - output) - - # Unused kwargs since we use custom_routing_function: - # - `scoring_func` and `e_score_correction_bias` only used for grouped - # topk routing inside vLLM and are non-trivial to infer - # and hard code `use_grouped_topk=False` - # - `renormalize` passed anyway because it's easy to infer - # - `num_expert_group` and `topk_group` used for inferring expert - # placement strategy in FusedMoE - # - `apply_router_weight_on_input` is already applied in Transformers - renormalize = getattr(text_config, "norm_topk_prob", top_k > 1) - num_expert_group = getattr(text_config, "n_group", None) - topk_group = getattr(text_config, "topk_group", None) - - # Expert parallel load balancing kwargs - parallel_config = self.parallel_config - eplb_config = parallel_config.eplb_config - - enable_eplb = parallel_config.enable_eplb - num_redundant_experts = eplb_config.num_redundant_experts - - # Recursively fuse MoE layers - def _fused_moe(module: nn.Module, prefix: str = ""): - for child_name, child_module in module.named_children(): - qual_name = maybe_prefix(prefix, child_name) - if (child_name == "experts" - and isinstance(child_module, nn.ModuleList)): - # Replace experts module with FusedMoE - new_module = TransformersFusedMoE( - num_experts=num_experts, - top_k=top_k, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - reduce_results=reduce_results, - renormalize=renormalize, - # Hard coded because topk happens in Transformers - use_grouped_topk=False, - num_expert_group=num_expert_group, - topk_group=topk_group, - quant_config=self.quant_config, - prefix=qual_name, - # TODO: activation - grok1, gpt-oss - enable_eplb=enable_eplb, - num_redundant_experts=num_redundant_experts, - # TODO: has_bias - gpt-oss - ) - setattr(module, child_name, new_module) - log_replacement(qual_name, child_module, new_module) - # If results are not all-reduced in FusedMoE, ensure they - # are all-reduced at the end of module.forward() - if not reduce_results and (new_module.tp_size > 1 - or new_module.ep_size > 1): - module.register_forward_hook(reduce_results_hook) - else: - _fused_moe(child_module, prefix=qual_name) - - _fused_moe(self.model) - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self, skip_prefixes=self.skip_prefixes) - return loader.load_weights( - weights, - mapper=self.hf_to_vllm_mapper, - expert_mapping=self.get_expert_mapping(), - ) - - @support_torch_compile(enable_if=can_enable_torch_compile) class TransformersModel(TransformersBase): hf_to_vllm_mapper = WeightsMapper( @@ -941,11 +776,6 @@ def create_attention_instances( return super().create_attention_instances(attn_type) -@support_torch_compile(enable_if=can_enable_torch_compile) -class TransformersMoEModel(TransformersMoEBase, TransformersModel): - pass - - @support_torch_compile(enable_if=can_enable_torch_compile) class TransformersForCausalLM(TransformersBase): @@ -984,11 +814,6 @@ def compute_logits( return logits -@support_torch_compile(enable_if=can_enable_torch_compile) -class TransformersMoEForCausalLM(TransformersMoEBase, TransformersForCausalLM): - pass - - def flatten_and_concat(x: list[torch.Tensor]) -> torch.Tensor: """Flatten until a list of tensors can be concatenated then do concat""" @@ -1130,17 +955,3 @@ def get_input_embeddings( inputs_embeds = inputs_embeds.masked_scatter( mask, multimodal_embeddings) return inputs_embeds - - -@support_torch_compile( - # set `positions` to last dim to support Qwen-mrope - dynamic_arg_dims={ - "input_ids": 0, - "positions": -1, - "intermediate_tensors": 0, - "inputs_embeds": 0, - }, - enable_if=can_enable_torch_compile) -class TransformersMoEForMultimodalLM(TransformersMoEForCausalLM, - TransformersForMultimodalLM): - pass diff --git a/vllm/model_executor/models/transformers_moe.py b/vllm/model_executor/models/transformers_moe.py new file mode 100644 index 000000000000..6a57e104cd4b --- /dev/null +++ b/vllm/model_executor/models/transformers_moe.py @@ -0,0 +1,255 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Wrapper around `transformers` MoE models.""" +from collections.abc import Iterable +from typing import Any + +import torch +import torch.nn as nn + +from vllm.compilation.decorators import support_torch_compile +from vllm.config.utils import getattr_iter +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op + +from .transformers import (TransformersBase, TransformersForCausalLM, + TransformersForMultimodalLM, TransformersModel, + can_enable_torch_compile, log_replacement) +from .utils import AutoWeightsLoader, maybe_prefix + + +@CustomOp.register("transformers_fused_moe") +class TransformersFusedMoE(FusedMoE): + """Custom FusedMoE for the Transformers backend.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._top_k_index: torch.Tensor = None + + def custom_routing_function(hidden_states, gating_output, topk, + renormalize): + """Return `top_k_weights` from `gating_output` and the + `top_k_index` we stored in the layer earlier.""" + return gating_output, self._top_k_index + + self.custom_routing_function = custom_routing_function + + def forward(self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, **kwargs: Any) -> torch.Tensor: + """In Transformers `experts.forward` will have this signature. + + We discard any extra kwargs because we cannot use them here.""" + return torch.ops.vllm.transformers_moe_forward(hidden_states, + top_k_index, + top_k_weights, + self.layer_name) + + +def transformers_moe_forward(hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + layer_name: str) -> torch.Tensor: + """Store the `top_k_index` in the layer and call the actual forward.""" + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + self._top_k_index = top_k_index + # Clone hidden_states because it will be mutated in-place in FusedMoE + return self.forward_impl(hidden_states.clone(), top_k_weights) + + +def transformers_moe_forward_fake(hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + layer_name: str) -> torch.Tensor: + return torch.empty_like(hidden_states) + + +direct_register_custom_op( + op_name="transformers_moe_forward", + op_func=transformers_moe_forward, + mutates_args=["hidden_states"], + fake_impl=transformers_moe_forward_fake, + dispatch_key=current_platform.dispatch_key, + tags=(torch.Tag.needs_fixed_stride_order, ), +) + + +class TransformersMoEBase(TransformersBase): + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + """ + Params for weights, fp8 weight scales, fp8 activation scales + (param_name, weight_name, expert_id, shard_id) + """ + ckpt_names = [ + # (ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name) + ("gate_proj", "down_proj", "up_proj"), # Most common MoE style + ("w1", "w2", "w3"), # Granite, Mixtral, Phi MoE style + ("linear", "linear_1", "linear_v"), # Grok1 style + ] + expert_mapping = [] + for gate_proj, down_proj, up_proj in ckpt_names: + expert_mapping.extend( + FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name=gate_proj, + ckpt_down_proj_name=down_proj, + ckpt_up_proj_name=up_proj, + num_experts=self.model_config.get_num_experts(), + num_redundant_experts=0, # TODO: enable EPLB + )) + return expert_mapping + + def init_hook(self): + """Initialize the MoE layers. + + It is important that this happens: + + - After pipeline parallelism so only the `FusedMoE` layers for this + pipeline stage are created. + - Before tensor parallelism so vLLM `Linear` layers are not created for + the unfused `Linear` layers inside the Transformers MoE layers. + """ + text_config = self.text_config + + # Positional arguments + num_experts = self.model_config.get_num_experts() + top_k = text_config.num_experts_per_tok + hidden_size = text_config.hidden_size + names = ["moe_intermediate_size", "intermediate_size"] + intermediate_size = getattr_iter(text_config, names, None) + + # Reduction kwargs + names = ["num_experts_shared", "shared_expert_intermediate_size"] + reduce_results = getattr_iter(text_config, names, 0) == 0 + + def reduce_results_hook(module, _, output): + """Forward hook that performs all-reduce on a nn.Module's output if + tensor parallel or expert parallel is enabled. This is used for + models with shared experts where the all reduce happens after any + shared experts have been added to the hidden state.""" + if isinstance(output, tuple): + output = output[0] + return module.experts.maybe_all_reduce_tensor_model_parallel( + output) + + # Unused kwargs since we use custom_routing_function: + # - `scoring_func` and `e_score_correction_bias` only used for grouped + # topk routing inside vLLM and are non-trivial to infer + # and hard code `use_grouped_topk=False` + # - `renormalize` passed anyway because it's easy to infer + # - `num_expert_group` and `topk_group` used for inferring expert + # placement strategy in FusedMoE + # - `apply_router_weight_on_input` is already applied in Transformers + renormalize = getattr(text_config, "norm_topk_prob", top_k > 1) + num_expert_group = getattr(text_config, "n_group", None) + topk_group = getattr(text_config, "topk_group", None) + + # MoE activation function + activation = "silu" + wrapped_arch = self.config.architectures[0].lower() + if "gptoss" in wrapped_arch: + activation = "swigluoai" + elif "grok1" in wrapped_arch: + activation = "gelu" + + # Expert parallel load balancing kwargs + parallel_config = self.parallel_config + eplb_config = parallel_config.eplb_config + + enable_eplb = parallel_config.enable_eplb + num_redundant_experts = eplb_config.num_redundant_experts + + # Recursively fuse MoE layers + def _fused_moe(module: nn.Module, prefix: str = ""): + for child_name, child_module in module.named_children(): + qual_name = maybe_prefix(prefix, child_name) + if (child_name == "experts" + and isinstance(child_module, nn.ModuleList)): + # Do the experts have biases + has_bias = False + for param_name, _ in child_module.named_parameters(): + if "bias" in param_name: + has_bias = True + break + # Replace experts module with FusedMoE + new_module = TransformersFusedMoE( + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + reduce_results=reduce_results, + renormalize=renormalize, + # Hard coded because topk happens in Transformers + use_grouped_topk=False, + num_expert_group=num_expert_group, + topk_group=topk_group, + quant_config=self.quant_config, + prefix=qual_name, + activation=activation, + enable_eplb=enable_eplb, + num_redundant_experts=num_redundant_experts, + has_bias=has_bias, + ) + setattr(module, child_name, new_module) + log_replacement(qual_name, child_module, new_module) + # If results are not all-reduced in FusedMoE, ensure they + # are all-reduced at the end of module.forward() + if not reduce_results and (new_module.tp_size > 1 + or new_module.ep_size > 1): + module.register_forward_hook(reduce_results_hook) + else: + _fused_moe(child_module, prefix=qual_name) + + _fused_moe(self.model) + + def load_weights( + self, + weights: Iterable[tuple[str, torch.Tensor]], + ) -> set[str]: + loader = AutoWeightsLoader(self, skip_prefixes=self.skip_prefixes) + return loader.load_weights( + weights, + mapper=self.hf_to_vllm_mapper, + expert_mapping=self.get_expert_mapping(), + ) + + +@support_torch_compile(enable_if=can_enable_torch_compile) +class TransformersMoEModel(TransformersMoEBase, TransformersModel): + pass + + +@support_torch_compile(enable_if=can_enable_torch_compile) +class TransformersMoEForCausalLM(TransformersMoEBase, TransformersForCausalLM): + pass + + +@support_torch_compile( + # set `positions` to last dim to support Qwen-mrope + dynamic_arg_dims={ + "input_ids": 0, + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + }, + enable_if=can_enable_torch_compile) +class TransformersMoEForMultimodalLM(TransformersMoEForCausalLM, + TransformersForMultimodalLM): + pass From 40809328726d6149d4d166faafec83672192441a Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 25 Sep 2025 15:34:59 +0000 Subject: [PATCH 37/64] Handle MXFP4 in linears and attentions Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index a6b2b8ecf06d..32e1be397caf 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -124,7 +124,7 @@ def can_enable_torch_compile(vllm_config: VllmConfig) -> bool: def replace_linear_class( linear: nn.Linear, style: Literal["colwise", "rowwise"], - quant_config: QuantizationConfig, + quant_config: Optional[QuantizationConfig], *, prefix: str = "", ) -> Union[ColumnParallelLinear, RowParallelLinear, ReplicatedLinear]: @@ -132,17 +132,21 @@ def replace_linear_class( Replace nn.Linear with one of vLLM's tensor parallel linear classes. Args: - linear (nn.Linear): `nn.Linear` to be replaced. - style (str): Tensor parallel style of the new linear, e.g. "colwise". - quant_config (QuantConfig): Quantization config for the new linear. + linear: `nn.Linear` to be replaced. + style: Tensor parallel style of the new linear, e.g. "colwise". + quant_config: Quantization config for the new linear. Returns: - Union[ColumnParallelLinear, RowParallelLinear]: The new linear. + The new linear. """ if not isinstance(style, str): raise ValueError( f"Unsupported parallel style type {type(style)}, expected str") + # MXFP4 linear layer is not implemented + if quant_config is not None and quant_config.get_name() == "mxfp4": + quant_config = None + vllm_linear_cls, vllm_linear_kwargs = { "colwise": (ColumnParallelLinear, {}), "colwise_rep": (ColumnParallelLinear, { @@ -620,6 +624,11 @@ def create_attention_instances( start, end = get_pp_indices(self.text_config.num_hidden_layers, self.pp_rank, self.pp_size) + # MXFP4 attention layer is not implemented + quant_config = self.quant_config + if quant_config is not None and quant_config.get_name() == "mxfp4": + quant_config = None + attention_instances = {} for i in range(start, end): # Handle interleaved sliding window attention @@ -636,7 +645,7 @@ def create_attention_instances( scale=head_size**-0.5, num_kv_heads=num_kv_heads, cache_config=self.cache_config, - quant_config=self.quant_config, + quant_config=quant_config, per_layer_sliding_window=per_layer_sliding_window, prefix=f"{i}.attn", attn_type=attn_type) From e14205815cf7f215a9272aa668ffe547c4186653 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 25 Sep 2025 16:31:23 +0000 Subject: [PATCH 38/64] Add error if user tries MXFP4 Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 32e1be397caf..4f32702d9298 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -144,7 +144,7 @@ def replace_linear_class( f"Unsupported parallel style type {type(style)}, expected str") # MXFP4 linear layer is not implemented - if quant_config is not None and quant_config.get_name() == "mxfp4": + if quant_config and quant_config.get_name() == "mxfp4": quant_config = None vllm_linear_cls, vllm_linear_kwargs = { @@ -451,6 +451,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.quant_config: Optional[ QuantizationConfig] = vllm_config.quant_config + if self.quant_config and self.quant_config.get_name() == "mxfp4": + raise ValueError( + "Transformers backend does not support MXFP4 yet.") + self.pp_group = get_pp_group() self.pp_size = self.pp_group.world_size self.pp_rank = self.pp_group.rank_in_group @@ -626,7 +630,7 @@ def create_attention_instances( # MXFP4 attention layer is not implemented quant_config = self.quant_config - if quant_config is not None and quant_config.get_name() == "mxfp4": + if quant_config and quant_config.get_name() == "mxfp4": quant_config = None attention_instances = {} From 957af8914647d9639cdc0ece64cac48c32be0a90 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 25 Sep 2025 17:08:18 +0000 Subject: [PATCH 39/64] Add errors for enable expert parallel Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- docs/models/supported_models.md | 1 - vllm/model_executor/models/transformers.py | 5 +++-- vllm/model_executor/models/transformers_moe.py | 11 +++++++++++ 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 7d8581f73aa1..bde2e8f4e932 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -32,7 +32,6 @@ If the Transformers model implementation follows all the steps in [writing a cus - All the features listed in the [compatibility matrix](../features/compatibility_matrix.md#feature-x-feature) - Any combination of the following vLLM parallelisation schemes: - Data parallel - - Expert parallel - Pipeline parallel - Tensor parallel diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 4f32702d9298..8cb75f48e221 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -452,8 +452,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): QuantizationConfig] = vllm_config.quant_config if self.quant_config and self.quant_config.get_name() == "mxfp4": - raise ValueError( - "Transformers backend does not support MXFP4 yet.") + raise NotImplementedError( + "Transformers backend does not support MXFP4 quantization yet." + ) self.pp_group = get_pp_group() self.pp_size = self.pp_group.world_size diff --git a/vllm/model_executor/models/transformers_moe.py b/vllm/model_executor/models/transformers_moe.py index 6a57e104cd4b..0ab320ba7497 100644 --- a/vllm/model_executor/models/transformers_moe.py +++ b/vllm/model_executor/models/transformers_moe.py @@ -93,6 +93,17 @@ def transformers_moe_forward_fake(hidden_states: torch.Tensor, class TransformersMoEBase(TransformersBase): + def __init__(self, *, vllm_config, prefix=""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + + if self.parallel_config.enable_expert_parallel: + raise NotImplementedError( + "Transformers backend does not support expert parallel yet.") + if self.parallel_config.enable_eplb: + raise NotImplementedError( + "Transformers backend does not support expert parallel load " + "balancing yet.") + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: """ Params for weights, fp8 weight scales, fp8 activation scales From 8f53035c20c03c350793bc7c94a1c6c75e4cc396 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Fri, 26 Sep 2025 14:59:13 +0000 Subject: [PATCH 40/64] Make minimum version checking part of `TransformersBase` Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 19 +++++++++++-------- .../model_executor/models/transformers_moe.py | 1 + 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index cb34c2adb546..1c46680b5e7c 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -22,6 +22,8 @@ import regex as re import torch +import transformers +from packaging.version import Version from torch import nn from transformers import (AutoModel, BatchFeature, PretrainedConfig, PreTrainedModel) @@ -725,6 +727,14 @@ def load_weights( ) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + def check_version(self, min_version: str, feature: str): + installed = Version(transformers.__version__) + required = Version(min_version) + if installed < required: + raise ImportError( + f"Transformers backend requires transformers>={required} " + f"for {feature}, but got {installed}") + @support_torch_compile(enable_if=can_enable_torch_compile) class TransformersModel(TransformersBase): @@ -781,14 +791,7 @@ def create_attention_instances( # Check minimum transformers version for encoder models support if attn_type == AttentionType.ENCODER_ONLY: - import transformers - from packaging.version import Version - installed = Version(transformers.__version__) - required = Version("4.57.0.dev0") - if installed < required: - raise ValueError( - "Encoder models with the Transformers backend require " - f"transformers>={required}, but got {installed}") + self.check_version("4.57.0.dev0", "encoder models support") return super().create_attention_instances(attn_type) diff --git a/vllm/model_executor/models/transformers_moe.py b/vllm/model_executor/models/transformers_moe.py index 0ab320ba7497..30a9842c68a6 100644 --- a/vllm/model_executor/models/transformers_moe.py +++ b/vllm/model_executor/models/transformers_moe.py @@ -94,6 +94,7 @@ def transformers_moe_forward_fake(hidden_states: torch.Tensor, class TransformersMoEBase(TransformersBase): def __init__(self, *, vllm_config, prefix=""): + self.check_version("4.57.0.dev0", "MoE models support") super().__init__(vllm_config=vllm_config, prefix=prefix) if self.parallel_config.enable_expert_parallel: From 9faf3d2bb1b2da16c72b20871be8ded414621f37 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Fri, 26 Sep 2025 15:01:18 +0000 Subject: [PATCH 41/64] Slghtly improve `FusedMoE` kwarg detection Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- .../model_executor/models/transformers_moe.py | 46 +++++++++++++------ 1 file changed, 31 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/models/transformers_moe.py b/vllm/model_executor/models/transformers_moe.py index 30a9842c68a6..98d03f4e0b1a 100644 --- a/vllm/model_executor/models/transformers_moe.py +++ b/vllm/model_executor/models/transformers_moe.py @@ -142,14 +142,20 @@ def init_hook(self): # Positional arguments num_experts = self.model_config.get_num_experts() - top_k = text_config.num_experts_per_tok + top_k = getattr_iter(text_config, ["num_experts_per_tok", "top_k"], + None) + assert top_k is not None hidden_size = text_config.hidden_size - names = ["moe_intermediate_size", "intermediate_size"] - intermediate_size = getattr_iter(text_config, names, None) + intermediate_size = getattr_iter( + text_config, ["moe_intermediate_size", "intermediate_size"], None) + assert intermediate_size is not None - # Reduction kwargs - names = ["num_experts_shared", "shared_expert_intermediate_size"] - reduce_results = getattr_iter(text_config, names, 0) == 0 + # If there are shared experts, the results are + # reduced after mlp.forward() not inside FusedMoE + num_experts_shared = getattr_iter(text_config, [ + "num_experts_shared", "n_shared_experts", "moe_num_shared_experts" + ], 0) + reduce_results = num_experts_shared == 0 def reduce_results_hook(module, _, output): """Forward hook that performs all-reduce on a nn.Module's output if @@ -194,14 +200,24 @@ def _fused_moe(module: nn.Module, prefix: str = ""): qual_name = maybe_prefix(prefix, child_name) if (child_name == "experts" and isinstance(child_module, nn.ModuleList)): + # Alias for readability + mlp = module + experts = child_module # Do the experts have biases has_bias = False - for param_name, _ in child_module.named_parameters(): - if "bias" in param_name: + for experts_param_name, _ in experts.named_parameters(): + if "bias" in experts_param_name: has_bias = True break + # Double check there are no shared experts + nonlocal reduce_results + if reduce_results: + for mlp_param_name, _ in mlp.named_parameters(): + if "shared_expert" in mlp_param_name: + reduce_results = False + break # Replace experts module with FusedMoE - new_module = TransformersFusedMoE( + fused_experts = TransformersFusedMoE( num_experts=num_experts, top_k=top_k, hidden_size=hidden_size, @@ -219,13 +235,13 @@ def _fused_moe(module: nn.Module, prefix: str = ""): num_redundant_experts=num_redundant_experts, has_bias=has_bias, ) - setattr(module, child_name, new_module) - log_replacement(qual_name, child_module, new_module) + setattr(mlp, child_name, fused_experts) + log_replacement(qual_name, experts, fused_experts) # If results are not all-reduced in FusedMoE, ensure they - # are all-reduced at the end of module.forward() - if not reduce_results and (new_module.tp_size > 1 - or new_module.ep_size > 1): - module.register_forward_hook(reduce_results_hook) + # are all-reduced at the end of mlp.forward() + if not reduce_results and (fused_experts.tp_size > 1 + or fused_experts.ep_size > 1): + mlp.register_forward_hook(reduce_results_hook) else: _fused_moe(child_module, prefix=qual_name) From 42ce295cf1edc7f168717ef1a86865c5cdcaa232 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Fri, 26 Sep 2025 17:25:15 +0000 Subject: [PATCH 42/64] Fix experts loading Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/utils.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 76bbe890eaa2..628bee2125b9 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -249,15 +249,17 @@ def _load_module( if expert_mapping is not None and child_prefix == "experts": for expert_name, loaded_weight in child_weights: - for (param_name, weight_name, expert_id, - shard_id) in expert_mapping: - if weight_name not in f"experts.{expert_name}": + for e_m in expert_mapping: + param_name, weight_name, expert_id, shard_id = e_m + experts_name = f"experts.{expert_name}" + if weight_name not in experts_name: continue fused_moe = child_modules[child_prefix] - param_name = ( - f"{param_name.removeprefix('experts.')}weight") - param = getattr(fused_moe, param_name) - weight_name = maybe_prefix(prefix, param_name) + mapped_name = experts_name.replace( + weight_name, + param_name).removeprefix("experts.") + param = getattr(fused_moe, mapped_name) + weight_name = maybe_prefix(prefix, mapped_name) fused_moe.weight_loader( param=param, loaded_weight=loaded_weight, @@ -266,7 +268,7 @@ def _load_module( expert_id=expert_id, ) logger.debug("Loaded %s for expert %d into %s", - param_name, expert_id, prefix) + mapped_name, expert_id, prefix) yield weight_name yield from self._load_module(prefix, From 3725c30b5ac2a8c70582e44ce7c2cafc0d37e2df Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Fri, 26 Sep 2025 17:51:23 +0000 Subject: [PATCH 43/64] Add defaults to `replace_linear_class` Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 1c46680b5e7c..cff22b9e9b3f 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -125,8 +125,8 @@ def can_enable_torch_compile(vllm_config: VllmConfig) -> bool: def replace_linear_class( linear: nn.Linear, - style: Literal["colwise", "rowwise"], - quant_config: Optional[QuantizationConfig], + style: Literal["colwise", "rowwise"] = "replicate", + quant_config: Optional[QuantizationConfig] = None, *, prefix: str = "", ) -> Union[ColumnParallelLinear, RowParallelLinear, ReplicatedLinear]: From 6142cdb0fb73665715fbfffa771bc7d71b32589e Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Fri, 26 Sep 2025 17:59:45 +0000 Subject: [PATCH 44/64] Handle `AutoGPTQ` not quantising `gate` Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers_moe.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/transformers_moe.py b/vllm/model_executor/models/transformers_moe.py index 98d03f4e0b1a..01669cbf2d4c 100644 --- a/vllm/model_executor/models/transformers_moe.py +++ b/vllm/model_executor/models/transformers_moe.py @@ -31,7 +31,8 @@ from .transformers import (TransformersBase, TransformersForCausalLM, TransformersForMultimodalLM, TransformersModel, - can_enable_torch_compile, log_replacement) + can_enable_torch_compile, log_replacement, + replace_linear_class) from .utils import AutoWeightsLoader, maybe_prefix @@ -187,10 +188,12 @@ def reduce_results_hook(module, _, output): elif "grok1" in wrapped_arch: activation = "gelu" - # Expert parallel load balancing kwargs + # Configs + quant_config = self.quant_config parallel_config = self.parallel_config eplb_config = parallel_config.eplb_config + # Expert parallel load balancing kwargs enable_eplb = parallel_config.enable_eplb num_redundant_experts = eplb_config.num_redundant_experts @@ -203,6 +206,13 @@ def _fused_moe(module: nn.Module, prefix: str = ""): # Alias for readability mlp = module experts = child_module + # GPTQ configs ddo not have lists of ignored modules + # and AutoGPTQ seems to avoid gate quantization. + if (quant_config and "gptq" in quant_config.get_name() + and (gate := getattr(mlp, "gate", None))): + new_gate = replace_linear_class( + gate, prefix=f"{prefix}.gate") + mlp.gate = new_gate # Do the experts have biases has_bias = False for experts_param_name, _ in experts.named_parameters(): @@ -235,7 +245,7 @@ def _fused_moe(module: nn.Module, prefix: str = ""): num_redundant_experts=num_redundant_experts, has_bias=has_bias, ) - setattr(mlp, child_name, fused_experts) + mlp.experts = fused_experts log_replacement(qual_name, experts, fused_experts) # If results are not all-reduced in FusedMoE, ensure they # are all-reduced at the end of mlp.forward() From 7ab3832df27e1c0bb7b32401bd32bbfec82ed8a8 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Fri, 26 Sep 2025 18:11:11 +0000 Subject: [PATCH 45/64] Add test Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- tests/models/test_transformers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py index 1817d4aeee9f..c68f37ad72ae 100644 --- a/tests/models/test_transformers.py +++ b/tests/models/test_transformers.py @@ -59,6 +59,7 @@ def check_implementation( [ ("meta-llama/Llama-3.2-1B-Instruct", "transformers"), ("hmellor/Ilama-3.2-1B", "auto"), # CUSTOM CODE + ("Qwen/Qwen1.5-MoE-A2.7B-Chat", "transformers"), # MoE ]) # trust_remote_code=True by default def test_models( hf_runner: type[HfRunner], From 928e9d5ff19365cd9758b274386d35fc0f9166f7 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Sat, 27 Sep 2025 14:16:20 +0200 Subject: [PATCH 46/64] Move expert weight loading to `FusedMoE.load_weights` Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/layers/fused_moe/layer.py | 31 +++++++++++++++++++ .../model_executor/models/transformers_moe.py | 22 +++++-------- vllm/model_executor/models/utils.py | 31 ++----------------- 3 files changed, 40 insertions(+), 44 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index eccae8b2a7af..abb728c3fbd9 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -959,6 +959,7 @@ def __init__( is_sequence_parallel=False, zero_expert_num: Optional[int] = 0, zero_expert_type: Optional[str] = None, + expert_mapping: Optional[list[tuple[str, str, int, str]]] = None, ): super().__init__() if params_dtype is None: @@ -996,6 +997,9 @@ def __init__( self.zero_expert_num = zero_expert_num self.zero_expert_type = zero_expert_type + # Expert mapping used in self.load_weights + self.expert_mapping = expert_mapping + # Round up hidden size if needed. hidden_size = maybe_roundup_hidden_size(hidden_size, moe_in_dtype, quant_config, @@ -1617,6 +1621,33 @@ def weight_loader(self, return False if return_success else None + def load_weights( + self, weights: Iterable[tuple[str, + torch.Tensor]]) -> Iterable[str]: + if (expert_mapping := self.expert_mapping) is None: + raise ValueError("`self.expert_mapping` must be provided to " + "load weights using `self.load_weights`.") + for expert_name, loaded_weight in weights: + qual_name = f"{self.layer_name}.{expert_name}" + for param_name, weight_name, expert_id, shard_id in expert_mapping: + if weight_name not in qual_name: + continue + weight_name = qual_name.replace(weight_name, param_name) + param_name = weight_name.removeprefix(f"{self.layer_name}.") + param = getattr(self, param_name) + success = self.weight_loader( + param=param, + loaded_weight=loaded_weight, + weight_name=weight_name, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) + if success: + logger.debug("Loaded %s for expert %d into %s", param_name, + expert_id, self.layer_name) + yield param_name + def get_expert_weights(self) -> Iterable[torch.Tensor]: weights = list(self.named_parameters()) assert all(weight.is_contiguous() for _, weight in weights) diff --git a/vllm/model_executor/models/transformers_moe.py b/vllm/model_executor/models/transformers_moe.py index 01669cbf2d4c..60cf07d00fe4 100644 --- a/vllm/model_executor/models/transformers_moe.py +++ b/vllm/model_executor/models/transformers_moe.py @@ -15,7 +15,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Wrapper around `transformers` MoE models.""" -from collections.abc import Iterable from typing import Any import torch @@ -33,7 +32,7 @@ TransformersForMultimodalLM, TransformersModel, can_enable_torch_compile, log_replacement, replace_linear_class) -from .utils import AutoWeightsLoader, maybe_prefix +from .utils import maybe_prefix @CustomOp.register("transformers_fused_moe") @@ -188,6 +187,9 @@ def reduce_results_hook(module, _, output): elif "grok1" in wrapped_arch: activation = "gelu" + # Expert mapping for `AutoWeightsLoader` + expert_mapping = self.get_expert_mapping() + # Configs quant_config = self.quant_config parallel_config = self.parallel_config @@ -198,7 +200,7 @@ def reduce_results_hook(module, _, output): num_redundant_experts = eplb_config.num_redundant_experts # Recursively fuse MoE layers - def _fused_moe(module: nn.Module, prefix: str = ""): + def _fused_moe(module: nn.Module, prefix: str): for child_name, child_module in module.named_children(): qual_name = maybe_prefix(prefix, child_name) if (child_name == "experts" @@ -244,6 +246,7 @@ def _fused_moe(module: nn.Module, prefix: str = ""): enable_eplb=enable_eplb, num_redundant_experts=num_redundant_experts, has_bias=has_bias, + expert_mapping=expert_mapping, ) mlp.experts = fused_experts log_replacement(qual_name, experts, fused_experts) @@ -255,18 +258,7 @@ def _fused_moe(module: nn.Module, prefix: str = ""): else: _fused_moe(child_module, prefix=qual_name) - _fused_moe(self.model) - - def load_weights( - self, - weights: Iterable[tuple[str, torch.Tensor]], - ) -> set[str]: - loader = AutoWeightsLoader(self, skip_prefixes=self.skip_prefixes) - return loader.load_weights( - weights, - mapper=self.hf_to_vllm_mapper, - expert_mapping=self.get_expert_mapping(), - ) + _fused_moe(self.model, prefix="model") @support_torch_compile(enable_if=can_enable_torch_compile) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 628bee2125b9..51cd41c864f0 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -210,7 +210,6 @@ def _load_module( base_prefix: str, module: nn.Module, weights: Iterable[tuple[str, torch.Tensor]], - expert_mapping: Optional[list[tuple[str, str, int, str]]], ) -> Iterable[str]: if isinstance(module, PPMissingLayer): return @@ -247,33 +246,9 @@ def _load_module( continue - if expert_mapping is not None and child_prefix == "experts": - for expert_name, loaded_weight in child_weights: - for e_m in expert_mapping: - param_name, weight_name, expert_id, shard_id = e_m - experts_name = f"experts.{expert_name}" - if weight_name not in experts_name: - continue - fused_moe = child_modules[child_prefix] - mapped_name = experts_name.replace( - weight_name, - param_name).removeprefix("experts.") - param = getattr(fused_moe, mapped_name) - weight_name = maybe_prefix(prefix, mapped_name) - fused_moe.weight_loader( - param=param, - loaded_weight=loaded_weight, - weight_name=weight_name, - shard_id=shard_id, - expert_id=expert_id, - ) - logger.debug("Loaded %s for expert %d into %s", - mapped_name, expert_id, prefix) - yield weight_name - yield from self._load_module(prefix, child_modules[child_prefix], - child_weights, expert_mapping) + child_weights) elif child_prefix in child_params: if self._can_skip(prefix): logger.debug("Skipping param %s", prefix) @@ -306,7 +281,6 @@ def load_weights( weights: Iterable[tuple[str, torch.Tensor]], *, mapper: Optional[WeightsMapper] = None, - expert_mapping: Optional[list[tuple[str, str, int, str]]] = None, ) -> set[str]: if mapper is not None: weights = mapper.apply(weights) @@ -314,8 +288,7 @@ def load_weights( weights = ((name, weight) for name, weight in weights if not self._can_skip(name)) - autoloaded_weights = set( - self._load_module("", self.module, weights, expert_mapping)) + autoloaded_weights = set(self._load_module("", self.module, weights)) return autoloaded_weights From 220d749b0910e4ae0dc18b4166d243c51affa46c Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 29 Sep 2025 12:12:22 +0000 Subject: [PATCH 47/64] Convert forward hook to wrapper class Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- .../model_executor/models/transformers_moe.py | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/models/transformers_moe.py b/vllm/model_executor/models/transformers_moe.py index 60cf07d00fe4..1f6ef6c25c74 100644 --- a/vllm/model_executor/models/transformers_moe.py +++ b/vllm/model_executor/models/transformers_moe.py @@ -157,15 +157,17 @@ def init_hook(self): ], 0) reduce_results = num_experts_shared == 0 - def reduce_results_hook(module, _, output): - """Forward hook that performs all-reduce on a nn.Module's output if - tensor parallel or expert parallel is enabled. This is used for - models with shared experts where the all reduce happens after any - shared experts have been added to the hidden state.""" - if isinstance(output, tuple): - output = output[0] - return module.experts.maybe_all_reduce_tensor_model_parallel( - output) + def add_all_reduce(mlp: nn.Module): + """Adds an all-reduce to the output of `mlp.forward()`.""" + + class MLPWithAllReduce(mlp.__class__): + + def forward(self, *args, **kwargs): + output = super().forward(*args, **kwargs) + return self.experts.maybe_all_reduce_tensor_model_parallel( + output) + + mlp.__class__ = MLPWithAllReduce # Unused kwargs since we use custom_routing_function: # - `scoring_func` and `e_score_correction_bias` only used for grouped @@ -251,10 +253,11 @@ def _fused_moe(module: nn.Module, prefix: str): mlp.experts = fused_experts log_replacement(qual_name, experts, fused_experts) # If results are not all-reduced in FusedMoE, ensure they - # are all-reduced at the end of mlp.forward() + # are all-reduced at the end of mlp.forward() if tensor + # parallel or expert parallel is enabled if not reduce_results and (fused_experts.tp_size > 1 or fused_experts.ep_size > 1): - mlp.register_forward_hook(reduce_results_hook) + add_all_reduce(mlp) else: _fused_moe(child_module, prefix=qual_name) From 1c6b2cd0752873dc1996bbea16b5992146743897 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 29 Sep 2025 14:32:17 +0000 Subject: [PATCH 48/64] Remove GPTQ workaround Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers_moe.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/vllm/model_executor/models/transformers_moe.py b/vllm/model_executor/models/transformers_moe.py index 1f6ef6c25c74..7d245e43be1a 100644 --- a/vllm/model_executor/models/transformers_moe.py +++ b/vllm/model_executor/models/transformers_moe.py @@ -30,8 +30,7 @@ from .transformers import (TransformersBase, TransformersForCausalLM, TransformersForMultimodalLM, TransformersModel, - can_enable_torch_compile, log_replacement, - replace_linear_class) + can_enable_torch_compile, log_replacement) from .utils import maybe_prefix @@ -193,7 +192,6 @@ def forward(self, *args, **kwargs): expert_mapping = self.get_expert_mapping() # Configs - quant_config = self.quant_config parallel_config = self.parallel_config eplb_config = parallel_config.eplb_config @@ -210,13 +208,6 @@ def _fused_moe(module: nn.Module, prefix: str): # Alias for readability mlp = module experts = child_module - # GPTQ configs ddo not have lists of ignored modules - # and AutoGPTQ seems to avoid gate quantization. - if (quant_config and "gptq" in quant_config.get_name() - and (gate := getattr(mlp, "gate", None))): - new_gate = replace_linear_class( - gate, prefix=f"{prefix}.gate") - mlp.gate = new_gate # Do the experts have biases has_bias = False for experts_param_name, _ in experts.named_parameters(): From 14e091fd30b9023b613a7673e78236af9e4d01b5 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 29 Sep 2025 14:43:30 +0000 Subject: [PATCH 49/64] Use smaller model in test Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- tests/models/test_transformers.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py index af08cefa7313..c57836180e6e 100644 --- a/tests/models/test_transformers.py +++ b/tests/models/test_transformers.py @@ -59,7 +59,7 @@ def check_implementation( [ ("meta-llama/Llama-3.2-1B-Instruct", "transformers"), ("hmellor/Ilama-3.2-1B", "auto"), # CUSTOM CODE - ("Qwen/Qwen1.5-MoE-A2.7B-Chat", "transformers"), # MoE + ("allenai/OLMoE-1B-7B-0924", "transformers"), # MoE ]) # trust_remote_code=True by default def test_models( hf_runner: type[HfRunner], @@ -68,6 +68,14 @@ def test_models( model: str, model_impl: str, ) -> None: + import transformers + from packaging.version import Version + installed = Version(transformers.__version__) + required = Version("4.57.0.dev0") + if model == "allenai/OLMoE-1B-7B-0924" and installed < required: + pytest.skip("MoE models with the Transformers backend require " + f"transformers>={required}, but got {installed}") + check_implementation(hf_runner, vllm_runner, example_prompts, From e58006445abacb94bc9dd4fbb2ceab814e45111e Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Tue, 30 Sep 2025 08:38:14 +0100 Subject: [PATCH 50/64] Add MoE test registry entry Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- tests/models/registry.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/registry.py b/tests/models/registry.py index 37ee474d3ecb..1e48c4bfda91 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -659,6 +659,7 @@ def check_available_online( "TransformersModel": _HfExamplesInfo("Qwen/Qwen3-Embedding-0.6B"), "TransformersForCausalLM": _HfExamplesInfo("hmellor/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501 "TransformersForMultimodalLM": _HfExamplesInfo("BAAI/Emu3-Chat-hf"), + "TransformersMoEForCausalLM": _HfExamplesInfo("allenai/OLMoE-1B-7B-0924"), } _EXAMPLE_MODELS = { From 3a6c08c5d74c1a57a7a69ec79bed30a26657c2c9 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Tue, 30 Sep 2025 10:30:01 +0200 Subject: [PATCH 51/64] Leave only the MXFP4 not implemented error Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 7a0add59a807..8e90e292d842 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -145,10 +145,6 @@ def replace_linear_class( raise ValueError( f"Unsupported parallel style type {type(style)}, expected str") - # MXFP4 linear layer is not implemented - if quant_config and quant_config.get_name() == "mxfp4": - quant_config = None - vllm_linear_cls, vllm_linear_kwargs = { "colwise": (ColumnParallelLinear, {}), "colwise_rep": (ColumnParallelLinear, { @@ -642,11 +638,6 @@ def create_attention_instances( start, end = get_pp_indices(self.text_config.num_hidden_layers, self.pp_rank, self.pp_size) - # MXFP4 attention layer is not implemented - quant_config = self.quant_config - if quant_config and quant_config.get_name() == "mxfp4": - quant_config = None - attention_instances = {} for i in range(start, end): # Handle interleaved sliding window attention @@ -663,7 +654,7 @@ def create_attention_instances( scale=head_size**-0.5, num_kv_heads=num_kv_heads, cache_config=self.cache_config, - quant_config=quant_config, + quant_config=self.quant_config, per_layer_sliding_window=per_layer_sliding_window, prefix=f"{i}.attn", attn_type=attn_type) From d327a67bc4f5f4d1524c1c7a9714ff0e62342e96 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Tue, 30 Sep 2025 23:29:02 +0100 Subject: [PATCH 52/64] Add MoE versions of the new pooling classes Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config/model.py | 4 ++-- vllm/model_executor/models/registry.py | 7 ++++--- vllm/model_executor/models/transformers_moe.py | 7 +------ vllm/model_executor/models/transformers_pooling.py | 13 +++++++++++++ 4 files changed, 20 insertions(+), 11 deletions(-) diff --git a/vllm/config/model.py b/vllm/config/model.py index 50244d622242..a5729f67e7c2 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -687,9 +687,9 @@ def _get_transformers_backend_cls(self) -> str: # Resolve Transformers backend pooling classes if runner == "pooling": if convert == "embed": - return "TransformersEmbeddingModel" + return prefix + "EmbeddingModel" if convert == "classify": - return "TransformersForSequenceClassification" + return prefix + "ForSequenceClassification" # Resolve Transformers backend generate classes if self.hf_config != self.hf_text_config: # If 'hf_text_config' is the same as 'hf_config'. If not, it is diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 285423612cf8..94744fe558bd 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -307,13 +307,14 @@ } _TRANSFORMERS_BACKEND_MODELS = { - "TransformersEmbeddingModel": ("transformers_pooling", "TransformersEmbeddingModel"), # noqa: E501 - "TransformersForSequenceClassification": ("transformers_pooling", "TransformersForSequenceClassification"), # noqa: E501 "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"), "TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501 - "TransformersMoEModel": ("transformers_moe", "TransformersMoEModel"), "TransformersMoEForCausalLM": ("transformers_moe", "TransformersMoEForCausalLM"), # noqa: E501 "TransformersMoEForMultimodalLM": ("transformers_moe", "TransformersMoEForMultimodalLM"), # noqa: E501 + "TransformersEmbeddingModel": ("transformers_pooling", "TransformersEmbeddingModel"), # noqa: E501 + "TransformersForSequenceClassification": ("transformers_pooling", "TransformersForSequenceClassification"), # noqa: E501 + "TransformersMoEForSequenceClassification": ("transformers_pooling", "TransformersMoEForSequenceClassification"), # noqa: E501 + "TransformersMoEEmbeddingModel": ("transformers_pooling", "TransformersMoEEmbeddingModel"), # noqa: E501 } # yapf: enable diff --git a/vllm/model_executor/models/transformers_moe.py b/vllm/model_executor/models/transformers_moe.py index 7d245e43be1a..5398135c034b 100644 --- a/vllm/model_executor/models/transformers_moe.py +++ b/vllm/model_executor/models/transformers_moe.py @@ -29,7 +29,7 @@ from vllm.utils import direct_register_custom_op from .transformers import (TransformersBase, TransformersForCausalLM, - TransformersForMultimodalLM, TransformersModel, + TransformersForMultimodalLM, can_enable_torch_compile, log_replacement) from .utils import maybe_prefix @@ -255,11 +255,6 @@ def _fused_moe(module: nn.Module, prefix: str): _fused_moe(self.model, prefix="model") -@support_torch_compile(enable_if=can_enable_torch_compile) -class TransformersMoEModel(TransformersMoEBase, TransformersModel): - pass - - @support_torch_compile(enable_if=can_enable_torch_compile) class TransformersMoEForCausalLM(TransformersMoEBase, TransformersForCausalLM): pass diff --git a/vllm/model_executor/models/transformers_pooling.py b/vllm/model_executor/models/transformers_pooling.py index 03a73461d6c7..c58a40dc4d2a 100644 --- a/vllm/model_executor/models/transformers_pooling.py +++ b/vllm/model_executor/models/transformers_pooling.py @@ -29,6 +29,7 @@ from .interfaces_base import VllmModelForPooling from .transformers import TransformersBase, can_enable_torch_compile +from .transformers_moe import TransformersMoEBase from .utils import WeightsMapper @@ -191,3 +192,15 @@ def forward(self, *args, **kwargs): vllm_config.model_config), ), }) + + +@support_torch_compile(enable_if=can_enable_torch_compile) +class TransformersMoEEmbeddingModel(TransformersMoEBase, + TransformersEmbeddingModel): + pass + + +@support_torch_compile(enable_if=can_enable_torch_compile) +class TransformersMoEForSequenceClassification( + TransformersMoEBase, TransformersForSequenceClassification): + pass From 48d599395e6d8c4ec1495a9580928cb753df4a1b Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 2 Oct 2025 19:16:04 +0200 Subject: [PATCH 53/64] Type hint `TransformersPoolingBase.create_attention_instances` properly Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers_pooling.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/transformers_pooling.py b/vllm/model_executor/models/transformers_pooling.py index c58a40dc4d2a..27fd40999fe2 100644 --- a/vllm/model_executor/models/transformers_pooling.py +++ b/vllm/model_executor/models/transformers_pooling.py @@ -20,7 +20,7 @@ import torch from transformers import AutoModelForSequenceClassification -from vllm.attention import AttentionType +from vllm.attention import Attention, AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool, @@ -80,7 +80,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.padding_idx = self.text_config.pad_token_id def create_attention_instances( - self, attn_type: AttentionType = AttentionType.DECODER): + self, + attn_type: AttentionType = AttentionType.DECODER + ) -> dict[int, Attention]: # TODO(hmellor): Better way to detect encoder models # In encoder models, the attention layers will have `is_causal=False` is_encoder = lambda m: not getattr(m, "is_causal", True) From 2da61406b86b66a5d18af2e48e71e042e1890a4b Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 2 Oct 2025 19:30:25 +0200 Subject: [PATCH 54/64] Merge `init_hook` and `tensor_parallel` into `recursive_replace` (also add `RMSNorm` replacement) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 111 ++++++++++-------- .../model_executor/models/transformers_moe.py | 20 ++-- 2 files changed, 70 insertions(+), 61 deletions(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 0e3cf0fdb8c1..484c9cdb0f68 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -37,6 +37,7 @@ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed.utils import get_pp_indices from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, ReplicatedLinear, RowParallelLinear) @@ -168,6 +169,28 @@ def replace_linear_class( ) +def replace_rms_norm_class(rms_norm: nn.Module) -> RMSNorm: + """Replace a Transformers RMSNorm with vLLM's RMSNorm. + + This method assumes: + - Weight is stored as `weight`. + - Epsilon is stored as `eps` or `variance_epsilon`. + - `with_scale` indicates whether the layer has a weight (Gemma3n only). + - `var_hidden_size` is used only ever used for Intern vision encoder. + """ + weight = getattr(rms_norm, "weight", None) + weight: Optional[torch.Tensor] = getattr(weight, "data", weight) + # Return early if weight not found + if weight is None: + return rms_norm + # Construct the new RMSNorm + return RMSNorm(hidden_size=weight.numel(), + eps=getattr_iter(rms_norm, ("eps", "variance_epsilon"), + 1e-6), + has_weight=getattr(rms_norm, "with_scale", True), + dtype=weight.dtype) + + # Copied from `accelerate` @contextmanager def init_on_device_without_buffers(device: torch.device): @@ -486,9 +509,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): trust_remote_code=self.model_config.trust_remote_code, ) + # Remove layers not on this pipeline parallel rank self.pipeline_parallel() - self.init_hook() - self.tensor_parallel() + # Substitute remaining layers with vLLM's layers as needed + self.recursive_replace() + # Create attention instances for KV cache allocation + self.attention_instances = self.create_attention_instances() # Input embeddings if not isinstance(self.model.get_input_embeddings(), PPMissingLayer): @@ -503,12 +529,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config=self.quant_config, )) - # Attention layers - self.attention_instances = self.create_attention_instances() - # Initialize any parameters that have not had their modules replaced self.init_parameters(self.model) + # Pipeline parallel intermediate tensors self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( ["hidden_states"], self.text_config.hidden_size)) @@ -567,62 +591,47 @@ def pipeline_parallel(self): if not self.pp_group.is_last_rank: setattr(self.model, name, PPMissingLayer()) - def init_hook(self): - """Method to be overridden by child classes if necessary. + def recursive_replace(self): + """Recursively replace modules in the model as needed. - Called after `pipeline_parallel()` but before `tensor_parallel()`.""" - pass + Currently, this replaces: - def tensor_parallel(self): - """ - Apply the model's tensor parallelization plan. - Currently only supports linear layers. + - `nn.Linear` with vLLM's tensor parallel linear classes + - `*RMSNorm` with vLLM's `RMSNorm` """ - # Look for tp plans in all of the PreTrainedModels found in self.model - is_pretrained_model = lambda m: isinstance(m, PreTrainedModel) - supports_tp_plan = lambda m: m.config.base_model_tp_plan is not None - pretrained_models = filter(is_pretrained_model, self.model.modules()) - models_with_tp_plan = filter(supports_tp_plan, pretrained_models) + tp_plan = self.model.tp_plan - if not any(models_with_tp_plan) and self.tp_size > 1: + if not tp_plan and self.tp_size > 1: tip = get_feature_request_tip(self.model_config.model, self.model_config.trust_remote_code) raise ValueError( f"{type(self.model)} does not support tensor parallel. {tip}") - def _tensor_parallel(module: nn.Module, prefix: str, tp_plan=None): - tp_plan = tp_plan or {} - - # If the current module is a PreTrainedModel, set the tp_plan for - # all of its children - if isinstance(module, PreTrainedModel): - tp_plan = module.config.base_model_tp_plan or {} - tp_plan = { - maybe_prefix(prefix, k): v - for k, v in tp_plan.items() - } - - # Some weight loaders expect linear layers to inherit from vLLM's - # LinearBase class, so we set a default style which causes any - # unspecified linear layers to be replaced with ReplicatedLinear + def _recursive_replace(module: nn.Module, prefix: str): for child_name, child_module in module.named_children(): + new_module = child_module qual_name = maybe_prefix(prefix, child_name) if isinstance(child_module, nn.Linear): generator = (p for p in tp_plan if re.match(p, qual_name)) pattern = next(generator, None) + # Some weight loaders expect all linear layers to inherit + # LinearBase, so we set a default style which causes any + # unspecified layers to be replaced with ReplicatedLinear style = tp_plan.get(pattern, "replicate") new_module = replace_linear_class(child_module, style, self.quant_config, prefix=qual_name) + elif child_module.__class__.__name__.endswith("RMSNorm"): + new_module = replace_rms_norm_class(child_module) + else: + _recursive_replace(child_module, prefix=qual_name) + + if new_module is not child_module: setattr(module, child_name, new_module) log_replacement(qual_name, child_module, new_module) - else: - _tensor_parallel(child_module, - prefix=qual_name, - tp_plan=tp_plan) - _tensor_parallel(self.model, prefix="model") + _recursive_replace(self.model, prefix="model") def create_attention_instances( self, @@ -672,15 +681,21 @@ def init_parameters(self, self.model: PreTrainedModel = AutoModel.from_config(...) ``` """ - for name, param in module.named_parameters(recurse=False): - if param.device == torch.device("meta"): - new_param = nn.Parameter( - torch.empty_like(param.data, - dtype=dtype or self.model_config.dtype, - device=self.device_config.device)) - setattr(module, name, new_param) - for child in module.children(): - self.init_parameters(child, dtype) + + def _init_parameters(module: nn.Module, dtype: Optional[torch.dtype]): + for name, param in module.named_parameters(recurse=False): + if param.device == torch.device("meta"): + new_param = nn.Parameter( + torch.empty_like( + param.data, + dtype=dtype or self.model_config.dtype, + device=self.device_config.device, + )) + setattr(module, name, new_param) + for child in module.children(): + _init_parameters(child, dtype) + + _init_parameters(module, dtype) def forward( self, diff --git a/vllm/model_executor/models/transformers_moe.py b/vllm/model_executor/models/transformers_moe.py index 5398135c034b..cb966256b350 100644 --- a/vllm/model_executor/models/transformers_moe.py +++ b/vllm/model_executor/models/transformers_moe.py @@ -127,16 +127,8 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: )) return expert_mapping - def init_hook(self): - """Initialize the MoE layers. - - It is important that this happens: - - - After pipeline parallelism so only the `FusedMoE` layers for this - pipeline stage are created. - - Before tensor parallelism so vLLM `Linear` layers are not created for - the unfused `Linear` layers inside the Transformers MoE layers. - """ + def recursive_replace(self): + """Initialize the MoE layers.""" text_config = self.text_config # Positional arguments @@ -200,7 +192,7 @@ def forward(self, *args, **kwargs): num_redundant_experts = eplb_config.num_redundant_experts # Recursively fuse MoE layers - def _fused_moe(module: nn.Module, prefix: str): + def _recursive_replace(module: nn.Module, prefix: str): for child_name, child_module in module.named_children(): qual_name = maybe_prefix(prefix, child_name) if (child_name == "experts" @@ -250,9 +242,11 @@ def _fused_moe(module: nn.Module, prefix: str): or fused_experts.ep_size > 1): add_all_reduce(mlp) else: - _fused_moe(child_module, prefix=qual_name) + _recursive_replace(child_module, prefix=qual_name) - _fused_moe(self.model, prefix="model") + _recursive_replace(self.model, prefix="model") + # Continue with the replacement of layers in TransformersBase + super().recursive_replace() @support_torch_compile(enable_if=can_enable_torch_compile) From e55e5d60ba8792d810db5d881643b885044a3e46 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 2 Oct 2025 19:33:31 +0200 Subject: [PATCH 55/64] Add min transformers version to skip the init tests Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- tests/models/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index 4712226fd832..04a2ffadf69e 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -661,7 +661,7 @@ def check_available_online( "TransformersForSequenceClassification": _HfExamplesInfo("papluca/xlm-roberta-base-language-detection", min_transformers_version="4.57.0.dev0"), # noqa: E501 "TransformersForCausalLM": _HfExamplesInfo("hmellor/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501 "TransformersForMultimodalLM": _HfExamplesInfo("BAAI/Emu3-Chat-hf"), - "TransformersMoEForCausalLM": _HfExamplesInfo("allenai/OLMoE-1B-7B-0924"), + "TransformersMoEForCausalLM": _HfExamplesInfo("allenai/OLMoE-1B-7B-0924", min_transformers_version="4.57.0.dev0"), # noqa: E501 } _EXAMPLE_MODELS = { From 3c1b8f8e2c1bad81e4430c16d7beb547a55a0349 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 2 Oct 2025 19:54:09 +0200 Subject: [PATCH 56/64] Add edge case for Ernie Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config/model.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/config/model.py b/vllm/config/model.py index a5729f67e7c2..2bf6a1671188 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -1220,7 +1220,12 @@ def get_num_experts(self) -> int: "n_routed_experts", # DeepSeek "num_local_experts", # Mixtral ] - return getattr_iter(self.hf_text_config, num_expert_names, 0) + num_experts = getattr_iter(self.hf_text_config, num_expert_names, 0) + if isinstance(num_experts, list): + # Ernie VL's remote code uses list[int]... + # The values are always the same so we just take the first one. + return num_experts[0] + return num_experts def get_layers_start_end_indices( self, parallel_config: ParallelConfig) -> tuple[int, int]: From 839ef4bf12d860900f4cf4ca02e7dce410de5617 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 2 Oct 2025 20:02:31 +0200 Subject: [PATCH 57/64] Add missing classes to test registry Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- tests/models/registry.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/models/registry.py b/tests/models/registry.py index 04a2ffadf69e..86a835975227 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -662,6 +662,9 @@ def check_available_online( "TransformersForCausalLM": _HfExamplesInfo("hmellor/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501 "TransformersForMultimodalLM": _HfExamplesInfo("BAAI/Emu3-Chat-hf"), "TransformersMoEForCausalLM": _HfExamplesInfo("allenai/OLMoE-1B-7B-0924", min_transformers_version="4.57.0.dev0"), # noqa: E501 + "TransformersMoEForMultimodalLM": _HfExamplesInfo("Qwen/Qwen3-VL-30B-A3B-Instruct", min_transformers_version="4.57.0.dev0"), # noqa: E501 + "TransformersMoEEmbeddingModel": _HfExamplesInfo("Qwen/Qwen3-30B-A3B", min_transformers_version="4.57.0.dev0"), # noqa: E501 + "TransformersMoEForSequenceClassification": _HfExamplesInfo("Qwen/Qwen3-30B-A3B", min_transformers_version="4.57.0.dev0"), # noqa: E501 } _EXAMPLE_MODELS = { From 006902e12bb2c3c8158bb1590ae6fc60501de958 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 2 Oct 2025 19:11:57 +0100 Subject: [PATCH 58/64] Update vllm/model_executor/models/transformers.py Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 484c9cdb0f68..613417f609bc 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -176,7 +176,7 @@ def replace_rms_norm_class(rms_norm: nn.Module) -> RMSNorm: - Weight is stored as `weight`. - Epsilon is stored as `eps` or `variance_epsilon`. - `with_scale` indicates whether the layer has a weight (Gemma3n only). - - `var_hidden_size` is used only ever used for Intern vision encoder. + - `var_hidden_size` is only ever used for Intern vision encoder. """ weight = getattr(rms_norm, "weight", None) weight: Optional[torch.Tensor] = getattr(weight, "data", weight) From 1c324311dfcd2d1a9c440156f5a4db2105c0f881 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 2 Oct 2025 21:02:11 +0200 Subject: [PATCH 59/64] Always return RMSNorm Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 33 ++++++++++++++-------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 613417f609bc..612bb1c194dd 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -176,19 +176,23 @@ def replace_rms_norm_class(rms_norm: nn.Module) -> RMSNorm: - Weight is stored as `weight`. - Epsilon is stored as `eps` or `variance_epsilon`. - `with_scale` indicates whether the layer has a weight (Gemma3n only). - - `var_hidden_size` is only ever used for Intern vision encoder. + - `var_hidden_size` is only ever used for Intern vision encoder in vLLM + and Transformers doesn't appear to have the same concept. """ - weight = getattr(rms_norm, "weight", None) - weight: Optional[torch.Tensor] = getattr(weight, "data", weight) - # Return early if weight not found - if weight is None: - return rms_norm - # Construct the new RMSNorm - return RMSNorm(hidden_size=weight.numel(), - eps=getattr_iter(rms_norm, ("eps", "variance_epsilon"), - 1e-6), - has_weight=getattr(rms_norm, "with_scale", True), - dtype=weight.dtype) + kwargs = { + "eps": getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6), + "has_weight": getattr(rms_norm, "with_scale", True) + } + if (weight := getattr(rms_norm, "weight", None)) is not None: + # If weight is a Parameter, get its data tensor + weight = getattr(weight, "data", weight) + kwargs["hidden_size"] = weight.numel() + kwargs["dtype"] = weight.dtype + else: + # No weight, fall back to weightless RMSNorm + kwargs["hidden_size"] = 1 + kwargs["has_weight"] = False + return RMSNorm(**kwargs) # Copied from `accelerate` @@ -607,8 +611,13 @@ def recursive_replace(self): raise ValueError( f"{type(self.model)} does not support tensor parallel. {tip}") + # Prefix the patterns because we always start from `self.model` + tp_plan = {maybe_prefix("model", k): v for k, v in tp_plan.items()} + def _recursive_replace(module: nn.Module, prefix: str): for child_name, child_module in module.named_children(): + if child_name == "language_model": + print("here") new_module = child_module qual_name = maybe_prefix(prefix, child_name) if isinstance(child_module, nn.Linear): From b01bd6e1cd8d998ff75c8a5cd5c7584b06c6962c Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 2 Oct 2025 21:03:50 +0200 Subject: [PATCH 60/64] Type hint replace_linear_class correctly Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 612bb1c194dd..594998e5a45e 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -124,9 +124,13 @@ def can_enable_torch_compile(vllm_config: VllmConfig) -> bool: return enable +Style = Literal["colwise", "colwise_rep", "rowwise", "rowwise_rep", + "replicate"] + + def replace_linear_class( linear: nn.Linear, - style: Literal["colwise", "rowwise"] = "replicate", + style: Style = "replicate", quant_config: Optional[QuantizationConfig] = None, *, prefix: str = "", From b1b62bef81163b4fbabe56649864f96d35a3a6c1 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 2 Oct 2025 21:11:37 +0200 Subject: [PATCH 61/64] Can't use 1 because vLLM checks hidden size agains input Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 594998e5a45e..22086e6b774a 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -173,7 +173,7 @@ def replace_linear_class( ) -def replace_rms_norm_class(rms_norm: nn.Module) -> RMSNorm: +def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm: """Replace a Transformers RMSNorm with vLLM's RMSNorm. This method assumes: @@ -184,17 +184,16 @@ def replace_rms_norm_class(rms_norm: nn.Module) -> RMSNorm: and Transformers doesn't appear to have the same concept. """ kwargs = { + "hidden_size": hidden_size, "eps": getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6), "has_weight": getattr(rms_norm, "with_scale", True) } if (weight := getattr(rms_norm, "weight", None)) is not None: # If weight is a Parameter, get its data tensor weight = getattr(weight, "data", weight) - kwargs["hidden_size"] = weight.numel() kwargs["dtype"] = weight.dtype else: # No weight, fall back to weightless RMSNorm - kwargs["hidden_size"] = 1 kwargs["has_weight"] = False return RMSNorm(**kwargs) @@ -636,7 +635,8 @@ def _recursive_replace(module: nn.Module, prefix: str): self.quant_config, prefix=qual_name) elif child_module.__class__.__name__.endswith("RMSNorm"): - new_module = replace_rms_norm_class(child_module) + new_module = replace_rms_norm_class( + child_module, self.config.hidden_size) else: _recursive_replace(child_module, prefix=qual_name) From 1d4f3f9a845f81c075c4f35112dc401b89fd7cfd Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Fri, 3 Oct 2025 00:42:40 +0200 Subject: [PATCH 62/64] Fix test util making everything MoE... Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- tests/models/utils.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/tests/models/utils.py b/tests/models/utils.py index 7e731cffc047..50936114865a 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -430,17 +430,26 @@ def dummy_hf_overrides( update_dict = { "num_layers": num_layers, - "num_experts": num_experts, - "num_experts_per_tok": 2, - "num_local_experts": num_experts, - # Otherwise there will not be any expert layers - "first_k_dense_replace": 0, - # To avoid OOM on DeepSeek-V3 - "n_routed_experts": num_experts, # For Gemma-3n "num_kv_shared_layers": 1, } + class DummyConfig: + hf_text_config = text_config + + # Only set MoE related config when the model has MoE layers. + # Otherwise all models detected as MoE by _get_transformers_backend_cls. + if ModelConfig.get_num_experts(DummyConfig) > 0: + update_dict.update({ + "num_experts": num_experts, + "num_experts_per_tok": 2, + "num_local_experts": num_experts, + # Otherwise there will not be any expert layers + "first_k_dense_replace": 0, + # To avoid OOM on DeepSeek-V3 + "n_routed_experts": num_experts, + }) + # Update num_hidden_layers for non-Longcat architectures if model_arch != "LongcatFlashForCausalLM" \ and model_arch != "LongCatFlashMTPModel": From 954e163568908c31cbfa6744ad982a078846fe6c Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Fri, 3 Oct 2025 01:00:41 +0200 Subject: [PATCH 63/64] remove print Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 22086e6b774a..561eb3db986c 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -619,8 +619,6 @@ def recursive_replace(self): def _recursive_replace(module: nn.Module, prefix: str): for child_name, child_module in module.named_children(): - if child_name == "language_model": - print("here") new_module = child_module qual_name = maybe_prefix(prefix, child_name) if isinstance(child_module, nn.Linear): From adca299e4c49788773697f7dbb2dac44501b55e1 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Fri, 3 Oct 2025 01:06:09 +0200 Subject: [PATCH 64/64] Disable RMSNorm swapping for now Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 561eb3db986c..18a0dafd001d 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -632,9 +632,11 @@ def _recursive_replace(module: nn.Module, prefix: str): style, self.quant_config, prefix=qual_name) - elif child_module.__class__.__name__.endswith("RMSNorm"): - new_module = replace_rms_norm_class( - child_module, self.config.hidden_size) + # TODO(hmellor): Enable RMSNorm replacement once we have a way + # to choose RMSNorm vs GemmaRMSNorm + # elif child_module.__class__.__name__.endswith("RMSNorm"): + # new_module = replace_rms_norm_class( + # child_module, self.config.hidden_size) else: _recursive_replace(child_module, prefix=qual_name)