-
Couldn't load subscription status.
- Fork 522
[Core]Add Ascend Quantize #7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 13 commits
bafc5f5
a4aaea4
43b9f70
46b7ca2
5b1b34d
7a41f8f
d332351
37c4543
7e230f0
1bfb206
9bbc77c
b5d4bf6
d3dd745
b809659
7f4b41c
a553906
3b94a9f
210e6dc
639b602
a92d9fe
44737f9
f1e9556
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| # | ||
| # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. | ||
| # This file is a part of the vllm-ascend project. | ||
| # Adapted from vllm/tests/basic_correctness/test_basic_correctness.py | ||
| # Copyright 2023 The vLLM team. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # | ||
| """Tests whether ascend quantization based on MindIE-Turbo is enabled correctly. | ||
| Run `pytest tests/quantization/test_mindie_turbo.py`. | ||
| """ | ||
|
|
||
| import pytest | ||
|
|
||
| from .utils import is_mindie_turbo_supported | ||
|
|
||
| MODELS = [ | ||
| "LLaMA3-8B_W8A8/", | ||
| ] | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not is_mindie_turbo_supported(), | ||
| reason="MindIE-Turbo is not installed.") | ||
| @pytest.mark.parametrize("model_name_or_path", MODELS) | ||
| @pytest.mark.parametrize("max_tokens", [5]) | ||
| def test_mindie_turbo( | ||
| model_name_or_path: str, | ||
| max_tokens: int, | ||
| ) -> None: | ||
|
|
||
| import vllm # noqa: F401 | ||
| from ..conftest import VllmRunner | ||
|
|
||
| import vllm_ascend # noqa: F401 | ||
| from vllm_ascend.quantization.quant_config import AscendLinearMethod | ||
|
|
||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why import inner the test? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is because mindie_turbo should be import before vllm in early versions of mindie_turbo. Perhaps this conflict has been resolved now and these packages imported can be moved outside. |
||
| prompt = "What's deep learning?" | ||
| example_prompts = [prompt] | ||
|
|
||
| with VllmRunner(model_name_or_path, | ||
| max_model_len=8192, | ||
| dtype="bfloat16", | ||
| enforce_eager=False, | ||
| gpu_memory_utilization=0.7) as vllm_model: | ||
|
|
||
| output = vllm_model.generate_greedy(example_prompts, max_tokens) | ||
| assert output | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| # | ||
| # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. | ||
| # This file is a part of the vllm-ascend project. | ||
| # Adapted from vllm/tests/basic_correctness/test_basic_correctness.py | ||
| # Copyright 2023 The vLLM team. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # | ||
|
|
||
| def is_mindie_turbo_supported() -> bool: | ||
| try: | ||
| import mindie_turbo | ||
| except: | ||
| return False | ||
|
|
||
| return True |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,12 +20,20 @@ | |
| import torch | ||
| from vllm.model_executor.layers.layernorm import RMSNorm | ||
|
|
||
| try: | ||
| from mindie_turbo import RMSNormWithAntiOutlier | ||
| except Exception: | ||
| pass | ||
|
|
||
|
|
||
| def forward_oot( | ||
| self, | ||
| x: torch.Tensor, | ||
| residual: Optional[torch.Tensor] = None, | ||
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: | ||
| if hasattr(self, "module"): | ||
| return self.module.forward_anti_outlier(x, residual) | ||
|
||
|
|
||
| import torch_npu | ||
|
|
||
| if residual is not None: | ||
|
|
@@ -37,4 +45,28 @@ def forward_oot( | |
| return x | ||
|
|
||
|
|
||
| def enable_rmsnorm_with_antioutlier(): | ||
| def init( | ||
| self, | ||
| hidden_size: int, | ||
| eps: float = 1e-6, | ||
| var_hidden_size: Optional[int] = None, | ||
| has_weight: bool = True, | ||
| ) -> None: | ||
| super(RMSNorm, self).__init__() | ||
| self.hidden_size = hidden_size | ||
| self.variance_epsilon = eps | ||
| self.variance_size_override = (None if var_hidden_size == hidden_size | ||
| else var_hidden_size) | ||
| self.has_weight = has_weight | ||
|
|
||
| self.weight = torch.ones(hidden_size) | ||
| if self.has_weight: | ||
| self.weight = torch.nn.Parameter(self.weight) | ||
|
|
||
| self.module = RMSNormWithAntiOutlier(self.hidden_size) | ||
|
|
||
| RMSNorm.__init__ = init | ||
|
|
||
|
|
||
| RMSNorm.forward_oot = forward_oot | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,7 +16,7 @@ | |
| # | ||
|
|
||
| import os | ||
| from typing import Optional, Tuple | ||
| from typing import TYPE_CHECKING, Optional, Tuple | ||
|
|
||
| import torch | ||
|
|
||
|
|
@@ -27,6 +27,10 @@ | |
|
|
||
| from vllm.config import VllmConfig | ||
| from vllm.platforms import Platform, PlatformEnum | ||
| if TYPE_CHECKING: | ||
| from vllm.utils import FlexibleArgumentParser | ||
| else: | ||
| FlexibleArgumentParser = None | ||
|
|
||
| os.environ["RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES"] = "1" | ||
|
|
||
|
|
@@ -53,6 +57,10 @@ class NPUPlatform(Platform): | |
| ray_device_key: str = "NPU" | ||
| device_control_env_var: str = "ASCEND_RT_VISIBLE_DEVICES" | ||
|
|
||
| supported_quantization: list[str] = [ | ||
| "ascend" | ||
| ] | ||
|
|
||
| @classmethod | ||
| def get_device_capability(cls, device_id: int = 0): | ||
| return None | ||
|
|
@@ -86,6 +94,16 @@ def synchronize(cls): | |
| def mem_get_info(cls) -> Tuple[int, int]: | ||
| return torch.npu.mem_get_info() | ||
|
|
||
| # Relies on this pull request https://github.com/vllm-project/vllm/pull/12432. | ||
| @classmethod | ||
| def pre_register_and_update(cls, | ||
| parser: Optional[FlexibleArgumentParser] = None | ||
| ) -> None: | ||
| """ | ||
| Do some pre-registeration or update action for ascend platform. | ||
| """ | ||
| from vllm_ascend.quantization.quant_config import AscendQuantConfig # noqa: F401 | ||
|
||
|
|
||
| @classmethod | ||
| def check_and_update_config(cls, vllm_config: VllmConfig) -> None: | ||
| # Register ops when setup. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,199 @@ | ||
| # | ||
| # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. | ||
| # This file is a part of the vllm-ascend project. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # | ||
|
|
||
| from typing import Any, Dict, List, Optional | ||
|
|
||
| import torch | ||
|
|
||
| from vllm.logger import init_logger | ||
| from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, | ||
| RowParallelLinear, UnquantizedLinearMethod) | ||
| from vllm.model_executor.layers.quantization import (register_quantization_config) | ||
| from vllm.model_executor.layers.quantization.base_config import (QuantizationConfig, QuantizeMethodBase) | ||
| from vllm.model_executor.parameter import (BasevLLMParameter, | ||
| ChannelQuantScaleParameter, | ||
| ModelWeightParameter) | ||
| from vllm.model_executor.layers.quantization.utils.quant_utils import ( | ||
| FUSED_LAYER_NAME_MAPPING) | ||
| from vllm.distributed import get_tensor_model_parallel_rank | ||
| from .quantizer import AscendQuantizer | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
| @register_quantization_config("ascend") | ||
| class AscendQuantConfig(QuantizationConfig): | ||
| """Config class for Ascend""" | ||
|
|
||
| def __init__(self, quant_config: Dict[str, Any]): | ||
| self.quant_description = quant_config.pop("quant_description") | ||
| self.quant_config = quant_config | ||
|
|
||
| def __repr__(self) -> str: | ||
| return "AscendQuantConfig:\n" + super().__repr__() | ||
|
|
||
| @classmethod | ||
| def get_name(cls) -> str: | ||
| return "ascend" | ||
|
|
||
| @classmethod | ||
| def get_supported_act_dtypes(cls) -> List[torch.dtype]: | ||
| return [torch.int8, torch.float16, torch.bfloat16] | ||
|
|
||
| @classmethod | ||
| def get_min_capability(cls) -> int: | ||
| raise NotImplementedError("Ascend hardware dose not support \"get_min_capability\" feature.") | ||
|
|
||
| @classmethod | ||
| def get_config_filenames(cls) -> List[str]: | ||
| return [] | ||
|
|
||
| @classmethod | ||
| def from_config(cls, config: Dict[str, Any]) -> "AscendQuantConfig": | ||
| return cls(config) | ||
|
|
||
| @classmethod | ||
| def override_quantization_method(cls, hf_quant_cfg, | ||
| user_quant) -> Optional[str]: | ||
| dev_type = hf_quant_cfg.get("dev_type", None) | ||
| if dev_type == "npu": | ||
| return "ascend" | ||
| return None | ||
|
|
||
| def get_quant_method(self, layer: torch.nn.Module, | ||
| prefix: str) -> Optional["QuantizeMethodBase"]: | ||
| if isinstance(layer, LinearBase): | ||
| if self.is_layer_skipped_ascend(prefix): | ||
| return UnquantizedLinearMethod() | ||
| return AscendLinearMethod(self) | ||
| return None | ||
|
|
||
| def is_layer_skipped_ascend(self, prefix: str): | ||
| # adapted from vllm.model_executor.layers.quantization.utils.quant_utils.is_layer_skipped | ||
| proj_name = prefix.split(".")[-1] | ||
| if proj_name in FUSED_LAYER_NAME_MAPPING: | ||
| shard_prefixes = [ | ||
| prefix.replace(proj_name, shard_proj_name) | ||
| for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name] | ||
| ] | ||
|
|
||
| is_skipped = None | ||
| for shard_prefix in shard_prefixes: | ||
| is_shard_skipped = self.quant_description[shard_prefix + '.weight'] == "FLOAT" | ||
|
|
||
| if is_skipped is None: | ||
| is_skipped = is_shard_skipped | ||
| elif is_shard_skipped != is_skipped: | ||
| raise ValueError( | ||
| f"Detected some but not all shards of {prefix} " | ||
| "are quantized. All shards of fused layers " | ||
| "to have the same precision.") | ||
| else: | ||
| is_skipped = self.quant_description[prefix + '.weight'] == "FLOAT" | ||
|
|
||
| assert is_skipped is not None | ||
| return is_skipped | ||
|
|
||
| def get_scaled_act_names(self) -> List[str]: | ||
| return [] | ||
|
|
||
|
|
||
| class AscendLinearMethod(LinearMethodBase): | ||
| """Linear method for Ascend quantization. | ||
|
|
||
| Args: | ||
| quant_config: The Ascend quantization config. | ||
| """ | ||
|
|
||
| def __init__(self, quant_config: AscendQuantConfig) -> None: | ||
| self.quantizer = AscendQuantizer.get_quantizer(quant_config.quant_config) | ||
| self.quant_method = self.quantizer.build_linear_method() | ||
|
|
||
| def create_weights( | ||
| self, | ||
| layer: torch.nn.Module, | ||
| input_size_per_partition: int, | ||
| output_partition_sizes: List[int], | ||
| input_size: int, | ||
| output_size: int, | ||
| params_dtype: torch.dtype, | ||
| **extra_weight_attrs, | ||
| ) -> None: | ||
| del output_size | ||
| output_size_per_partition = sum(output_partition_sizes) | ||
| weight_loader = extra_weight_attrs.get("weight_loader") | ||
|
|
||
| weights = self.quant_method.create_weights( | ||
| input_size_per_partition, | ||
| output_size_per_partition, | ||
| params_dtype | ||
| ) | ||
|
|
||
| weight_name = self.quant_method.get_weight() | ||
| if weight_name in weights.keys(): | ||
| layer.register_parameter( | ||
| weight_name, | ||
| ModelWeightParameter( | ||
| data=weights[weight_name].transpose(0, 1), | ||
| input_dim=1, | ||
| output_dim=0, | ||
| weight_loader=weight_loader | ||
| ) | ||
| ) | ||
| else: | ||
| raise ValueError(f"{weight_name} is nor registered. Please check your linear quant method implementation.") | ||
|
|
||
| pertensor_names = self.quant_method.get_pertensor_param() | ||
| for pertensor_name in pertensor_names: | ||
| if pertensor_name in weights.keys(): | ||
| layer.register_parameter( | ||
| pertensor_name, | ||
| BasevLLMParameter( | ||
| data=weights[pertensor_name], | ||
| weight_loader=weight_loader | ||
| ) | ||
| ) | ||
| else: | ||
| raise ValueError(f"{pertensor_name} is nor registered. Please check your linear quant method implementation.") | ||
|
|
||
| perchannel_names = self.quant_method.get_perchannel_param() | ||
| for perchannel_name in perchannel_names: | ||
| if perchannel_name in weights.keys(): | ||
| layer.register_parameter( | ||
| perchannel_name, | ||
| ChannelQuantScaleParameter( | ||
| data=weights[perchannel_name], | ||
| output_dim=0, | ||
| weight_loader=weight_loader | ||
| ) | ||
| ) | ||
| else: | ||
| raise ValueError(f"{perchannel_name} is nor registered. Please check your linear quant method implementation.") | ||
|
|
||
| def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | ||
| if hasattr(self.quant_method, 'transpose_weight') and self.quant_method.transpose_weight: | ||
| layer.weight.data = layer.weight.data.transpose(1, 0) | ||
|
|
||
| def apply( | ||
| self, | ||
| layer: torch.nn.Module, | ||
| x: torch.Tensor, | ||
| bias: Optional[torch.Tensor] = None, | ||
| ) -> torch.Tensor: | ||
| if isinstance(layer, RowParallelLinear): | ||
| tp_rank = get_tensor_model_parallel_rank() | ||
| return self.quant_method.apply(layer, x, bias, tp_rank) | ||
| return self.quant_method.apply(layer, x, bias) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a TODO here. Once more method is available in vllm-ascend. the skip can be removed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test case is designed for quantization methods based on mindie-turbo. For other possible quantization methods in the future, we can add new test cases.