Skip to content

Commit a12c16c

Browse files
committed
Fix BNB Bug
Signed-off-by: Jee Jee Li <[email protected]>
1 parent 5fadcf9 commit a12c16c

File tree

2 files changed

+26
-7
lines changed

2 files changed

+26
-7
lines changed

vllm/model_executor/model_loader/loader.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
get_tensor_model_parallel_world_size)
2929
from vllm.envs import VLLM_USE_MODELSCOPE
3030
from vllm.logger import init_logger
31+
from vllm.model_executor.layers.linear import ReplicatedLinear
3132
from vllm.model_executor.layers.quantization.base_config import (
3233
QuantizationConfig)
3334
from vllm.model_executor.model_loader.tensorizer import (
@@ -786,6 +787,7 @@ def __init__(self, load_config: LoadConfig):
786787
with open(config_file_path, "r") as f:
787788
config = json.load(f)
788789
self.target_modules = config["target_modules"]
790+
self.wo_sharded_weights_modules: List[str] = []
789791

790792
def _get_config_file(self, qlora_adapter: str) -> str:
791793
is_local = os.path.isdir(qlora_adapter)
@@ -1005,16 +1007,21 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors,
10051007
if any(target_module in weight_name for target_module in
10061008
self.target_modules) and weight_name.endswith(".weight"):
10071009
weight_name = weight_name.replace(".weight", ".qweight")
1008-
1009-
if any(module in weight_name
1010-
for module in self.column_parallel_weights_modules):
1010+
# Without sharding
1011+
if any(
1012+
weight_name.startswith(module)
1013+
for module in self.wo_sharded_weights_modules):
1014+
weight_sub_tensor = weight_tensor
1015+
# Shard by column
1016+
elif any(module in weight_name
1017+
for module in self.column_parallel_weights_modules):
10111018

10121019
total_size = weight_tensor.size(-1)
10131020
start_index = total_size // tp_size * tp_rank
10141021
end_index = total_size // tp_size * (tp_rank + 1)
10151022
weight_sub_tensor = weight_tensor[...,
10161023
start_index:end_index]
1017-
1024+
# Shard by row
10181025
else:
10191026
total_size = weight_tensor.size(0)
10201027
start_index = total_size // tp_size * tp_rank
@@ -1068,7 +1075,15 @@ def _load_weights(self, model_config: ModelConfig,
10681075
model.column_parallel_weights_modules
10691076
else:
10701077
self.column_parallel_weights_modules = []
1071-
1078+
# Some modules like `ReplicatedLinear` should not have their weights
1079+
# sharded. The reason for implementing it this way is to avoid new
1080+
# static variable in the model implementation.
1081+
# TODO: Can we reduce the static variables needed for BNB based on
1082+
# model information?
1083+
self.wo_sharded_weights_modules = [
1084+
name for name, module in model.named_modules()
1085+
if isinstance(module, (ReplicatedLinear, ))
1086+
]
10721087
self.model_type = type(model).__name__
10731088

10741089
logger.info("Loading weights with BitsAndBytes quantization. "

vllm/model_executor/models/minicpmv.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -850,7 +850,9 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
850850
".kv_proj.",
851851
]
852852
# in TP, these weights are partitioned along the column dimension (dim=-1)
853-
column_parallel_weights_modules = [".down_proj.", ".o_proj.", ".fc2."]
853+
column_parallel_weights_modules = [
854+
".down_proj.", ".o_proj.", ".self_attn.out_proj.", ".fc2."
855+
]
854856
bitsandbytes_stacked_params_mapping = {
855857
# shard_name, weight_name, index
856858
"q_proj": ("qkv_proj", 0),
@@ -1002,7 +1004,9 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
10021004
".kv_proj.",
10031005
]
10041006
# in TP, these weights are partitioned along the column dimension (dim=-1)
1005-
column_parallel_weights_modules = [".down_proj.", ".o_proj.", ".fc2."]
1007+
column_parallel_weights_modules = [
1008+
".down_proj.", ".o_proj.", ".self_attn.out_proj.", ".fc2."
1009+
]
10061010
bitsandbytes_stacked_params_mapping = {
10071011
# shard_name, weight_name, index
10081012
"q_proj": ("qkv_proj", 0),

0 commit comments

Comments
 (0)