@@ -464,24 +464,27 @@ def forward(
464
464
class MolmoMLP (nn .Module ):
465
465
"""Molmo's LLM mlp."""
466
466
467
- def __init__ (
468
- self ,
469
- config : PretrainedConfig ,
470
- input_dim : Optional [int ] = None ,
471
- quant_config : Optional [QuantizationConfig ] = None ,
472
- ) -> None :
467
+ def __init__ (self ,
468
+ config : PretrainedConfig ,
469
+ input_dim : Optional [int ] = None ,
470
+ quant_config : Optional [QuantizationConfig ] = None ,
471
+ proj_name : str = "gate_up_proj" ) -> None :
473
472
super ().__init__ ()
474
473
self .hidden_size = config .hidden_size
475
474
self .intermediate_size = config .intermediate_size // 2
476
475
477
- # Feed-forward input projection.
478
- self .gate_up_proj = MergedColumnParallelLinear (
479
- input_dim or self .hidden_size ,
480
- [self .intermediate_size ] * 2 ,
481
- bias = False ,
482
- quant_config = quant_config ,
483
- )
484
-
476
+ # Molmo's LLM proj weights are already merged into the disk, while
477
+ # image_projector proj is separate. If the same proj_name were used, it
478
+ # would create ambiguity and make it difficult to support BNB and LoRA.
479
+ self .proj_name = proj_name
480
+ setattr (
481
+ self , proj_name ,
482
+ MergedColumnParallelLinear (
483
+ input_dim or self .hidden_size ,
484
+ [self .intermediate_size ] * 2 ,
485
+ bias = False ,
486
+ quant_config = quant_config ,
487
+ ))
485
488
# Activation function.
486
489
self .act_fn = SiluAndMul ()
487
490
@@ -497,7 +500,7 @@ def forward(
497
500
self ,
498
501
x : torch .Tensor ,
499
502
) -> torch .Tensor :
500
- gate_up , _ = self . gate_up_proj (x )
503
+ gate_up , _ = getattr ( self , self . proj_name ) (x )
501
504
x = self .act_fn (gate_up )
502
505
x , _ = self .down_proj (x )
503
506
return x
@@ -520,7 +523,9 @@ def __init__(
520
523
prefix = f"{ prefix } .self_attn" )
521
524
522
525
# MLP block.
523
- self .mlp = MolmoMLP (config , quant_config = quant_config )
526
+ self .mlp = MolmoMLP (config ,
527
+ quant_config = quant_config ,
528
+ proj_name = "gate_up_proj" )
524
529
525
530
# LayerNorm
526
531
assert config .layer_norm_type == "rms"
@@ -616,6 +621,7 @@ def __init__(
616
621
config ,
617
622
input_dim = vision_config .image_emb_dim ,
618
623
quant_config = quant_config ,
624
+ proj_name = "merged_linear" ,
619
625
)
620
626
621
627
image_dim = vision_config .image_emb_dim * len (self .vit_layers )
@@ -714,8 +720,8 @@ def load_weights(self, weights: Iterable[Tuple[str,
714
720
torch .Tensor ]]) -> Set [str ]:
715
721
stacked_params_mapping = [
716
722
# (param_name, shard_name, shard_id)
717
- ("gate_up_proj " , "gate_proj" , 0 ),
718
- ("gate_up_proj " , "up_proj" , 1 ),
723
+ ("merged_linear " , "gate_proj" , 0 ),
724
+ ("merged_linear " , "up_proj" , 1 ),
719
725
]
720
726
params_dict = dict (self .named_parameters ())
721
727
loaded_params : Set [str ] = set ()
0 commit comments