diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 9686231fb4bd..5989d877ffdd 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -11,12 +11,16 @@ from vllm.config import ModelConfig, ModelImpl from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.quantization.bitsandbytes import ( + BitsAndBytesConfig) from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models.adapters import (as_classification_model, as_embedding_model, as_reward_model) +from vllm.model_executor.models.utils import WeightsMapper logger = init_logger(__name__) @@ -153,19 +157,57 @@ def get_sub_modules(self, def configure_quant_config(quant_config: QuantizationConfig, model_class: Type[nn.Module]): - """ - Pass packed_modules_mapping by reference to quant_config so that - quant_config can properly match fused modules - Note that model attributes are passed by reference to quant_config, - enabling them to be updated by model_class.__new__ (ex. chatglm, qwen) - """ - packed_mapping = getattr(model_class, "packed_modules_mapping", None) - if packed_mapping is not None: - # pass packed_modules_mapping by reference to quant_config - quant_config.packed_modules_mapping = packed_mapping - else: - logger.warning( - "The model class %s has not defined `packed_modules_mapping`, " - "this may lead to incorrect mapping of quantized or ignored " - "modules", model_class.__name__) + def _configure_packed_modules_mapping(): + """ + Pass packed_modules_mapping by reference to quant_config so that + quant_config can properly match fused modules + + Note that model attributes are passed by reference to quant_config, + enabling them to be updated by model_class.__new__ (ex. chatglm, qwen) + """ + packed_mapping = getattr(model_class, "packed_modules_mapping", None) + if packed_mapping is not None: + # pass packed_modules_mapping by reference to quant_config + quant_config.packed_modules_mapping = packed_mapping + else: + logger.warning( + "The model class %s has not defined `packed_modules_mapping`, " + "this may lead to incorrect mapping of quantized or ignored " + "modules", model_class.__name__) + + def _configure_quant_skip_modules(): + """ + Configures the quantization skip modules for the model based on the + provided quantization configuration. + This function checks if the model class has a `hf_to_vllm_mapper` + attribute. If it does, it uses this mapper to update the list of + modules to be skip for different quantization. + configurations. + - For `BitsAndBytesConfig`, it updates the `llm_int8_skip_modules`. + - For `AWQConfig`, it updates the `modules_to_not_convert`. + + """ + + if getattr(model_class, "hf_to_vllm_mapper", None) is None: + return + hf_to_vllm_mapper: WeightsMapper = model_class.hf_to_vllm_mapper + + # BitsAndBytes + if (isinstance(quant_config, BitsAndBytesConfig) + and quant_config.llm_int8_skip_modules): + quant_config.llm_int8_skip_modules = [ + hf_to_vllm_mapper._map_name(module) + for module in quant_config.llm_int8_skip_modules + ] + # AWQ + elif (isinstance(quant_config, AWQConfig) + and quant_config.modules_to_not_convert): + quant_config.modules_to_not_convert = [ + hf_to_vllm_mapper._map_name(module) + for module in quant_config.modules_to_not_convert + ] + # TODO: Supports more quantization types. + + _configure_packed_modules_mapping() + _configure_quant_skip_modules()