Skip to content

Commit a3a3ee4

Browse files
authored
[Misc] Merge bitsandbytes_stacked_params_mapping and packed_modules_mapping (#11924)
Signed-off-by: Jee Jee Li <[email protected]>
1 parent 87054a5 commit a3a3ee4

24 files changed

+49
-200
lines changed

vllm/model_executor/model_loader/loader.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@
3939
from vllm.model_executor.model_loader.tensorizer import (
4040
TensorizerConfig, is_vllm_tensorized, load_with_tensorizer,
4141
serialize_vllm_model, tensorizer_weights_iterator)
42-
from vllm.model_executor.model_loader.utils import (get_model_architecture,
42+
from vllm.model_executor.model_loader.utils import (ParamMapping,
43+
get_model_architecture,
4344
set_default_torch_dtype)
4445
from vllm.model_executor.model_loader.weight_utils import (
4546
download_safetensors_index_file_from_hf, download_weights_from_hf,
@@ -983,21 +984,11 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors,
983984

984985
def _get_bnb_target_modules(self, model: nn.Module) -> None:
985986

986-
# TODO: Maybe we can replace bitsandbytes_stacked_params_mapping with
987-
# packed_modules_mapping.
988-
inverse_stacked_mapping: Dict[str, List[str]] = {}
989-
for orig, (
990-
packed,
991-
idx,
992-
) in model.bitsandbytes_stacked_params_mapping.items():
993-
if packed not in inverse_stacked_mapping:
994-
inverse_stacked_mapping[packed] = []
995-
inverse_stacked_mapping[packed].insert(idx, orig)
996-
997987
for name, module in model.named_modules():
998988
if isinstance(module, (LinearBase, )):
999989
last_name = name.split(".")[-1]
1000-
if sub_modules := inverse_stacked_mapping.get(last_name, []):
990+
if sub_modules := self.modules_mapping.packed_mapping.get(
991+
last_name, []):
1001992
# Map vllm's names to transformers's names.
1002993
for sub_name in sub_modules:
1003994
self.target_modules.append(
@@ -1018,15 +1009,19 @@ def _load_weights(self, model_config: ModelConfig,
10181009
"The required method 'load_weights' is not defined in class"
10191010
f" {type(model).__name__}.")
10201011

1021-
if not hasattr(model, "bitsandbytes_stacked_params_mapping"):
1012+
if not hasattr(model, "packed_modules_mapping"):
10221013
raise AttributeError(
10231014
f"Model {type(model).__name__} does not support BitsAndBytes "
1024-
"quantization yet.")
1015+
"quantization yet. No 'packed_modules_mapping' found.")
1016+
1017+
self.modules_mapping = ParamMapping(
1018+
copy.deepcopy(model.packed_modules_mapping))
10251019

10261020
# For some models like Molmo, we need to use hf_to_vllm_mapper
10271021
# to ensure correct loading of weights.
10281022
if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
10291023
self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name)
1024+
10301025
# Modules whose weights might have fused on disk
10311026
# we need their output_sizes to make shard in flight correctly with TP
10321027
self.maybe_fused_weights_modules: Dict[str, List[int]] = {}
@@ -1109,7 +1104,7 @@ def _load_weights(self, model_config: ModelConfig,
11091104
for shard_name, (
11101105
weight_name,
11111106
index,
1112-
) in model.bitsandbytes_stacked_params_mapping.items():
1107+
) in self.modules_mapping.inverse_packed_mapping.items():
11131108
shard_pos = quant_param_name.find(shard_name)
11141109
# Some models, such as MiniCPM V2.5/2.6, contain both
11151110
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'

vllm/model_executor/model_loader/utils.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Utilities for selecting and loading models."""
22
import contextlib
3-
from typing import Tuple, Type
3+
from dataclasses import dataclass, field
4+
from typing import Dict, List, Tuple, Type
45

56
import torch
67
from torch import nn
@@ -49,3 +50,26 @@ def get_model_architecture(
4950

5051
def get_architecture_class_name(model_config: ModelConfig) -> str:
5152
return get_model_architecture(model_config)[1]
53+
54+
55+
@dataclass
56+
class ParamMapping:
57+
"""
58+
A class to handle parameter mapping for model weight loading.
59+
It creates a bidirectional mapping between packed parameters and their
60+
constituent parts.
61+
"""
62+
packed_mapping: Dict[str, List[str]]
63+
inverse_packed_mapping: Dict[str, Tuple[str,
64+
int]] = field(default_factory=dict)
65+
66+
def __post_init__(self):
67+
for packed_name, sub_params in self.packed_mapping.items():
68+
# Skip self-contained cases (e.g., {"W_pack": ["W_pack"]})
69+
if len(sub_params) == 1 and sub_params[0] == packed_name:
70+
continue
71+
for index, param_name in enumerate(sub_params):
72+
self.inverse_packed_mapping[param_name] = (
73+
packed_name,
74+
index,
75+
)

vllm/model_executor/models/baichuan.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -350,13 +350,6 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
350350
embedding_modules = {}
351351
embedding_padding_modules = []
352352

353-
# BitandBytes specific attributes
354-
bitsandbytes_stacked_params_mapping = {
355-
# shard_name, weight_name, index
356-
"gate_proj": ("gate_up_proj", 0),
357-
"up_proj": ("gate_up_proj", 1),
358-
}
359-
360353
def __init__(
361354
self,
362355
*,

vllm/model_executor/models/exaone.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -430,14 +430,6 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
430430
"lm_head": "output_embeddings",
431431
}
432432
embedding_padding_modules = ["lm_head"]
433-
bitsandbytes_stacked_params_mapping = {
434-
# shard_name, weight_name, index
435-
"q_proj": ("qkv_proj", 0),
436-
"k_proj": ("qkv_proj", 1),
437-
"v_proj": ("qkv_proj", 2),
438-
"c_fc_0": ("gate_up_proj", 0),
439-
"c_fc_1": ("gate_up_proj", 1),
440-
}
441433

442434
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
443435
super().__init__()

vllm/model_executor/models/falcon.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -409,9 +409,9 @@ def forward(
409409

410410

411411
class FalconForCausalLM(nn.Module, SupportsPP):
412-
413-
# BitandBytes specific attributes
414-
bitsandbytes_stacked_params_mapping = {}
412+
packed_modules_mapping = {
413+
"query_key_value": ["query_key_value"],
414+
}
415415

416416
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
417417
super().__init__()

vllm/model_executor/models/gemma.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -349,15 +349,6 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
349349
"gate_up_proj",
350350
"down_proj",
351351
]
352-
# BitandBytes specific attributes
353-
bitsandbytes_stacked_params_mapping = {
354-
# shard_name, weight_name, index
355-
"q_proj": ("qkv_proj", 0),
356-
"k_proj": ("qkv_proj", 1),
357-
"v_proj": ("qkv_proj", 2),
358-
"gate_proj": ("gate_up_proj", 0),
359-
"up_proj": ("gate_up_proj", 1),
360-
}
361352

362353
# Gemma does not apply LoRA to the embedding layer.
363354
embedding_modules = {}

vllm/model_executor/models/gemma2.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -399,16 +399,6 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
399399
embedding_modules = {}
400400
embedding_padding_modules = []
401401

402-
# BitandBytes specific attributes
403-
bitsandbytes_stacked_params_mapping = {
404-
# shard_name, weight_name, index
405-
"q_proj": ("qkv_proj", 0),
406-
"k_proj": ("qkv_proj", 1),
407-
"v_proj": ("qkv_proj", 2),
408-
"gate_proj": ("gate_up_proj", 0),
409-
"up_proj": ("gate_up_proj", 1),
410-
}
411-
412402
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
413403
config = vllm_config.model_config.hf_config
414404
quant_config = vllm_config.quant_config

vllm/model_executor/models/granite.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -362,14 +362,6 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
362362
"lm_head": "output_embeddings",
363363
}
364364
embedding_padding_modules = ["lm_head"]
365-
bitsandbytes_stacked_params_mapping = {
366-
# shard_name, weight_name, index
367-
"q_proj": ("qkv_proj", 0),
368-
"k_proj": ("qkv_proj", 1),
369-
"v_proj": ("qkv_proj", 2),
370-
"gate_proj": ("gate_up_proj", 0),
371-
"up_proj": ("gate_up_proj", 1),
372-
}
373365

374366
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
375367
super().__init__()

vllm/model_executor/models/idefics3.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -662,16 +662,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
662662
"down_proj",
663663
]
664664

665-
# BitandBytes specific attributes
666-
bitsandbytes_stacked_params_mapping = {
667-
# shard_name, weight_name, index
668-
"q_proj": ("qkv_proj", 0),
669-
"k_proj": ("qkv_proj", 1),
670-
"v_proj": ("qkv_proj", 2),
671-
"gate_proj": ("gate_up_proj", 0),
672-
"up_proj": ("gate_up_proj", 1),
673-
}
674-
675665
embedding_modules = {}
676666
embedding_padding_modules = []
677667

vllm/model_executor/models/llama.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -478,16 +478,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
478478
}
479479
embedding_padding_modules = ["lm_head"]
480480

481-
# BitandBytes specific attributes
482-
bitsandbytes_stacked_params_mapping = {
483-
# shard_name, weight_name, index
484-
"q_proj": ("qkv_proj", 0),
485-
"k_proj": ("qkv_proj", 1),
486-
"v_proj": ("qkv_proj", 2),
487-
"gate_proj": ("gate_up_proj", 0),
488-
"up_proj": ("gate_up_proj", 1),
489-
}
490-
491481
# Mistral/Llama models can also be loaded with --load-format mistral
492482
# from consolidated.safetensors checkpoints
493483
mistral_mapping = {

0 commit comments

Comments
 (0)