Skip to content

Commit eda4d30

Browse files
committed
Support FP8 FA from Quark format
1 parent 8e87b08 commit eda4d30

File tree

4 files changed

+31
-24
lines changed

4 files changed

+31
-24
lines changed

vllm/model_executor/layers/quantization/quark/quark.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import fnmatch
2-
import re
32
from typing import Any, Dict, List, Optional, cast
43

54
import torch
@@ -122,6 +121,12 @@ def from_config(cls, config: Dict[str, Any]) -> "QuarkConfig":
122121
for q_config in q_configs:
123122
q_config["output_tensors"] = None
124123

124+
# In case q_proj output is also quantized, remove the configuration
125+
# to keep qkv consistency.
126+
q_proj_q_config = cast(Dict[str, Any],
127+
layer_quant_config.get("*q_proj"))
128+
q_proj_q_config["output_tensors"] = None
129+
125130
return cls(quant_config=config,
126131
kv_cache_group=kv_cache_group,
127132
kv_cache_config=kv_cache_config,
@@ -286,25 +291,14 @@ def get_cache_scale(self, name: str) -> Optional[str]:
286291
:param name: param name
287292
:return: matching param name for KV cache scale in vLLM
288293
"""
289-
if self.kv_cache_group is None or len(self.kv_cache_group) == 0:
290-
return None
291-
292-
kv_proj_names = [
293-
re.split(r"[*.]", kv_cache)[-1] for kv_cache in self.kv_cache_group
294-
]
295-
if name.endswith(".output_scale"):
296-
if len(kv_proj_names) == 1 and kv_proj_names[0] in name:
297-
kv_output_scale_name = "." + kv_proj_names[0] + ".output_scale"
298-
return name.replace(kv_output_scale_name, ".attn.k_scale")
299-
300-
elif len(kv_proj_names) == 2:
301-
for kv_proj_name in kv_proj_names:
302-
if kv_proj_name in name and kv_proj_name == "k_proj":
303-
return name.replace(".k_proj.output_scale",
304-
".attn.k_scale")
305-
elif kv_proj_name in name and kv_proj_name == "v_proj":
306-
return name.replace(".v_proj.output_scale",
307-
".attn.v_scale")
294+
if name.endswith(".output_scale") and ".k_proj" in name:
295+
return name.replace(".k_proj.output_scale", ".attn.k_scale")
296+
if name.endswith(".output_scale") and ".v_proj" in name:
297+
return name.replace(".v_proj.output_scale", ".attn.v_scale")
298+
if name.endswith(".output_scale") and ".q_proj" in name:
299+
return name.replace(".q_proj.output_scale", ".attn.q_scale")
300+
if name.endswith("self_attn.prob_output_scale"):
301+
return name.replace(".prob_output_scale", ".attn.prob_scale")
308302

309303
# If no matches, return None
310304
return None

vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool]):
2121
self.qscheme = qscheme
2222
self.is_static_input_scheme = is_static_input_scheme
2323
self.cutlass_fp8_supported = cutlass_fp8_supported()
24+
self.out_dtype = torch.get_default_dtype()
2425

2526
@classmethod
2627
def get_min_capability(cls) -> int:
@@ -134,6 +135,7 @@ def apply_weights(self,
134135
input=x,
135136
weight=layer.weight,
136137
weight_scale=layer.weight_scale,
138+
out_dtype=self.out_dtype,
137139
input_scale=layer.input_scale,
138140
bias=bias,
139141
cutlass_fp8_supported=self.cutlass_fp8_supported,

vllm/model_executor/models/grok1.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from vllm.model_executor.layers.quantization.base_config import (
3838
QuantizationConfig)
3939
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
40+
from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig
4041
from vllm.model_executor.layers.rotary_embedding import get_rope
4142
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
4243
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -197,7 +198,9 @@ def __init__(
197198
) -> None:
198199
super().__init__()
199200
self.hidden_size = config.hidden_size
200-
self.use_fp8 = isinstance(quant_config, Fp8Config)
201+
self.use_fp8 = isinstance(
202+
quant_config, Fp8Config) or (isinstance(quant_config, QuarkConfig)
203+
and quant_config._is_fp8_w8a8)
201204
# Requires transformers > 4.32.0
202205
rope_theta = getattr(config, "rope_theta", 10000)
203206
self.attn = Grok1Attention(hidden_size=self.hidden_size,

vllm/model_executor/models/llama.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from vllm.model_executor.layers.logits_processor import LogitsProcessor
4141
from vllm.model_executor.layers.quantization import QuantizationConfig
4242
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
43+
from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig
4344
from vllm.model_executor.layers.rotary_embedding import get_rope
4445
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
4546
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -84,7 +85,9 @@ def __init__(
8485
quant_config=quant_config,
8586
prefix=f"{prefix}.down_proj",
8687
)
87-
self.use_fp8 = (isinstance(quant_config, Fp8Config)
88+
self.use_fp8 = (isinstance(quant_config, Fp8Config) or
89+
(isinstance(quant_config, QuarkConfig)
90+
and quant_config._is_fp8_w8a8)
8891
if current_platform.is_rocm() and not is_navi() else
8992
False)
9093
if hidden_act != "silu":
@@ -196,10 +199,13 @@ def __init__(self,
196199
sliding_window = None
197200

198201
# For CUDA devices and Navi4x, attn_fp8 will be set to false.
202+
use_fp8 = isinstance(
203+
quant_config, Fp8Config) or (isinstance(quant_config, QuarkConfig)
204+
and quant_config._is_fp8_w8a8)
199205
self.attn_fp8_out = envs.VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT \
200206
and current_platform.is_rocm() \
201207
and not is_navi() \
202-
and isinstance(quant_config, Fp8Config)
208+
and use_fp8
203209

204210
self.attn = Attention(
205211
self.num_heads,
@@ -240,7 +246,9 @@ def __init__(
240246
) -> None:
241247
super().__init__()
242248
self.hidden_size = config.hidden_size
243-
self.use_fp8 = (isinstance(quant_config, Fp8Config)
249+
self.use_fp8 = (isinstance(quant_config, Fp8Config) or
250+
(isinstance(quant_config, QuarkConfig)
251+
and quant_config._is_fp8_w8a8)
244252
if current_platform.is_rocm() and not is_navi() else
245253
False)
246254
rope_theta = getattr(config, "rope_theta", 10000)

0 commit comments

Comments
 (0)