Skip to content

Commit af0dfcc

Browse files
jeejeeleehissu-hyvarinen
authored andcommitted
[Bugfix] Fix MiniCPMV and Mllama BNB bug (vllm-project#9917)
Signed-off-by: Jee Jee Li <[email protected]>
1 parent c33dfa1 commit af0dfcc

File tree

4 files changed

+145
-65
lines changed

4 files changed

+145
-65
lines changed

vllm/model_executor/layers/resampler.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from torch.nn.init import trunc_normal_
4242

4343
from vllm.model_executor.layers.linear import ReplicatedLinear
44+
from vllm.model_executor.layers.quantization import QuantizationConfig
4445

4546
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
4647

@@ -154,15 +155,15 @@ class BaseResampler(nn.Module):
154155
A tensor with the shape of (grid_size**2, embed_dim)
155156
"""
156157

157-
def __init__(
158-
self,
159-
num_queries: int,
160-
embed_dim: int,
161-
num_heads: int,
162-
kv_dim: Optional[int] = None,
163-
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
164-
do_post_projection: bool = True,
165-
) -> None:
158+
def __init__(self,
159+
num_queries: int,
160+
embed_dim: int,
161+
num_heads: int,
162+
kv_dim: Optional[int] = None,
163+
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
164+
do_post_projection: bool = True,
165+
quant_config: Optional[QuantizationConfig] = None,
166+
prefix: str = "") -> None:
166167
super().__init__()
167168

168169
self.num_queries = num_queries
@@ -172,7 +173,11 @@ def __init__(
172173
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
173174
trunc_normal_(self.query, std=0.02)
174175
if kv_dim is not None and kv_dim != embed_dim:
175-
self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False)
176+
self.kv_proj = ReplicatedLinear(kv_dim,
177+
embed_dim,
178+
bias=False,
179+
quant_config=quant_config,
180+
prefix=prefix)
176181
else:
177182
# Maintain the same return value with ReplicatedLinear.forward
178183
self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa
@@ -209,22 +214,24 @@ class Resampler2(BaseResampler):
209214
present in minicpmv2.0, but not qwen-vl.
210215
"""
211216

212-
def __init__(
213-
self,
214-
grid_size: int,
215-
embed_dim: int,
216-
num_heads: int,
217-
kv_dim: Optional[int] = None,
218-
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
219-
adaptive: bool = False,
220-
do_post_projection: bool = True,
221-
) -> None:
217+
def __init__(self,
218+
grid_size: int,
219+
embed_dim: int,
220+
num_heads: int,
221+
kv_dim: Optional[int] = None,
222+
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
223+
adaptive: bool = False,
224+
do_post_projection: bool = True,
225+
quant_config: Optional[QuantizationConfig] = None,
226+
prefix: str = "") -> None:
222227
super().__init__(grid_size**2,
223228
embed_dim,
224229
num_heads,
225230
kv_dim,
226231
norm_layer,
227-
do_post_projection=do_post_projection)
232+
do_post_projection=do_post_projection,
233+
quant_config=quant_config,
234+
prefix=prefix)
228235

229236
self.adaptive = adaptive
230237
pos_embed_arr = get_2d_sincos_pos_embed(embed_dim,

vllm/model_executor/model_loader/loader.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
get_tensor_model_parallel_world_size)
2929
from vllm.envs import VLLM_USE_MODELSCOPE
3030
from vllm.logger import init_logger
31+
from vllm.model_executor.layers.linear import ReplicatedLinear
3132
from vllm.model_executor.layers.quantization.base_config import (
3233
QuantizationConfig)
3334
from vllm.model_executor.model_loader.tensorizer import (
@@ -771,6 +772,8 @@ def __init__(self, load_config: LoadConfig):
771772
with open(config_file_path, "r") as f:
772773
config = json.load(f)
773774
self.target_modules = config["target_modules"]
775+
# Save the module names without sharding.
776+
self.unsharded_weights_modules: List[str] = []
774777

775778
def _get_config_file(self, qlora_adapter: str) -> str:
776779
is_local = os.path.isdir(qlora_adapter)
@@ -990,16 +993,21 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors,
990993
if any(target_module in weight_name for target_module in
991994
self.target_modules) and weight_name.endswith(".weight"):
992995
weight_name = weight_name.replace(".weight", ".qweight")
993-
994-
if any(module in weight_name
995-
for module in self.column_parallel_weights_modules):
996+
# Without sharding
997+
if any(
998+
weight_name.startswith(module)
999+
for module in self.unsharded_weights_modules):
1000+
weight_sub_tensor = weight_tensor
1001+
# Shard by column
1002+
elif any(module in weight_name
1003+
for module in self.column_parallel_weights_modules):
9961004

9971005
total_size = weight_tensor.size(-1)
9981006
start_index = total_size // tp_size * tp_rank
9991007
end_index = total_size // tp_size * (tp_rank + 1)
10001008
weight_sub_tensor = weight_tensor[...,
10011009
start_index:end_index]
1002-
1010+
# Shard by row
10031011
else:
10041012
total_size = weight_tensor.size(0)
10051013
start_index = total_size // tp_size * tp_rank
@@ -1053,7 +1061,15 @@ def _load_weights(self, model_config: ModelConfig,
10531061
model.column_parallel_weights_modules
10541062
else:
10551063
self.column_parallel_weights_modules = []
1056-
1064+
# Some modules like `ReplicatedLinear` should not have their weights
1065+
# sharded. The reason for implementing it this way is to avoid new
1066+
# static variable in the model implementation.
1067+
# TODO: Can we reduce the static variables needed for BNB based on
1068+
# model information?
1069+
self.unsharded_weights_modules = [
1070+
name for name, module in model.named_modules()
1071+
if isinstance(module, (ReplicatedLinear, ))
1072+
]
10571073
self.model_type = type(model).__name__
10581074

10591075
logger.info("Loading weights with BitsAndBytes quantization. "
@@ -1100,7 +1116,13 @@ def _load_weights(self, model_config: ModelConfig,
11001116
for shard_name, (
11011117
weight_name, index
11021118
) in model.bitsandbytes_stacked_params_mapping.items():
1103-
if shard_name in quant_param_name:
1119+
1120+
shard_pos = quant_param_name.find(shard_name)
1121+
# Some models, such as MiniCPM V2.5/2.6, contain both
1122+
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
1123+
# from being incorrectly identified as being present in
1124+
# 'vpm.encoder.layers.0.self_attn.qkv_proj.qweight
1125+
if shard_pos > 0 and quant_param_name[shard_pos - 1] == ".":
11041126
shard_index = index
11051127
quant_param_name = quant_param_name.replace(
11061128
shard_name, weight_name)

vllm/model_executor/models/minicpmv.py

Lines changed: 83 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -131,16 +131,22 @@ class MiniCPMVImageEmbeddingInputs(TypedDict):
131131

132132
class Resampler2_5(BaseResampler):
133133

134-
def __init__(
135-
self,
136-
num_queries: int,
137-
embed_dim: int,
138-
num_heads: int,
139-
kv_dim: Optional[int] = None,
140-
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
141-
max_size: Tuple[int, int] = (70, 70),
142-
) -> None:
143-
super().__init__(num_queries, embed_dim, num_heads, kv_dim, norm_layer)
134+
def __init__(self,
135+
num_queries: int,
136+
embed_dim: int,
137+
num_heads: int,
138+
kv_dim: Optional[int] = None,
139+
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
140+
max_size: Tuple[int, int] = (70, 70),
141+
quant_config: Optional[QuantizationConfig] = None,
142+
prefix: str = "") -> None:
143+
super().__init__(num_queries,
144+
embed_dim,
145+
num_heads,
146+
kv_dim,
147+
norm_layer,
148+
quant_config=quant_config,
149+
prefix=prefix)
144150

145151
self.max_size = max_size
146152
self._set_2d_pos_cache(self.max_size)
@@ -404,7 +410,10 @@ def __init__(
404410
self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else
405411
self.vpm.embeddings.embed_dim)
406412
self.embed_dim = self.config.hidden_size
407-
self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
413+
self.resampler = self.init_resampler(self.embed_dim,
414+
self.vision_dim,
415+
quant_config=quant_config,
416+
prefix="resampler")
408417
self.resampler.to(device="cuda", dtype=param_dtype)
409418
# TODO: why is there _KEYS_TO_MODIFY_MAPPING? lm_head should be in llm
410419
self.lm_head = ParallelLMHead(config.vocab_size,
@@ -666,7 +675,11 @@ def init_vision_module(
666675
) -> nn.Module:
667676
raise NotImplementedError
668677

669-
def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
678+
def init_resampler(self,
679+
embed_dim: int,
680+
vision_dim: int,
681+
quant_config: Optional[QuantizationConfig] = None,
682+
prefix: str = "") -> nn.Module:
670683
raise NotImplementedError
671684

672685
def get_vision_embedding(
@@ -743,16 +756,21 @@ def init_vision_module(
743756
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
744757
return self.model.embed_tokens(input_ids)
745758

746-
def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
759+
def init_resampler(self,
760+
embed_dim: int,
761+
vision_dim: int,
762+
quant_config: Optional[QuantizationConfig] = None,
763+
prefix: str = "") -> nn.Module:
747764
with set_default_torch_dtype(torch.float16):
748-
resampler = Resampler2(
749-
embed_dim=embed_dim,
750-
num_heads=embed_dim // 128,
751-
grid_size=int(math.sqrt(self.config.query_num)),
752-
kv_dim=vision_dim,
753-
adaptive=False,
754-
do_post_projection=True,
755-
)
765+
resampler = Resampler2(embed_dim=embed_dim,
766+
num_heads=embed_dim // 128,
767+
grid_size=int(
768+
math.sqrt(self.config.query_num)),
769+
kv_dim=vision_dim,
770+
adaptive=False,
771+
do_post_projection=True,
772+
quant_config=quant_config,
773+
prefix=prefix)
756774

757775
return resampler
758776

@@ -825,9 +843,21 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
825843
".k_proj.",
826844
".v_proj.",
827845
".o_proj.",
846+
# vision encoder
847+
".fc1.",
848+
".fc2.",
849+
# Currently, vllm does not support BNB quantization for the `out_proj`
850+
# of the resampler, so it's necessary to distinguish between the
851+
# vision encoder and the resampler's out_proj. The same applies to
852+
# MiniCPMV2_6.
853+
".self_attn.out_proj.", # vision encoder out_proj
854+
# resampler
855+
".kv_proj.",
828856
]
829857
# in TP, these weights are partitioned along the column dimension (dim=-1)
830-
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
858+
column_parallel_weights_modules = [
859+
".down_proj.", ".o_proj.", ".self_attn.out_proj.", ".fc2."
860+
]
831861
bitsandbytes_stacked_params_mapping = {
832862
# shard_name, weight_name, index
833863
"q_proj": ("qkv_proj", 0),
@@ -877,14 +907,18 @@ def init_vision_module(
877907
model.encoder.layers = model.encoder.layers[:-1]
878908
return model
879909

880-
def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
910+
def init_resampler(self,
911+
embed_dim: int,
912+
vision_dim: int,
913+
quant_config: Optional[QuantizationConfig] = None,
914+
prefix: str = "") -> nn.Module:
881915
with set_default_torch_dtype(torch.float16):
882-
resampler = Resampler2_5(
883-
num_queries=self.config.query_num,
884-
embed_dim=embed_dim,
885-
num_heads=embed_dim // 128,
886-
kv_dim=vision_dim,
887-
)
916+
resampler = Resampler2_5(num_queries=self.config.query_num,
917+
embed_dim=embed_dim,
918+
num_heads=embed_dim // 128,
919+
kv_dim=vision_dim,
920+
quant_config=quant_config,
921+
prefix=prefix)
888922
return resampler
889923

890924
def get_vision_embedding(
@@ -967,9 +1001,17 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
9671001
".k_proj.",
9681002
".v_proj.",
9691003
".o_proj.",
1004+
# vision encoder
1005+
".fc1.",
1006+
".fc2.",
1007+
".self_attn.out_proj.",
1008+
# resampler
1009+
".kv_proj.",
9701010
]
9711011
# in TP, these weights are partitioned along the column dimension (dim=-1)
972-
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
1012+
column_parallel_weights_modules = [
1013+
".down_proj.", ".o_proj.", ".self_attn.out_proj.", ".fc2."
1014+
]
9731015
bitsandbytes_stacked_params_mapping = {
9741016
# shard_name, weight_name, index
9751017
"q_proj": ("qkv_proj", 0),
@@ -1019,15 +1061,19 @@ def init_vision_module(
10191061
model.encoder.layers = model.encoder.layers[:-1]
10201062
return model
10211063

1022-
def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
1064+
def init_resampler(self,
1065+
embed_dim: int,
1066+
vision_dim: int,
1067+
quant_config: Optional[QuantizationConfig] = None,
1068+
prefix: str = "") -> nn.Module:
10231069
with set_default_torch_dtype(torch.float16):
10241070
# The resampler in 2.6 remains consistent with the one in 2.5.
1025-
resampler = Resampler2_5(
1026-
num_queries=self.config.query_num,
1027-
embed_dim=embed_dim,
1028-
num_heads=embed_dim // 128,
1029-
kv_dim=vision_dim,
1030-
)
1071+
resampler = Resampler2_5(num_queries=self.config.query_num,
1072+
embed_dim=embed_dim,
1073+
num_heads=embed_dim // 128,
1074+
kv_dim=vision_dim,
1075+
quant_config=quant_config,
1076+
prefix=prefix)
10311077
return resampler
10321078

10331079
def get_vision_embedding(

vllm/model_executor/models/mllama.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1057,9 +1057,14 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
10571057
".k_proj.",
10581058
".v_proj.",
10591059
".o_proj.",
1060+
".fc1.",
1061+
".fc2.",
1062+
# The `multi_modal_projector` is at the top level of the model,
1063+
# so we can't add a dot in front of it.
1064+
"multi_modal_projector."
10601065
]
10611066
# in TP, these weights are partitioned along the column dimension (dim=-1)
1062-
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
1067+
column_parallel_weights_modules = [".down_proj.", ".o_proj.", ".fc2."]
10631068
bitsandbytes_stacked_params_mapping = {
10641069
# shard_name, weight_name, index
10651070
"q_proj": ("qkv_proj", 0),

0 commit comments

Comments
 (0)