Skip to content

Commit 09500f7

Browse files
authored
[Model] Add BNB quantization support for Mllama (#9720)
1 parent ef7865b commit 09500f7

File tree

3 files changed

+84
-12
lines changed

3 files changed

+84
-12
lines changed

vllm/model_executor/layers/quantization/bitsandbytes.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44

55
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
6+
UnquantizedLinearMethod,
67
set_weight_attrs)
78
from vllm.model_executor.layers.quantization.base_config import (
89
QuantizationConfig)
@@ -23,7 +24,7 @@ def __init__(
2324
bnb_4bit_use_double_quant: bool = False,
2425
llm_int8_enable_fp32_cpu_offload: bool = False,
2526
llm_int8_has_fp16_weight: bool = False,
26-
llm_int8_skip_modules: Optional[Any] = None,
27+
llm_int8_skip_modules: Optional[List[str]] = None,
2728
llm_int8_threshold: float = 0.0,
2829
) -> None:
2930

@@ -34,11 +35,15 @@ def __init__(
3435
self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant
3536
self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload
3637
self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight
37-
self.llm_int8_skip_modules = llm_int8_skip_modules
38+
self.llm_int8_skip_modules = llm_int8_skip_modules or []
3839
self.llm_int8_threshold = llm_int8_threshold
3940

4041
def __repr__(self) -> str:
41-
return "BitsAndBytesConfig"
42+
return (f"BitsAndBytesConfig(load_in_8bit={self.load_in_8bit}, "
43+
f"load_in_4bit={self.load_in_4bit}, "
44+
f"bnb_4bit_compute_dtype={self.bnb_4bit_compute_dtype}, "
45+
f"bnb_4bit_quant_type={self.bnb_4bit_quant_type}, "
46+
f"llm_int8_skip_modules={self.llm_int8_skip_modules})")
4247

4348
@classmethod
4449
def get_name(self) -> str:
@@ -102,15 +107,21 @@ def get_safe_value(config, keys, default_value=None):
102107
llm_int8_threshold=llm_int8_threshold)
103108

104109
def get_quant_method(self, layer: torch.nn.Module,
105-
prefix: str) -> Optional["BitsAndBytesLinearMethod"]:
110+
prefix: str) -> Optional["LinearMethodBase"]:
106111
if isinstance(layer, LinearBase):
112+
if is_layer_skipped_bnb(prefix, self.llm_int8_skip_modules):
113+
return UnquantizedLinearMethod()
107114
return BitsAndBytesLinearMethod(self)
108115
return None
109116

110117
def get_scaled_act_names(self) -> List[str]:
111118
return []
112119

113120

121+
def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: List[str]):
122+
return any(module_name in prefix for module_name in llm_int8_skip_modules)
123+
124+
114125
class BitsAndBytesLinearMethod(LinearMethodBase):
115126
"""Linear method for BitsAndBytes.
116127
@@ -211,6 +222,11 @@ def _apply_8bit_weight(
211222
from bitsandbytes import MatmulLtState, matmul
212223

213224
original_type = x.dtype
225+
original_shape = x.shape
226+
reshape_after_matmul = False
227+
if x.ndim > 2:
228+
x = x.reshape(-1, x.size(-1))
229+
reshape_after_matmul = True
214230
bf_x = x.to(torch.bfloat16)
215231

216232
qweight = layer.qweight
@@ -265,6 +281,9 @@ def _apply_8bit_weight(
265281

266282
out = out.to(original_type)
267283

284+
if reshape_after_matmul:
285+
out = out.view(*original_shape[:-1], out.size(-1))
286+
268287
if bias is not None:
269288
out += bias
270289

@@ -282,6 +301,11 @@ def _apply_4bit_weight(
282301
from bitsandbytes import matmul_4bit
283302

284303
original_type = x.dtype
304+
original_shape = x.shape
305+
reshape_after_matmul = False
306+
if x.ndim > 2:
307+
x = x.reshape(-1, x.size(-1))
308+
reshape_after_matmul = True
285309
bf_x = x.to(torch.bfloat16)
286310

287311
qweight = layer.qweight
@@ -310,6 +334,9 @@ def _apply_4bit_weight(
310334

311335
out = out.to(original_type)
312336

337+
if reshape_after_matmul:
338+
out = out.view(*original_shape[:-1], out.size(-1))
339+
313340
if bias is not None:
314341
out += bias
315342

vllm/model_executor/model_loader/loader.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -899,6 +899,19 @@ def _get_quantized_weights_iterator(
899899
return self._unquantized_generator(hf_weights_files, use_safetensors,
900900
quant_state_dict), quant_state_dict
901901

902+
def _is_8bit_weight_name(self, weight_name: str):
903+
quantized_suffix = {".scb", ".weight_format"}
904+
return any(weight_name.lower().endswith(suffix)
905+
for suffix in quantized_suffix)
906+
907+
def _is_4bit_weight_name(self, weight_name: str):
908+
quantized_suffix = {
909+
"absmax", "quant_map", "nested_absmax", "nested_quant_map",
910+
"bitsandbytes"
911+
}
912+
suffix = weight_name.split(".")[-1]
913+
return any(q_suffix in suffix for q_suffix in quantized_suffix)
914+
902915
def _quantized_8bit_generator(self, hf_weights_files, use_safetensors,
903916
quant_state_dict) -> Generator:
904917
for weight_name, weight_tensor in self._hf_weight_iter(
@@ -912,7 +925,7 @@ def _quantized_8bit_generator(self, hf_weights_files, use_safetensors,
912925
for weight_name, weight_tensor in self._hf_weight_iter(
913926
hf_weights_files, use_safetensors):
914927

915-
if not weight_name.endswith((".weight", ".bias")):
928+
if self._is_8bit_weight_name(weight_name):
916929
continue
917930

918931
qweight_name = weight_name.replace(".weight", ".qweight")
@@ -932,7 +945,7 @@ def _quantized_4bit_generator(self, hf_weights_files, use_safetensors,
932945
use_safetensors)
933946
temp_state_dict = {}
934947
for weight_name, weight_tensor in weight_iterator:
935-
if weight_name.endswith((".weight", ".bias")):
948+
if not self._is_4bit_weight_name(weight_name):
936949
continue
937950
# bitsandbytes library requires
938951
# weight.quant_state.bitsandbytes__* in CPU
@@ -956,7 +969,7 @@ def _parse_quant_state(param_name: str,
956969
for weight_name, weight_tensor in self._hf_weight_iter(
957970
hf_weights_files, use_safetensors):
958971

959-
if not weight_name.endswith((".weight", ".bias")):
972+
if self._is_4bit_weight_name(weight_name):
960973
continue
961974

962975
if (f"{weight_name}.quant_state.bitsandbytes__nf4" \

vllm/model_executor/models/mllama.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,10 @@ def forward(self, hidden_state: torch.Tensor,
325325
# TODO: support other attention backends for attention in vision model
326326
class MllamaVisionSdpaAttention(nn.Module):
327327

328-
def __init__(self, config: config_mllama.MllamaVisionConfig):
328+
def __init__(self,
329+
config: config_mllama.MllamaVisionConfig,
330+
quant_config: Optional[QuantizationConfig] = None,
331+
prefix: str = ""):
329332
super().__init__()
330333

331334
model_parallel_size = get_tensor_model_parallel_world_size()
@@ -341,12 +344,16 @@ def __init__(self, config: config_mllama.MllamaVisionConfig):
341344
self.head_dim,
342345
self.num_heads,
343346
bias=False,
347+
quant_config=quant_config,
348+
prefix=f"{prefix}.qkv_proj",
344349
)
345350
self.o_proj = RowParallelLinear(
346351
self.num_heads * self.head_dim,
347352
self.embed_dim,
348353
bias=False,
349354
input_is_parallel=True,
355+
quant_config=quant_config,
356+
prefix=f"{prefix}.o_proj",
350357
)
351358

352359
def forward(
@@ -393,7 +400,8 @@ def __init__(
393400
self.is_gated = is_gated
394401
self.intermediate_size = config.intermediate_size
395402

396-
self.self_attn = MllamaVisionSdpaAttention(config)
403+
self.self_attn = MllamaVisionSdpaAttention(
404+
config, quant_config=quant_config, prefix=f"{prefix}.self_attn")
397405
self.mlp = CLIPMLP(config,
398406
quant_config=quant_config,
399407
prefix=f"{prefix}.mlp")
@@ -1002,6 +1010,7 @@ def __init__(
10021010
org_num_embeddings=config.vocab_size,
10031011
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
10041012
quant_config=quant_config,
1013+
prefix=f"{prefix}.lm_head",
10051014
)
10061015

10071016
def forward(
@@ -1037,6 +1046,26 @@ def forward(
10371046
@INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_mllama)
10381047
@INPUT_REGISTRY.register_input_processor(input_processor_for_mllama)
10391048
class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
1049+
# BitandBytes specific attributes
1050+
default_bitsandbytes_target_modules = [
1051+
".gate_proj.",
1052+
".down_proj.",
1053+
".up_proj.",
1054+
".q_proj.",
1055+
".k_proj.",
1056+
".v_proj.",
1057+
".o_proj.",
1058+
]
1059+
# in TP, these weights are partitioned along the column dimension (dim=-1)
1060+
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
1061+
bitsandbytes_stacked_params_mapping = {
1062+
# shard_name, weight_name, index
1063+
"q_proj": ("qkv_proj", 0),
1064+
"k_proj": ("qkv_proj", 1),
1065+
"v_proj": ("qkv_proj", 2),
1066+
"gate_proj": ("gate_up_proj", 0),
1067+
"up_proj": ("gate_up_proj", 1),
1068+
}
10401069

10411070
def __init__(self,
10421071
config: config_mllama.MllamaConfig,
@@ -1061,10 +1090,13 @@ def __init__(self,
10611090
quant_config=quant_config,
10621091
prefix="language_model",
10631092
)
1064-
self.multi_modal_projector = nn.Linear(
1093+
self.multi_modal_projector = ColumnParallelLinear(
10651094
config.vision_config.vision_output_dim,
10661095
config.text_config.hidden_size,
10671096
bias=True,
1097+
quant_config=quant_config,
1098+
gather_output=True,
1099+
prefix="multi_modal_projector",
10681100
)
10691101
self.logits_processor = LogitsProcessor(config.output_hidden_states,
10701102
config.text_config.vocab_size)
@@ -1128,7 +1160,7 @@ def _parse_and_validate_image_input(self, **kwargs: object):
11281160
raise ValueError("No images provided.")
11291161
max_num_tiles = max(
11301162
max([len(x) for x in y[0]]) for y in pixel_values)
1131-
device = self.multi_modal_projector.weight.device
1163+
device = next(self.multi_modal_projector.parameters()).device
11321164
bsz = len(pixel_values)
11331165
out_num_tiles = []
11341166
out_images = torch.zeros(
@@ -1204,7 +1236,7 @@ def get_cross_attention_states(
12041236
cross_attention_states = self.vision_model(pixel_values,
12051237
aspect_ratio_ids,
12061238
aspect_ratio_mask)
1207-
cross_attention_states = self.multi_modal_projector(
1239+
cross_attention_states, _ = self.multi_modal_projector(
12081240
cross_attention_states)
12091241

12101242
bsz, _, _, _, image_token_dim = tuple(cross_attention_states.shape)

0 commit comments

Comments
 (0)