Skip to content

Commit f57ee56

Browse files
authored
[Model] Modify MolmoForCausalLM MLP (#11510)
Signed-off-by: Jee Jee Li <[email protected]>
1 parent dcb1a94 commit f57ee56

File tree

1 file changed

+24
-18
lines changed

1 file changed

+24
-18
lines changed

vllm/model_executor/models/molmo.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -464,24 +464,27 @@ def forward(
464464
class MolmoMLP(nn.Module):
465465
"""Molmo's LLM mlp."""
466466

467-
def __init__(
468-
self,
469-
config: PretrainedConfig,
470-
input_dim: Optional[int] = None,
471-
quant_config: Optional[QuantizationConfig] = None,
472-
) -> None:
467+
def __init__(self,
468+
config: PretrainedConfig,
469+
input_dim: Optional[int] = None,
470+
quant_config: Optional[QuantizationConfig] = None,
471+
proj_name: str = "gate_up_proj") -> None:
473472
super().__init__()
474473
self.hidden_size = config.hidden_size
475474
self.intermediate_size = config.intermediate_size // 2
476475

477-
# Feed-forward input projection.
478-
self.gate_up_proj = MergedColumnParallelLinear(
479-
input_dim or self.hidden_size,
480-
[self.intermediate_size] * 2,
481-
bias=False,
482-
quant_config=quant_config,
483-
)
484-
476+
# Molmo's LLM proj weights are already merged into the disk, while
477+
# image_projector proj is separate. If the same proj_name were used, it
478+
# would create ambiguity and make it difficult to support BNB and LoRA.
479+
self.proj_name = proj_name
480+
setattr(
481+
self, proj_name,
482+
MergedColumnParallelLinear(
483+
input_dim or self.hidden_size,
484+
[self.intermediate_size] * 2,
485+
bias=False,
486+
quant_config=quant_config,
487+
))
485488
# Activation function.
486489
self.act_fn = SiluAndMul()
487490

@@ -497,7 +500,7 @@ def forward(
497500
self,
498501
x: torch.Tensor,
499502
) -> torch.Tensor:
500-
gate_up, _ = self.gate_up_proj(x)
503+
gate_up, _ = getattr(self, self.proj_name)(x)
501504
x = self.act_fn(gate_up)
502505
x, _ = self.down_proj(x)
503506
return x
@@ -520,7 +523,9 @@ def __init__(
520523
prefix=f"{prefix}.self_attn")
521524

522525
# MLP block.
523-
self.mlp = MolmoMLP(config, quant_config=quant_config)
526+
self.mlp = MolmoMLP(config,
527+
quant_config=quant_config,
528+
proj_name="gate_up_proj")
524529

525530
# LayerNorm
526531
assert config.layer_norm_type == "rms"
@@ -616,6 +621,7 @@ def __init__(
616621
config,
617622
input_dim=vision_config.image_emb_dim,
618623
quant_config=quant_config,
624+
proj_name="merged_linear",
619625
)
620626

621627
image_dim = vision_config.image_emb_dim * len(self.vit_layers)
@@ -714,8 +720,8 @@ def load_weights(self, weights: Iterable[Tuple[str,
714720
torch.Tensor]]) -> Set[str]:
715721
stacked_params_mapping = [
716722
# (param_name, shard_name, shard_id)
717-
("gate_up_proj", "gate_proj", 0),
718-
("gate_up_proj", "up_proj", 1),
723+
("merged_linear", "gate_proj", 0),
724+
("merged_linear", "up_proj", 1),
719725
]
720726
params_dict = dict(self.named_parameters())
721727
loaded_params: Set[str] = set()

0 commit comments

Comments
 (0)