-
-
Notifications
You must be signed in to change notification settings - Fork 10.8k
[ROCm] [Feature] [Doc] [Dockerfile] [BugFix] Support Per-Token-Activation Per-Channel-Weight FP8 Quantization Inferencing #12501
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 4 commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
8c61544
Add ptpc-fp8 quantization
kliuae d86df2d
Enable torch._scaled_mm rowwise gemm fp8
tjtanaa 4d57881
Update PyTorch version in Dockerfile.rocm_base; Update AMD GPU instal…
tjtanaa ef98cef
add ptpc fp8 unittests
tjtanaa 0f309c2
fix test_fp8.py::test_kv_cache_model_load_and_run; remove unnecessary…
tjtanaa 004dadb
Merge remote-tracking branch 'origin/main' into ptpc-fp8-rocm-2
tjtanaa 73d7bd1
format lint code
tjtanaa 12f42de
Merge remote-tracking branch 'origin/main' into ptpc-fp8-rocm-2
tjtanaa 881ce38
introduce USE_ROWWISE_TORCH_SCALED_MM
tjtanaa File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,54 @@ | ||
| """Tests whether PTPC w8a8 FP8 computation is enabled correctly. | ||
|
|
||
| Run `pytest tests/quantization/test_ptpc_fp8.py --forked`. | ||
| """ | ||
| import pytest | ||
| import torch | ||
|
|
||
| from tests.quantization.utils import is_quant_method_supported | ||
| from vllm.model_executor.layers.quantization.fp8 import Fp8KVCacheMethod | ||
| from vllm.model_executor.layers.quantization.ptpc_fp8 import ( | ||
| PTPCFp8LinearMethod) | ||
| from vllm.platforms import current_platform | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not is_quant_method_supported("ptpc_fp8"), | ||
| reason="PTPC FP8 is not supported on this GPU type.") | ||
| @pytest.mark.skipif(not current_platform.is_rocm(), | ||
| reason="This test is for ROCm GPU.") | ||
| @pytest.mark.parametrize("dtype", ["auto", "bfloat16", "float16"]) | ||
| @pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8", "fp8_e4m3"]) | ||
| def test_ptpc_fp8_rocm(vllm_runner, dtype: str, kv_cache_dtype: str) -> None: | ||
|
|
||
| try: | ||
| with vllm_runner("facebook/opt-125m", | ||
| dtype=dtype, | ||
| quantization="ptpc_fp8", | ||
| kv_cache_dtype=kv_cache_dtype) as llm: | ||
|
|
||
| model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 | ||
| fc1 = model.model.decoder.layers[0].fc1 | ||
| assert isinstance(fc1.quant_method, PTPCFp8LinearMethod) | ||
| if kv_cache_dtype == "ptpc_fp8": | ||
| attn = model.model.decoder.layers[0].self_attn.attn | ||
| assert isinstance(attn.quant_method, Fp8KVCacheMethod) | ||
| assert attn._k_scale == 1.0 | ||
| assert attn._v_scale == 1.0 | ||
|
|
||
| if current_platform.has_device_capability(94): | ||
| # For GPUs with hardware support, we keep weights in fp8 | ||
| assert fc1.weight.dtype == torch.float8_e4m3fnuz | ||
| else: | ||
| pytest.skip() | ||
|
|
||
| output = llm.generate_greedy("Hello my name is", max_tokens=20) | ||
| assert output | ||
| except AssertionError as e: | ||
| if str( | ||
| e | ||
| ) == "Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support output dtype of bfloat16. torch.float16 is specified.": # noqa: E501 | ||
| # If the error message matches, the test passes | ||
| pass | ||
| else: | ||
| # If the error message does not match, re-raise the exception | ||
| raise |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,123 @@ | ||
| from typing import Any, Dict, List, Optional | ||
|
|
||
| import torch | ||
| from torch.nn.parameter import Parameter | ||
|
|
||
| from vllm import _custom_ops as ops | ||
| from vllm.logger import init_logger | ||
| from vllm.model_executor.layers.linear import (LinearBase, | ||
| UnquantizedLinearMethod) | ||
| from vllm.model_executor.layers.quantization.base_config import ( | ||
| QuantizeMethodBase) | ||
| from vllm.model_executor.layers.quantization.fp8 import (Fp8Config, | ||
| Fp8KVCacheMethod, | ||
| Fp8LinearMethod) | ||
| from vllm.model_executor.layers.quantization.utils.quant_utils import ( | ||
| is_layer_skipped) | ||
| from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( | ||
| apply_fp8_linear) | ||
| from vllm.platforms import current_platform | ||
|
|
||
| ACTIVATION_SCHEMES = ["static", "dynamic"] | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
| class PTPCFp8Config(Fp8Config): | ||
| """Config class for Per-Token-Per-Channel Dynamic Quantization Fp8.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| activation_scheme: str = "dynamic", | ||
| ignored_layers: Optional[List[str]] = None, | ||
| ) -> None: | ||
| if not current_platform.is_rocm(): | ||
| raise ValueError( | ||
| "ptpc_fp8 quantization is supported only on ROCm.") | ||
|
|
||
| if not current_platform.has_device_capability(94): | ||
| raise ValueError( | ||
| "ptpc_fp8 quantization is supported only on AMD Instinct MI300 GPUs and newer." # noqa: E501 | ||
| ) | ||
| if activation_scheme == "static": | ||
| raise ValueError( | ||
| "ptpc_fp8 as of now only support dynamic quantization.") | ||
|
|
||
| super().__init__(is_checkpoint_fp8_serialized=False, | ||
| activation_scheme=activation_scheme, | ||
| ignored_layers=ignored_layers) | ||
|
|
||
| @classmethod | ||
| def get_name(cls) -> str: | ||
| return "ptpc_fp8" | ||
|
|
||
| @classmethod | ||
| def from_config(cls, config: Dict[str, Any]) -> "PTPCFp8Config": | ||
| activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) | ||
| ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) | ||
| return cls(activation_scheme=activation_scheme, | ||
| ignored_layers=ignored_layers) | ||
|
|
||
| def get_quant_method(self, layer: torch.nn.Module, | ||
| prefix: str) -> Optional["QuantizeMethodBase"]: | ||
| from vllm.attention.layer import Attention # Avoid circular import | ||
|
|
||
| if isinstance(layer, LinearBase): | ||
| if is_layer_skipped(prefix, self.ignored_layers): | ||
| return UnquantizedLinearMethod() | ||
| return PTPCFp8LinearMethod(self) | ||
| elif isinstance(layer, Attention): | ||
| return Fp8KVCacheMethod(self) | ||
| return None | ||
|
|
||
|
|
||
| class PTPCFp8LinearMethod(Fp8LinearMethod): | ||
| """Linear method for Per-Token and Per-Channel FP8 Quantization. | ||
| Only supports loading quantized BF16 model checkpoints with dynamic | ||
| activation scaling. To load FP16 model checkpoints, user must specify | ||
| to convert the FP16 model weight loading into BF16. | ||
| The weight scaling factor will be initialized after | ||
| the model weights are loaded. | ||
|
|
||
| Limitations: | ||
| 1. Only support float8_e4m3fnuz data type due to the limitation of | ||
| torch._scaled_mm (https://github.com/ROCm/pytorch/blob/8c0504d7f3fb0ee4c278c096a5c3caedb01129fa/aten/src/ATen/native/cuda/Blas.cpp#L1041) | ||
|
|
||
| Args: | ||
| quant_config: The quantization config. | ||
| """ | ||
|
|
||
| def __init__(self, quant_config: PTPCFp8Config): | ||
| super().__init__(quant_config=quant_config) | ||
| # Force weight quantization | ||
| self.quant_config.is_checkpoint_fp8_serialized = False | ||
|
|
||
| def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | ||
| layer.weight = torch.nn.Parameter(layer.weight.data, | ||
| requires_grad=False) | ||
|
|
||
| assert layer.weight.data.dtype == torch.bfloat16, \ | ||
| f"Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support output dtype of bfloat16. {str(layer.weight.data.dtype)} is specified." # noqa: E501 | ||
| # Quantize the weights. | ||
| qweight, weight_scale = ops.scaled_fp8_quant( | ||
| layer.weight, scale=None, use_per_token_if_dynamic=True) | ||
|
|
||
| # Update the layer with the new values. | ||
| layer.weight = Parameter( | ||
| qweight.t(), requires_grad=False) # Pretranspose the weight | ||
| layer.weight_scale = Parameter(weight_scale, requires_grad=False) | ||
| layer.input_scale = None | ||
|
|
||
| def apply(self, | ||
| layer: torch.nn.Module, | ||
| x: torch.Tensor, | ||
| bias: Optional[torch.Tensor] = None) -> torch.Tensor: | ||
|
|
||
| return apply_fp8_linear(input=x, | ||
| weight=layer.weight, | ||
| weight_scale=layer.weight_scale, | ||
| input_scale=None, | ||
| input_scale_ub=None, | ||
| bias=bias, | ||
| cutlass_fp8_supported=False, | ||
| use_per_token_if_dynamic=True) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.