Skip to content

Conversation

jeejeelee
Copy link
Collaborator

@jeejeelee jeejeelee commented Feb 19, 2025

Motivation

Some models, such as QWEN25-VL, have modified their layer hierarchy compared to their original transformers implementation. This change causes quantization's skip modules to become ineffective, leading to incorrect initialization of linear methods.

Reproduce code

import vllm
llm = vllm.LLM(
    "unsloth/Qwen2.5-VL-72B-Instruct-unsloth-bnb-4bit",
    max_model_len=3200,
    quantization="bitsandbytes",
    load_format="bitsandbytes",
    trust_remote_code=True,
)

TODO

  • Investigate other quantization method (e.g. AWQ)

  • Optimize the implementation logic

Signed-off-by: Jee Jee Li <[email protected]>
@jeejeelee jeejeelee marked this pull request as draft February 19, 2025 17:39
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@jeejeelee jeejeelee marked this pull request as ready for review March 5, 2025 12:49
@jeejeelee jeejeelee requested a review from mgoin March 5, 2025 12:49
Comment on lines +196 to +210
# BitsAndBytes
if (isinstance(quant_config, BitsAndBytesConfig)
and quant_config.llm_int8_skip_modules):
quant_config.llm_int8_skip_modules = [
hf_to_vllm_mapper._map_name(module)
for module in quant_config.llm_int8_skip_modules
]
# AWQ
elif (isinstance(quant_config, AWQConfig)
and quant_config.modules_to_not_convert):
quant_config.modules_to_not_convert = [
hf_to_vllm_mapper._map_name(module)
for module in quant_config.modules_to_not_convert
]
# TODO: Supports more quantization types.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should introduce a common ignored_modules or ignored_prefixes to QuantizationConfig like packed_modules_mapping https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/base_config.py#L60-L66

Then each quant config can convert their specific llm_int8_skip_modules, modules_to_not_convert, etc in a canonical format in ignored_modules. This will also allow us to generalize the is_layer_skipped function

Copy link
Contributor

@kylesayrs kylesayrs Mar 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd support an implementation like this as well. This current implementation could fail to properly map module names in nested models.

modules_to_not_convert = ["SubModel.A"]
SubModel.hf_to_vllm_mapper = Mapper(orig_to_new_prefix={"A": "B"})

Note that "SubModel.A" will not match because "SubModel.A" does not start with "A"

This is a fairly minor issue, but something to keep in mind.

Another implementation could look like this:

  1. Add a mutable ignored_modules attribute to QuantizationConfig
  2. At construction-time, using the method-specific constructor to populate the ignored_modules attribute from disk
  3. At initialize-time, within SupportsQuant, use the given model prefix and mapper to update the ignored_modules list with the proper model-specific mapping
    a. ignored_modules = [prefix + hf_to_vllm_mapper[module - prefix] for module in ignored_modules]

This has the advantage of further standardizing around the QuantizationConfig base, as well as supporting mapping with nested models

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jeejeelee Here's a WIP of what that might look like: #14635

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kylesayrs Can you provide an example?

Comment on lines +161 to +177
def _configure_packed_modules_mapping():
"""
Pass packed_modules_mapping by reference to quant_config so that
quant_config can properly match fused modules
Note that model attributes are passed by reference to quant_config,
enabling them to be updated by model_class.__new__ (ex. chatglm, qwen)
"""
packed_mapping = getattr(model_class, "packed_modules_mapping", None)
if packed_mapping is not None:
# pass packed_modules_mapping by reference to quant_config
quant_config.packed_modules_mapping = packed_mapping
else:
logger.warning(
"The model class %s has not defined `packed_modules_mapping`, "
"this may lead to incorrect mapping of quantized or ignored "
"modules", model_class.__name__)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed after we added SupportsQuant (#13104), I thought getting the packed_modules_mapping from the model to the quant config was the main purpose of that. cc @kylesayrs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _configure_packed_modules_mapping function needs to remain in place until SupportsQuant has been added to all applicable models

@jeejeelee
Copy link
Collaborator Author

Close due to #14635

@jeejeelee jeejeelee closed this Mar 13, 2025
@jeejeelee jeejeelee deleted the fix-quant-skip-modules branch March 14, 2025 01:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants