-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[ROCm][Quantization] extend AMD Quark to support mixed-precision quantized model #24239
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[ROCm][Quantization] extend AMD Quark to support mixed-precision quantized model #24239
Conversation
Signed-off-by: xuebwang-amd <[email protected]>
…latform supportness Signed-off-by: xuebwang-amd <[email protected]>
Signed-off-by: xuebwang-amd <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request extends Quark to support mixed-precision models, specifically for {MXFP4, FP8} schemes. The changes involve updating quantization configuration logic to handle mixed-precision setups and adding new tests to validate model accuracies. My review identified two high-severity issues. First, in the new test file, environment variables are not handled safely, which could lead to test state leakage. I've recommended using pytest.monkeypatch
for robust cleanup. Second, in the Quark configuration logic, a fragile substring check is used for matching layer names, which could result in applying incorrect quantization schemes. I've suggested a more robust pattern matching approach to ensure correctness. Addressing these issues will improve the reliability and correctness of the new mixed-precision quantization feature.
def test_mixed_precision_model_accuracies(config: EvaluationConfig, task: str): | ||
os.environ["VLLM_QUARK_EMU_MEM_OPT"] = "1" | ||
|
||
results = lm_eval.simple_evaluate(model="vllm", | ||
model_args=config.get_model_args(), | ||
tasks=task, | ||
batch_size="auto") | ||
|
||
rtol = 0.05 | ||
|
||
EXPECTED_VALUE = config.excepted_value | ||
measured_value = results["results"][task]["acc,none"] | ||
assert (measured_value - rtol < EXPECTED_VALUE | ||
and measured_value + rtol > EXPECTED_VALUE | ||
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" | ||
|
||
del os.environ["VLLM_QUARK_EMU_MEM_OPT"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Setting and deleting an environment variable directly using os.environ
can lead to state leakage between tests if an exception occurs before the del
statement. This can cause subsequent tests to fail or behave unexpectedly. It's safer to use pytest's monkeypatch
fixture, which is already used in this file, to manage environment variables. monkeypatch
ensures that the environment is restored to its original state after the test function completes, regardless of whether it passes or fails.
def test_mixed_precision_model_accuracies(config: EvaluationConfig, task: str): | |
os.environ["VLLM_QUARK_EMU_MEM_OPT"] = "1" | |
results = lm_eval.simple_evaluate(model="vllm", | |
model_args=config.get_model_args(), | |
tasks=task, | |
batch_size="auto") | |
rtol = 0.05 | |
EXPECTED_VALUE = config.excepted_value | |
measured_value = results["results"][task]["acc,none"] | |
assert (measured_value - rtol < EXPECTED_VALUE | |
and measured_value + rtol > EXPECTED_VALUE | |
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" | |
del os.environ["VLLM_QUARK_EMU_MEM_OPT"] | |
def test_mixed_precision_model_accuracies(config: EvaluationConfig, task: str, monkeypatch): | |
monkeypatch.setenv("VLLM_QUARK_EMU_MEM_OPT", "1") | |
results = lm_eval.simple_evaluate(model="vllm", | |
model_args=config.get_model_args(), | |
tasks=task, | |
batch_size="auto") | |
rtol = 0.05 | |
EXPECTED_VALUE = config.excepted_value | |
measured_value = results["results"][task]["acc,none"] | |
assert (measured_value - rtol < EXPECTED_VALUE | |
and measured_value + rtol > EXPECTED_VALUE | |
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated, along with several refactoring.
layer_quant_configs = list() | ||
for name_pattern in layer_quant_config: | ||
if fnmatch.fnmatch(layer_name, name_pattern): | ||
return layer_quant_config[name_pattern] | ||
if layer_name in name_pattern: | ||
layer_quant_configs.append( | ||
layer_quant_config[name_pattern]) | ||
return layer_quant_configs[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The matching logic if layer_name in name_pattern:
is likely incorrect and can lead to bugs. It performs a substring check, which can cause unintended matches (e.g., attention.q_proj
would incorrectly match a pattern for cross_attention.q_proj
). This is likely not the intended behavior for matching layer configurations and could lead to applying the wrong quantization scheme to a layer.
Given that fnmatch
is used elsewhere in this file for pattern matching, it seems the intention is to support glob patterns. If name_pattern
can be a comma-separated list of patterns, the logic should be updated to split the string and apply fnmatch
to each part. This ensures accurate matching of layer configurations and prevents applying the wrong quantization scheme.
The current implementation also unnecessarily creates a list layer_quant_configs
to immediately return its first element. This can be simplified by returning directly upon finding a match.
layer_quant_configs = list() | |
for name_pattern in layer_quant_config: | |
if fnmatch.fnmatch(layer_name, name_pattern): | |
return layer_quant_config[name_pattern] | |
if layer_name in name_pattern: | |
layer_quant_configs.append( | |
layer_quant_config[name_pattern]) | |
return layer_quant_configs[0] | |
for name_pattern in layer_quant_config: | |
patterns = [p.strip() for p in name_pattern.split(',')] | |
for p in patterns: | |
if fnmatch.fnmatch(layer_name, p): | |
return layer_quant_config[name_pattern] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code snippet suggest from gemini-code-assist is problematic. Because for name_pattern
, it looks like model.layers.0.block_sparse_moe.experts.0.w1
as an example. So name_pattern.split(',')
doesn't make sense and subsequent fnmatch.fnmatch
is also irrelevant.
Signed-off-by: xuebwang-amd <[email protected]>
Signed-off-by: xuebwang-amd <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, great start!
dict[str, Any], self.quant_config.get("layer_quant_config")) | ||
layer_quant_configs = list() | ||
for name_pattern in layer_quant_config: | ||
if fnmatch.fnmatch(layer_name, name_pattern): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this change necessary? Also layer_quant_configs
seem unused: appends the first matched config and immediately returns it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update as also suggested #24239 (comment)
) -> tuple[torch.Tensor, None]: | ||
assert block_shape is None | ||
if not current_platform.supports_mx(): | ||
VLLM_QUARK_EMU_MEM_OPT = (os.environ.get("VLLM_QUARK_EMU_MEM_OPT", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general for env flags it is better to add to vllm/vllm/envs.py
with comments on its effect.
Can you keep this change local? In particular we want to move away from simulation to triton kernels as we move forward. cc @fxmarty-amd
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Totally agree on that.
The reason why VLLM_QUARK_EMU_MEM_OPT
is not added into vllm/vllm/envs.py
is because it's better to make it as a local and temporal environment variable, just for make things work at this moment. After non-emulation kernels such as triton or aiter implementations are integrated, we can totally remove it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@xuebwang-amd this variable that I added previously has been removed as per @mgoin request in order to avoid adding new a new unnecessary env variable to vllm, especially given that we have a decently fast mxfp4 dequantization kernel.
Please avoid adding this environment variable, keep it local for testing if needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I appreciate your previous effort about this emulation approach, it played a role more than local test. The functionality goes on like what I'm doing here.
Actually, it indeed goes to the mx.qdq_mxfp4
defined in the https://github.com/vllm-project/vllm/blob/8de261b04a0a0e916d3d25d528d0f2ddeede2a6b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py#L94C5-L94C25 with enable the VLLM_QUARK_EMU_MEM_OPT=1
.
The real motivation of this environment variable is to let flow go to the emulation flow regardless of platform support of MX because the non-emulation kernels haven't been implemented into the flow.
Therefore, the solution here is to remove the if-else
statement:
if not current_platform.supports_mx(): A = quant_dequant_mxfp4(A) else: raise NotImplementedError()
and let it to be always A = quant_dequant_mxfp4(A)
.
layer_quant_set = set(layer_quant_names) | ||
|
||
if not kv_cache_set.issubset(layer_quant_set): | ||
if not (kv_cache_set.issubset(layer_quant_set) or \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain what is goal for these changes around kv cache?
For AMP models, are kv-caches still uniformly quantized the same way across all layers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, currently mixed precision is not applied on the KV cache dimension across all KV layers.
Changes here aim to correctly verify if the kv cache pattern such as {'*v_proj', '*k_proj'}
can match, in other words, can be found in at least one layer_quant_set
keys (i.e., layer names).
This is essential when going to AMP scenarios that layer_quant_names are specified one by one, rather than concentrating in a fuzzy matching way.
Signed-off-by: xuebwang-amd <[email protected]>
Signed-off-by: xuebwang-amd <[email protected]>
Signed-off-by: xuebwang-amd <[email protected]>
Signed-off-by: xuebwang-amd <[email protected]>
@pytest.fixture(scope="function", autouse=True) | ||
def use_v0_only(monkeypatch): | ||
""" | ||
This module relies on V0 internals, so set VLLM_USE_V1=0. | ||
""" | ||
monkeypatch.setenv('VLLM_USE_V1', '0') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's avoid using v0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For test purpose, especially for accuracy test, using V0 is safe. Even for hardware metric test later, using V0 is still safer while valuable for demonstrations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
vllm v0 is deprecated: #18571
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
V1 is reported to be having issues as you can see. Since mixed-precision quantization is not dependent on V0/V1 engine, it's safe to use V0.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use_v0_only
had been removed as the V0 backend is deprecated #25351 very recently. Thanks @fxmarty-amd
try: | ||
huggingface_hub.list_repo_refs( | ||
"amd/Llama-2-70b-chat-hf-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8") | ||
HF_HUB_AMD_ORG_ACCESS = True | ||
except huggingface_hub.errors.RepositoryNotFoundError: | ||
HF_HUB_AMD_ORG_ACCESS = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use public models
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These models are under progress for publish.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have an ETA for when we can expect these models to be published?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AMD's colleagues are speeding up the progress, hopefully they can make it happen some time next week.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@xuebwang-amd I meant that for unit testing you can probably use small models just for integration test purposes (as e.g. in
vllm/tests/kernels/moe/test_mxfp4_moe.py
Lines 51 to 55 in 58c360d
@pytest.mark.parametrize('model_case', [ | |
ModelCase("fxmarty/qwen_1.5-moe-a2.7b-mxfp4", tp=1), | |
ModelCase("fxmarty/deepseek_r1_3_layers_mxfp4", tp=8), | |
ModelCase("fxmarty/Llama-4-Scout-17B-16E-Instruct-2-layers-mxfp4", tp=1) | |
]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fxmarty-amd your motivation here is to reduce the CI time cost, that's good. We can consider pick up one public model into the CI test. @gshtras @SageMoore
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have an ETA for when we can expect these models to be published?
Eventually they are published.
reason="Read access to huggingface.co/amd is required for this test.") | ||
def test_mixed_precision_model_accuracies(model_name: str, | ||
accuracy_numbers: dict, monkeypatch): | ||
monkeypatch.setenv("VLLM_QUARK_EMU_MEM_OPT", "1") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This environment variable has no effect - it has been removed from vllm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then we need to remove the if-else
statement in the _mxfp4_quantize, as commented in above #24239 (comment)
) -> tuple[torch.Tensor, None]: | ||
assert block_shape is None | ||
if not current_platform.supports_mx(): | ||
VLLM_QUARK_EMU_MEM_OPT = (os.environ.get("VLLM_QUARK_EMU_MEM_OPT", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@xuebwang-amd this variable that I added previously has been removed as per @mgoin request in order to avoid adding new a new unnecessary env variable to vllm, especially given that we have a decently fast mxfp4 dequantization kernel.
Please avoid adding this environment variable, keep it local for testing if needed.
As examples, we provide some ready-to-use quantized mixed precision model to show the usage in vLLM and the accuracy benifits. They are: | ||
|
||
- amd/Llama-2-70b-chat-hf-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8 | ||
- amd/Mixtral-8x7B-Instruct-v0.1-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8 | ||
- amd/Qwen3-8B-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make these public + add link
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They're going to be published.
Can you provide:
|
Co-authored-by: fxmarty-amd <[email protected]> Signed-off-by: xuebwang-amd <[email protected]>
Co-authored-by: fxmarty-amd <[email protected]> Signed-off-by: xuebwang-amd <[email protected]>
Co-authored-by: fxmarty-amd <[email protected]> Signed-off-by: xuebwang-amd <[email protected]>
Co-authored-by: fxmarty-amd <[email protected]> Signed-off-by: xuebwang-amd <[email protected]>
Co-authored-by: fxmarty-amd <[email protected]> Signed-off-by: xuebwang-amd <[email protected]>
Co-authored-by: fxmarty-amd <[email protected]> Signed-off-by: xuebwang-amd <[email protected]>
Co-authored-by: fxmarty-amd <[email protected]> Signed-off-by: xuebwang-amd <[email protected]>
One can check the detailed layerwise MXFP8/FP8 configuration in the |
Signed-off-by: xuebwang-amd <[email protected]>
This pull request has merge conflicts that must be resolved before it can be |
3f92445
to
575caf2
Compare
…uark_layerwise_mixed_precision
I have make it reverted back to a good commit 575caf2, so no conflictions anymore. |
…uark_layerwise_mixed_precision
…uark_layerwise_mixed_precision
…uark_layerwise_mixed_precision
…uark_layerwise_mixed_precision
…uark_layerwise_mixed_precision
…uark_layerwise_mixed_precision
Purpose
This PR aims to support layerwise mixed-precision quantization model inference, extending from quantized models in single scheme such as MXFP4, FP8 (aka PTQ models).
Here, the layerwise mixed-precision configuration for a given model is searched and then quantized by amd-quark. Specifically, in this PR, we focus on mixed scheme of {MXFP4, FP8}. FP8 here denotes for FP8 per-tensor scheme.
With the mixed-precision quantized model, one could achieve an optimal balance between accuracy and hardware metrics.
To demonstrate the benefits of mixed-precision model in the PR, we show the model accuracies on several commonly used tasks only using Quark emulation kernel for MXFP4 and triton kernel for FP8.
Test Plan
Test on
Test Result
List of TODO items