Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions vllm/model_executor/layers/quantization/bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,9 @@ def create_qweight_for_4bit():
qweight = create_qweight_for_8bit()
else:
qweight = create_qweight_for_4bit()

layer.register_parameter("qweight", qweight)
# Enable parameters to have the same name as in the BNB
# checkpoint format.
layer.register_parameter("weight", qweight)
set_weight_attrs(qweight, extra_weight_attrs)

def apply(self,
Expand Down Expand Up @@ -234,7 +235,7 @@ def _apply_8bit_weight(
reshape_after_matmul = True
bf_x = x.to(torch.bfloat16)

qweight = layer.qweight
qweight = layer.weight
offsets = qweight.bnb_shard_offsets
quant_states = qweight.bnb_quant_state
matmul_states = qweight.matmul_state
Expand Down Expand Up @@ -313,7 +314,7 @@ def _apply_4bit_weight(
reshape_after_matmul = True
bf_x = x.to(torch.bfloat16)

qweight = layer.qweight
qweight = layer.weight
quant_states = qweight.bnb_quant_state
offsets = qweight.bnb_shard_offsets

Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/resampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def __init__(self,
embed_dim,
bias=False,
quant_config=quant_config,
prefix=prefix)
prefix=f"{prefix}.kv_proj")
else:
# Maintain the same return value with ReplicatedLinear.forward
self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa
Expand Down
14 changes: 5 additions & 9 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,7 +917,7 @@ def _quantized_8bit_generator(self, hf_weights_files, use_safetensors,
if not weight_name.lower().endswith(".scb"):
continue

weight_key = weight_name.lower().replace(".scb", ".qweight")
weight_key = weight_name.lower().replace(".scb", ".weight")
quant_state_dict[weight_key] = weight_tensor

for weight_name, weight_tensor in self._hf_weight_iter(
Expand All @@ -926,11 +926,9 @@ def _quantized_8bit_generator(self, hf_weights_files, use_safetensors,
if self._is_8bit_weight_name(weight_name):
continue

qweight_name = weight_name.replace(".weight", ".qweight")

if qweight_name in quant_state_dict:
if weight_name in quant_state_dict:
set_weight_attrs(weight_tensor, {"load_in_8bit": True})
yield qweight_name, weight_tensor
yield weight_name, weight_tensor
else:
yield weight_name, weight_tensor

Expand Down Expand Up @@ -975,9 +973,8 @@ def _parse_quant_state(param_name: str,
(f"{weight_name}.quant_state.bitsandbytes__fp4" \
in temp_state_dict):
quant_state = _parse_quant_state(weight_name, temp_state_dict)
weight_name = weight_name.replace(".weight", ".qweight")
quant_state_dict[weight_name] = quant_state
yield weight_name.replace(".weight", ".qweight"), weight_tensor
yield weight_name, weight_tensor
else:
yield weight_name, weight_tensor

Expand All @@ -992,7 +989,6 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors,

if any(target_module in weight_name for target_module in
self.target_modules) and weight_name.endswith(".weight"):
weight_name = weight_name.replace(".weight", ".qweight")
# Without sharding
if any(
weight_name.startswith(module)
Expand Down Expand Up @@ -1121,7 +1117,7 @@ def _load_weights(self, model_config: ModelConfig,
# Some models, such as MiniCPM V2.5/2.6, contain both
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
# from being incorrectly identified as being present in
# 'vpm.encoder.layers.0.self_attn.qkv_proj.qweight
# 'vpm.encoder.layers.0.self_attn.qkv_proj.weight
if shard_pos > 0 and quant_param_name[shard_pos - 1] == ".":
shard_index = index
quant_param_name = quant_param_name.replace(
Expand Down