Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions tests/quantization/test_quark.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

import pytest
import torch

from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
QuarkLinearMethod, QuarkW8A8Fp8, QuarkW8A8Int8)
Expand Down Expand Up @@ -63,3 +64,28 @@ def check_model(model):

output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output


def test_quark_fp8_parity(vllm_runner):
quark_model_id = "amd-quark/llama-tiny-fp8-quark-quant-method"
fp8_model_id = "amd-quark/llama-tiny-fp8-quant-method"

llm_kwargs = {
"tensor_parallel_size": 1,
"enforce_eager": True,
"gpu_memory_utilization": 0.1
}
with (vllm_runner(quark_model_id, **llm_kwargs) as
quark_handle, vllm_runner(fp8_model_id, **llm_kwargs) as fp8_handle):
quark_model = (quark_handle.model.llm_engine.model_executor.
driver_worker.model_runner.model)
quark_state_dict = quark_model.state_dict()

fp8_model = (fp8_handle.model.llm_engine.model_executor.driver_worker.
model_runner.model)
fp8_state_dict = fp8_model.state_dict()

assert fp8_state_dict.keys() == quark_state_dict.keys()

for key in fp8_state_dict:
assert torch.equal(fp8_state_dict[key], quark_state_dict[key])
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,24 @@ def process_weights_after_loading(self, layer) -> None:
# tensor scales (thus N scales being passed to the kernel),
# requantize so we can always run per tensor
if self.qscheme == "per_tensor":
max_w_scale, weight = requantize_with_max_scale(
weight=layer.weight,
weight_scale=layer.weight_scale,
logical_widths=layer.logical_widths,
)

if current_platform.is_fp8_fnuz():
if current_platform.is_rocm():
input_scale = getattr(layer, 'input_scale', None)
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=max_w_scale,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=input_scale)
if input_scale is not None:
layer.input_scale = Parameter(input_scale,
requires_grad=False)
else:
max_w_scale = layer.weight_scale
weight = layer.weight

max_w_scale, weight = requantize_with_max_scale(
weight=weight,
weight_scale=max_w_scale,
logical_widths=layer.logical_widths,
)

layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
Expand Down