diff --git a/docs/features/quantization/quark.md b/docs/features/quantization/quark.md old mode 100644 new mode 100755 index 385e3bbb8712..be0702f4c9e1 --- a/docs/features/quantization/quark.md +++ b/docs/features/quantization/quark.md @@ -281,4 +281,36 @@ python quantize_quark.py --model_dir Qwen/Qwen1.5-MoE-A2.7B-Chat \ --group_size 32 ``` -The current integration supports [all combination of FP4, FP6_E3M2, FP6_E2M3](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/ocp_mx_utils.py) used for either weights or activations. Eventually, some target hardware support mixed precision GEMM, as AMD Instinct MI350/MI355, for example using FP6 for activations and FP4 for weights. +The current integration supports [all combination of FP4, FP6_E3M2, FP6_E2M3](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/ocp_mx_utils.py) used for either weights or activations. + +## Using Quark Quantized layerwise Auto Mixed Precision (AMP) Models + +vLLM also supports loading layerwise mixed precision model quantized using AMD Quark. Currently, mixed scheme of {MXFP4, FP8} is supported, where FP8 here denotes for FP8 per-tensor scheme. More mixed precision schemes are planned to be supported in a near future, including + +- Unquantized Linear and/or MoE layer(s) as an option for each layer, i.e., mixed of {MXFP4, FP8, BF16/FP16} +- MXFP6 quantization extension, i.e., {MXFP4, MXFP6, FP8, BF16/FP16} + +Although one can maximize serving throughput using the lowest precision supported on a given device (e.g. MXFP4 for AMD Instinct MI355, FP8 for AMD Instinct MI300), these aggressive schemes can be detrimental to accuracy recovering from quantization on target tasks. Mixed precision allows to strike a balance between maximizing accuracy and throughput. + +There are two steps to generate and deploy a mixed precision model quantized with AMD Quark, as shown below. + +### 1. Quantize a model using mixed precision in AMD Quark + +Firstly, the layerwise mixed-precision configuration for a given LLM model is searched and then quantized using AMD Quark. We will provide a detailed tutorial with Quark APIs later. + +As examples, we provide some ready-to-use quantized mixed precision model to show the usage in vLLM and the accuracy benifits. They are: + +- amd/Llama-2-70b-chat-hf-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8 +- amd/Mixtral-8x7B-Instruct-v0.1-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8 +- amd/Qwen3-8B-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8 + +### 2. inference the quantized mixed precision model in vLLM + +Models quantized with AMD Quark using mixed precision can natively be reload in vLLM, and e.g. evaluated using lm-evaluation-harness as follow: + +```bash +lm_eval --model vllm \ + --model_args pretrained=amd/Llama-2-70b-chat-hf-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8,tensor_parallel_size=4,dtype=auto,gpu_memory_utilization=0.8,trust_remote_code=False \ + --tasks mmlu \ + --batch_size auto +``` diff --git a/tests/quantization/test_mixed_precision.py b/tests/quantization/test_mixed_precision.py new file mode 100755 index 000000000000..51526470b423 --- /dev/null +++ b/tests/quantization/test_mixed_precision.py @@ -0,0 +1,69 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Test quark-quantized {MXFP4, FP8} mixed precision models. + +Run `pytest tests/quantization/test_mixed_precision.py`. + +""" + +import importlib +import importlib.metadata +from dataclasses import dataclass + +import lm_eval +import pytest +from packaging import version + +QUARK_MXFP4_AVAILABLE = importlib.util.find_spec("quark") is not None and version.parse( + importlib.metadata.version("amd-quark") +) >= version.parse("0.8.99") + + +@dataclass +class ModelCase: + model_id: str + tp: int + + +@dataclass +class EvaluationConfig: + model_name: str + + def get_model_args(self) -> str: + return ( + f"pretrained={self.model_name}," + "tensor_parallel_size=4,dtype=auto,gpu_memory_utilization=0.8,trust_remote_code=False" + ) + + +TEST_CONFIGS = { + # Mixed-precision (AMP) model + # - Demonstrates end-to-end pipeline functionality + "amd/Qwen3-8B-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8": {"arc_challenge": 0.52, "mmlu": 0.72}, + # Non-mixed-precision (PTQ) model + # - Reference for pipeline compatibility verification -> No conflicts or breakings + "amd/Llama-2-70b-chat-hf-FP8-MLPerf-fp8_attn_quark_format": { + "arc_challenge": 0.53, + "mmlu": 0.61, + }, +} + + +@pytest.mark.parametrize("model_name, accuracy_numbers", TEST_CONFIGS.items()) +@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available") +def test_mixed_precision_model_accuracies(model_name: str, accuracy_numbers: dict): + results = lm_eval.simple_evaluate( + model="vllm", + model_args=EvaluationConfig(model_name).get_model_args(), + tasks=list(accuracy_numbers.keys()), + batch_size=8, + ) + + rtol = 0.05 + + for task, expect_accuracy in accuracy_numbers.items(): + measured_accuracy = results["results"][task]["acc,none"] + assert ( + measured_accuracy - rtol < expect_accuracy + and measured_accuracy + rtol > expect_accuracy + ), f"Expected: {expect_accuracy} | Measured: {measured_accuracy}" diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py old mode 100644 new mode 100755 diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py old mode 100644 new mode 100755 index d5459594b798..0303da175b97 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -114,7 +114,14 @@ def from_config(cls, config: dict[str, Any]) -> "QuarkConfig": layer_quant_names = list(layer_quant_config.keys()) layer_quant_set = set(layer_quant_names) - if not kv_cache_set.issubset(layer_quant_set): + if not ( + kv_cache_set.issubset(layer_quant_set) + or any( + fnmatch.fnmatchcase(layer_quant, pat) + for layer_quant in list(layer_quant_set) + for pat in list(kv_cache_set) + ) + ): raise ValueError( "The Quark quantized model has the " "kv_cache_group parameter setting, " @@ -124,10 +131,15 @@ def from_config(cls, config: dict[str, Any]) -> "QuarkConfig": ) q_configs = [ - cast(dict[str, Any], layer_quant_config.get(name)) - for name in kv_cache_group + quant_cfg + for name, quant_cfg in layer_quant_config.items() + if any(fnmatch.fnmatchcase(name, pattern) for pattern in kv_cache_group) ] - if not all(deep_compare(q_config, q_configs[0]) for q_config in q_configs): + + if not all( + deep_compare(q_config["output_tensors"], q_configs[0]["output_tensors"]) + for q_config in q_configs + ): raise ValueError( "The quantization method used for kv_cache should " "be the same, but the quantization method for the " @@ -312,9 +324,9 @@ def _find_matched_config( layer_quant_config = cast( dict[str, Any], self.quant_config.get("layer_quant_config") ) - for name_pattern in layer_quant_config: - if fnmatch.fnmatch(layer_name, name_pattern): - return layer_quant_config[name_pattern] + for name_pattern, config in layer_quant_config.items(): + if layer_name in name_pattern: + return config layer_type = cast(str, type(module)) layer_type_quant_config = cast(