Skip to content

Commit 6b2147f

Browse files
authored
Support FP8 FA from Quark format (#388)
* Support FP8 FA from Quark format * Support FP8 FA from Quark format * nit: update comment
1 parent 28b1ad9 commit 6b2147f

File tree

4 files changed

+44
-24
lines changed

4 files changed

+44
-24
lines changed

vllm/model_executor/layers/quantization/quark/quark.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import fnmatch
2-
import re
32
from typing import Any, Dict, List, Optional, cast
43

54
import torch
@@ -122,6 +121,12 @@ def from_config(cls, config: Dict[str, Any]) -> "QuarkConfig":
122121
for q_config in q_configs:
123122
q_config["output_tensors"] = None
124123

124+
# In case q_proj output is also quantized, remove the configuration
125+
# to keep qkv consistency.
126+
q_proj_q_config = cast(Dict[str, Any],
127+
layer_quant_config.get("*q_proj"))
128+
q_proj_q_config["output_tensors"] = None
129+
125130
return cls(quant_config=config,
126131
kv_cache_group=kv_cache_group,
127132
kv_cache_config=kv_cache_config,
@@ -148,6 +153,19 @@ def _check_scheme_supported(self,
148153
else:
149154
return False
150155

156+
def is_fp8_w8a8(self) -> bool:
157+
# Returns True if all quantized layers in model are fp8 w8a8
158+
global_quant_config = cast(
159+
Dict[str, Any], self.quant_config.get("global_quant_config"))
160+
layer_quant_configs = cast(Dict[str, Any],
161+
self.quant_config.get("layer_quant_config"))
162+
for config in (global_quant_config, *layer_quant_configs.values()):
163+
weight_config = cast(Dict[str, Any], config.get("weight"))
164+
input_config = cast(Dict[str, Any], config.get("input_tensors"))
165+
if not self._is_fp8_w8a8(weight_config, input_config):
166+
return False
167+
return True
168+
151169
def _is_fp8_w8a8(self, weight_quant: Optional[Dict[str, Any]],
152170
input_quant: Optional[Dict[str, Any]]) -> bool:
153171
# Confirm weights and input quantized.
@@ -286,25 +304,14 @@ def get_cache_scale(self, name: str) -> Optional[str]:
286304
:param name: param name
287305
:return: matching param name for KV cache scale in vLLM
288306
"""
289-
if self.kv_cache_group is None or len(self.kv_cache_group) == 0:
290-
return None
291-
292-
kv_proj_names = [
293-
re.split(r"[*.]", kv_cache)[-1] for kv_cache in self.kv_cache_group
294-
]
295-
if name.endswith(".output_scale"):
296-
if len(kv_proj_names) == 1 and kv_proj_names[0] in name:
297-
kv_output_scale_name = "." + kv_proj_names[0] + ".output_scale"
298-
return name.replace(kv_output_scale_name, ".attn.k_scale")
299-
300-
elif len(kv_proj_names) == 2:
301-
for kv_proj_name in kv_proj_names:
302-
if kv_proj_name in name and kv_proj_name == "k_proj":
303-
return name.replace(".k_proj.output_scale",
304-
".attn.k_scale")
305-
elif kv_proj_name in name and kv_proj_name == "v_proj":
306-
return name.replace(".v_proj.output_scale",
307-
".attn.v_scale")
307+
if name.endswith(".output_scale") and ".k_proj" in name:
308+
return name.replace(".k_proj.output_scale", ".attn.k_scale")
309+
if name.endswith(".output_scale") and ".v_proj" in name:
310+
return name.replace(".v_proj.output_scale", ".attn.v_scale")
311+
if name.endswith(".output_scale") and ".q_proj" in name:
312+
return name.replace(".q_proj.output_scale", ".attn.q_scale")
313+
if name.endswith("self_attn.prob_output_scale"):
314+
return name.replace(".prob_output_scale", ".attn.prob_scale")
308315

309316
# If no matches, return None
310317
return None

vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool]):
2121
self.qscheme = qscheme
2222
self.is_static_input_scheme = is_static_input_scheme
2323
self.cutlass_fp8_supported = cutlass_fp8_supported()
24+
self.out_dtype = torch.get_default_dtype()
2425

2526
@classmethod
2627
def get_min_capability(cls) -> int:
@@ -134,6 +135,7 @@ def apply_weights(self,
134135
input=x,
135136
weight=layer.weight,
136137
weight_scale=layer.weight_scale,
138+
out_dtype=self.out_dtype,
137139
input_scale=layer.input_scale,
138140
bias=bias,
139141
cutlass_fp8_supported=self.cutlass_fp8_supported,

vllm/model_executor/models/grok1.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from vllm.model_executor.layers.quantization.base_config import (
3838
QuantizationConfig)
3939
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
40+
from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig
4041
from vllm.model_executor.layers.rotary_embedding import get_rope
4142
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
4243
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -197,7 +198,9 @@ def __init__(
197198
) -> None:
198199
super().__init__()
199200
self.hidden_size = config.hidden_size
200-
self.use_fp8 = isinstance(quant_config, Fp8Config)
201+
self.use_fp8 = isinstance(
202+
quant_config, Fp8Config) or (isinstance(quant_config, QuarkConfig)
203+
and quant_config.is_fp8_w8a8())
201204
# Requires transformers > 4.32.0
202205
rope_theta = getattr(config, "rope_theta", 10000)
203206
self.attn = Grok1Attention(hidden_size=self.hidden_size,

vllm/model_executor/models/llama.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from vllm.model_executor.layers.logits_processor import LogitsProcessor
4141
from vllm.model_executor.layers.quantization import QuantizationConfig
4242
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
43+
from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig
4344
from vllm.model_executor.layers.rotary_embedding import get_rope
4445
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
4546
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -84,7 +85,9 @@ def __init__(
8485
quant_config=quant_config,
8586
prefix=f"{prefix}.down_proj",
8687
)
87-
self.use_fp8 = (isinstance(quant_config, Fp8Config)
88+
self.use_fp8 = (isinstance(quant_config, Fp8Config) or
89+
(isinstance(quant_config, QuarkConfig)
90+
and quant_config.is_fp8_w8a8())
8891
if current_platform.is_rocm() and not is_navi() else
8992
False)
9093
if hidden_act != "silu":
@@ -196,10 +199,13 @@ def __init__(self,
196199
sliding_window = None
197200

198201
# For CUDA devices and Navi4x, attn_fp8 will be set to false.
202+
use_fp8 = isinstance(
203+
quant_config, Fp8Config) or (isinstance(quant_config, QuarkConfig)
204+
and quant_config.is_fp8_w8a8())
199205
self.attn_fp8_out = envs.VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT \
200206
and current_platform.is_rocm() \
201207
and not is_navi() \
202-
and isinstance(quant_config, Fp8Config)
208+
and use_fp8
203209

204210
self.attn = Attention(
205211
self.num_heads,
@@ -240,7 +246,9 @@ def __init__(
240246
) -> None:
241247
super().__init__()
242248
self.hidden_size = config.hidden_size
243-
self.use_fp8 = (isinstance(quant_config, Fp8Config)
249+
self.use_fp8 = (isinstance(quant_config, Fp8Config) or
250+
(isinstance(quant_config, QuarkConfig)
251+
and quant_config.is_fp8_w8a8())
244252
if current_platform.is_rocm() and not is_navi() else
245253
False)
246254
rope_theta = getattr(config, "rope_theta", 10000)

0 commit comments

Comments
 (0)