Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
d45f2be
extend quark to support mixed-precision quantization model
xuebwang-amd Sep 4, 2025
d984821
use an environment variable to support mxfp4 quantize regardless of p…
xuebwang-amd Sep 4, 2025
8976bbb
add a test for quark mixed precision models
xuebwang-amd Sep 4, 2025
4ed7a68
fix pre-commit issues
xuebwang-amd Sep 4, 2025
b217e06
add one section about mixed-precision usage in the Quark document
xuebwang-amd Sep 4, 2025
5f4b012
tiny update AMP document
xuebwang-amd Sep 5, 2025
24a0203
refactor test script and add a new model
xuebwang-amd Sep 5, 2025
55742f9
simplify layer_quant_configs matching
xuebwang-amd Sep 5, 2025
4ba49d3
update AMP section in the Quark document
xuebwang-amd Sep 5, 2025
f5dfdd2
Merge branch 'main' into xuebin/upstream_amd_quark_layerwise_mixed_pr…
xuebwang-amd Sep 15, 2025
8fcb1ad
Update docs/features/quantization/quark.md
xuebwang-amd Sep 15, 2025
d58e162
Update docs/features/quantization/quark.md
xuebwang-amd Sep 15, 2025
7072b0a
Update docs/features/quantization/quark.md
xuebwang-amd Sep 15, 2025
755d214
Update docs/features/quantization/quark.md
xuebwang-amd Sep 15, 2025
3a59917
Update docs/features/quantization/quark.md
xuebwang-amd Sep 15, 2025
b6aea4e
Update docs/features/quantization/quark.md
xuebwang-amd Sep 15, 2025
356e03c
Update docs/features/quantization/quark.md
xuebwang-amd Sep 15, 2025
89243f3
remove VLLM_QUARK_EMU_MEM_OPT and use_v0
xuebwang-amd Sep 17, 2025
1e6f959
correct and simplify layer_quant_config parsing in QuarkConfig
xuebwang-amd Sep 17, 2025
9475678
update one pre-commit issue
xuebwang-amd Sep 17, 2025
875d83a
fix markdownlint issue
xuebwang-amd Sep 18, 2025
812db0f
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Sep 18, 2025
7f9db09
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Sep 25, 2025
deb645d
update excepted accuracy numbers since the amp models in hf have been…
xuebwang-amd Sep 25, 2025
5ad31a0
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Sep 26, 2025
1fddee9
reduce test models to be one
xuebwang-amd Sep 26, 2025
00460e7
remove HF_HUB_AMD_ORG_ACCESS since model is public
xuebwang-amd Sep 28, 2025
6819a72
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Sep 28, 2025
64f3071
fix pre-commit issue in test_mixed_precision.py
xuebwang-amd Sep 29, 2025
3e98843
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Sep 29, 2025
9022e5b
merge main into current branch with fixing conflictions
xuebwang-amd Oct 9, 2025
ef93b22
update with fixing conflictions
xuebwang-amd Oct 9, 2025
0b15d87
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Oct 9, 2025
2f0c1cb
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Oct 9, 2025
f9704de
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Oct 9, 2025
8d80088
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Oct 9, 2025
d34e23e
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Oct 10, 2025
575caf2
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Oct 10, 2025
bd59e51
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Oct 10, 2025
9ef2b8f
Resolved merge conflicts
xuebwang-amd Oct 13, 2025
1c2c4d5
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Oct 13, 2025
20b23dc
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Oct 14, 2025
9256a8e
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Oct 14, 2025
6f8294c
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Oct 14, 2025
8c719b3
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Oct 14, 2025
3fac262
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Oct 14, 2025
d4c7cc9
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Oct 15, 2025
350cc15
add a non-mixed-precision (PTQ) model as a reference for pipeline com…
xuebwang-amd Oct 17, 2025
db3cc7e
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Oct 17, 2025
8762dff
use a quark_format model as reference
xuebwang-amd Oct 18, 2025
687c2cd
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Oct 18, 2025
decd8d7
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Oct 18, 2025
bb39f1d
Merge remote-tracking branch 'origin/main' into xuebin/upstream_amd_q…
xuebwang-amd Oct 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion docs/features/quantization/quark.md
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -281,4 +281,36 @@ python quantize_quark.py --model_dir Qwen/Qwen1.5-MoE-A2.7B-Chat \
--group_size 32
```

The current integration supports [all combination of FP4, FP6_E3M2, FP6_E2M3](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/ocp_mx_utils.py) used for either weights or activations. Eventually, some target hardware support mixed precision GEMM, as AMD Instinct MI350/MI355, for example using FP6 for activations and FP4 for weights.
The current integration supports [all combination of FP4, FP6_E3M2, FP6_E2M3](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/ocp_mx_utils.py) used for either weights or activations.

## Using Quark Quantized layerwise Auto Mixed Precision (AMP) Models

vLLM also supports loading layerwise mixed precision model quantized using AMD Quark. Currently, mixed scheme of {MXFP4, FP8} is supported, where FP8 here denotes for FP8 per-tensor scheme. More mixed precision schemes are planned to be supported in a near future, including

- Unquantized Linear and/or MoE layer(s) as an option for each layer, i.e., mixed of {MXFP4, FP8, BF16/FP16}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- Unquantized Linear and/or MoE layer(s) as an option for each layer, i.e., mixed of {MXFP4, FP8, BF16/FP16}
- Unquantized Linear and/or MoE layer(s) for layers too sensitive to be quantized

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better not to draw conclusion or provide guiding descriptions about why layers are quantized or not quantized, they're searched.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just trying to make the doc less verbose

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make it accurate.

- MXFP6 quantization extension, i.e., {MXFP4, MXFP6, FP8, BF16/FP16}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- MXFP6 quantization extension, i.e., {MXFP4, MXFP6, FP8, BF16/FP16}
- MXFP6 quantization

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MXFP6 only would be misleading as single scheme (bitwidth) quantization.
In mixed-precision, associated terminology is so-called "bitwidth candidates" which is a kind of set in mathematical concept. That's why we need to provide full possible schemes (bitwidth) like {MXFP4, MXFP6, FP8, BF16/FP16}, rather than a single one.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just trying to make the doc less verbose (mxfp4, bf16, fp8 usability is already implied above)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make it accurate. {MXFP4, MXFP6, FP8, BF16/FP16} is a whole, a base unit for mixed-precision.


Although one can maximize serving throughput using the lowest precision supported on a given device (e.g. MXFP4 for AMD Instinct MI355, FP8 for AMD Instinct MI300), these aggressive schemes can be detrimental to accuracy recovering from quantization on target tasks. Mixed precision allows to strike a balance between maximizing accuracy and throughput.

There are two steps to generate and deploy a mixed precision model quantized with AMD Quark, as shown below.

### 1. Quantize a model using mixed precision in AMD Quark

Firstly, the layerwise mixed-precision configuration for a given LLM model is searched and then quantized using AMD Quark. We will provide a detailed tutorial with Quark APIs later.

As examples, we provide some ready-to-use quantized mixed precision model to show the usage in vLLM and the accuracy benifits. They are:

- amd/Llama-2-70b-chat-hf-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8
- amd/Mixtral-8x7B-Instruct-v0.1-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8
- amd/Qwen3-8B-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8
Comment on lines +301 to +305
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make these public + add link

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They're going to be published.


### 2. inference the quantized mixed precision model in vLLM

Models quantized with AMD Quark using mixed precision can natively be reload in vLLM, and e.g. evaluated using lm-evaluation-harness as follow:

```bash
lm_eval --model vllm \
--model_args pretrained=amd/Llama-2-70b-chat-hf-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8,tensor_parallel_size=4,dtype=auto,gpu_memory_utilization=0.8,trust_remote_code=False \
--tasks mmlu \
--batch_size auto
```
69 changes: 69 additions & 0 deletions tests/quantization/test_mixed_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Test quark-quantized {MXFP4, FP8} mixed precision models.

Run `pytest tests/quantization/test_mixed_precision.py`.

"""

import importlib
import importlib.metadata
from dataclasses import dataclass

import lm_eval
import pytest
from packaging import version

QUARK_MXFP4_AVAILABLE = importlib.util.find_spec("quark") is not None and version.parse(
importlib.metadata.version("amd-quark")
) >= version.parse("0.8.99")


@dataclass
class ModelCase:
model_id: str
tp: int


@dataclass
class EvaluationConfig:
model_name: str

def get_model_args(self) -> str:
return (
f"pretrained={self.model_name},"
"tensor_parallel_size=4,dtype=auto,gpu_memory_utilization=0.8,trust_remote_code=False"
)


TEST_CONFIGS = {
# Mixed-precision (AMP) model
# - Demonstrates end-to-end pipeline functionality
"amd/Qwen3-8B-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8": {"arc_challenge": 0.52, "mmlu": 0.72},
# Non-mixed-precision (PTQ) model
# - Reference for pipeline compatibility verification -> No conflicts or breakings
"amd/Llama-2-70b-chat-hf-FP8-MLPerf-fp8_attn_quark_format": {
"arc_challenge": 0.53,
"mmlu": 0.61,
},
}


@pytest.mark.parametrize("model_name, accuracy_numbers", TEST_CONFIGS.items())
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
def test_mixed_precision_model_accuracies(model_name: str, accuracy_numbers: dict):
results = lm_eval.simple_evaluate(
model="vllm",
model_args=EvaluationConfig(model_name).get_model_args(),
tasks=list(accuracy_numbers.keys()),
batch_size=8,
)

rtol = 0.05

for task, expect_accuracy in accuracy_numbers.items():
measured_accuracy = results["results"][task]["acc,none"]
assert (
measured_accuracy - rtol < expect_accuracy
and measured_accuracy + rtol > expect_accuracy
), f"Expected: {expect_accuracy} | Measured: {measured_accuracy}"
Empty file modified vllm/model_executor/layers/fused_moe/utils.py
100644 → 100755
Empty file.
26 changes: 19 additions & 7 deletions vllm/model_executor/layers/quantization/quark/quark.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,14 @@ def from_config(cls, config: dict[str, Any]) -> "QuarkConfig":
layer_quant_names = list(layer_quant_config.keys())
layer_quant_set = set(layer_quant_names)

if not kv_cache_set.issubset(layer_quant_set):
if not (
kv_cache_set.issubset(layer_quant_set)
or any(
fnmatch.fnmatchcase(layer_quant, pat)
for layer_quant in list(layer_quant_set)
for pat in list(kv_cache_set)
)
):
raise ValueError(
"The Quark quantized model has the "
"kv_cache_group parameter setting, "
Expand All @@ -124,10 +131,15 @@ def from_config(cls, config: dict[str, Any]) -> "QuarkConfig":
)

q_configs = [
cast(dict[str, Any], layer_quant_config.get(name))
for name in kv_cache_group
quant_cfg
for name, quant_cfg in layer_quant_config.items()
if any(fnmatch.fnmatchcase(name, pattern) for pattern in kv_cache_group)
]
if not all(deep_compare(q_config, q_configs[0]) for q_config in q_configs):

if not all(
deep_compare(q_config["output_tensors"], q_configs[0]["output_tensors"])
for q_config in q_configs
):
raise ValueError(
"The quantization method used for kv_cache should "
"be the same, but the quantization method for the "
Expand Down Expand Up @@ -312,9 +324,9 @@ def _find_matched_config(
layer_quant_config = cast(
dict[str, Any], self.quant_config.get("layer_quant_config")
)
for name_pattern in layer_quant_config:
if fnmatch.fnmatch(layer_name, name_pattern):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change necessary? Also layer_quant_configs seem unused: appends the first matched config and immediately returns it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update as also suggested #24239 (comment)

Comment on lines -315 to -316
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am confused. What is this PR changing here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The gemini-code-assist had also similar questions. Please see my comments above, e.g., #24239 (comment)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for referencing our previous discussion, @xuebwang-amd. I'd like to clarify the change in behavior introduced by replacing fnmatch.fnmatch(layer_name, name_pattern) with layer_name in name_pattern.

Key Difference:

  • fnmatch.fnmatch(layer_name, name_pattern): This function performs shell-style wildcard matching. For example, if name_pattern is "model.layers.*.attn.q_proj", it would match "model.layers.0.attn.q_proj", "model.layers.1.attn.q_proj", etc.
  • layer_name in name_pattern: This performs a simple substring check. For "model.layers.0.attn.q_proj" in "model.layers.*.attn.q_proj", it would return False because "model.layers.0.attn.q_proj" is not a substring of "model.layers.*.attn.q_proj".

This change fundamentally alters how layer names are matched against the layer_quant_config. If the intention is to support wildcard patterns for layer quantization (which is a common practice for applying the same quantization scheme to multiple layers), then the current change would likely break existing configurations or limit the flexibility of defining quantization patterns.

Unless there's a specific reason to remove wildcard matching, I recommend reverting to fnmatch.fnmatch to maintain the expected pattern matching behavior. The iteration over layer_quant_config.items() is correct for accessing the config value, but the matching logic should be adjusted.

Suggested change
for name_pattern in layer_quant_config:
if fnmatch.fnmatch(layer_name, name_pattern):
for name_pattern, config in layer_quant_config.items():
if fnmatch.fnmatch(layer_name, name_pattern):
return config

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gemini code bot is not useful here.

@xuebwang-amd I don't understand why this PR introduces an handling different than e.g.
https://github.com/amd/Quark/blob/release/0.9/quark/torch/export/main_export/quant_config_parser.py#L67-L70
and e.g.
https://github.com/amd/Quark/blob/release/0.9/quark/torch/quantization/model_transformation.py#L80-L84

why would the handling in vllm be different than we have in quark, e.g. when reloading models through Transformers library? I think it is not a good thing. Maybe existing models rely on fnmatch.fnmatch and things would break now.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There have been lots of discussions about it in this PR.
To emphasize here is, this is for AMP in which layers are specified one by one, so name_pattern in layer_quant_config works in a strict matching way while fnmatch.fnmatch doesn't fit here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why it is okay to do the change here but not with Transformers backend (reloading quark models through transformers). Or maybe I misunderstand something

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see so you want precise match to take precedence over wildcard matching.

I'd suggest keeping the wildcard matching logic after your exact match loop. Otherwise, it looks like the new code won't match with wildcard anymore for non-mixed-precision models.

Copy link
Author

@xuebwang-amd xuebwang-amd Oct 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can fully understand your concern here. Please find my explanations above like:

#24239 (comment)
#24239 (comment)
#24239 (comment)

To ensure no breaking or confliction to existed PTQ model matching, I add a a non-mixed-precision (PTQ, public) model as a reference to demonstrate pipeline compatibility in the tests/quantization/test_mixed_precision.py https://github.com/xuebwang-amd/vllm/blob/db3cc7eba1609370e34b35f51c7a5fa3111bb868/tests/quantization/test_mixed_precision.py#L45

Conclusion is: no conflicts or breakings using precise substring containment matching rule.

return layer_quant_config[name_pattern]
for name_pattern, config in layer_quant_config.items():
if layer_name in name_pattern:
Comment on lines +327 to +328
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we make sure somewhere that e.g. q_proj from the checkpoint/Transformers gets correctly mapped to qkv_proj in vllm (https://github.com/ROCm/vllm/blob/eb9d4de9eb7649bdf36b2d0e4832fcaab8465153/vllm/model_executor/models/llama.py#L150) prior to doing the check layer_name in name_pattern?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question.
Quark model/config is nearly decoupled with vllm's model implementation. The q_proj, k_proj, v_proj are merged in vllm, while they are separated in Quark quantized model and configs. The q_proj, k_proj, v_proj are mandatorily to have same bitwidth, i.e., same quantization scheme in Quark's AMP so that the alignment is achieved.
Therefore, the matching of Quark's layerwise quant config is applied to q_proj, k_proj, v_proj individually.

return config

layer_type = cast(str, type(module))
layer_type_quant_config = cast(
Expand Down