Skip to content

Commit 6b12b2e

Browse files
hmelloryewentao256
authored andcommitted
FusedMoE support for the Transformers backend (#22650)
Signed-off-by: Harry Mellor <[email protected]> Signed-off-by: yewentao256 <[email protected]>
1 parent bbeace2 commit 6b12b2e

File tree

10 files changed

+485
-91
lines changed

10 files changed

+485
-91
lines changed

docs/models/supported_models.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@ These models are what we list in [supported-text-models][supported-text-models]
1717

1818
### Transformers
1919

20-
vLLM also supports model implementations that are available in Transformers. You should expect the performance of a Transformers model implementation used in vLLM to be within <1% of the performance of a dedicated vLLM model implementation. We call this feature the "Transformers backend".
20+
vLLM also supports model implementations that are available in Transformers. You should expect the performance of a Transformers model implementation used in vLLM to be within <5% of the performance of a dedicated vLLM model implementation. We call this feature the "Transformers backend".
2121

2222
Currently, the Transformers backend works for the following:
2323

2424
- Modalities: embedding models, language models and vision-language models*
25-
- Architectures: encoder-only, decoder-only
25+
- Architectures: encoder-only, decoder-only, mixture-of-experts
2626
- Attention types: full attention and/or sliding attention
2727

2828
_*Vision-language models currently accept only image inputs. Support for video inputs will be added in a future release._
@@ -31,6 +31,7 @@ If the Transformers model implementation follows all the steps in [writing a cus
3131

3232
- All the features listed in the [compatibility matrix](../features/README.md#feature-x-feature)
3333
- Any combination of the following vLLM parallelisation schemes:
34+
- Data parallel
3435
- Pipeline parallel
3536
- Tensor parallel
3637

tests/models/registry.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,10 @@ def check_available_online(
661661
"TransformersForSequenceClassification": _HfExamplesInfo("papluca/xlm-roberta-base-language-detection", min_transformers_version="4.57.0.dev0"), # noqa: E501
662662
"TransformersForCausalLM": _HfExamplesInfo("hmellor/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
663663
"TransformersForMultimodalLM": _HfExamplesInfo("BAAI/Emu3-Chat-hf"),
664+
"TransformersMoEForCausalLM": _HfExamplesInfo("allenai/OLMoE-1B-7B-0924", min_transformers_version="4.57.0.dev0"), # noqa: E501
665+
"TransformersMoEForMultimodalLM": _HfExamplesInfo("Qwen/Qwen3-VL-30B-A3B-Instruct", min_transformers_version="4.57.0.dev0"), # noqa: E501
666+
"TransformersMoEEmbeddingModel": _HfExamplesInfo("Qwen/Qwen3-30B-A3B", min_transformers_version="4.57.0.dev0"), # noqa: E501
667+
"TransformersMoEForSequenceClassification": _HfExamplesInfo("Qwen/Qwen3-30B-A3B", min_transformers_version="4.57.0.dev0"), # noqa: E501
664668
}
665669

666670
_EXAMPLE_MODELS = {

tests/models/test_transformers.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def check_implementation(
6666
[
6767
("meta-llama/Llama-3.2-1B-Instruct", "transformers"),
6868
("hmellor/Ilama-3.2-1B", "auto"), # CUSTOM CODE
69+
("allenai/OLMoE-1B-7B-0924", "transformers"), # MoE
6970
]) # trust_remote_code=True by default
7071
def test_models(
7172
hf_runner: type[HfRunner],
@@ -74,6 +75,14 @@ def test_models(
7475
model: str,
7576
model_impl: str,
7677
) -> None:
78+
import transformers
79+
from packaging.version import Version
80+
installed = Version(transformers.__version__)
81+
required = Version("4.57.0.dev0")
82+
if model == "allenai/OLMoE-1B-7B-0924" and installed < required:
83+
pytest.skip("MoE models with the Transformers backend require "
84+
f"transformers>={required}, but got {installed}")
85+
7786
check_implementation(hf_runner,
7887
vllm_runner,
7988
example_prompts,

tests/models/utils.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -430,17 +430,26 @@ def dummy_hf_overrides(
430430

431431
update_dict = {
432432
"num_layers": num_layers,
433-
"num_experts": num_experts,
434-
"num_experts_per_tok": 2,
435-
"num_local_experts": num_experts,
436-
# Otherwise there will not be any expert layers
437-
"first_k_dense_replace": 0,
438-
# To avoid OOM on DeepSeek-V3
439-
"n_routed_experts": num_experts,
440433
# For Gemma-3n
441434
"num_kv_shared_layers": 1,
442435
}
443436

437+
class DummyConfig:
438+
hf_text_config = text_config
439+
440+
# Only set MoE related config when the model has MoE layers.
441+
# Otherwise all models detected as MoE by _get_transformers_backend_cls.
442+
if ModelConfig.get_num_experts(DummyConfig) > 0:
443+
update_dict.update({
444+
"num_experts": num_experts,
445+
"num_experts_per_tok": 2,
446+
"num_local_experts": num_experts,
447+
# Otherwise there will not be any expert layers
448+
"first_k_dense_replace": 0,
449+
# To avoid OOM on DeepSeek-V3
450+
"n_routed_experts": num_experts,
451+
})
452+
444453
# Update num_hidden_layers for non-Longcat architectures
445454
if model_arch != "LongcatFlashForCausalLM" \
446455
and model_arch != "LongCatFlashMTPModel":

vllm/config/model.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
MultiModalConfig)
2121
from vllm.config.pooler import PoolerConfig
2222
from vllm.config.scheduler import RunnerType
23-
from vllm.config.utils import assert_hashable, config
23+
from vllm.config.utils import assert_hashable, config, getattr_iter
2424
from vllm.logger import init_logger
2525
from vllm.platforms import current_platform
2626
from vllm.transformers_utils.config import (
@@ -667,6 +667,8 @@ def validate_model_config_after(self: "ModelConfig") -> "ModelConfig":
667667
def _get_transformers_backend_cls(self) -> str:
668668
"""Determine which Transformers backend class will be used if
669669
`model_impl` is set to `transformers` or `auto`."""
670+
prefix = "Transformers"
671+
prefix += "MoE" if self.get_num_experts() > 1 else ""
670672
# Check if the architecture we're wrapping has defaults
671673
runner = None
672674
convert = None
@@ -685,15 +687,15 @@ def _get_transformers_backend_cls(self) -> str:
685687
# Resolve Transformers backend pooling classes
686688
if runner == "pooling":
687689
if convert == "embed":
688-
return "TransformersEmbeddingModel"
690+
return prefix + "EmbeddingModel"
689691
if convert == "classify":
690-
return "TransformersForSequenceClassification"
692+
return prefix + "ForSequenceClassification"
691693
# Resolve Transformers backend generate classes
692694
if self.hf_config != self.hf_text_config:
693695
# If 'hf_text_config' is the same as 'hf_config'. If not, it is
694696
# probably a composite config, i.e. multimodal
695-
return "TransformersForMultimodalLM"
696-
return "TransformersForCausalLM"
697+
return prefix + "ForMultimodalLM"
698+
return prefix + "ForCausalLM"
697699

698700
def using_transformers_backend(self) -> bool:
699701
"""Check if the model is using the Transformers backend class."""
@@ -1025,17 +1027,7 @@ def _verify_bnb_config(self) -> None:
10251027
self.enforce_eager = True
10261028

10271029
def _verify_with_expert_parallelism(self) -> None:
1028-
num_expert_names = [
1029-
"moe_num_experts", # Dbrx
1030-
"num_experts", # Jamba
1031-
"n_routed_experts", # DeepSeek
1032-
"num_local_experts", # Mixtral
1033-
]
1034-
num_experts = 0
1035-
for name in num_expert_names:
1036-
num_experts = getattr(self.hf_text_config, name, 0)
1037-
if num_experts > 0:
1038-
break
1030+
num_experts = self.get_num_experts()
10391031
if num_experts < 1:
10401032
raise ValueError(
10411033
"Number of experts in the model must be greater than 0 "
@@ -1220,6 +1212,21 @@ def get_num_attention_heads(self, parallel_config: ParallelConfig) -> int:
12201212
num_heads = getattr(self.hf_text_config, "num_attention_heads", 0)
12211213
return num_heads // parallel_config.tensor_parallel_size
12221214

1215+
def get_num_experts(self) -> int:
1216+
"""Returns the number of experts in the model."""
1217+
num_expert_names = [
1218+
"num_experts", # Jamba
1219+
"moe_num_experts", # Dbrx
1220+
"n_routed_experts", # DeepSeek
1221+
"num_local_experts", # Mixtral
1222+
]
1223+
num_experts = getattr_iter(self.hf_text_config, num_expert_names, 0)
1224+
if isinstance(num_experts, list):
1225+
# Ernie VL's remote code uses list[int]...
1226+
# The values are always the same so we just take the first one.
1227+
return num_experts[0]
1228+
return num_experts
1229+
12231230
def get_layers_start_end_indices(
12241231
self, parallel_config: ParallelConfig) -> tuple[int, int]:
12251232
from vllm.distributed.utils import get_pp_indices

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -960,6 +960,7 @@ def __init__(
960960
is_sequence_parallel=False,
961961
zero_expert_num: Optional[int] = 0,
962962
zero_expert_type: Optional[str] = None,
963+
expert_mapping: Optional[list[tuple[str, str, int, str]]] = None,
963964
):
964965
super().__init__()
965966
if params_dtype is None:
@@ -996,6 +997,9 @@ def __init__(
996997
self.zero_expert_num = zero_expert_num
997998
self.zero_expert_type = zero_expert_type
998999

1000+
# Expert mapping used in self.load_weights
1001+
self.expert_mapping = expert_mapping
1002+
9991003
# Round up hidden size if needed.
10001004
hidden_size = maybe_roundup_hidden_size(hidden_size, moe_in_dtype,
10011005
quant_config,
@@ -1617,6 +1621,33 @@ def weight_loader(self,
16171621

16181622
return False if return_success else None
16191623

1624+
def load_weights(
1625+
self, weights: Iterable[tuple[str,
1626+
torch.Tensor]]) -> Iterable[str]:
1627+
if (expert_mapping := self.expert_mapping) is None:
1628+
raise ValueError("`self.expert_mapping` must be provided to "
1629+
"load weights using `self.load_weights`.")
1630+
for expert_name, loaded_weight in weights:
1631+
qual_name = f"{self.layer_name}.{expert_name}"
1632+
for param_name, weight_name, expert_id, shard_id in expert_mapping:
1633+
if weight_name not in qual_name:
1634+
continue
1635+
weight_name = qual_name.replace(weight_name, param_name)
1636+
param_name = weight_name.removeprefix(f"{self.layer_name}.")
1637+
param = getattr(self, param_name)
1638+
success = self.weight_loader(
1639+
param=param,
1640+
loaded_weight=loaded_weight,
1641+
weight_name=weight_name,
1642+
shard_id=shard_id,
1643+
expert_id=expert_id,
1644+
return_success=True,
1645+
)
1646+
if success:
1647+
logger.debug("Loaded %s for expert %d into %s", param_name,
1648+
expert_id, self.layer_name)
1649+
yield param_name
1650+
16201651
def get_expert_weights(self) -> Iterable[torch.Tensor]:
16211652
weights = list(self.named_parameters())
16221653
assert all(weight.is_contiguous() for _, weight in weights)

vllm/model_executor/models/registry.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,10 +307,14 @@
307307
}
308308

309309
_TRANSFORMERS_BACKEND_MODELS = {
310-
"TransformersEmbeddingModel": ("transformers_pooling", "TransformersEmbeddingModel"), # noqa: E501
311-
"TransformersForSequenceClassification": ("transformers_pooling", "TransformersForSequenceClassification"), # noqa: E501
312310
"TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
313311
"TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
312+
"TransformersMoEForCausalLM": ("transformers_moe", "TransformersMoEForCausalLM"), # noqa: E501
313+
"TransformersMoEForMultimodalLM": ("transformers_moe", "TransformersMoEForMultimodalLM"), # noqa: E501
314+
"TransformersEmbeddingModel": ("transformers_pooling", "TransformersEmbeddingModel"), # noqa: E501
315+
"TransformersForSequenceClassification": ("transformers_pooling", "TransformersForSequenceClassification"), # noqa: E501
316+
"TransformersMoEForSequenceClassification": ("transformers_pooling", "TransformersMoEForSequenceClassification"), # noqa: E501
317+
"TransformersMoEEmbeddingModel": ("transformers_pooling", "TransformersMoEEmbeddingModel"), # noqa: E501
314318
}
315319
# yapf: enable
316320

0 commit comments

Comments
 (0)