diff --git a/tests/quantization/test_quark.py b/tests/quantization/test_quark.py index ce918a324887..ae09ac58e675 100644 --- a/tests/quantization/test_quark.py +++ b/tests/quantization/test_quark.py @@ -5,6 +5,7 @@ """ import pytest +import torch from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501 QuarkLinearMethod, QuarkW8A8Fp8, QuarkW8A8Int8) @@ -63,3 +64,28 @@ def check_model(model): output = llm.generate_greedy("Hello my name is", max_tokens=20) assert output + + +def test_quark_fp8_parity(vllm_runner): + quark_model_id = "amd-quark/llama-tiny-fp8-quark-quant-method" + fp8_model_id = "amd-quark/llama-tiny-fp8-quant-method" + + llm_kwargs = { + "tensor_parallel_size": 1, + "enforce_eager": True, + "gpu_memory_utilization": 0.1 + } + with (vllm_runner(quark_model_id, **llm_kwargs) as + quark_handle, vllm_runner(fp8_model_id, **llm_kwargs) as fp8_handle): + quark_model = (quark_handle.model.llm_engine.model_executor. + driver_worker.model_runner.model) + quark_state_dict = quark_model.state_dict() + + fp8_model = (fp8_handle.model.llm_engine.model_executor.driver_worker. + model_runner.model) + fp8_state_dict = fp8_model.state_dict() + + assert fp8_state_dict.keys() == quark_state_dict.keys() + + for key in fp8_state_dict: + assert torch.equal(fp8_state_dict[key], quark_state_dict[key]) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py index afd4bb722dad..f8eb3611592e 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -34,21 +34,24 @@ def process_weights_after_loading(self, layer) -> None: # tensor scales (thus N scales being passed to the kernel), # requantize so we can always run per tensor if self.qscheme == "per_tensor": - max_w_scale, weight = requantize_with_max_scale( - weight=layer.weight, - weight_scale=layer.weight_scale, - logical_widths=layer.logical_widths, - ) - - if current_platform.is_fp8_fnuz(): + if current_platform.is_rocm(): input_scale = getattr(layer, 'input_scale', None) weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( - weight=weight, - weight_scale=max_w_scale, + weight=layer.weight, + weight_scale=layer.weight_scale, input_scale=input_scale) if input_scale is not None: layer.input_scale = Parameter(input_scale, requires_grad=False) + else: + max_w_scale = layer.weight_scale + weight = layer.weight + + max_w_scale, weight = requantize_with_max_scale( + weight=weight, + weight_scale=max_w_scale, + logical_widths=layer.logical_widths, + ) layer.weight = Parameter(weight.t(), requires_grad=False) layer.weight_scale = Parameter(max_w_scale, requires_grad=False)