Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ th {
| `FalconForCausalLM` | Falcon | `tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc. | | ✅︎ | ✅︎ |
| `FalconMambaForCausalLM` | FalconMamba | `tiiuae/falcon-mamba-7b`, `tiiuae/falcon-mamba-7b-instruct`, etc. | | ✅︎ | ✅︎ |
| `FalconH1ForCausalLM` | Falcon-H1 | `tiiuae/Falcon-H1-34B-Base`, `tiiuae/Falcon-H1-34B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `FlexOlmoForCausalLM` | FlexOlmo | `allenai/FlexOlmo-7x7B-1T`, `allenai/FlexOlmo-7x7B-1T-RT`, etc. | | ✅︎ | ✅︎ |
| `GemmaForCausalLM` | Gemma | `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Gemma2ForCausalLM` | Gemma 2 | `google/gemma-2-9b`, `google/gemma-2-27b`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it`, etc. | ✅︎ | ✅︎ | ✅︎ |
Expand Down
1 change: 1 addition & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def check_available_online(
"Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"),
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
"FalconH1ForCausalLM": _HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base"),
"FlexOlmoForCausalLM": _HfExamplesInfo("allenai/Flex-reddit-2x7B-1T"),
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"),
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"),
Expand Down
157 changes: 157 additions & 0 deletions vllm/model_executor/models/flex_olmo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only FlexOlmo model compatible with HuggingFace weights."""

from typing import Optional

import torch
from torch import nn

from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.models.olmoe import OlmoeAttention, OlmoeForCausalLM
from vllm.transformers_utils.configs import FlexOlmoConfig

logger = init_logger(__name__)


class FlexOlmoAttention(OlmoeAttention):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)

hf_config = vllm_config.model_config.hf_config
assert isinstance(hf_config, FlexOlmoConfig)

self.k_norm = RMSNorm(
self.total_num_kv_heads * self.head_dim, eps=hf_config.rms_norm_eps
)
self.q_norm = RMSNorm(
self.total_num_heads * self.head_dim, eps=hf_config.rms_norm_eps
)


class FlexOlmoMoE(nn.Module):
"""A tensor-parallel MoE implementation for FlexOlmo that shards each expert
across all ranks.

Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()

hf_config = vllm_config.model_config.hf_config
assert isinstance(hf_config, FlexOlmoConfig)

tp_size = get_tensor_model_parallel_world_size()

# Gate always runs at half / full precision for now.
self.gate = ReplicatedLinear(
hf_config.hidden_size,
hf_config.num_experts,
bias=False,
return_bias=False,
quant_config=None,
prefix=f"{prefix}.gate",
)

# Gate always runs at half / full precision for now.
self.experts = FusedMoE(
num_experts=hf_config.num_experts,
top_k=hf_config.num_experts_per_tok,
hidden_size=hf_config.hidden_size,
intermediate_size=hf_config.intermediate_size,
reduce_results=True,
renormalize=False,
quant_config=None,
tp_size=tp_size,
prefix=f"{prefix}.experts",
)

self.top_k = hf_config.num_experts_per_tok

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
hidden_dim = hidden_states.shape[-1]
hidden_states = hidden_states.view(-1, hidden_dim)

# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
# Warning: The experts mutate the hidden state input! This messes up
# basic things like the residual stream.
final_hidden_states = self.experts(
hidden_states=hidden_states.detach().clone(),
router_logits=router_logits.float(),
)

return final_hidden_states.view(orig_shape)


class FlexOlmoDecoderLayer(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
hf_config = vllm_config.model_config.hf_config
assert isinstance(hf_config, FlexOlmoConfig)

self.self_attn = FlexOlmoAttention(
vllm_config=vllm_config, prefix=f"{prefix}.self_attn"
)
self.post_attention_layernorm = RMSNorm(
hf_config.hidden_size, eps=hf_config.rms_norm_eps
)
self.post_feedforward_layernorm = RMSNorm(
hf_config.hidden_size, eps=hf_config.rms_norm_eps
)

self.mlp = FlexOlmoMoE(vllm_config=vllm_config, prefix=f"{prefix}.mlp")

def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
# Attention block.
residual = hidden_states
hidden_states = self.self_attn(positions, hidden_states)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = hidden_states + residual

# MLP block.
residual = hidden_states
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_feedforward_layernorm(hidden_states)
hidden_states = residual + hidden_states
return hidden_states, None


class FlexOlmoForCausalLM(OlmoeForCausalLM):
fall_back_to_pt_during_load = False

def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: type[nn.Module] = FlexOlmoDecoderLayer,
):
super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)
92 changes: 46 additions & 46 deletions vllm/model_executor/models/olmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,14 @@
from collections.abc import Iterable
from functools import partial
from itertools import islice
from typing import Any, Optional, Union
from typing import Optional, Union

import torch
from torch import nn
from transformers import OlmoeConfig

from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.config import VllmConfig
from vllm.distributed import (
get_pp_group,
get_tensor_model_parallel_rank,
Expand Down Expand Up @@ -117,20 +116,21 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:


class OlmoeAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 4096,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
self.hidden_size = hidden_size

config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config

self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 4096)

num_heads = config.num_attention_heads
num_kv_heads = config.num_key_value_heads

tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
Expand All @@ -145,15 +145,15 @@ def __init__(
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.head_dim = self.hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings

self.qkv_proj = QKVParallelLinear(
hidden_size,
self.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
Expand All @@ -166,7 +166,7 @@ def __init__(
self.k_norm = RMSNorm(self.total_num_kv_heads * self.head_dim, eps=1e-5)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
self.hidden_size,
bias=False,
quant_config=quant_config,
)
Expand Down Expand Up @@ -218,28 +218,15 @@ def forward(


class OlmoeDecoderLayer(nn.Module):
def __init__(
self,
config: OlmoeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config

self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 4096)

self.self_attn = OlmoeAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
vllm_config=vllm_config,
prefix=f"{prefix}.self_attn",
)

Expand Down Expand Up @@ -280,12 +267,16 @@ def forward(

@support_torch_compile
class OlmoeModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: type[nn.Module] = OlmoeDecoderLayer,
):
super().__init__()

config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config

self.vocab_size = config.vocab_size
self.config = config
Expand All @@ -295,9 +286,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: OlmoeDecoderLayer(
config, cache_config, quant_config, prefix=prefix
),
lambda prefix: layer_type(vllm_config=vllm_config, prefix=prefix),
prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
Expand Down Expand Up @@ -339,7 +328,10 @@ def forward(
{"hidden_states": hidden_states, "residual": residual}
)

hidden_states, _ = self.norm(hidden_states, residual)
if residual is not None:
hidden_states, _ = self.norm(hidden_states, residual)
else:
hidden_states = self.norm(hidden_states)
return hidden_states

def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
Expand Down Expand Up @@ -455,14 +447,22 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
],
}

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: type[nn.Module] = OlmoeDecoderLayer,
):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = OlmoeModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"),
layer_type=layer_type,
)
self.lm_head = ParallelLMHead(
config.vocab_size,
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
"Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"),
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
"Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"),
"FlexOlmoForCausalLM": ("flex_olmo", "FlexOlmoForCausalLM"),
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
"Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
Expand Down
1 change: 1 addition & 0 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __getitem__(self, key):
deepseek_vl_v2="DeepseekVLV2Config",
deepseek_v3="DeepseekV3Config",
deepseek_v32="DeepseekV3Config",
flex_olmo="FlexOlmoConfig",
kimi_vl="KimiVLConfig",
Llama_Nemotron_Nano_VL="Nemotron_Nano_VL_Config",
RefinedWeb="RWConfig", # For tiiuae/falcon-40b(-instruct)
Expand Down
2 changes: 2 additions & 0 deletions vllm/transformers_utils/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
# `FalconConfig` class from the official HuggingFace transformers library.
from vllm.transformers_utils.configs.falcon import RWConfig
from vllm.transformers_utils.configs.flex_olmo import FlexOlmoConfig
from vllm.transformers_utils.configs.jais import JAISConfig
from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig
from vllm.transformers_utils.configs.lfm2_moe import Lfm2MoeConfig
Expand Down Expand Up @@ -45,6 +46,7 @@
"DeepseekV3Config",
"DotsOCRConfig",
"EAGLEConfig",
"FlexOlmoConfig",
"RWConfig",
"JAISConfig",
"Lfm2MoeConfig",
Expand Down
Loading