@@ -131,16 +131,22 @@ class MiniCPMVImageEmbeddingInputs(TypedDict):
131
131
132
132
class Resampler2_5 (BaseResampler ):
133
133
134
- def __init__ (
135
- self ,
136
- num_queries : int ,
137
- embed_dim : int ,
138
- num_heads : int ,
139
- kv_dim : Optional [int ] = None ,
140
- norm_layer : Callable [[int ], nn .LayerNorm ] = DEFAULT_LN ,
141
- max_size : Tuple [int , int ] = (70 , 70 ),
142
- ) -> None :
143
- super ().__init__ (num_queries , embed_dim , num_heads , kv_dim , norm_layer )
134
+ def __init__ (self ,
135
+ num_queries : int ,
136
+ embed_dim : int ,
137
+ num_heads : int ,
138
+ kv_dim : Optional [int ] = None ,
139
+ norm_layer : Callable [[int ], nn .LayerNorm ] = DEFAULT_LN ,
140
+ max_size : Tuple [int , int ] = (70 , 70 ),
141
+ quant_config : Optional [QuantizationConfig ] = None ,
142
+ prefix : str = "" ) -> None :
143
+ super ().__init__ (num_queries ,
144
+ embed_dim ,
145
+ num_heads ,
146
+ kv_dim ,
147
+ norm_layer ,
148
+ quant_config = quant_config ,
149
+ prefix = prefix )
144
150
145
151
self .max_size = max_size
146
152
self ._set_2d_pos_cache (self .max_size )
@@ -404,7 +410,10 @@ def __init__(
404
410
self .vision_dim = (self .vpm .embed_dim if self .version == (2 , 0 ) else
405
411
self .vpm .embeddings .embed_dim )
406
412
self .embed_dim = self .config .hidden_size
407
- self .resampler = self .init_resampler (self .embed_dim , self .vision_dim )
413
+ self .resampler = self .init_resampler (self .embed_dim ,
414
+ self .vision_dim ,
415
+ quant_config = quant_config ,
416
+ prefix = "resampler" )
408
417
self .resampler .to (device = "cuda" , dtype = param_dtype )
409
418
# TODO: why is there _KEYS_TO_MODIFY_MAPPING? lm_head should be in llm
410
419
self .lm_head = ParallelLMHead (config .vocab_size ,
@@ -666,7 +675,11 @@ def init_vision_module(
666
675
) -> nn .Module :
667
676
raise NotImplementedError
668
677
669
- def init_resampler (self , embed_dim : int , vision_dim : int ) -> nn .Module :
678
+ def init_resampler (self ,
679
+ embed_dim : int ,
680
+ vision_dim : int ,
681
+ quant_config : Optional [QuantizationConfig ] = None ,
682
+ prefix : str = "" ) -> nn .Module :
670
683
raise NotImplementedError
671
684
672
685
def get_vision_embedding (
@@ -743,16 +756,21 @@ def init_vision_module(
743
756
def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
744
757
return self .model .embed_tokens (input_ids )
745
758
746
- def init_resampler (self , embed_dim : int , vision_dim : int ) -> nn .Module :
759
+ def init_resampler (self ,
760
+ embed_dim : int ,
761
+ vision_dim : int ,
762
+ quant_config : Optional [QuantizationConfig ] = None ,
763
+ prefix : str = "" ) -> nn .Module :
747
764
with set_default_torch_dtype (torch .float16 ):
748
- resampler = Resampler2 (
749
- embed_dim = embed_dim ,
750
- num_heads = embed_dim // 128 ,
751
- grid_size = int (math .sqrt (self .config .query_num )),
752
- kv_dim = vision_dim ,
753
- adaptive = False ,
754
- do_post_projection = True ,
755
- )
765
+ resampler = Resampler2 (embed_dim = embed_dim ,
766
+ num_heads = embed_dim // 128 ,
767
+ grid_size = int (
768
+ math .sqrt (self .config .query_num )),
769
+ kv_dim = vision_dim ,
770
+ adaptive = False ,
771
+ do_post_projection = True ,
772
+ quant_config = quant_config ,
773
+ prefix = prefix )
756
774
757
775
return resampler
758
776
@@ -825,9 +843,21 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
825
843
".k_proj." ,
826
844
".v_proj." ,
827
845
".o_proj." ,
846
+ # vision encoder
847
+ ".fc1." ,
848
+ ".fc2." ,
849
+ # Currently, vllm does not support BNB quantization for the `out_proj`
850
+ # of the resampler, so it's necessary to distinguish between the
851
+ # vision encoder and the resampler's out_proj. The same applies to
852
+ # MiniCPMV2_6.
853
+ ".self_attn.out_proj." , # vision encoder out_proj
854
+ # resampler
855
+ ".kv_proj." ,
828
856
]
829
857
# in TP, these weights are partitioned along the column dimension (dim=-1)
830
- column_parallel_weights_modules = [".down_proj." , ".o_proj." ]
858
+ column_parallel_weights_modules = [
859
+ ".down_proj." , ".o_proj." , ".self_attn.out_proj." , ".fc2."
860
+ ]
831
861
bitsandbytes_stacked_params_mapping = {
832
862
# shard_name, weight_name, index
833
863
"q_proj" : ("qkv_proj" , 0 ),
@@ -877,14 +907,18 @@ def init_vision_module(
877
907
model .encoder .layers = model .encoder .layers [:- 1 ]
878
908
return model
879
909
880
- def init_resampler (self , embed_dim : int , vision_dim : int ) -> nn .Module :
910
+ def init_resampler (self ,
911
+ embed_dim : int ,
912
+ vision_dim : int ,
913
+ quant_config : Optional [QuantizationConfig ] = None ,
914
+ prefix : str = "" ) -> nn .Module :
881
915
with set_default_torch_dtype (torch .float16 ):
882
- resampler = Resampler2_5 (
883
- num_queries = self . config . query_num ,
884
- embed_dim = embed_dim ,
885
- num_heads = embed_dim // 128 ,
886
- kv_dim = vision_dim ,
887
- )
916
+ resampler = Resampler2_5 (num_queries = self . config . query_num ,
917
+ embed_dim = embed_dim ,
918
+ num_heads = embed_dim // 128 ,
919
+ kv_dim = vision_dim ,
920
+ quant_config = quant_config ,
921
+ prefix = prefix )
888
922
return resampler
889
923
890
924
def get_vision_embedding (
@@ -967,9 +1001,17 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
967
1001
".k_proj." ,
968
1002
".v_proj." ,
969
1003
".o_proj." ,
1004
+ # vision encoder
1005
+ ".fc1." ,
1006
+ ".fc2." ,
1007
+ ".self_attn.out_proj." ,
1008
+ # resampler
1009
+ ".kv_proj." ,
970
1010
]
971
1011
# in TP, these weights are partitioned along the column dimension (dim=-1)
972
- column_parallel_weights_modules = [".down_proj." , ".o_proj." ]
1012
+ column_parallel_weights_modules = [
1013
+ ".down_proj." , ".o_proj." , ".self_attn.out_proj." , ".fc2."
1014
+ ]
973
1015
bitsandbytes_stacked_params_mapping = {
974
1016
# shard_name, weight_name, index
975
1017
"q_proj" : ("qkv_proj" , 0 ),
@@ -1019,15 +1061,19 @@ def init_vision_module(
1019
1061
model .encoder .layers = model .encoder .layers [:- 1 ]
1020
1062
return model
1021
1063
1022
- def init_resampler (self , embed_dim : int , vision_dim : int ) -> nn .Module :
1064
+ def init_resampler (self ,
1065
+ embed_dim : int ,
1066
+ vision_dim : int ,
1067
+ quant_config : Optional [QuantizationConfig ] = None ,
1068
+ prefix : str = "" ) -> nn .Module :
1023
1069
with set_default_torch_dtype (torch .float16 ):
1024
1070
# The resampler in 2.6 remains consistent with the one in 2.5.
1025
- resampler = Resampler2_5 (
1026
- num_queries = self . config . query_num ,
1027
- embed_dim = embed_dim ,
1028
- num_heads = embed_dim // 128 ,
1029
- kv_dim = vision_dim ,
1030
- )
1071
+ resampler = Resampler2_5 (num_queries = self . config . query_num ,
1072
+ embed_dim = embed_dim ,
1073
+ num_heads = embed_dim // 128 ,
1074
+ kv_dim = vision_dim ,
1075
+ quant_config = quant_config ,
1076
+ prefix = prefix )
1031
1077
return resampler
1032
1078
1033
1079
def get_vision_embedding (
0 commit comments