Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
78 commits
Select commit Hold shift + click to select a range
3b3356b
Add MoE base class to Transformers backend
hmellor Aug 11, 2025
ee32a6e
Add MoE model classes to Transformers backend
hmellor Aug 11, 2025
1f0b49e
Merge branch 'main' into transformers-backend-fused-moe
hmellor Aug 13, 2025
64e827e
Make expert mapping comprehensive
hmellor Aug 13, 2025
48ab51a
Add guard against EPLB until it's tested
hmellor Aug 13, 2025
80f15e2
Add FusedMoE weight loading to `AutoWeightLoader`
hmellor Aug 13, 2025
68a132f
More compact expert mapping generation
hmellor Aug 13, 2025
781d15d
Use forward hook to perform all reduce after experts in TP/EP
hmellor Aug 13, 2025
8b1f3e8
Remove `params_dtype` because it's not used for most MoEs
hmellor Aug 13, 2025
5b6f5fd
Better debug log for expert loading
hmellor Aug 13, 2025
c19ae9b
Update transformers backend doc
hmellor Aug 14, 2025
a9272bb
Small doc tweak
hmellor Aug 14, 2025
690df42
Add docstring to `reduce_results`
hmellor Aug 14, 2025
fd8bddb
Set `renormalize` correctly
hmellor Aug 14, 2025
87473f7
Set `top_k` correctly
hmellor Aug 14, 2025
4d7e41c
Add support for grouped topk expert selection
hmellor Aug 14, 2025
a6c0483
Make `use_grouped_topk` a bool
hmellor Aug 14, 2025
827540c
Merge branch 'main' into transformers-backend-fused-moe
hmellor Aug 25, 2025
cc45642
Add note for removal of `get_num_experts`
hmellor Aug 25, 2025
d711ab5
Better handling of shared experts and renosmalisation
hmellor Aug 25, 2025
88154b6
Add util which does `getattr` for a list of possible names
hmellor Aug 25, 2025
bf26f60
Use new util in ModelConfig
hmellor Aug 25, 2025
982643c
Set `intermediate_size` properly
hmellor Aug 25, 2025
583d5c6
Move reduction kwargs to be together
hmellor Aug 25, 2025
e4fad6f
Set `e_score_correction_bias` correctly
hmellor Aug 25, 2025
b5a5cbc
Add EPLB support
hmellor Aug 25, 2025
dbd0c4e
Label remaining missing features with the models that require them
hmellor Aug 25, 2025
3ff86cb
Merge branch 'main' into transformers-backend-fused-moe
hmellor Aug 28, 2025
40c097b
Fix `support_torch_compile` for MoE classes
hmellor Aug 28, 2025
55fb46f
Fix typo
hmellor Aug 28, 2025
4297d9f
Merge branch 'main' into transformers-backend-fused-moe
hmellor Sep 17, 2025
7585379
Merge branch 'main' into transformers-backend-fused-moe
hmellor Sep 22, 2025
6277c32
Merge branch 'main' of https://github.com/vllm-project/vllm into tran…
hmellor Sep 22, 2025
47b9d53
CustomFusedMoE
hmellor Sep 23, 2025
8b35117
Use custom op for top_k_index bypass
hmellor Sep 25, 2025
524b4c4
Fix reduce_results handling
hmellor Sep 25, 2025
8b2bd59
Reorganise the kwargs a little
hmellor Sep 25, 2025
c4120fb
Make docs claim more conservative
hmellor Sep 25, 2025
0134d62
Better transformers backend class resolution
hmellor Sep 25, 2025
dbe1352
Allow MoE pooling
hmellor Sep 25, 2025
152d7ff
Remove TODO
hmellor Sep 25, 2025
199d1db
Extract `TransformersMoE` classes to `transformers_moe`
hmellor Sep 25, 2025
4080932
Handle MXFP4 in linears and attentions
hmellor Sep 25, 2025
e142058
Add error if user tries MXFP4
hmellor Sep 25, 2025
957af89
Add errors for enable expert parallel
hmellor Sep 25, 2025
64574ea
Merge branch 'main' into transformers-backend-fused-moe
hmellor Sep 26, 2025
8f53035
Make minimum version checking part of `TransformersBase`
hmellor Sep 26, 2025
9faf3d2
Slghtly improve `FusedMoE` kwarg detection
hmellor Sep 26, 2025
42ce295
Fix experts loading
hmellor Sep 26, 2025
3725c30
Add defaults to `replace_linear_class`
hmellor Sep 26, 2025
6142cdb
Handle `AutoGPTQ` not quantising `gate`
hmellor Sep 26, 2025
7ab3832
Add test
hmellor Sep 26, 2025
928e9d5
Move expert weight loading to `FusedMoE.load_weights`
hmellor Sep 27, 2025
fffa989
Merge branch 'main' into transformers-backend-fused-moe
hmellor Sep 27, 2025
220d749
Convert forward hook to wrapper class
hmellor Sep 29, 2025
e86b359
Merge branch 'main' into transformers-backend-fused-moe
hmellor Sep 29, 2025
1c6b2cd
Remove GPTQ workaround
hmellor Sep 29, 2025
14e091f
Use smaller model in test
hmellor Sep 29, 2025
a47791e
Merge branch 'main' into transformers-backend-fused-moe
hmellor Sep 30, 2025
e580064
Add MoE test registry entry
hmellor Sep 30, 2025
3a6c08c
Leave only the MXFP4 not implemented error
hmellor Sep 30, 2025
32666f6
Merge branch 'main' into transformers-backend-fused-moe
hmellor Sep 30, 2025
d327a67
Add MoE versions of the new pooling classes
hmellor Sep 30, 2025
077f62f
Merge branch 'main' into transformers-backend-fused-moe
hmellor Oct 2, 2025
2ff8022
Merge branch 'main' into transformers-backend-fused-moe
hmellor Oct 2, 2025
48d5993
Type hint `TransformersPoolingBase.create_attention_instances` properly
hmellor Oct 2, 2025
2da6140
Merge `init_hook` and `tensor_parallel` into `recursive_replace` (als…
hmellor Oct 2, 2025
e55e5d6
Add min transformers version to skip the init tests
hmellor Oct 2, 2025
3c1b8f8
Add edge case for Ernie
hmellor Oct 2, 2025
839ef4b
Add missing classes to test registry
hmellor Oct 2, 2025
006902e
Update vllm/model_executor/models/transformers.py
hmellor Oct 2, 2025
1c32431
Always return RMSNorm
hmellor Oct 2, 2025
b01bd6e
Type hint replace_linear_class correctly
hmellor Oct 2, 2025
b1b62be
Can't use 1 because vLLM checks hidden size agains input
hmellor Oct 2, 2025
1d4f3f9
Fix test util making everything MoE...
hmellor Oct 2, 2025
761c377
Merge branch 'main' into transformers-backend-fused-moe
hmellor Oct 2, 2025
954e163
remove print
hmellor Oct 2, 2025
adca299
Disable RMSNorm swapping for now
hmellor Oct 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ These models are what we list in [supported-text-models][supported-text-models]

### Transformers

vLLM also supports model implementations that are available in Transformers. You should expect the performance of a Transformers model implementation used in vLLM to be within <1% of the performance of a dedicated vLLM model implementation. We call this feature the "Transformers backend".
vLLM also supports model implementations that are available in Transformers. You should expect the performance of a Transformers model implementation used in vLLM to be within <5% of the performance of a dedicated vLLM model implementation. We call this feature the "Transformers backend".

Currently, the Transformers backend works for the following:

- Modalities: embedding models, language models and vision-language models*
- Architectures: encoder-only, decoder-only
- Architectures: encoder-only, decoder-only, mixture-of-experts
- Attention types: full attention and/or sliding attention

_*Vision-language models currently accept only image inputs. Support for video inputs will be added in a future release._
Expand All @@ -31,6 +31,7 @@ If the Transformers model implementation follows all the steps in [writing a cus

- All the features listed in the [compatibility matrix](../features/README.md#feature-x-feature)
- Any combination of the following vLLM parallelisation schemes:
- Data parallel
- Pipeline parallel
- Tensor parallel

Expand Down
4 changes: 4 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,10 @@ def check_available_online(
"TransformersForSequenceClassification": _HfExamplesInfo("papluca/xlm-roberta-base-language-detection", min_transformers_version="4.57.0.dev0"), # noqa: E501
"TransformersForCausalLM": _HfExamplesInfo("hmellor/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
"TransformersForMultimodalLM": _HfExamplesInfo("BAAI/Emu3-Chat-hf"),
"TransformersMoEForCausalLM": _HfExamplesInfo("allenai/OLMoE-1B-7B-0924", min_transformers_version="4.57.0.dev0"), # noqa: E501
"TransformersMoEForMultimodalLM": _HfExamplesInfo("Qwen/Qwen3-VL-30B-A3B-Instruct", min_transformers_version="4.57.0.dev0"), # noqa: E501
"TransformersMoEEmbeddingModel": _HfExamplesInfo("Qwen/Qwen3-30B-A3B", min_transformers_version="4.57.0.dev0"), # noqa: E501
"TransformersMoEForSequenceClassification": _HfExamplesInfo("Qwen/Qwen3-30B-A3B", min_transformers_version="4.57.0.dev0"), # noqa: E501
}

_EXAMPLE_MODELS = {
Expand Down
9 changes: 9 additions & 0 deletions tests/models/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def check_implementation(
[
("meta-llama/Llama-3.2-1B-Instruct", "transformers"),
("hmellor/Ilama-3.2-1B", "auto"), # CUSTOM CODE
("allenai/OLMoE-1B-7B-0924", "transformers"), # MoE
]) # trust_remote_code=True by default
def test_models(
hf_runner: type[HfRunner],
Expand All @@ -74,6 +75,14 @@ def test_models(
model: str,
model_impl: str,
) -> None:
import transformers
from packaging.version import Version
installed = Version(transformers.__version__)
required = Version("4.57.0.dev0")
if model == "allenai/OLMoE-1B-7B-0924" and installed < required:
pytest.skip("MoE models with the Transformers backend require "
f"transformers>={required}, but got {installed}")

check_implementation(hf_runner,
vllm_runner,
example_prompts,
Expand Down
23 changes: 16 additions & 7 deletions tests/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,17 +430,26 @@ def dummy_hf_overrides(

update_dict = {
"num_layers": num_layers,
"num_experts": num_experts,
"num_experts_per_tok": 2,
"num_local_experts": num_experts,
# Otherwise there will not be any expert layers
"first_k_dense_replace": 0,
# To avoid OOM on DeepSeek-V3
"n_routed_experts": num_experts,
# For Gemma-3n
"num_kv_shared_layers": 1,
}

class DummyConfig:
hf_text_config = text_config

# Only set MoE related config when the model has MoE layers.
# Otherwise all models detected as MoE by _get_transformers_backend_cls.
if ModelConfig.get_num_experts(DummyConfig) > 0:
update_dict.update({
"num_experts": num_experts,
"num_experts_per_tok": 2,
"num_local_experts": num_experts,
# Otherwise there will not be any expert layers
"first_k_dense_replace": 0,
# To avoid OOM on DeepSeek-V3
"n_routed_experts": num_experts,
})

# Update num_hidden_layers for non-Longcat architectures
if model_arch != "LongcatFlashForCausalLM" \
and model_arch != "LongCatFlashMTPModel":
Expand Down
39 changes: 23 additions & 16 deletions vllm/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
MultiModalConfig)
from vllm.config.pooler import PoolerConfig
from vllm.config.scheduler import RunnerType
from vllm.config.utils import assert_hashable, config
from vllm.config.utils import assert_hashable, config, getattr_iter
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.transformers_utils.config import (
Expand Down Expand Up @@ -667,6 +667,8 @@ def validate_model_config_after(self: "ModelConfig") -> "ModelConfig":
def _get_transformers_backend_cls(self) -> str:
"""Determine which Transformers backend class will be used if
`model_impl` is set to `transformers` or `auto`."""
prefix = "Transformers"
prefix += "MoE" if self.get_num_experts() > 1 else ""
# Check if the architecture we're wrapping has defaults
runner = None
convert = None
Expand All @@ -685,15 +687,15 @@ def _get_transformers_backend_cls(self) -> str:
# Resolve Transformers backend pooling classes
if runner == "pooling":
if convert == "embed":
return "TransformersEmbeddingModel"
return prefix + "EmbeddingModel"
if convert == "classify":
return "TransformersForSequenceClassification"
return prefix + "ForSequenceClassification"
# Resolve Transformers backend generate classes
if self.hf_config != self.hf_text_config:
# If 'hf_text_config' is the same as 'hf_config'. If not, it is
# probably a composite config, i.e. multimodal
return "TransformersForMultimodalLM"
return "TransformersForCausalLM"
return prefix + "ForMultimodalLM"
return prefix + "ForCausalLM"

def using_transformers_backend(self) -> bool:
"""Check if the model is using the Transformers backend class."""
Expand Down Expand Up @@ -1025,17 +1027,7 @@ def _verify_bnb_config(self) -> None:
self.enforce_eager = True

def _verify_with_expert_parallelism(self) -> None:
num_expert_names = [
"moe_num_experts", # Dbrx
"num_experts", # Jamba
"n_routed_experts", # DeepSeek
"num_local_experts", # Mixtral
]
num_experts = 0
for name in num_expert_names:
num_experts = getattr(self.hf_text_config, name, 0)
if num_experts > 0:
break
num_experts = self.get_num_experts()
if num_experts < 1:
raise ValueError(
"Number of experts in the model must be greater than 0 "
Expand Down Expand Up @@ -1220,6 +1212,21 @@ def get_num_attention_heads(self, parallel_config: ParallelConfig) -> int:
num_heads = getattr(self.hf_text_config, "num_attention_heads", 0)
return num_heads // parallel_config.tensor_parallel_size

def get_num_experts(self) -> int:
"""Returns the number of experts in the model."""
num_expert_names = [
"num_experts", # Jamba
"moe_num_experts", # Dbrx
"n_routed_experts", # DeepSeek
"num_local_experts", # Mixtral
]
num_experts = getattr_iter(self.hf_text_config, num_expert_names, 0)
if isinstance(num_experts, list):
# Ernie VL's remote code uses list[int]...
# The values are always the same so we just take the first one.
return num_experts[0]
return num_experts

def get_layers_start_end_indices(
self, parallel_config: ParallelConfig) -> tuple[int, int]:
from vllm.distributed.utils import get_pp_indices
Expand Down
31 changes: 31 additions & 0 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,6 +960,7 @@ def __init__(
is_sequence_parallel=False,
zero_expert_num: Optional[int] = 0,
zero_expert_type: Optional[str] = None,
expert_mapping: Optional[list[tuple[str, str, int, str]]] = None,
):
super().__init__()
if params_dtype is None:
Expand Down Expand Up @@ -996,6 +997,9 @@ def __init__(
self.zero_expert_num = zero_expert_num
self.zero_expert_type = zero_expert_type

# Expert mapping used in self.load_weights
self.expert_mapping = expert_mapping

# Round up hidden size if needed.
hidden_size = maybe_roundup_hidden_size(hidden_size, moe_in_dtype,
quant_config,
Expand Down Expand Up @@ -1617,6 +1621,33 @@ def weight_loader(self,

return False if return_success else None

def load_weights(
self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Iterable[str]:
if (expert_mapping := self.expert_mapping) is None:
raise ValueError("`self.expert_mapping` must be provided to "
"load weights using `self.load_weights`.")
for expert_name, loaded_weight in weights:
qual_name = f"{self.layer_name}.{expert_name}"
for param_name, weight_name, expert_id, shard_id in expert_mapping:
if weight_name not in qual_name:
continue
weight_name = qual_name.replace(weight_name, param_name)
param_name = weight_name.removeprefix(f"{self.layer_name}.")
param = getattr(self, param_name)
success = self.weight_loader(
param=param,
loaded_weight=loaded_weight,
weight_name=weight_name,
shard_id=shard_id,
expert_id=expert_id,
return_success=True,
)
if success:
logger.debug("Loaded %s for expert %d into %s", param_name,
expert_id, self.layer_name)
yield param_name

def get_expert_weights(self) -> Iterable[torch.Tensor]:
weights = list(self.named_parameters())
assert all(weight.is_contiguous() for _, weight in weights)
Expand Down
8 changes: 6 additions & 2 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,10 +307,14 @@
}

_TRANSFORMERS_BACKEND_MODELS = {
"TransformersEmbeddingModel": ("transformers_pooling", "TransformersEmbeddingModel"), # noqa: E501
"TransformersForSequenceClassification": ("transformers_pooling", "TransformersForSequenceClassification"), # noqa: E501
"TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
"TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
"TransformersMoEForCausalLM": ("transformers_moe", "TransformersMoEForCausalLM"), # noqa: E501
"TransformersMoEForMultimodalLM": ("transformers_moe", "TransformersMoEForMultimodalLM"), # noqa: E501
"TransformersEmbeddingModel": ("transformers_pooling", "TransformersEmbeddingModel"), # noqa: E501
"TransformersForSequenceClassification": ("transformers_pooling", "TransformersForSequenceClassification"), # noqa: E501
"TransformersMoEForSequenceClassification": ("transformers_pooling", "TransformersMoEForSequenceClassification"), # noqa: E501
"TransformersMoEEmbeddingModel": ("transformers_pooling", "TransformersMoEEmbeddingModel"), # noqa: E501
}
# yapf: enable

Expand Down
Loading