-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[Bugfix] Fix quantization skip modules logic #13562
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b33cc2c
b1e18ba
3f73e8c
dd2021f
8883c72
61e3041
6ed5282
4230e6b
856bea1
8d2badd
380ce5f
708b413
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,12 +11,16 @@ | |
|
||
from vllm.config import ModelConfig, ModelImpl | ||
from vllm.logger import init_logger | ||
from vllm.model_executor.layers.quantization.awq import AWQConfig | ||
from vllm.model_executor.layers.quantization.base_config import ( | ||
QuantizationConfig) | ||
from vllm.model_executor.layers.quantization.bitsandbytes import ( | ||
BitsAndBytesConfig) | ||
from vllm.model_executor.models import ModelRegistry | ||
from vllm.model_executor.models.adapters import (as_classification_model, | ||
as_embedding_model, | ||
as_reward_model) | ||
from vllm.model_executor.models.utils import WeightsMapper | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
@@ -153,19 +157,57 @@ def get_sub_modules(self, | |
|
||
def configure_quant_config(quant_config: QuantizationConfig, | ||
model_class: Type[nn.Module]): | ||
""" | ||
Pass packed_modules_mapping by reference to quant_config so that | ||
quant_config can properly match fused modules | ||
|
||
Note that model attributes are passed by reference to quant_config, | ||
enabling them to be updated by model_class.__new__ (ex. chatglm, qwen) | ||
""" | ||
packed_mapping = getattr(model_class, "packed_modules_mapping", None) | ||
if packed_mapping is not None: | ||
# pass packed_modules_mapping by reference to quant_config | ||
quant_config.packed_modules_mapping = packed_mapping | ||
else: | ||
logger.warning( | ||
"The model class %s has not defined `packed_modules_mapping`, " | ||
"this may lead to incorrect mapping of quantized or ignored " | ||
"modules", model_class.__name__) | ||
def _configure_packed_modules_mapping(): | ||
""" | ||
Pass packed_modules_mapping by reference to quant_config so that | ||
quant_config can properly match fused modules | ||
|
||
Note that model attributes are passed by reference to quant_config, | ||
enabling them to be updated by model_class.__new__ (ex. chatglm, qwen) | ||
""" | ||
packed_mapping = getattr(model_class, "packed_modules_mapping", None) | ||
if packed_mapping is not None: | ||
# pass packed_modules_mapping by reference to quant_config | ||
quant_config.packed_modules_mapping = packed_mapping | ||
else: | ||
logger.warning( | ||
"The model class %s has not defined `packed_modules_mapping`, " | ||
"this may lead to incorrect mapping of quantized or ignored " | ||
"modules", model_class.__name__) | ||
|
||
def _configure_quant_skip_modules(): | ||
""" | ||
Configures the quantization skip modules for the model based on the | ||
provided quantization configuration. | ||
This function checks if the model class has a `hf_to_vllm_mapper` | ||
attribute. If it does, it uses this mapper to update the list of | ||
modules to be skip for different quantization. | ||
configurations. | ||
- For `BitsAndBytesConfig`, it updates the `llm_int8_skip_modules`. | ||
- For `AWQConfig`, it updates the `modules_to_not_convert`. | ||
|
||
""" | ||
|
||
if getattr(model_class, "hf_to_vllm_mapper", None) is None: | ||
return | ||
hf_to_vllm_mapper: WeightsMapper = model_class.hf_to_vllm_mapper | ||
|
||
# BitsAndBytes | ||
if (isinstance(quant_config, BitsAndBytesConfig) | ||
and quant_config.llm_int8_skip_modules): | ||
quant_config.llm_int8_skip_modules = [ | ||
hf_to_vllm_mapper._map_name(module) | ||
for module in quant_config.llm_int8_skip_modules | ||
] | ||
# AWQ | ||
elif (isinstance(quant_config, AWQConfig) | ||
and quant_config.modules_to_not_convert): | ||
quant_config.modules_to_not_convert = [ | ||
hf_to_vllm_mapper._map_name(module) | ||
for module in quant_config.modules_to_not_convert | ||
] | ||
# TODO: Supports more quantization types. | ||
Comment on lines
+196
to
+210
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we should introduce a common Then each quant config can convert their specific There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd support an implementation like this as well. This current implementation could fail to properly map module names in nested models.
This is a fairly minor issue, but something to keep in mind. Another implementation could look like this:
This has the advantage of further standardizing around the QuantizationConfig base, as well as supporting mapping with nested models There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jeejeelee Here's a WIP of what that might look like: #14635 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @kylesayrs Can you provide an example? |
||
|
||
_configure_packed_modules_mapping() | ||
_configure_quant_skip_modules() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this needed after we added
SupportsQuant
(#13104), I thought getting the packed_modules_mapping from the model to the quant config was the main purpose of that. cc @kylesayrsThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
_configure_packed_modules_mapping
function needs to remain in place untilSupportsQuant
has been added to all applicable models