From eda4d30c3d4fae5a821564736452e28c2b0f7221 Mon Sep 17 00:00:00 2001 From: Bowen Bao Date: Fri, 24 Jan 2025 20:39:52 +0000 Subject: [PATCH 1/3] Support FP8 FA from Quark format --- .../layers/quantization/quark/quark.py | 34 ++++++++----------- .../quark/schemes/quark_w8a8_fp8.py | 2 ++ vllm/model_executor/models/grok1.py | 5 ++- vllm/model_executor/models/llama.py | 14 ++++++-- 4 files changed, 31 insertions(+), 24 deletions(-) diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index fc214255eca7..8cdd65225a11 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -1,5 +1,4 @@ import fnmatch -import re from typing import Any, Dict, List, Optional, cast import torch @@ -122,6 +121,12 @@ def from_config(cls, config: Dict[str, Any]) -> "QuarkConfig": for q_config in q_configs: q_config["output_tensors"] = None + # In case q_proj output is also quantized, remove the configuration + # to keep qkv consistency. + q_proj_q_config = cast(Dict[str, Any], + layer_quant_config.get("*q_proj")) + q_proj_q_config["output_tensors"] = None + return cls(quant_config=config, kv_cache_group=kv_cache_group, kv_cache_config=kv_cache_config, @@ -286,25 +291,14 @@ def get_cache_scale(self, name: str) -> Optional[str]: :param name: param name :return: matching param name for KV cache scale in vLLM """ - if self.kv_cache_group is None or len(self.kv_cache_group) == 0: - return None - - kv_proj_names = [ - re.split(r"[*.]", kv_cache)[-1] for kv_cache in self.kv_cache_group - ] - if name.endswith(".output_scale"): - if len(kv_proj_names) == 1 and kv_proj_names[0] in name: - kv_output_scale_name = "." + kv_proj_names[0] + ".output_scale" - return name.replace(kv_output_scale_name, ".attn.k_scale") - - elif len(kv_proj_names) == 2: - for kv_proj_name in kv_proj_names: - if kv_proj_name in name and kv_proj_name == "k_proj": - return name.replace(".k_proj.output_scale", - ".attn.k_scale") - elif kv_proj_name in name and kv_proj_name == "v_proj": - return name.replace(".v_proj.output_scale", - ".attn.v_scale") + if name.endswith(".output_scale") and ".k_proj" in name: + return name.replace(".k_proj.output_scale", ".attn.k_scale") + if name.endswith(".output_scale") and ".v_proj" in name: + return name.replace(".v_proj.output_scale", ".attn.v_scale") + if name.endswith(".output_scale") and ".q_proj" in name: + return name.replace(".q_proj.output_scale", ".attn.q_scale") + if name.endswith("self_attn.prob_output_scale"): + return name.replace(".prob_output_scale", ".attn.prob_scale") # If no matches, return None return None diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py index 206931ea2ffc..447911a64863 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -21,6 +21,7 @@ def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool]): self.qscheme = qscheme self.is_static_input_scheme = is_static_input_scheme self.cutlass_fp8_supported = cutlass_fp8_supported() + self.out_dtype = torch.get_default_dtype() @classmethod def get_min_capability(cls) -> int: @@ -134,6 +135,7 @@ def apply_weights(self, input=x, weight=layer.weight, weight_scale=layer.weight_scale, + out_dtype=self.out_dtype, input_scale=layer.input_scale, bias=bias, cutlass_fp8_supported=self.cutlass_fp8_supported, diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py index 99940c547f3d..bc30184ebdb6 100644 --- a/vllm/model_executor/models/grok1.py +++ b/vllm/model_executor/models/grok1.py @@ -37,6 +37,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.fp8 import Fp8Config +from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -197,7 +198,9 @@ def __init__( ) -> None: super().__init__() self.hidden_size = config.hidden_size - self.use_fp8 = isinstance(quant_config, Fp8Config) + self.use_fp8 = isinstance( + quant_config, Fp8Config) or (isinstance(quant_config, QuarkConfig) + and quant_config._is_fp8_w8a8) # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 10000) self.attn = Grok1Attention(hidden_size=self.hidden_size, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 5bb2718dd949..e340638b4016 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -40,6 +40,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.fp8 import Fp8Config +from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -84,7 +85,9 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.down_proj", ) - self.use_fp8 = (isinstance(quant_config, Fp8Config) + self.use_fp8 = (isinstance(quant_config, Fp8Config) or + (isinstance(quant_config, QuarkConfig) + and quant_config._is_fp8_w8a8) if current_platform.is_rocm() and not is_navi() else False) if hidden_act != "silu": @@ -196,10 +199,13 @@ def __init__(self, sliding_window = None # For CUDA devices and Navi4x, attn_fp8 will be set to false. + use_fp8 = isinstance( + quant_config, Fp8Config) or (isinstance(quant_config, QuarkConfig) + and quant_config._is_fp8_w8a8) self.attn_fp8_out = envs.VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT \ and current_platform.is_rocm() \ and not is_navi() \ - and isinstance(quant_config, Fp8Config) + and use_fp8 self.attn = Attention( self.num_heads, @@ -240,7 +246,9 @@ def __init__( ) -> None: super().__init__() self.hidden_size = config.hidden_size - self.use_fp8 = (isinstance(quant_config, Fp8Config) + self.use_fp8 = (isinstance(quant_config, Fp8Config) or + (isinstance(quant_config, QuarkConfig) + and quant_config._is_fp8_w8a8) if current_platform.is_rocm() and not is_navi() else False) rope_theta = getattr(config, "rope_theta", 10000) From 7cd245840f73273e181a2cdb38ee0a4b1d504c14 Mon Sep 17 00:00:00 2001 From: Bowen Bao Date: Fri, 24 Jan 2025 20:39:52 +0000 Subject: [PATCH 2/3] Support FP8 FA from Quark format --- .../model_executor/layers/quantization/quark/quark.py | 11 +++++++++++ vllm/model_executor/models/grok1.py | 2 +- vllm/model_executor/models/llama.py | 6 +++--- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index 8cdd65225a11..f48fb898d144 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -153,6 +153,17 @@ def _check_scheme_supported(self, else: return False + def is_fp8_w8a8(self) -> bool: + # Returns True if all quantized layers in model are fp8 w8a8. + global_quant_config = self.quant_config.get("global_quant_config") + layer_quant_configs = self.quant_config.get("layer_quant_config") + for quant_config in (global_quant_config, + *layer_quant_configs.values()): + if not self._is_fp8_w8a8(quant_config.get("weight"), + quant_config.get("input_tensors")): + return False + return True + def _is_fp8_w8a8(self, weight_quant: Optional[Dict[str, Any]], input_quant: Optional[Dict[str, Any]]) -> bool: # Confirm weights and input quantized. diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py index bc30184ebdb6..fb082fbf29b1 100644 --- a/vllm/model_executor/models/grok1.py +++ b/vllm/model_executor/models/grok1.py @@ -200,7 +200,7 @@ def __init__( self.hidden_size = config.hidden_size self.use_fp8 = isinstance( quant_config, Fp8Config) or (isinstance(quant_config, QuarkConfig) - and quant_config._is_fp8_w8a8) + and quant_config.is_fp8_w8a8()) # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 10000) self.attn = Grok1Attention(hidden_size=self.hidden_size, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index e340638b4016..a82fb4398ebb 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -87,7 +87,7 @@ def __init__( ) self.use_fp8 = (isinstance(quant_config, Fp8Config) or (isinstance(quant_config, QuarkConfig) - and quant_config._is_fp8_w8a8) + and quant_config.is_fp8_w8a8()) if current_platform.is_rocm() and not is_navi() else False) if hidden_act != "silu": @@ -201,7 +201,7 @@ def __init__(self, # For CUDA devices and Navi4x, attn_fp8 will be set to false. use_fp8 = isinstance( quant_config, Fp8Config) or (isinstance(quant_config, QuarkConfig) - and quant_config._is_fp8_w8a8) + and quant_config.is_fp8_w8a8()) self.attn_fp8_out = envs.VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT \ and current_platform.is_rocm() \ and not is_navi() \ @@ -248,7 +248,7 @@ def __init__( self.hidden_size = config.hidden_size self.use_fp8 = (isinstance(quant_config, Fp8Config) or (isinstance(quant_config, QuarkConfig) - and quant_config._is_fp8_w8a8) + and quant_config.is_fp8_w8a8()) if current_platform.is_rocm() and not is_navi() else False) rope_theta = getattr(config, "rope_theta", 10000) From 271265478abd583a883656d9d1be25a067fafce4 Mon Sep 17 00:00:00 2001 From: Bowen Bao Date: Tue, 28 Jan 2025 00:09:05 +0000 Subject: [PATCH 3/3] nit: update comment --- .../layers/quantization/quark/quark.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index f48fb898d144..144036814faf 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -154,13 +154,15 @@ def _check_scheme_supported(self, return False def is_fp8_w8a8(self) -> bool: - # Returns True if all quantized layers in model are fp8 w8a8. - global_quant_config = self.quant_config.get("global_quant_config") - layer_quant_configs = self.quant_config.get("layer_quant_config") - for quant_config in (global_quant_config, - *layer_quant_configs.values()): - if not self._is_fp8_w8a8(quant_config.get("weight"), - quant_config.get("input_tensors")): + # Returns True if all quantized layers in model are fp8 w8a8 + global_quant_config = cast( + Dict[str, Any], self.quant_config.get("global_quant_config")) + layer_quant_configs = cast(Dict[str, Any], + self.quant_config.get("layer_quant_config")) + for config in (global_quant_config, *layer_quant_configs.values()): + weight_config = cast(Dict[str, Any], config.get("weight")) + input_config = cast(Dict[str, Any], config.get("input_tensors")) + if not self._is_fp8_w8a8(weight_config, input_config): return False return True