diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 5ac8f2121f97..2310747c3a2d 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -363,6 +363,7 @@ th { | `FalconForCausalLM` | Falcon | `tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc. | | ✅︎ | ✅︎ | | `FalconMambaForCausalLM` | FalconMamba | `tiiuae/falcon-mamba-7b`, `tiiuae/falcon-mamba-7b-instruct`, etc. | | ✅︎ | ✅︎ | | `FalconH1ForCausalLM` | Falcon-H1 | `tiiuae/Falcon-H1-34B-Base`, `tiiuae/Falcon-H1-34B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `FlexOlmoForCausalLM` | FlexOlmo | `allenai/FlexOlmo-7x7B-1T`, `allenai/FlexOlmo-7x7B-1T-RT`, etc. | | ✅︎ | ✅︎ | | `GemmaForCausalLM` | Gemma | `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Gemma2ForCausalLM` | Gemma 2 | `google/gemma-2-9b`, `google/gemma-2-27b`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it`, etc. | ✅︎ | ✅︎ | ✅︎ | diff --git a/tests/models/registry.py b/tests/models/registry.py index b581eb1851cb..219d7790e8ed 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -250,6 +250,7 @@ def check_available_online( "Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), "FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"), "FalconH1ForCausalLM": _HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base"), + "FlexOlmoForCausalLM": _HfExamplesInfo("allenai/Flex-reddit-2x7B-1T"), "GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"), "Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"), "Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"), diff --git a/vllm/model_executor/models/flex_olmo.py b/vllm/model_executor/models/flex_olmo.py new file mode 100644 index 000000000000..b1fbbf086896 --- /dev/null +++ b/vllm/model_executor/models/flex_olmo.py @@ -0,0 +1,157 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only FlexOlmo model compatible with HuggingFace weights.""" + +from typing import Optional + +import torch +from torch import nn + +from vllm.config import VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.models.olmoe import OlmoeAttention, OlmoeForCausalLM +from vllm.transformers_utils.configs import FlexOlmoConfig + +logger = init_logger(__name__) + + +class FlexOlmoAttention(OlmoeAttention): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + + hf_config = vllm_config.model_config.hf_config + assert isinstance(hf_config, FlexOlmoConfig) + + self.k_norm = RMSNorm( + self.total_num_kv_heads * self.head_dim, eps=hf_config.rms_norm_eps + ) + self.q_norm = RMSNorm( + self.total_num_heads * self.head_dim, eps=hf_config.rms_norm_eps + ) + + +class FlexOlmoMoE(nn.Module): + """A tensor-parallel MoE implementation for FlexOlmo that shards each expert + across all ranks. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + hf_config = vllm_config.model_config.hf_config + assert isinstance(hf_config, FlexOlmoConfig) + + tp_size = get_tensor_model_parallel_world_size() + + # Gate always runs at half / full precision for now. + self.gate = ReplicatedLinear( + hf_config.hidden_size, + hf_config.num_experts, + bias=False, + return_bias=False, + quant_config=None, + prefix=f"{prefix}.gate", + ) + + # Gate always runs at half / full precision for now. + self.experts = FusedMoE( + num_experts=hf_config.num_experts, + top_k=hf_config.num_experts_per_tok, + hidden_size=hf_config.hidden_size, + intermediate_size=hf_config.intermediate_size, + reduce_results=True, + renormalize=False, + quant_config=None, + tp_size=tp_size, + prefix=f"{prefix}.experts", + ) + + self.top_k = hf_config.num_experts_per_tok + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] + hidden_states = hidden_states.view(-1, hidden_dim) + + # router_logits: (num_tokens, n_experts) + router_logits = self.gate(hidden_states) + # Warning: The experts mutate the hidden state input! This messes up + # basic things like the residual stream. + final_hidden_states = self.experts( + hidden_states=hidden_states.detach().clone(), + router_logits=router_logits.float(), + ) + + return final_hidden_states.view(orig_shape) + + +class FlexOlmoDecoderLayer(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + hf_config = vllm_config.model_config.hf_config + assert isinstance(hf_config, FlexOlmoConfig) + + self.self_attn = FlexOlmoAttention( + vllm_config=vllm_config, prefix=f"{prefix}.self_attn" + ) + self.post_attention_layernorm = RMSNorm( + hf_config.hidden_size, eps=hf_config.rms_norm_eps + ) + self.post_feedforward_layernorm = RMSNorm( + hf_config.hidden_size, eps=hf_config.rms_norm_eps + ) + + self.mlp = FlexOlmoMoE(vllm_config=vllm_config, prefix=f"{prefix}.mlp") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + # Attention block. + residual = hidden_states + hidden_states = self.self_attn(positions, hidden_states) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = hidden_states + residual + + # MLP block. + residual = hidden_states + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + return hidden_states, None + + +class FlexOlmoForCausalLM(OlmoeForCausalLM): + fall_back_to_pt_during_load = False + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = FlexOlmoDecoderLayer, + ): + super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type) diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index 90ec1a890417..0e4b408775f5 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -17,15 +17,14 @@ from collections.abc import Iterable from functools import partial from itertools import islice -from typing import Any, Optional, Union +from typing import Optional, Union import torch from torch import nn -from transformers import OlmoeConfig from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig +from vllm.config import VllmConfig from vllm.distributed import ( get_pp_group, get_tensor_model_parallel_rank, @@ -117,20 +116,21 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class OlmoeAttention(nn.Module): - def __init__( - self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - rope_theta: float = 10000, - rope_scaling: Optional[dict[str, Any]] = None, - max_position_embeddings: int = 4096, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() - self.hidden_size = hidden_size + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", 4096) + + num_heads = config.num_attention_heads + num_kv_heads = config.num_key_value_heads + tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads assert self.total_num_heads % tp_size == 0 @@ -145,7 +145,7 @@ def __init__( # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - self.head_dim = hidden_size // self.total_num_heads + self.head_dim = self.hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -153,7 +153,7 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.qkv_proj = QKVParallelLinear( - hidden_size, + self.hidden_size, self.head_dim, self.total_num_heads, self.total_num_kv_heads, @@ -166,7 +166,7 @@ def __init__( self.k_norm = RMSNorm(self.total_num_kv_heads * self.head_dim, eps=1e-5) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, - hidden_size, + self.hidden_size, bias=False, quant_config=quant_config, ) @@ -218,28 +218,15 @@ def forward( class OlmoeDecoderLayer(nn.Module): - def __init__( - self, - config: OlmoeConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", 4096) self.self_attn = OlmoeAttention( - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - num_kv_heads=config.num_key_value_heads, - rope_theta=rope_theta, - rope_scaling=rope_scaling, - max_position_embeddings=max_position_embeddings, - cache_config=cache_config, - quant_config=quant_config, + vllm_config=vllm_config, prefix=f"{prefix}.self_attn", ) @@ -280,12 +267,16 @@ def forward( @support_torch_compile class OlmoeModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = OlmoeDecoderLayer, + ): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config self.vocab_size = config.vocab_size self.config = config @@ -295,9 +286,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: OlmoeDecoderLayer( - config, cache_config, quant_config, prefix=prefix - ), + lambda prefix: layer_type(vllm_config=vllm_config, prefix=prefix), prefix=f"{prefix}.layers", ) self.norm = RMSNorm(config.hidden_size, eps=1e-5) @@ -339,7 +328,10 @@ def forward( {"hidden_states": hidden_states, "residual": residual} ) - hidden_states, _ = self.norm(hidden_states, residual) + if residual is not None: + hidden_states, _ = self.norm(hidden_states, residual) + else: + hidden_states = self.norm(hidden_states) return hidden_states def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: @@ -455,14 +447,22 @@ class OlmoeForCausalLM(nn.Module, SupportsPP): ], } - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = OlmoeDecoderLayer, + ): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config self.model = OlmoeModel( - vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + layer_type=layer_type, ) self.lm_head = ParallelLMHead( config.vocab_size, diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 32e50f9a8e48..080489f16eca 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -90,6 +90,7 @@ "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"), "FalconForCausalLM": ("falcon", "FalconForCausalLM"), "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"), + "FlexOlmoForCausalLM": ("flex_olmo", "FlexOlmoForCausalLM"), "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"), "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"), "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"), diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 87bbe73d834a..4a8bb8f8b41d 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -74,6 +74,7 @@ def __getitem__(self, key): deepseek_vl_v2="DeepseekVLV2Config", deepseek_v3="DeepseekV3Config", deepseek_v32="DeepseekV3Config", + flex_olmo="FlexOlmoConfig", kimi_vl="KimiVLConfig", Llama_Nemotron_Nano_VL="Nemotron_Nano_VL_Config", RefinedWeb="RWConfig", # For tiiuae/falcon-40b(-instruct) diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 6917123ce662..befe9cdae76a 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -17,6 +17,7 @@ # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the # `FalconConfig` class from the official HuggingFace transformers library. from vllm.transformers_utils.configs.falcon import RWConfig +from vllm.transformers_utils.configs.flex_olmo import FlexOlmoConfig from vllm.transformers_utils.configs.jais import JAISConfig from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig from vllm.transformers_utils.configs.lfm2_moe import Lfm2MoeConfig @@ -45,6 +46,7 @@ "DeepseekV3Config", "DotsOCRConfig", "EAGLEConfig", + "FlexOlmoConfig", "RWConfig", "JAISConfig", "Lfm2MoeConfig", diff --git a/vllm/transformers_utils/configs/flex_olmo.py b/vllm/transformers_utils/configs/flex_olmo.py new file mode 100644 index 000000000000..1f2f4d446288 --- /dev/null +++ b/vllm/transformers_utils/configs/flex_olmo.py @@ -0,0 +1,77 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from transformers.configuration_utils import PretrainedConfig + + +class FlexOlmoConfig(PretrainedConfig): + model_type = "flex_olmo" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=100352, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-06, + use_cache=True, + pad_token_id=100277, + bos_token_id=None, + eos_token_id=100257, + tie_word_embeddings=False, + rope_theta=500000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + num_experts_per_tok=5, + num_experts=7, + output_router_logits=False, + router_aux_loss_coef=0.01, + norm_topk_prob=False, + **kwargs, + ): + if "architectures" not in kwargs: + kwargs["architectures"] = ["FlexOlmoForCausalLM"] + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.num_experts_per_tok = num_experts_per_tok + self.num_experts = num_experts + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.norm_topk_prob = norm_topk_prob + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"]