Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 24 additions & 18 deletions vllm/model_executor/models/molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,24 +464,27 @@ def forward(
class MolmoMLP(nn.Module):
"""Molmo's LLM mlp."""

def __init__(
self,
config: PretrainedConfig,
input_dim: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
def __init__(self,
config: PretrainedConfig,
input_dim: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
proj_name: str = "gate_up_proj") -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size // 2

# Feed-forward input projection.
self.gate_up_proj = MergedColumnParallelLinear(
input_dim or self.hidden_size,
[self.intermediate_size] * 2,
bias=False,
quant_config=quant_config,
)

# Molmo's LLM proj weights are already merged into the disk, while
# image_projector proj is separate. If the same proj_name were used, it
# would create ambiguity and make it difficult to support BNB and LoRA.
self.proj_name = proj_name
setattr(
self, proj_name,
MergedColumnParallelLinear(
input_dim or self.hidden_size,
[self.intermediate_size] * 2,
bias=False,
quant_config=quant_config,
))
# Activation function.
self.act_fn = SiluAndMul()

Expand All @@ -497,7 +500,7 @@ def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(x)
gate_up, _ = getattr(self, self.proj_name)(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
Expand All @@ -520,7 +523,9 @@ def __init__(
prefix=f"{prefix}.self_attn")

# MLP block.
self.mlp = MolmoMLP(config, quant_config=quant_config)
self.mlp = MolmoMLP(config,
quant_config=quant_config,
proj_name="gate_up_proj")

# LayerNorm
assert config.layer_norm_type == "rms"
Expand Down Expand Up @@ -616,6 +621,7 @@ def __init__(
config,
input_dim=vision_config.image_emb_dim,
quant_config=quant_config,
proj_name="merged_linear",
)

image_dim = vision_config.image_emb_dim * len(self.vit_layers)
Expand Down Expand Up @@ -714,8 +720,8 @@ def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
("merged_linear", "gate_proj", 0),
("merged_linear", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
Expand Down
Loading