From bc940168d9832c77aef14b885b12d8a9f9255faa Mon Sep 17 00:00:00 2001 From: mgoin Date: Thu, 30 Jan 2025 20:07:00 +0000 Subject: [PATCH 01/19] Fix quantization for chatglm Signed-off-by: mgoin --- vllm/model_executor/models/chatglm.py | 10 +++++-- .../models/glm4_vision_encoder.py | 29 ++++++++++++++----- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index d5f9b4d19e5c..b1f8de45fad8 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -263,12 +263,14 @@ def __init__( self.total_num_kv_heads, bias=config.add_bias_linear or config.add_qkv_bias, quant_config=quant_config, + prefix=f"{prefix}.query_key_value", ) self.dense = RowParallelLinear( self.total_num_heads * self.head_dim, config.hidden_size, bias=config.add_bias_linear, quant_config=quant_config, + prefix=f"{prefix}.dense", ) # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141 @@ -325,6 +327,7 @@ def __init__( self, config: ChatGLMConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() @@ -336,6 +339,7 @@ def __init__( [config.ffn_hidden_size] * 2, bias=config.add_bias_linear, quant_config=quant_config, + prefix=f"{prefix}.dense_h_to_4h", ) self.activation_func = SiluAndMul() @@ -346,6 +350,7 @@ def __init__( config.hidden_size, bias=config.add_bias_linear, quant_config=quant_config, + prefix=f"{prefix}.dense_4h_to_h", ) def forward(self, hidden_states): @@ -394,7 +399,7 @@ def __init__( config.hidden_size, eps=config.layernorm_epsilon) # MLP - self.mlp = GLMMLP(config, quant_config) + self.mlp = GLMMLP(config, quant_config, prefix=f"{prefix}.mlp") def forward( self, @@ -505,7 +510,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.embedding = VocabParallelEmbedding(config.padded_vocab_size, config.hidden_size, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.embedding") self.num_layers = config.num_layers self.multi_query_group_num = config.multi_query_group_num diff --git a/vllm/model_executor/models/glm4_vision_encoder.py b/vllm/model_executor/models/glm4_vision_encoder.py index 51922e6f2d03..7dc4531d0fcc 100644 --- a/vllm/model_executor/models/glm4_vision_encoder.py +++ b/vllm/model_executor/models/glm4_vision_encoder.py @@ -72,11 +72,13 @@ def __init__( self.head_dim, config.num_heads, quant_config=quant_config, + prefix=f"{prefix}.query_key_value", ) self.dense = RowParallelLinear( config.hidden_size, config.hidden_size, quant_config=quant_config, + prefix=f"{prefix}.dense", ) self.attn = MultiHeadAttention(self.num_heads_per_rank, self.head_dim, @@ -99,6 +101,7 @@ def __init__( self, config, quant_config: Optional[QuantizationConfig] = None, + prefix: str = '', ): super().__init__() self.config = config @@ -107,11 +110,13 @@ def __init__( config.hidden_size, config.intermediate_size, quant_config=quant_config, + prefix=f"{prefix}.fc1", ) self.fc2 = RowParallelLinear( config.intermediate_size, config.hidden_size, quant_config=quant_config, + prefix=f"{prefix}.fc2", ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -135,7 +140,9 @@ def __init__( self.attention = Attention(config, quant_config=quant_config, prefix=f"{prefix}.attention") - self.mlp = MLP(config, quant_config=quant_config) + self.mlp = MLP(config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") self.post_attention_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -179,6 +186,7 @@ def __init__( config, in_features, quant_config: Optional[QuantizationConfig] = None, + prefix: str = '', ): """ The original implementation is the same as: @@ -220,7 +228,8 @@ def __init__( self.linear_proj = ReplicatedLinear(in_features, config.hidden_size, bias=False, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.linear_proj") self.norm1 = nn.LayerNorm(config.hidden_size) self.act1 = nn.GELU() self.act2 = SiluAndMul() @@ -228,12 +237,15 @@ def __init__( self.merged_proj = MergedColumnParallelLinear( config.hidden_size, [config.ffn_hidden_size] * 2, bias=False, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.merged_proj") - self.dense_4h_to_h = RowParallelLinear(config.ffn_hidden_size, - config.hidden_size, - bias=False, - quant_config=quant_config) + self.dense_4h_to_h = RowParallelLinear( + config.ffn_hidden_size, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.dense_4h_to_h") def forward(self, x): x, _ = self.linear_proj(x) @@ -260,7 +272,8 @@ def __init__( prefix=f"{prefix}.transformer") self.linear_proj = GLU(config, in_features=config.hidden_size, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.linear_proj") self.conv = nn.Conv2d(in_channels=vision_config.hidden_size, out_channels=config.hidden_size, kernel_size=2, From 10908f683ee738f334367ef3c23f17f9a64b9f88 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 30 Jan 2025 15:32:24 -0500 Subject: [PATCH 02/19] additional prefix fixes Signed-off-by: Kyle Sayers --- vllm/model_executor/models/glm4_vision_encoder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/glm4_vision_encoder.py b/vllm/model_executor/models/glm4_vision_encoder.py index 7dc4531d0fcc..7e53c98dafe2 100644 --- a/vllm/model_executor/models/glm4_vision_encoder.py +++ b/vllm/model_executor/models/glm4_vision_encoder.py @@ -169,7 +169,7 @@ def __init__( self.layers = nn.ModuleList([ TransformerLayer(config, quant_config=quant_config, - prefix=f"{prefix}.layer.{layer_idx}") + prefix=f"{prefix}.layers.{layer_idx}") for layer_idx in range(config.num_hidden_layers) ]) @@ -238,7 +238,7 @@ def __init__( config.hidden_size, [config.ffn_hidden_size] * 2, bias=False, quant_config=quant_config, - prefix=f"{prefix}.merged_proj") + prefix=f"{prefix}.gate_proj") self.dense_4h_to_h = RowParallelLinear( config.ffn_hidden_size, From 1ef97f004dc1c9fb9da60fc5a38b6e845a1a495b Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 30 Jan 2025 16:00:45 -0500 Subject: [PATCH 03/19] use merged_proj Signed-off-by: Kyle Sayers --- vllm/model_executor/models/glm4_vision_encoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/glm4_vision_encoder.py b/vllm/model_executor/models/glm4_vision_encoder.py index 7e53c98dafe2..d99701601011 100644 --- a/vllm/model_executor/models/glm4_vision_encoder.py +++ b/vllm/model_executor/models/glm4_vision_encoder.py @@ -238,7 +238,7 @@ def __init__( config.hidden_size, [config.ffn_hidden_size] * 2, bias=False, quant_config=quant_config, - prefix=f"{prefix}.gate_proj") + prefix=f"{prefix}.merged_proj") self.dense_4h_to_h = RowParallelLinear( config.ffn_hidden_size, From 6446c94181e2509dac215a9198c56b23033a17aa Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 30 Jan 2025 17:34:19 -0500 Subject: [PATCH 04/19] remove reliance on FUSED_LAYER_NAME_MAPPING in favor of packed_modules_mapping Signed-off-by: Kyle Sayers --- .../layers/quantization/base_config.py | 3 ++- .../compressed_tensors/compressed_tensors.py | 5 ++++- .../quantization/compressed_tensors/utils.py | 17 ++++++++------- .../layers/quantization/quark/quark.py | 11 +++++----- .../layers/quantization/quark/utils.py | 17 ++++++++------- .../layers/quantization/utils/quant_utils.py | 21 ++++++++----------- vllm/model_executor/model_loader/loader.py | 5 +++++ 7 files changed, 44 insertions(+), 35 deletions(-) diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index 2fb2642dd515..a527d7da1f52 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -1,6 +1,6 @@ import inspect from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Type +from typing import Any, Dict, List, Mapping, Optional, Type import torch from torch import nn @@ -57,6 +57,7 @@ def method_has_implemented_embedding( class QuantizationConfig(ABC): """Base class for quantization configs.""" + packed_modules_mapping: Mapping[str, List[str]] = dict() @abstractmethod def get_name(self) -> str: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index dd2dd02eaf72..eebd0eb85378 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -78,7 +78,10 @@ def get_quant_method( # Check if the layer is skipped for quantization. # TODO (@robertgshaw2): support module names - if should_ignore_layer(prefix, ignore=self.ignore): + if should_ignore_layer( + prefix, + ignore=self.ignore, + packed_modules_mapping=self.packed_modules_mapping): return UnquantizedLinearMethod() if isinstance(layer, LinearBase): scheme = self.get_scheme(layer=layer, layer_name=prefix) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py index 8fcbda377428..28bf716f6b00 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py @@ -1,12 +1,10 @@ import re -from typing import Iterable, Optional +from types import MappingProxyType +from typing import Iterable, List, Mapping, Optional from compressed_tensors import CompressionFormat from torch.nn import Module -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - FUSED_LAYER_NAME_MAPPING) - def is_activation_quantization_format(format: str) -> bool: _ACTIVATION_QUANTIZATION_FORMATS = [ @@ -17,8 +15,11 @@ def is_activation_quantization_format(format: str) -> bool: return format in _ACTIVATION_QUANTIZATION_FORMATS -def should_ignore_layer(layer_name: Optional[str], - ignore: Iterable[str]) -> bool: +def should_ignore_layer( + layer_name: Optional[str], + ignore: Iterable[str] = tuple(), + packed_modules_mapping: Mapping[str, List[str]] = MappingProxyType({}) +) -> bool: if layer_name is None: return False @@ -30,8 +31,8 @@ def should_ignore_layer(layer_name: Optional[str], # in the safetensors checkpoint. So, we convert the name # from the fused version to unfused + check to make sure that # each shard of the fused layer has the same scheme. - if proj_name in FUSED_LAYER_NAME_MAPPING and layer_name not in ignore: - shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name] + if proj_name in packed_modules_mapping and layer_name not in ignore: + shard_proj_names = packed_modules_mapping[proj_name] # Convert fused_name --> [shard_names] shard_names = [ diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index fc214255eca7..b67b5c2733a8 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -16,8 +16,6 @@ QuarkScheme, QuarkW8A8Fp8, QuarkW8A8Int8) from vllm.model_executor.layers.quantization.quark.utils import ( deep_compare, should_ignore_layer) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - FUSED_LAYER_NAME_MAPPING) from vllm.platforms import current_platform __all__ = ["QuarkLinearMethod"] @@ -56,7 +54,10 @@ def get_quant_method(self, layer: torch.nn.Module, # Check if the layer is skipped for quantization. exclude_layers = cast(List[str], self.quant_config.get("exclude")) - if should_ignore_layer(prefix, ignore=exclude_layers): + if should_ignore_layer( + prefix, + ignore=exclude_layers, + packed_modules_mapping=self.packed_modules_mapping): return UnquantizedLinearMethod() if isinstance(layer, LinearBase): scheme = self.get_scheme(layer=layer, layer_name=prefix) @@ -199,8 +200,8 @@ def _find_matched_config(self, layer_name: str, module: torch.nn.Module) -> Dict[str, Any]: proj_name = layer_name.split(".")[-1] - if proj_name in FUSED_LAYER_NAME_MAPPING: - shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name] + if proj_name in self.packed_modules_mapping: + shard_proj_names = self.packed_modules_mapping[proj_name] # Convert fused_name --> [shard_names] shard_names = [ diff --git a/vllm/model_executor/layers/quantization/quark/utils.py b/vllm/model_executor/layers/quantization/quark/utils.py index 742a629bdb1c..3cc91ae04e21 100644 --- a/vllm/model_executor/layers/quantization/quark/utils.py +++ b/vllm/model_executor/layers/quantization/quark/utils.py @@ -1,8 +1,6 @@ import re -from typing import Any, Iterable, Optional - -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - FUSED_LAYER_NAME_MAPPING) +from types import MappingProxyType +from typing import Any, Iterable, List, Mapping, Optional def deep_compare(dict1: Any, dict2: Any) -> bool: @@ -18,8 +16,11 @@ def deep_compare(dict1: Any, dict2: Any) -> bool: return dict1 == dict2 -def should_ignore_layer(layer_name: Optional[str], - ignore: Iterable[str]) -> bool: +def should_ignore_layer( + layer_name: Optional[str], + ignore: Iterable[str], + packed_modules_mapping: Mapping[str, List[str]] = MappingProxyType({}) +) -> bool: if layer_name is None: return False @@ -31,8 +32,8 @@ def should_ignore_layer(layer_name: Optional[str], # in the safetensors checkpoint. So, we convert the name # from the fused version to unfused + check to make sure that # each shard of the fused layer has the same scheme. - if proj_name in FUSED_LAYER_NAME_MAPPING: - shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name] + if proj_name in packed_modules_mapping: + shard_proj_names = packed_modules_mapping[proj_name] # Convert fused_name --> [shard_names] shard_names = [ diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 83055d6000d8..25a6fdf6e8f8 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -1,5 +1,6 @@ """This file is used for /tests and /benchmarks""" -from typing import List, Optional +from types import MappingProxyType +from typing import List, Mapping, Optional import numpy import torch @@ -11,14 +12,6 @@ SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] -# Note: this is a hack. We should update each model to register the -# stacked params and get it from there instead in a future PR. -# fused_name: List[shard_name] -FUSED_LAYER_NAME_MAPPING = { - "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"] -} - def pack_quantized_values_into_int32(w_q: torch.Tensor, wtype: ScalarType, @@ -63,14 +56,18 @@ def unpack_quantized_values_into_int32(w_q: torch.Tensor, return res.permute(inv_perm) -def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool: +def is_layer_skipped( + prefix: str, + ignored_layers: List[str], + packed_modules_mapping: Mapping[str, List[str]] = MappingProxyType({}) +) -> bool: # prefix: model.layers.0.self_attn.q_proj # proj_name: q_proj proj_name = prefix.split(".")[-1] - if proj_name in FUSED_LAYER_NAME_MAPPING: + if proj_name in packed_modules_mapping: shard_prefixes = [ prefix.replace(proj_name, shard_proj_name) - for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name] + for shard_proj_name in packed_modules_mapping[proj_name] ] is_skipped = None diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 712266ee4263..401cafd4ee19 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -110,6 +110,11 @@ def _initialize_model( model_config = vllm_config.model_config model_class, _ = get_model_architecture(model_config) + # share reference to packed_modules_mapping with quant_config + packed_mapping = hasattr(model_class, "packed_modules_mapping", None) + if packed_mapping is not None and vllm_config.quant_config is not None: + vllm_config.quant_config.packed_modules_mapping = packed_mapping + signatures = inspect.signature(model_class.__init__) all_params = [param.name for param in signatures.parameters.values()] if "vllm_config" in all_params and "prefix" in all_params: From 1b1e1808e6ffa671258e85085df88830167115e7 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 31 Jan 2025 14:25:45 -0500 Subject: [PATCH 05/19] update reference when model class is selected Signed-off-by: Kyle Sayers --- vllm/model_executor/model_loader/loader.py | 4 ++-- vllm/model_executor/models/chatglm.py | 14 +++++++++++--- vllm/model_executor/models/minicpmv.py | 12 +++++++++--- vllm/model_executor/models/qwen.py | 14 +++++++++++--- 4 files changed, 33 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 401cafd4ee19..4752ef93f73b 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -110,8 +110,8 @@ def _initialize_model( model_config = vllm_config.model_config model_class, _ = get_model_architecture(model_config) - # share reference to packed_modules_mapping with quant_config - packed_mapping = hasattr(model_class, "packed_modules_mapping", None) + # pass packed_modules_mapping by reference to quant_config + packed_mapping = getattr(model_class, "packed_modules_mapping", None) if packed_mapping is not None and vllm_config.quant_config is not None: vllm_config.quant_config.packed_modules_mapping = packed_mapping diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index b1f8de45fad8..e8a566e2eae6 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -770,6 +770,7 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, SupportsMultiModal): # Ensure that the LoRA support check passes when the class is not # initialized, but set all these attributes to empty. + # These will be updated when a model class is selected packed_modules_mapping = {} supported_lora_modules = [] embedding_modules = {} @@ -781,9 +782,16 @@ def __new__( prefix: str = "", ) -> None: config = vllm_config.model_config.hf_config + # Initialize VL - if hasattr(config, "vision_config"): - return ChatGLMV(vllm_config=vllm_config, prefix=prefix) + if hasattr(config, "vision_config"): # noqa: SIM108 + instance_cls = ChatGLMV # Initialize LLM else: - return ChatGLM(vllm_config=vllm_config, prefix=prefix) \ No newline at end of file + instance_cls = ChatGLM + + cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping) + cls.supported_lora_modules += instance_cls.supported_lora_modules + cls.embedding_modules.update(instance_cls.embedding_modules) + cls.embedding_padding_modules += instance_cls.embedding_padding_modules + return instance_cls(vllm_config=vllm_config, prefix=prefix) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index bf967d33a317..a33ac943feb9 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -1473,6 +1473,7 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsMultiModal, SupportsLoRA): """ # Ensure that the LoRA support check passes when the class is not # initialized, but set all these attributes to empty. + # These will be updated when a model class is selected packed_modules_mapping = {} supported_lora_modules = [] embedding_modules = {} @@ -1489,8 +1490,13 @@ def __new__(cls, *, vllm_config: VllmConfig, prefix: str = ""): version = str(config.version).split(".") version = tuple([int(x) for x in version]) # Dispatch class based on version - instance_class = _SUPPORT_VERSION.get(version) - if instance_class is None: + instance_cls = _SUPPORT_VERSION.get(version) + if instance_cls is None: raise ValueError( "Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6") - return instance_class(vllm_config=vllm_config, prefix=prefix) + + cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping) + cls.supported_lora_modules += instance_cls.supported_lora_modules + cls.embedding_modules.update(instance_cls.embedding_modules) + cls.embedding_padding_modules += instance_cls.embedding_padding_modules + return instance_cls(vllm_config=vllm_config, prefix=prefix) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 86a9d3089c3e..f0684b3025d3 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -1129,6 +1129,7 @@ class QWenLMHeadModel(QWenBaseModel, SupportsMultiModal, SupportsLoRA): """ # Ensure that the LoRA support check passes when the class is not # initialized, but set all these attributes to empty. + # These will be updated when a model class is selected packed_modules_mapping = {} supported_lora_modules = [] embedding_modules = {} @@ -1140,9 +1141,16 @@ def __new__( prefix: str = "", ) -> QWenBaseModel: config = vllm_config.model_config.hf_config + # Initialize VL - if hasattr(config, "visual"): - return QWenVL(vllm_config=vllm_config, prefix=prefix) + if hasattr(config, "visual"): # noqa: SIM108 + instance_cls = QWenVL # Initialize LLM else: - return QWenLLM(vllm_config=vllm_config, prefix=prefix) + instance_cls = QWenLLM + + cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping) + cls.supported_lora_modules += instance_cls.supported_lora_modules + cls.embedding_modules.update(instance_cls.embedding_modules) + cls.embedding_padding_modules += instance_cls.embedding_padding_modules + return instance_cls(vllm_config=vllm_config, prefix=prefix) From 41c756d31b66a5ad28c9a63347287bc2595159f0 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 31 Jan 2025 17:14:58 -0500 Subject: [PATCH 06/19] update comment Signed-off-by: Kyle Sayers --- vllm/model_executor/models/chatglm.py | 2 +- vllm/model_executor/models/minicpmv.py | 2 +- vllm/model_executor/models/qwen.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index e8a566e2eae6..bfce57fe84dc 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -770,7 +770,7 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, SupportsMultiModal): # Ensure that the LoRA support check passes when the class is not # initialized, but set all these attributes to empty. - # These will be updated when a model class is selected + # These will be updated when an instance class is selected packed_modules_mapping = {} supported_lora_modules = [] embedding_modules = {} diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index a33ac943feb9..e9e737b121ff 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -1473,7 +1473,7 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsMultiModal, SupportsLoRA): """ # Ensure that the LoRA support check passes when the class is not # initialized, but set all these attributes to empty. - # These will be updated when a model class is selected + # These will be updated when an instance class is selected packed_modules_mapping = {} supported_lora_modules = [] embedding_modules = {} diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index f0684b3025d3..4db74a756216 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -1129,7 +1129,7 @@ class QWenLMHeadModel(QWenBaseModel, SupportsMultiModal, SupportsLoRA): """ # Ensure that the LoRA support check passes when the class is not # initialized, but set all these attributes to empty. - # These will be updated when a model class is selected + # These will be updated when an instance class is selected packed_modules_mapping = {} supported_lora_modules = [] embedding_modules = {} From 0960a155dfbce296ed3f4bdb602b28e30a6d0e6d Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sun, 2 Feb 2025 11:27:57 -0500 Subject: [PATCH 07/19] remove _handle_fused_layers Signed-off-by: Kyle Sayers --- .../quantization/compressed_tensors/utils.py | 47 ------------------- 1 file changed, 47 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py index b3eecf1a21da..10630b6e8bdb 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py @@ -78,53 +78,6 @@ def check_equal_or_regex_match(layer_name: str, return False -def _handle_fused_layers(func): - """ - Decorator to handle fused layers by mapping vllm fused layer names - to their corresponding unfused layer names for quantization/pruning schemes. - """ - # fused_layer_name -> unfused_layer_name - fused_layer_map = { - "qkv_proj": "q_proj", - "gate_up_proj": "up_proj", - } - - def fused_layer_handler(layer_name: Optional[str], module: Module, - targets: Iterable[str]) -> Optional[str]: - """ - Wrapper function specifically designed to support the - find_matched_target function. - - It handles cases where the provided layer name corresponds to a - fused layer in vllm, mapping it to its equivalent unfused layer name - based on the predefined fused_layer_map. If the original layer name - raises a ValueError in the wrapped function, this handler - will attempt to resolve the issue by substituting with unfused - layer name. - - :param layer_name: Name of the layer, which may be fused. - :param module: An instance of torch.nn.Module. - :param targets: A list of target names or patterns to match. - :return: The result of the wrapped find_matched_target function with - the resolved layer name. - :raises ValueError: If the layer name cannot be resolved to a - valid target. - """ - try: - return func(layer_name, module, targets) - except ValueError: - if layer_name is None: - layer_name = "" - parent_name, fused_proj_name = layer_name.rsplit(".", 1) - unfused_proj_name = fused_layer_map.get(fused_proj_name, - fused_proj_name) - new_layer_name = f"{parent_name}.{unfused_proj_name}" - return func(new_layer_name, module, targets) - - return fused_layer_handler - - -@_handle_fused_layers def find_matched_target( layer_name: Optional[str], module: Module, From 1ceab2d55c3c9f3aa7d089d560fe021054f09645 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sun, 2 Feb 2025 11:33:24 -0500 Subject: [PATCH 08/19] shorten arg name Signed-off-by: Kyle Sayers --- .../model_executor/layers/quantization/quark/utils.py | 6 +++--- .../layers/quantization/utils/quant_utils.py | 11 ++++++++--- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/quantization/quark/utils.py b/vllm/model_executor/layers/quantization/quark/utils.py index 3cc91ae04e21..86f958b2fbb8 100644 --- a/vllm/model_executor/layers/quantization/quark/utils.py +++ b/vllm/model_executor/layers/quantization/quark/utils.py @@ -19,7 +19,7 @@ def deep_compare(dict1: Any, dict2: Any) -> bool: def should_ignore_layer( layer_name: Optional[str], ignore: Iterable[str], - packed_modules_mapping: Mapping[str, List[str]] = MappingProxyType({}) + mapping: Mapping[str, List[str]] = MappingProxyType({}) ) -> bool: if layer_name is None: return False @@ -32,8 +32,8 @@ def should_ignore_layer( # in the safetensors checkpoint. So, we convert the name # from the fused version to unfused + check to make sure that # each shard of the fused layer has the same scheme. - if proj_name in packed_modules_mapping: - shard_proj_names = packed_modules_mapping[proj_name] + if proj_name in mapping: + shard_proj_names = mapping[proj_name] # Convert fused_name --> [shard_names] shard_names = [ diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 6bb867dd0de2..93ac73d20b72 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -173,15 +173,20 @@ def unpack_quantized_values_into_int32(w_q: torch.Tensor, def is_layer_skipped( prefix: str, ignored_layers: List[str], - packed_modules_mapping: Mapping[str, List[str]] = MappingProxyType({}) + mapping: Mapping[str, List[str]] = MappingProxyType({}) ) -> bool: # prefix: model.layers.0.self_attn.q_proj # proj_name: q_proj proj_name = prefix.split(".")[-1] - if proj_name in packed_modules_mapping: + + # Fused layers like gate_up_proj or qkv_proj will not be fused + # in the safetensors checkpoint. So, we convert the name + # from the fused version to unfused + check to make sure that + # each shard of the fused layer has the same scheme. + if proj_name in mapping: shard_prefixes = [ prefix.replace(proj_name, shard_proj_name) - for shard_proj_name in packed_modules_mapping[proj_name] + for shard_proj_name in mapping[proj_name] ] is_skipped = None From a134b92a0907f45d73ac2cfdeacb027c24af9179 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sun, 2 Feb 2025 12:04:52 -0500 Subject: [PATCH 09/19] break out to function, raise warning if is missing Signed-off-by: Kyle Sayers --- vllm/model_executor/model_loader/loader.py | 7 +++--- vllm/model_executor/model_loader/utils.py | 25 ++++++++++++++++++++++ 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 973514bd3b9b..e66e2c584750 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -41,6 +41,7 @@ TensorizerConfig, is_vllm_tensorized, load_with_tensorizer, serialize_vllm_model, tensorizer_weights_iterator) from vllm.model_executor.model_loader.utils import (ParamMapping, + configure_quant_config, get_model_architecture, set_default_torch_dtype) from vllm.model_executor.model_loader.weight_utils import ( @@ -111,10 +112,8 @@ def _initialize_model( model_config = vllm_config.model_config model_class, _ = get_model_architecture(model_config) - # pass packed_modules_mapping by reference to quant_config - packed_mapping = getattr(model_class, "packed_modules_mapping", None) - if packed_mapping is not None and vllm_config.quant_config is not None: - vllm_config.quant_config.packed_modules_mapping = packed_mapping + if vllm_config.quant_config is not None: + configure_quant_config(vllm_config.quant_config, model_class) signatures = inspect.signature(model_class.__init__) all_params = [param.name for param in signatures.parameters.values()] diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 3f923d2f6632..f5d2fe9c14ec 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -7,11 +7,16 @@ from torch import nn from vllm.config import ModelConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models.adapters import (as_classification_model, as_embedding_model, as_reward_model) +logger = init_logger(__name__) + @contextlib.contextmanager def set_default_torch_dtype(dtype: torch.dtype): @@ -73,3 +78,23 @@ def __post_init__(self): packed_name, index, ) + + +def configure_quant_config(quant_config: QuantizationConfig, + model_class: Type[nn.Module]): + """ + Pass packed_modules_mapping by reference to quant_config so that + quant_config can properly match fused modules + + Note that model attributes are passed by reference to quant_config, + enabling them to be updated by model_class.__new__ (ex. chatglm, qwen) + """ + packed_mapping = hasattr(model_class, "packed_modules_mapping", None) + if packed_mapping is not None: + # pass packed_modules_mapping by reference to quant_config + quant_config.packed_modules_mapping = packed_mapping + else: + logger.warning( + "The model class %s has not defined `packed_modules_mapping`, " + "this may lead to incorrect mapping of quantized or ignored " + "modules", model_class.__name__) From 68a73e23cc6c1c1d19dbc476b4eeaa878be3e2a5 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sun, 2 Feb 2025 12:20:54 -0500 Subject: [PATCH 10/19] typos Signed-off-by: Kyle Sayers --- .../compressed_tensors/compressed_tensors.py | 17 ++++++++--------- .../layers/quantization/quark/quark.py | 7 +++---- vllm/model_executor/model_loader/utils.py | 2 +- 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index a9e8a8bb5686..18abff133e6c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -81,10 +81,9 @@ def get_quant_method( # Check if the layer is skipped for quantization. # TODO (@robertgshaw2): support module names - if should_ignore_layer( - prefix, - ignore=self.ignore, - packed_modules_mapping=self.packed_modules_mapping): + if should_ignore_layer(prefix, + ignore=self.ignore, + mapping=self.packed_modules_mapping): return UnquantizedLinearMethod() if isinstance(layer, LinearBase): scheme = self.get_scheme(layer=layer, layer_name=prefix) @@ -395,11 +394,11 @@ def get_scheme(self, if self.sparsity_scheme_map: is_ignored = False with suppress(ValueError): - is_ignored = find_matched_target( - layer_name=layer_name, - module=layer, - targets=self.sparsity_ignore_list, - mapping=self.packed_modules_mapping) + find_matched_target(layer_name=layer_name, + module=layer, + targets=self.sparsity_ignore_list, + mapping=self.packed_modules_mapping) + is_ignored = True # if the layer is in the sparsity ignore list, # we should not apply any sparsity scheme diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index b67b5c2733a8..07b3854665f7 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -54,10 +54,9 @@ def get_quant_method(self, layer: torch.nn.Module, # Check if the layer is skipped for quantization. exclude_layers = cast(List[str], self.quant_config.get("exclude")) - if should_ignore_layer( - prefix, - ignore=exclude_layers, - packed_modules_mapping=self.packed_modules_mapping): + if should_ignore_layer(prefix, + ignore=exclude_layers, + mapping=self.packed_modules_mapping): return UnquantizedLinearMethod() if isinstance(layer, LinearBase): scheme = self.get_scheme(layer=layer, layer_name=prefix) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index f5d2fe9c14ec..e8e33d54fe73 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -89,7 +89,7 @@ def configure_quant_config(quant_config: QuantizationConfig, Note that model attributes are passed by reference to quant_config, enabling them to be updated by model_class.__new__ (ex. chatglm, qwen) """ - packed_mapping = hasattr(model_class, "packed_modules_mapping", None) + packed_mapping = getattr(model_class, "packed_modules_mapping", None) if packed_mapping is not None: # pass packed_modules_mapping by reference to quant_config quant_config.packed_modules_mapping = packed_mapping From bb8bc2037e478a4ca6d856231dcf55becef9440c Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sun, 2 Feb 2025 12:45:03 -0500 Subject: [PATCH 11/19] WIP: need testing Signed-off-by: Kyle Sayers --- .../quantization/compressed_tensors/utils.py | 29 +++++-------------- 1 file changed, 8 insertions(+), 21 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py index 10630b6e8bdb..26452401c6a2 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py @@ -165,7 +165,7 @@ def _match_fused_layer(layer_name: str, target_layers: Iterable[str], mapping: Mapping[str, List[str]]) -> Optional[str]: """ Match a fused layer name to its corresponding individual layer in - target_layers. + target_layers. Returns first value in mapping which matches targets Examples: layer_name = "model.layers.0.self_attn.qkv_proj" @@ -173,27 +173,14 @@ def _match_fused_layer(layer_name: str, target_layers: Iterable[str], "model.layers.0.self_attn.k_proj", "model.layers.0.self_attn.v_proj"] """ - # Split into parent path and layer type - # e.g., "model.layers.0.self_attn" and "qkv_proj" - parent_path = ".".join(layer_name.split(".")[:-1]) - layer_type = layer_name.split(".")[-1] + unfused_paths = sum( + (layer_name.replace(fused, unfused) + for (fused, unfused) in mapping if layer_name.endswith(fused)), + start=[]) - if layer_type not in mapping: - return None - - possible_layer_types = mapping[layer_type] - - # Look for a target layer that: - # 1. Has the same parent path - # 2. Ends with one of the possible individual layer types for target in target_layers: - is_same_parent = parent_path in target - is_matching_type = any(type_suffix in target - for type_suffix in possible_layer_types) - - if is_same_parent and is_matching_type and all( - '.'.join([parent_path, type_suffix]) - for type_suffix in possible_layer_types): - return target + for unfused_path in unfused_paths: + if _is_equal_or_regex_match(unfused_path, target): + return target return None From b70cf8c4afb36fda70cbd482b71c27fbc5af3818 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sun, 2 Feb 2025 23:59:17 -0500 Subject: [PATCH 12/19] matching logic, comments Signed-off-by: Kyle Sayers --- .../quantization/compressed_tensors/utils.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py index 26452401c6a2..a32122df4ef5 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py @@ -97,6 +97,9 @@ def find_matched_target( First, we try to match the layer_name with a target Second, we try to match the module's name with a target + Third, we try to map the layer_name to a list of fused module names. + All fused module names must match in order for a match to be + successful. A successful match returns the first module name :param layer_name: layer name :param module: torch.nn.Module @@ -178,9 +181,13 @@ def _match_fused_layer(layer_name: str, target_layers: Iterable[str], for (fused, unfused) in mapping if layer_name.endswith(fused)), start=[]) - for target in target_layers: - for unfused_path in unfused_paths: - if _is_equal_or_regex_match(unfused_path, target): - return target + if len(unfused_paths) <= 0: + return None - return None + for unfused_path in unfused_paths: + if not any( + _is_equal_or_regex_match(unfused_path, target) + for target in target_layers): + return None + + return unfused_paths[0] From 28b95402e913420617ef7e7d0c2a5cdbf43faf87 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 3 Feb 2025 11:58:51 -0500 Subject: [PATCH 13/19] implement fused strategy Signed-off-by: Kyle Sayers --- .../compressed_tensors/compressed_tensors.py | 9 ++- .../quantization/compressed_tensors/utils.py | 67 +++++++++++++------ 2 files changed, 52 insertions(+), 24 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 7d5fe3cd0982..56555fdd657b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -387,7 +387,8 @@ def get_scheme(self, layer_name=layer_name, module=layer, targets=self.target_scheme_map.keys(), - mapping=self.packed_modules_mapping) + fused_mapping=self.packed_modules_mapping, + fused_strategy="all") scheme_dict = self.target_scheme_map[matched_target] weight_quant = scheme_dict.get("weights") @@ -399,7 +400,8 @@ def get_scheme(self, find_matched_target(layer_name=layer_name, module=layer, targets=self.sparsity_ignore_list, - mapping=self.packed_modules_mapping) + fused_mapping=self.packed_modules_mapping, + fused_strategy="all") is_ignored = True # if the layer is in the sparsity ignore list, @@ -410,7 +412,8 @@ def get_scheme(self, layer_name=layer_name, module=layer, targets=self.sparsity_scheme_map.keys(), - mapping=self.packed_modules_mapping) + fused_mapping=self.packed_modules_mapping, + fused_strategy="any") sparsity_scheme = self.sparsity_scheme_map.get(matched_target) if self.supports_cutlass_24(weight_quant=weight_quant, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py index 1eba21bbe59d..c5c6b85f4dec 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py @@ -80,12 +80,13 @@ def check_equal_or_regex_match(layer_name: str, return False -def find_matched_target( - layer_name: Optional[str], - module: Module, - targets: Iterable[str], - mapping: Mapping[str, List[str]] = MappingProxyType({}) -) -> str: +def find_matched_target(layer_name: Optional[str], + module: Module, + targets: Iterable[str], + fused_mapping: Mapping[str, + List[str]] = MappingProxyType( + {}), + fused_strategy: str = "all") -> str: """ Helper function to look up which "target" in the compressed-tensors config that a layer corresponds to. @@ -106,6 +107,9 @@ def find_matched_target( :param layer_name: layer name :param module: torch.nn.Module :param targets: list of targets to match the layer against + :param fused_mapping: map from fused layer names to its components + :param fused_strategy: either "all" or "any". If using "all", fused + layers match if "all" of its components match """ if layer_name is None: @@ -114,7 +118,8 @@ def find_matched_target( matched_target = (_find_first_match(layer_name, targets) or _find_first_match(module.__class__.__name__, targets, True) - or _match_fused_layer(layer_name, targets, mapping)) + or _match_fused_layer(layer_name, targets, fused_mapping, + fused_strategy)) if matched_target is None: raise ValueError( @@ -166,30 +171,50 @@ def _is_equal_or_regex_match(value: str, return False -def _match_fused_layer(layer_name: str, target_layers: Iterable[str], - mapping: Mapping[str, List[str]]) -> Optional[str]: +def _match_fused_layer(layer_name: str, + target_layers: Iterable[str], + mapping: Mapping[str, List[str]], + fused_strategy: str = "all") -> Optional[str]: """ Match a fused layer name to its corresponding individual layer in target_layers. Returns first value in mapping which matches targets + :param layer_name: layer name + :param target_layers: list of targets to match the layer against + :param mapping: map from fused layer names to its components + :param fused_strategy: either "all" or "any". If using "all", fused + layers match if "all" of its components match + Examples: layer_name = "model.layers.0.self_attn.qkv_proj" target_layers = ["model.layers.0.self_attn.q_proj", "model.layers.0.self_attn.k_proj", "model.layers.0.self_attn.v_proj"] """ - unfused_paths = sum( - (layer_name.replace(fused, unfused) - for (fused, unfused) in mapping if layer_name.endswith(fused)), - start=[]) - - if len(unfused_paths) <= 0: + # find layer_name in mapping + fused = next((key for key in mapping if layer_name.endswith(key)), None) + if fused is None: return None - for unfused_path in unfused_paths: - if not any( - _is_equal_or_regex_match(unfused_path, target) - for target in target_layers): - return None + # expand path of unfused components + unfused_paths = [ + layer_name.replace(fused, unfused) for unfused in mapping[fused] + ] + + # for each unfused component, find a match in targets + unfused_matches = [] + for unfused in unfused_paths: + for target in target_layers: + if _is_equal_or_regex_match(unfused, target): + unfused_matches.append(target) + break + else: + unfused_matches.append(None) + + # use strategy to aggregate results + if fused_strategy == "all" and all(unfused_matches): # noqa: SIM114 + return unfused_matches[0] + elif fused_strategy == "any" and any(unfused_matches): + return next(match for match in unfused_matches if match is not None) - return unfused_paths[0] + return None From e7ded255dfd8d1cfdee883a90c247efcdeaf60fd Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 3 Feb 2025 13:28:44 -0500 Subject: [PATCH 14/19] rename to fused_mapping, logic clarity Signed-off-by: Kyle Sayers --- .../compressed_tensors/compressed_tensors.py | 38 +++++--------- .../quantization/compressed_tensors/utils.py | 51 ++++++++----------- .../layers/quantization/quark/quark.py | 2 +- .../layers/quantization/quark/utils.py | 6 +-- 4 files changed, 37 insertions(+), 60 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 56555fdd657b..872937be5476 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from contextlib import suppress from typing import Any, Dict, List, Literal, Optional, Tuple, cast import torch @@ -85,7 +84,7 @@ def get_quant_method( # TODO (@robertgshaw2): support module names if should_ignore_layer(prefix, ignore=self.ignore, - mapping=self.packed_modules_mapping): + fused_mapping=self.packed_modules_mapping): return UnquantizedLinearMethod() if isinstance(layer, LinearBase): scheme = self.get_scheme(layer=layer, layer_name=prefix) @@ -381,40 +380,27 @@ def get_scheme(self, # Will be empty for models with only sparsity weight_quant = input_quant = None - sparsity_scheme: Optional[SparsityCompressionConfig] = None if self.target_scheme_map: matched_target = find_matched_target( layer_name=layer_name, module=layer, targets=self.target_scheme_map.keys(), - fused_mapping=self.packed_modules_mapping, - fused_strategy="all") + fused_mapping=self.packed_modules_mapping) scheme_dict = self.target_scheme_map[matched_target] weight_quant = scheme_dict.get("weights") input_quant = scheme_dict.get("input_activations") - if self.sparsity_scheme_map: - is_ignored = False - with suppress(ValueError): - find_matched_target(layer_name=layer_name, - module=layer, - targets=self.sparsity_ignore_list, - fused_mapping=self.packed_modules_mapping, - fused_strategy="all") - is_ignored = True - - # if the layer is in the sparsity ignore list, - # we should not apply any sparsity scheme - - if not is_ignored: - matched_target = find_matched_target( - layer_name=layer_name, - module=layer, - targets=self.sparsity_scheme_map.keys(), - fused_mapping=self.packed_modules_mapping, - fused_strategy="any") - sparsity_scheme = self.sparsity_scheme_map.get(matched_target) + # Find the sparsity scheme of the layer + # assume that fused layers inerhit first component's sparsity scheme + sparsity_targets = (self.sparsity_scheme_map.keys() - + set(self.sparsity_ignore_list)) + matched_target = find_matched_target( + layer_name=layer_name, + module=layer, + targets=sparsity_targets, + fused_mapping=self.packed_modules_mapping) + sparsity_scheme = self.sparsity_scheme_map.get(matched_target, None) if self.supports_cutlass_24(weight_quant=weight_quant, input_quant=input_quant, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py index c5c6b85f4dec..9442abc114bf 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py @@ -20,7 +20,7 @@ def is_activation_quantization_format(format: str) -> bool: def should_ignore_layer( layer_name: Optional[str], ignore: Iterable[str] = tuple(), - mapping: Mapping[str, List[str]] = MappingProxyType({}) + fused_mapping: Mapping[str, List[str]] = MappingProxyType({}) ) -> bool: if layer_name is None: return False @@ -33,8 +33,8 @@ def should_ignore_layer( # in the safetensors checkpoint. So, we convert the name # from the fused version to unfused + check to make sure that # each shard of the fused layer has the same scheme. - if proj_name in mapping and layer_name not in ignore: - shard_proj_names = mapping[proj_name] + if proj_name in fused_mapping and layer_name not in ignore: + shard_proj_names = fused_mapping[proj_name] # Convert fused_name --> [shard_names] shard_names = [ @@ -80,13 +80,12 @@ def check_equal_or_regex_match(layer_name: str, return False -def find_matched_target(layer_name: Optional[str], - module: Module, - targets: Iterable[str], - fused_mapping: Mapping[str, - List[str]] = MappingProxyType( - {}), - fused_strategy: str = "all") -> str: +def find_matched_target( + layer_name: Optional[str], + module: Module, + targets: Iterable[str], + fused_mapping: Mapping[str, List[str]] = MappingProxyType({}) +) -> str: """ Helper function to look up which "target" in the compressed-tensors config that a layer corresponds to. @@ -101,8 +100,8 @@ def find_matched_target(layer_name: Optional[str], First, we try to match the layer_name with a target Second, we try to match the module's name with a target Third, we try to map the layer_name to a list of fused module names. - All fused module names must match in order for a match to be - successful. A successful match returns the first module name + *All* component module names must match in order for a match to be + successful. A successful match returns the first component target :param layer_name: layer name :param module: torch.nn.Module @@ -115,11 +114,10 @@ def find_matched_target(layer_name: Optional[str], if layer_name is None: layer_name = "" - matched_target = (_find_first_match(layer_name, targets) - or _find_first_match(module.__class__.__name__, targets, - True) - or _match_fused_layer(layer_name, targets, fused_mapping, - fused_strategy)) + matched_target = ( + _find_first_match(layer_name, targets) + or _find_first_match(module.__class__.__name__, targets, True) + or _match_fused_layer(layer_name, targets, fused_mapping)) if matched_target is None: raise ValueError( @@ -171,19 +169,18 @@ def _is_equal_or_regex_match(value: str, return False -def _match_fused_layer(layer_name: str, - target_layers: Iterable[str], - mapping: Mapping[str, List[str]], - fused_strategy: str = "all") -> Optional[str]: +def _match_fused_layer(layer_name: str, target_layers: Iterable[str], + mapping: Mapping[str, List[str]]) -> Optional[str]: """ Match a fused layer name to its corresponding individual layer in target_layers. Returns first value in mapping which matches targets + Implements an "all" matching strategy where a fused layer matches iff + "all" of its components match + :param layer_name: layer name :param target_layers: list of targets to match the layer against :param mapping: map from fused layer names to its components - :param fused_strategy: either "all" or "any". If using "all", fused - layers match if "all" of its components match Examples: layer_name = "model.layers.0.self_attn.qkv_proj" @@ -211,10 +208,4 @@ def _match_fused_layer(layer_name: str, else: unfused_matches.append(None) - # use strategy to aggregate results - if fused_strategy == "all" and all(unfused_matches): # noqa: SIM114 - return unfused_matches[0] - elif fused_strategy == "any" and any(unfused_matches): - return next(match for match in unfused_matches if match is not None) - - return None + return unfused_matches[0] if all(unfused_matches) else None diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index d7f60dfaf9cf..ba123565a0ec 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -58,7 +58,7 @@ def get_quant_method(self, layer: torch.nn.Module, exclude_layers = cast(List[str], self.quant_config.get("exclude")) if should_ignore_layer(prefix, ignore=exclude_layers, - mapping=self.packed_modules_mapping): + fused_mapping=self.packed_modules_mapping): return UnquantizedLinearMethod() if isinstance(layer, LinearBase): scheme = self.get_scheme(layer=layer, layer_name=prefix) diff --git a/vllm/model_executor/layers/quantization/quark/utils.py b/vllm/model_executor/layers/quantization/quark/utils.py index 0061b902d9d4..17e0df021085 100644 --- a/vllm/model_executor/layers/quantization/quark/utils.py +++ b/vllm/model_executor/layers/quantization/quark/utils.py @@ -21,7 +21,7 @@ def deep_compare(dict1: Any, dict2: Any) -> bool: def should_ignore_layer( layer_name: Optional[str], ignore: Iterable[str], - mapping: Mapping[str, List[str]] = MappingProxyType({}) + fused_mapping: Mapping[str, List[str]] = MappingProxyType({}) ) -> bool: if layer_name is None: return False @@ -34,8 +34,8 @@ def should_ignore_layer( # in the safetensors checkpoint. So, we convert the name # from the fused version to unfused + check to make sure that # each shard of the fused layer has the same scheme. - if proj_name in mapping: - shard_proj_names = mapping[proj_name] + if proj_name in fused_mapping: + shard_proj_names = fused_mapping[proj_name] # Convert fused_name --> [shard_names] shard_names = [ From fe0e1ec2f5b1401e6ab43e946d17c4d5ff5f9875 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 3 Feb 2025 13:43:24 -0500 Subject: [PATCH 15/19] fix suppress Signed-off-by: Kyle Sayers --- .../compressed_tensors/compressed_tensors.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 872937be5476..d7d7d06c2667 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +from contextlib import suppress from typing import Any, Dict, List, Literal, Optional, Tuple, cast import torch @@ -395,12 +396,14 @@ def get_scheme(self, # assume that fused layers inerhit first component's sparsity scheme sparsity_targets = (self.sparsity_scheme_map.keys() - set(self.sparsity_ignore_list)) - matched_target = find_matched_target( - layer_name=layer_name, - module=layer, - targets=sparsity_targets, - fused_mapping=self.packed_modules_mapping) - sparsity_scheme = self.sparsity_scheme_map.get(matched_target, None) + sparsity_scheme: Optional[SparsityCompressionConfig] = None + with suppress(ValueError): + matched_target = find_matched_target( + layer_name=layer_name, + module=layer, + targets=sparsity_targets, + fused_mapping=self.packed_modules_mapping) + sparsity_scheme = self.sparsity_scheme_map[matched_target] if self.supports_cutlass_24(weight_quant=weight_quant, input_quant=input_quant, From 2fce7f742939ff0ec0ec48faa73d4a2ff8e3336a Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 3 Feb 2025 17:27:02 -0500 Subject: [PATCH 16/19] type hint Signed-off-by: Kyle Sayers --- .../layers/quantization/compressed_tensors/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py index 9442abc114bf..0a3298576660 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py @@ -199,7 +199,7 @@ def _match_fused_layer(layer_name: str, target_layers: Iterable[str], ] # for each unfused component, find a match in targets - unfused_matches = [] + unfused_matches: List[Optional[str]] = [] for unfused in unfused_paths: for target in target_layers: if _is_equal_or_regex_match(unfused, target): From b898dd6013c92b9a939d856e0198af0336c2f23b Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 3 Feb 2025 18:35:01 -0500 Subject: [PATCH 17/19] rename mapping to fused_mapping Signed-off-by: Kyle Sayers --- .../quantization/compressed_tensors/utils.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py index 0a3298576660..85ae1d5cb787 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py @@ -169,18 +169,19 @@ def _is_equal_or_regex_match(value: str, return False -def _match_fused_layer(layer_name: str, target_layers: Iterable[str], - mapping: Mapping[str, List[str]]) -> Optional[str]: +def _match_fused_layer( + layer_name: str, target_layers: Iterable[str], + fused_mapping: Mapping[str, List[str]]) -> Optional[str]: """ Match a fused layer name to its corresponding individual layer in - target_layers. Returns first value in mapping which matches targets + target_layers. Returns first value in fused_mapping which matches targets Implements an "all" matching strategy where a fused layer matches iff "all" of its components match :param layer_name: layer name :param target_layers: list of targets to match the layer against - :param mapping: map from fused layer names to its components + :param fused_mapping: map from fused layer names to its components Examples: layer_name = "model.layers.0.self_attn.qkv_proj" @@ -189,13 +190,14 @@ def _match_fused_layer(layer_name: str, target_layers: Iterable[str], "model.layers.0.self_attn.v_proj"] """ # find layer_name in mapping - fused = next((key for key in mapping if layer_name.endswith(key)), None) + fused = next((key for key in fused_mapping if layer_name.endswith(key)), + None) if fused is None: return None # expand path of unfused components unfused_paths = [ - layer_name.replace(fused, unfused) for unfused in mapping[fused] + layer_name.replace(fused, unfused) for unfused in fused_mapping[fused] ] # for each unfused component, find a match in targets From 526c93d29ebb157ee990129c6f62afb8097d2791 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 3 Feb 2025 22:14:06 -0500 Subject: [PATCH 18/19] rename to fused_mapping 2 Signed-off-by: Kyle Sayers --- .../model_executor/layers/quantization/utils/quant_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 62689f46c111..c7ce3a42c81f 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -174,7 +174,7 @@ def unpack_quantized_values_into_int32(w_q: torch.Tensor, def is_layer_skipped( prefix: str, ignored_layers: List[str], - mapping: Mapping[str, List[str]] = MappingProxyType({}) + fused_mapping: Mapping[str, List[str]] = MappingProxyType({}) ) -> bool: # prefix: model.layers.0.self_attn.q_proj # proj_name: q_proj @@ -184,10 +184,10 @@ def is_layer_skipped( # in the safetensors checkpoint. So, we convert the name # from the fused version to unfused + check to make sure that # each shard of the fused layer has the same scheme. - if proj_name in mapping: + if proj_name in fused_mapping: shard_prefixes = [ prefix.replace(proj_name, shard_proj_name) - for shard_proj_name in mapping[proj_name] + for shard_proj_name in fused_mapping[proj_name] ] is_skipped = None From 35d609f3a698c47eac875423329697b9c523ac36 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 4 Feb 2025 10:44:02 -0500 Subject: [PATCH 19/19] add comment Signed-off-by: Kyle Sayers --- vllm/model_executor/models/chatglm.py | 2 ++ vllm/model_executor/models/minicpmv.py | 2 ++ vllm/model_executor/models/qwen.py | 2 ++ 3 files changed, 6 insertions(+) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 9463d4fef9c1..a31648675259 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -792,6 +792,8 @@ def __new__( else: instance_cls = ChatGLM + # quant_config references base class members, + # so update values before init is called cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping) cls.supported_lora_modules += instance_cls.supported_lora_modules cls.embedding_modules.update(instance_cls.embedding_modules) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 3eb99371080e..3606d81dbcee 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -1497,6 +1497,8 @@ def __new__(cls, *, vllm_config: VllmConfig, prefix: str = ""): raise ValueError( "Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6") + # quant_config references base class members, + # so update values before init is called cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping) cls.supported_lora_modules += instance_cls.supported_lora_modules cls.embedding_modules.update(instance_cls.embedding_modules) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 2364b72b31db..89503459b0c1 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -1151,6 +1151,8 @@ def __new__( else: instance_cls = QWenLLM + # quant_config references base class members, + # so update values before init is called cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping) cls.supported_lora_modules += instance_cls.supported_lora_modules cls.embedding_modules.update(instance_cls.embedding_modules)