Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"stabilityai/stablelm-3b-4e1t",
# "allenai/OLMo-1B", # Broken
"bigcode/starcoder2-3b",
"google/gemma-1.1-2b-it",
]


Expand Down
25 changes: 11 additions & 14 deletions vllm/model_executor/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
Expand Down Expand Up @@ -148,12 +148,14 @@ def __init__(self,
quant_config=quant_config,
)

self.rotary_emb = get_rope(
# TODO(woosuk): Use the `get_rope` interface.
self.rotary_emb = GemmaRotaryEmbedding(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
max_position_embeddings=max_position_embeddings,
base=self.rope_theta,
is_neox_style=True,
dtype=torch.get_default_dtype(),
)
self.attn = Attention(self.num_heads,
self.head_dim,
Expand Down Expand Up @@ -204,10 +206,10 @@ def __init__(
hidden_activation=getattr(config, "hidden_activation", None),
quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.input_layernorm = GemmaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)

def forward(
self,
Expand Down Expand Up @@ -257,7 +259,7 @@ def __init__(
GemmaDecoderLayer(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

# Normalize the embedding by sqrt(hidden_size)
# The normalizer's data type should be downcasted to the model's
Expand Down Expand Up @@ -331,7 +333,6 @@ def __init__(
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()

@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
Expand Down Expand Up @@ -388,10 +389,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# GemmaRMSNorm is different from Llama's in that it multiplies
# (1 + weight) to the output, instead of just weight.
if "norm.weight" in name:
loaded_weight += 1.0
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
Expand Down