Skip to content

Conversation

mgoin
Copy link
Member

@mgoin mgoin commented May 12, 2025

FIX #16850 (comment)

Addresses issues with both the FP8 W8A8 and FP8 Marlin path.

Marlin needed to allow for the layer to not have weight_block_sizes present, and both paths needed to update their ignored modules check to work with fused modules.

Manually verified evals:

FP8 W8A8

WARNING 05-12 16:46:30 [marlin_utils_fp8.py:81] Your GPU does not have native support for FP8 computation but FP8 quantization is being used. Weight-only FP8 compression will be used leveraging the Marlin kernel. This may degrade performance for compute-heavy workloads.
vllm (pretrained=nm-testing/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform,enforce_eager=True,max_model_len=2048,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7513|±  |0.0119|
|     |       |strict-match    |     5|exact_match|↑  |0.7536|±  |0.0119|

FP8 Marlin

vllm (pretrained=nm-testing/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform,enforce_eager=True,max_model_len=2048,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7559|±  |0.0118|
|     |       |strict-match    |     5|exact_match|↑  |0.7574|±  |0.0118|

====

Original failure found in CI: https://buildkite.com/vllm/ci/builds/19754#0196bd80-d18b-4c8a-bf76-50e9a53c5a6c/43-353

ERROR 05-10 22:15:39 [multiproc_executor.py:487]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/quantization/fbgemm_fp8.py", line 144, in process_weights_after_loading
ERROR 05-10 22:15:39 [multiproc_executor.py:487]     prepare_fp8_layer_for_marlin(layer)
ERROR 05-10 22:15:39 [multiproc_executor.py:487]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py", line 122, in prepare_fp8_layer_for_marlin
ERROR 05-10 22:15:39 [multiproc_executor.py:487]     if layer.weight_block_size is None:
ERROR 05-10 22:15:39 [multiproc_executor.py:487]        ^^^^^^^^^^^^^^^^^^^^^^^
ERROR 05-10 22:15:39 [multiproc_executor.py:487]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1940, in __getattr__
ERROR 05-10 22:15:39 [multiproc_executor.py:487]     raise AttributeError(
ERROR 05-10 22:15:39 [multiproc_executor.py:487] AttributeError: 'QKVParallelLinear' object has no attribute 'weight_block_size'

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@jinzhen-lin
Copy link
Contributor

The second CI failure seems a different problem of compressed-tensors.

@robertgshaw2-redhat
Copy link
Collaborator

Thanks Michael

@mgoin
Copy link
Member Author

mgoin commented May 12, 2025

@jinzhen-lin This actually is not sufficient. The FBGEMM model runs but has no correctness

lm_eval --model vllm --model_args pretrained=nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform,enforce_eager=True,max_model_len=2048,tensor_parallel_size=4 --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto

vllm (pretrained=nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform,enforce_eager=True,max_model_len=2048,tensor_parallel_size=4,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |    0|±  |     0|
|     |       |strict-match    |     5|exact_match|↑  |    0|±  |     0|

@jinzhen-lin
Copy link
Contributor

@jinzhen-lin This actually is not sufficient. The FBGEMM model runs but has no correctness

lm_eval --model vllm --model_args pretrained=nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform,enforce_eager=True,max_model_len=2048,tensor_parallel_size=4 --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto

vllm (pretrained=nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform,enforce_eager=True,max_model_len=2048,tensor_parallel_size=4,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |    0|±  |     0|
|     |       |strict-match    |     5|exact_match|↑  |    0|±  |     0|

I would find the reason tomorrow.

@mgoin
Copy link
Member Author

mgoin commented May 12, 2025

I actually found I get the same result using FBGEMM models with or without marlin (i.e. W8A8). There is something broken in the base case

EDIT: This was due to incorrect skip module checks

@mgoin mgoin changed the title Fix FP8 Marlin for FBGEMM integration Fix FBGEMM integration May 12, 2025
@mgoin mgoin added this to the v0.9.0 milestone May 12, 2025
@robertgshaw2-redhat robertgshaw2-redhat enabled auto-merge (squash) May 12, 2025 16:55
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label May 12, 2025
@mgoin mgoin changed the title Fix FBGEMM integration [Bugfix] Fix FBGEMM integration May 12, 2025
@mgoin mgoin added the bug Something isn't working label May 12, 2025
@robertgshaw2-redhat robertgshaw2-redhat merged commit f065de4 into vllm-project:main May 12, 2025
97 checks passed
mawong-amd pushed a commit to ROCm/vllm that referenced this pull request May 14, 2025
zzzyq pushed a commit to zzzyq/vllm that referenced this pull request May 24, 2025
Signed-off-by: mgoin <[email protected]>
Signed-off-by: Yuqi Zhang <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants