39
39
from vllm .model_executor .model_loader .tensorizer import (
40
40
TensorizerConfig , is_vllm_tensorized , load_with_tensorizer ,
41
41
serialize_vllm_model , tensorizer_weights_iterator )
42
- from vllm .model_executor .model_loader .utils import (get_model_architecture ,
42
+ from vllm .model_executor .model_loader .utils import (ParamMapping ,
43
+ get_model_architecture ,
43
44
set_default_torch_dtype )
44
45
from vllm .model_executor .model_loader .weight_utils import (
45
46
download_safetensors_index_file_from_hf , download_weights_from_hf ,
@@ -983,21 +984,11 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors,
983
984
984
985
def _get_bnb_target_modules (self , model : nn .Module ) -> None :
985
986
986
- # TODO: Maybe we can replace bitsandbytes_stacked_params_mapping with
987
- # packed_modules_mapping.
988
- inverse_stacked_mapping : Dict [str , List [str ]] = {}
989
- for orig , (
990
- packed ,
991
- idx ,
992
- ) in model .bitsandbytes_stacked_params_mapping .items ():
993
- if packed not in inverse_stacked_mapping :
994
- inverse_stacked_mapping [packed ] = []
995
- inverse_stacked_mapping [packed ].insert (idx , orig )
996
-
997
987
for name , module in model .named_modules ():
998
988
if isinstance (module , (LinearBase , )):
999
989
last_name = name .split ("." )[- 1 ]
1000
- if sub_modules := inverse_stacked_mapping .get (last_name , []):
990
+ if sub_modules := self .modules_mapping .packed_mapping .get (
991
+ last_name , []):
1001
992
# Map vllm's names to transformers's names.
1002
993
for sub_name in sub_modules :
1003
994
self .target_modules .append (
@@ -1018,15 +1009,19 @@ def _load_weights(self, model_config: ModelConfig,
1018
1009
"The required method 'load_weights' is not defined in class"
1019
1010
f" { type (model ).__name__ } ." )
1020
1011
1021
- if not hasattr (model , "bitsandbytes_stacked_params_mapping " ):
1012
+ if not hasattr (model , "packed_modules_mapping " ):
1022
1013
raise AttributeError (
1023
1014
f"Model { type (model ).__name__ } does not support BitsAndBytes "
1024
- "quantization yet." )
1015
+ "quantization yet. No 'packed_modules_mapping' found." )
1016
+
1017
+ self .modules_mapping = ParamMapping (
1018
+ copy .deepcopy (model .packed_modules_mapping ))
1025
1019
1026
1020
# For some models like Molmo, we need to use hf_to_vllm_mapper
1027
1021
# to ensure correct loading of weights.
1028
1022
if hf_to_vllm_mapper := getattr (model , "hf_to_vllm_mapper" , None ):
1029
1023
self .weight_mapper = lambda name : hf_to_vllm_mapper ._map_name (name )
1024
+
1030
1025
# Modules whose weights might have fused on disk
1031
1026
# we need their output_sizes to make shard in flight correctly with TP
1032
1027
self .maybe_fused_weights_modules : Dict [str , List [int ]] = {}
@@ -1109,7 +1104,7 @@ def _load_weights(self, model_config: ModelConfig,
1109
1104
for shard_name , (
1110
1105
weight_name ,
1111
1106
index ,
1112
- ) in model . bitsandbytes_stacked_params_mapping .items ():
1107
+ ) in self . modules_mapping . inverse_packed_mapping .items ():
1113
1108
shard_pos = quant_param_name .find (shard_name )
1114
1109
# Some models, such as MiniCPM V2.5/2.6, contain both
1115
1110
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
0 commit comments