Skip to content

Commit 2400a79

Browse files
jeejeeleeJC1DA
authored andcommitted
[Misc] Modify BNB parameter name (vllm-project#9997)
Signed-off-by: Jee Jee Li <[email protected]> Signed-off-by: Loc Huynh <[email protected]>
1 parent 4e694f5 commit 2400a79

File tree

3 files changed

+11
-14
lines changed

3 files changed

+11
-14
lines changed

vllm/model_executor/layers/quantization/bitsandbytes.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,9 @@ def create_qweight_for_4bit():
203203
qweight = create_qweight_for_8bit()
204204
else:
205205
qweight = create_qweight_for_4bit()
206-
207-
layer.register_parameter("qweight", qweight)
206+
# Enable parameters to have the same name as in the BNB
207+
# checkpoint format.
208+
layer.register_parameter("weight", qweight)
208209
set_weight_attrs(qweight, extra_weight_attrs)
209210

210211
def apply(self,
@@ -234,7 +235,7 @@ def _apply_8bit_weight(
234235
reshape_after_matmul = True
235236
bf_x = x.to(torch.bfloat16)
236237

237-
qweight = layer.qweight
238+
qweight = layer.weight
238239
offsets = qweight.bnb_shard_offsets
239240
quant_states = qweight.bnb_quant_state
240241
matmul_states = qweight.matmul_state
@@ -313,7 +314,7 @@ def _apply_4bit_weight(
313314
reshape_after_matmul = True
314315
bf_x = x.to(torch.bfloat16)
315316

316-
qweight = layer.qweight
317+
qweight = layer.weight
317318
quant_states = qweight.bnb_quant_state
318319
offsets = qweight.bnb_shard_offsets
319320

vllm/model_executor/layers/resampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def __init__(self,
177177
embed_dim,
178178
bias=False,
179179
quant_config=quant_config,
180-
prefix=prefix)
180+
prefix=f"{prefix}.kv_proj")
181181
else:
182182
# Maintain the same return value with ReplicatedLinear.forward
183183
self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa

vllm/model_executor/model_loader/loader.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -892,7 +892,7 @@ def _quantized_8bit_generator(self, hf_weights_files, use_safetensors,
892892
if not weight_name.lower().endswith(".scb"):
893893
continue
894894

895-
weight_key = weight_name.lower().replace(".scb", ".qweight")
895+
weight_key = weight_name.lower().replace(".scb", ".weight")
896896
quant_state_dict[weight_key] = weight_tensor
897897

898898
for weight_name, weight_tensor in self._hf_weight_iter(
@@ -901,11 +901,9 @@ def _quantized_8bit_generator(self, hf_weights_files, use_safetensors,
901901
if self._is_8bit_weight_name(weight_name):
902902
continue
903903

904-
qweight_name = weight_name.replace(".weight", ".qweight")
905-
906-
if qweight_name in quant_state_dict:
904+
if weight_name in quant_state_dict:
907905
set_weight_attrs(weight_tensor, {"load_in_8bit": True})
908-
yield qweight_name, weight_tensor
906+
yield weight_name, weight_tensor
909907
else:
910908
yield weight_name, weight_tensor
911909

@@ -950,9 +948,8 @@ def _parse_quant_state(param_name: str,
950948
(f"{weight_name}.quant_state.bitsandbytes__fp4" \
951949
in temp_state_dict):
952950
quant_state = _parse_quant_state(weight_name, temp_state_dict)
953-
weight_name = weight_name.replace(".weight", ".qweight")
954951
quant_state_dict[weight_name] = quant_state
955-
yield weight_name.replace(".weight", ".qweight"), weight_tensor
952+
yield weight_name, weight_tensor
956953
else:
957954
yield weight_name, weight_tensor
958955

@@ -967,7 +964,6 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors,
967964

968965
if any(target_module in weight_name for target_module in
969966
self.target_modules) and weight_name.endswith(".weight"):
970-
weight_name = weight_name.replace(".weight", ".qweight")
971967
# Without sharding
972968
if any(
973969
weight_name.startswith(module)
@@ -1093,7 +1089,7 @@ def _load_weights(self, model_config: ModelConfig,
10931089
# Some models, such as MiniCPM V2.5/2.6, contain both
10941090
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
10951091
# from being incorrectly identified as being present in
1096-
# 'vpm.encoder.layers.0.self_attn.qkv_proj.qweight
1092+
# 'vpm.encoder.layers.0.self_attn.qkv_proj.weight
10971093
if shard_pos > 0 and quant_param_name[shard_pos - 1] == ".":
10981094
shard_index = index
10991095
quant_param_name = quant_param_name.replace(

0 commit comments

Comments
 (0)