Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
bc94016
Fix quantization for chatglm
mgoin Jan 30, 2025
10908f6
additional prefix fixes
kylesayrs Jan 30, 2025
1ef97f0
use merged_proj
kylesayrs Jan 30, 2025
6446c94
remove reliance on FUSED_LAYER_NAME_MAPPING in favor of packed_module…
kylesayrs Jan 30, 2025
1b1e180
update reference when model class is selected
kylesayrs Jan 31, 2025
41c756d
update comment
kylesayrs Jan 31, 2025
568e317
Merge remote-tracking branch 'origin' into redhat/fix-packed-mapping-…
kylesayrs Feb 2, 2025
adaf93b
Merge remote-tracking branch 'upstream/main' into redhat/fix-packed-m…
kylesayrs Feb 2, 2025
0960a15
remove _handle_fused_layers
kylesayrs Feb 2, 2025
1ceab2d
shorten arg name
kylesayrs Feb 2, 2025
a134b92
break out to function, raise warning if is missing
kylesayrs Feb 2, 2025
68a73e2
typos
kylesayrs Feb 2, 2025
bb8bc20
WIP: need testing
kylesayrs Feb 2, 2025
b70cf8c
matching logic, comments
kylesayrs Feb 3, 2025
2100c56
Merge remote-tracking branch 'origin' into redhat/fix-packed-mapping-…
kylesayrs Feb 3, 2025
28b9540
implement fused strategy
kylesayrs Feb 3, 2025
e7ded25
rename to fused_mapping, logic clarity
kylesayrs Feb 3, 2025
fe0e1ec
fix suppress
kylesayrs Feb 3, 2025
7078b6d
Merge remote-tracking branch 'upstream/main' into redhat/fix-packed-m…
kylesayrs Feb 3, 2025
2fce7f7
type hint
kylesayrs Feb 3, 2025
b898dd6
rename mapping to fused_mapping
kylesayrs Feb 3, 2025
526c93d
rename to fused_mapping 2
kylesayrs Feb 4, 2025
35d609f
add comment
kylesayrs Feb 4, 2025
2ca5887
Merge remote-tracking branch 'upstream/main' into redhat/fix-packed-m…
kylesayrs Feb 4, 2025
33f4bcc
Merge remote-tracking branch 'upstream/main' into redhat/fix-packed-m…
kylesayrs Feb 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/quantization/base_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import inspect
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Type
from typing import Any, Dict, List, Mapping, Optional, Type

import torch
from torch import nn
Expand Down Expand Up @@ -57,6 +57,7 @@ def method_has_implemented_embedding(

class QuantizationConfig(ABC):
"""Base class for quantization configs."""
packed_modules_mapping: Mapping[str, List[str]] = dict()

@abstractmethod
def get_name(self) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,10 @@ def get_quant_method(

# Check if the layer is skipped for quantization.
# TODO (@robertgshaw2): support module names
if should_ignore_layer(prefix, ignore=self.ignore):
if should_ignore_layer(
prefix,
ignore=self.ignore,
packed_modules_mapping=self.packed_modules_mapping):
return UnquantizedLinearMethod()
if isinstance(layer, LinearBase):
scheme = self.get_scheme(layer=layer, layer_name=prefix)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import re
from typing import Iterable, Optional
from types import MappingProxyType
from typing import Iterable, List, Mapping, Optional

from compressed_tensors import CompressionFormat
from torch.nn import Module

from vllm.model_executor.layers.quantization.utils.quant_utils import (
FUSED_LAYER_NAME_MAPPING)


def is_activation_quantization_format(format: str) -> bool:
_ACTIVATION_QUANTIZATION_FORMATS = [
Expand All @@ -17,8 +15,11 @@ def is_activation_quantization_format(format: str) -> bool:
return format in _ACTIVATION_QUANTIZATION_FORMATS


def should_ignore_layer(layer_name: Optional[str],
ignore: Iterable[str]) -> bool:
def should_ignore_layer(
layer_name: Optional[str],
ignore: Iterable[str] = tuple(),
packed_modules_mapping: Mapping[str, List[str]] = MappingProxyType({})
) -> bool:
if layer_name is None:
return False

Expand All @@ -30,8 +31,8 @@ def should_ignore_layer(layer_name: Optional[str],
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
if proj_name in FUSED_LAYER_NAME_MAPPING and layer_name not in ignore:
shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name]
if proj_name in packed_modules_mapping and layer_name not in ignore:
shard_proj_names = packed_modules_mapping[proj_name]

# Convert fused_name --> [shard_names]
shard_names = [
Expand Down
11 changes: 6 additions & 5 deletions vllm/model_executor/layers/quantization/quark/quark.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
QuarkScheme, QuarkW8A8Fp8, QuarkW8A8Int8)
from vllm.model_executor.layers.quantization.quark.utils import (
deep_compare, should_ignore_layer)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
FUSED_LAYER_NAME_MAPPING)
from vllm.platforms import current_platform

__all__ = ["QuarkLinearMethod"]
Expand Down Expand Up @@ -56,7 +54,10 @@ def get_quant_method(self, layer: torch.nn.Module,

# Check if the layer is skipped for quantization.
exclude_layers = cast(List[str], self.quant_config.get("exclude"))
if should_ignore_layer(prefix, ignore=exclude_layers):
if should_ignore_layer(
prefix,
ignore=exclude_layers,
packed_modules_mapping=self.packed_modules_mapping):
return UnquantizedLinearMethod()
if isinstance(layer, LinearBase):
scheme = self.get_scheme(layer=layer, layer_name=prefix)
Expand Down Expand Up @@ -199,8 +200,8 @@ def _find_matched_config(self, layer_name: str,
module: torch.nn.Module) -> Dict[str, Any]:

proj_name = layer_name.split(".")[-1]
if proj_name in FUSED_LAYER_NAME_MAPPING:
shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name]
if proj_name in self.packed_modules_mapping:
shard_proj_names = self.packed_modules_mapping[proj_name]

# Convert fused_name --> [shard_names]
shard_names = [
Expand Down
17 changes: 9 additions & 8 deletions vllm/model_executor/layers/quantization/quark/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import re
from typing import Any, Iterable, Optional

from vllm.model_executor.layers.quantization.utils.quant_utils import (
FUSED_LAYER_NAME_MAPPING)
from types import MappingProxyType
from typing import Any, Iterable, List, Mapping, Optional


def deep_compare(dict1: Any, dict2: Any) -> bool:
Expand All @@ -18,8 +16,11 @@ def deep_compare(dict1: Any, dict2: Any) -> bool:
return dict1 == dict2


def should_ignore_layer(layer_name: Optional[str],
ignore: Iterable[str]) -> bool:
def should_ignore_layer(
layer_name: Optional[str],
ignore: Iterable[str],
packed_modules_mapping: Mapping[str, List[str]] = MappingProxyType({})
) -> bool:
if layer_name is None:
return False

Expand All @@ -31,8 +32,8 @@ def should_ignore_layer(layer_name: Optional[str],
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
if proj_name in FUSED_LAYER_NAME_MAPPING:
shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name]
if proj_name in packed_modules_mapping:
shard_proj_names = packed_modules_mapping[proj_name]

# Convert fused_name --> [shard_names]
shard_names = [
Expand Down
21 changes: 9 additions & 12 deletions vllm/model_executor/layers/quantization/utils/quant_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""This file is used for /tests and /benchmarks"""
from typing import List, Optional
from types import MappingProxyType
from typing import List, Mapping, Optional

import numpy
import torch
Expand All @@ -11,14 +12,6 @@
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]

# Note: this is a hack. We should update each model to register the
# stacked params and get it from there instead in a future PR.
# fused_name: List[shard_name]
FUSED_LAYER_NAME_MAPPING = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}


def pack_quantized_values_into_int32(w_q: torch.Tensor,
wtype: ScalarType,
Expand Down Expand Up @@ -63,14 +56,18 @@ def unpack_quantized_values_into_int32(w_q: torch.Tensor,
return res.permute(inv_perm)


def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
def is_layer_skipped(
prefix: str,
ignored_layers: List[str],
packed_modules_mapping: Mapping[str, List[str]] = MappingProxyType({})
) -> bool:
# prefix: model.layers.0.self_attn.q_proj
# proj_name: q_proj
proj_name = prefix.split(".")[-1]
if proj_name in FUSED_LAYER_NAME_MAPPING:
if proj_name in packed_modules_mapping:
shard_prefixes = [
prefix.replace(proj_name, shard_proj_name)
for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name]
for shard_proj_name in packed_modules_mapping[proj_name]
]

is_skipped = None
Expand Down
5 changes: 5 additions & 0 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ def _initialize_model(
model_config = vllm_config.model_config
model_class, _ = get_model_architecture(model_config)

# pass packed_modules_mapping by reference to quant_config
packed_mapping = getattr(model_class, "packed_modules_mapping", None)
if packed_mapping is not None and vllm_config.quant_config is not None:
vllm_config.quant_config.packed_modules_mapping = packed_mapping

signatures = inspect.signature(model_class.__init__)
all_params = [param.name for param in signatures.parameters.values()]
if "vllm_config" in all_params and "prefix" in all_params:
Expand Down
24 changes: 19 additions & 5 deletions vllm/model_executor/models/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,12 +263,14 @@ def __init__(
self.total_num_kv_heads,
bias=config.add_bias_linear or config.add_qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.query_key_value",
)
self.dense = RowParallelLinear(
self.total_num_heads * self.head_dim,
config.hidden_size,
bias=config.add_bias_linear,
quant_config=quant_config,
prefix=f"{prefix}.dense",
)

# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
Expand Down Expand Up @@ -325,6 +327,7 @@ def __init__(
self,
config: ChatGLMConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()

Expand All @@ -336,6 +339,7 @@ def __init__(
[config.ffn_hidden_size] * 2,
bias=config.add_bias_linear,
quant_config=quant_config,
prefix=f"{prefix}.dense_h_to_4h",
)

self.activation_func = SiluAndMul()
Expand All @@ -346,6 +350,7 @@ def __init__(
config.hidden_size,
bias=config.add_bias_linear,
quant_config=quant_config,
prefix=f"{prefix}.dense_4h_to_h",
)

def forward(self, hidden_states):
Expand Down Expand Up @@ -394,7 +399,7 @@ def __init__(
config.hidden_size, eps=config.layernorm_epsilon)

# MLP
self.mlp = GLMMLP(config, quant_config)
self.mlp = GLMMLP(config, quant_config, prefix=f"{prefix}.mlp")

def forward(
self,
Expand Down Expand Up @@ -505,7 +510,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
config.hidden_size,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.embedding")

self.num_layers = config.num_layers
self.multi_query_group_num = config.multi_query_group_num
Expand Down Expand Up @@ -764,6 +770,7 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
SupportsMultiModal):
# Ensure that the LoRA support check passes when the class is not
# initialized, but set all these attributes to empty.
# These will be updated when an instance class is selected
packed_modules_mapping = {}
supported_lora_modules = []
embedding_modules = {}
Expand All @@ -775,9 +782,16 @@ def __new__(
prefix: str = "",
) -> None:
config = vllm_config.model_config.hf_config

# Initialize VL
if hasattr(config, "vision_config"):
return ChatGLMV(vllm_config=vllm_config, prefix=prefix)
if hasattr(config, "vision_config"): # noqa: SIM108
instance_cls = ChatGLMV
# Initialize LLM
else:
return ChatGLM(vllm_config=vllm_config, prefix=prefix)
instance_cls = ChatGLM

cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping)
cls.supported_lora_modules += instance_cls.supported_lora_modules
cls.embedding_modules.update(instance_cls.embedding_modules)
cls.embedding_padding_modules += instance_cls.embedding_padding_modules
return instance_cls(vllm_config=vllm_config, prefix=prefix)
31 changes: 22 additions & 9 deletions vllm/model_executor/models/glm4_vision_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,13 @@ def __init__(
self.head_dim,
config.num_heads,
quant_config=quant_config,
prefix=f"{prefix}.query_key_value",
)
self.dense = RowParallelLinear(
config.hidden_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.dense",
)

self.attn = MultiHeadAttention(self.num_heads_per_rank, self.head_dim,
Expand All @@ -99,6 +101,7 @@ def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = '',
):
super().__init__()
self.config = config
Expand All @@ -107,11 +110,13 @@ def __init__(
config.hidden_size,
config.intermediate_size,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
)
self.fc2 = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand All @@ -135,7 +140,9 @@ def __init__(
self.attention = Attention(config,
quant_config=quant_config,
prefix=f"{prefix}.attention")
self.mlp = MLP(config, quant_config=quant_config)
self.mlp = MLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.post_attention_layernorm = LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)

Expand All @@ -162,7 +169,7 @@ def __init__(
self.layers = nn.ModuleList([
TransformerLayer(config,
quant_config=quant_config,
prefix=f"{prefix}.layer.{layer_idx}")
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(config.num_hidden_layers)
])

Expand All @@ -179,6 +186,7 @@ def __init__(
config,
in_features,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = '',
):
"""
The original implementation is the same as:
Expand Down Expand Up @@ -220,20 +228,24 @@ def __init__(
self.linear_proj = ReplicatedLinear(in_features,
config.hidden_size,
bias=False,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.linear_proj")
self.norm1 = nn.LayerNorm(config.hidden_size)
self.act1 = nn.GELU()
self.act2 = SiluAndMul()

self.merged_proj = MergedColumnParallelLinear(
config.hidden_size, [config.ffn_hidden_size] * 2,
bias=False,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.merged_proj")

self.dense_4h_to_h = RowParallelLinear(config.ffn_hidden_size,
config.hidden_size,
bias=False,
quant_config=quant_config)
self.dense_4h_to_h = RowParallelLinear(
config.ffn_hidden_size,
config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.dense_4h_to_h")

def forward(self, x):
x, _ = self.linear_proj(x)
Expand All @@ -260,7 +272,8 @@ def __init__(
prefix=f"{prefix}.transformer")
self.linear_proj = GLU(config,
in_features=config.hidden_size,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.linear_proj")
self.conv = nn.Conv2d(in_channels=vision_config.hidden_size,
out_channels=config.hidden_size,
kernel_size=2,
Expand Down
Loading