Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added tests/__init__.py
Empty file.
Empty file added tests/quantization/__init__.py
Empty file.
58 changes: 58 additions & 0 deletions tests/quantization/test_mindie_turbo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Tests whether ascend quantization based on MindIE-Turbo is enabled correctly.
Run `pytest tests/quantization/test_mindie_turbo.py`.
"""

import pytest

from .utils import is_mindie_turbo_supported

MODELS = [
"LLaMA3-8B_W8A8/",
]


@pytest.mark.skipif(not is_mindie_turbo_supported(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a TODO here. Once more method is available in vllm-ascend. the skip can be removed

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test case is designed for quantization methods based on mindie-turbo. For other possible quantization methods in the future, we can add new test cases.

reason="MindIE-Turbo is not installed.")
@pytest.mark.parametrize("model_name_or_path", MODELS)
@pytest.mark.parametrize("max_tokens", [5])
def test_mindie_turbo(
model_name_or_path: str,
max_tokens: int,
) -> None:

import vllm # noqa: F401
from ..conftest import VllmRunner

import vllm_ascend # noqa: F401
from vllm_ascend.quantization.quant_config import AscendLinearMethod

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why import inner the test?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because mindie_turbo should be import before vllm in early versions of mindie_turbo. Perhaps this conflict has been resolved now and these packages imported can be moved outside.

prompt = "What's deep learning?"
example_prompts = [prompt]

with VllmRunner(model_name_or_path,
max_model_len=8192,
dtype="bfloat16",
enforce_eager=False,
gpu_memory_utilization=0.7) as vllm_model:

output = vllm_model.generate_greedy(example_prompts, max_tokens)
assert output
26 changes: 26 additions & 0 deletions tests/quantization/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

def is_mindie_turbo_supported() -> bool:
try:
import mindie_turbo
except:
return False

return True
32 changes: 32 additions & 0 deletions vllm_ascend/ops/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,20 @@
import torch
from vllm.model_executor.layers.layernorm import RMSNorm

try:
from mindie_turbo import RMSNormWithAntiOutlier
except Exception:
pass


def forward_oot(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if hasattr(self, "module"):
return self.module.forward_anti_outlier(x, residual)
Copy link
Collaborator

@wangxiyuan wangxiyuan Feb 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does self.module only used here? If yes, how about something like:

try:
    from mindie_turbo import RMSNormWithAntiOutlier
except:
    RMSNormWithAntiOutlier = None


def forward_oot():
    if RMSNormWithAntiOutlier is not None:
        return RMSNormWithAntiOutlier(self.hidden_size).forward_anti_outlier(x, residual)
  ....

Not sure enable_rmsnorm_with_antioutlier is need, it seems only added a new self.module there.

Copy link
Collaborator Author

@Angazenn Angazenn Feb 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Details of RMSNormWithAntiOutlier are moved out of vllm_ascend. There's no need to change the implemantation of rmsnorm in vllm_ascend now.


import torch_npu

if residual is not None:
Expand All @@ -37,4 +45,28 @@ def forward_oot(
return x


def enable_rmsnorm_with_antioutlier():
def init(
self,
hidden_size: int,
eps: float = 1e-6,
var_hidden_size: Optional[int] = None,
has_weight: bool = True,
) -> None:
super(RMSNorm, self).__init__()
self.hidden_size = hidden_size
self.variance_epsilon = eps
self.variance_size_override = (None if var_hidden_size == hidden_size
else var_hidden_size)
self.has_weight = has_weight

self.weight = torch.ones(hidden_size)
if self.has_weight:
self.weight = torch.nn.Parameter(self.weight)

self.module = RMSNormWithAntiOutlier(self.hidden_size)

RMSNorm.__init__ = init


RMSNorm.forward_oot = forward_oot
20 changes: 19 additions & 1 deletion vllm_ascend/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#

import os
from typing import Optional, Tuple
from typing import TYPE_CHECKING, Optional, Tuple

import torch

Expand All @@ -27,6 +27,10 @@

from vllm.config import VllmConfig
from vllm.platforms import Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.utils import FlexibleArgumentParser
else:
FlexibleArgumentParser = None

os.environ["RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES"] = "1"

Expand All @@ -53,6 +57,10 @@ class NPUPlatform(Platform):
ray_device_key: str = "NPU"
device_control_env_var: str = "ASCEND_RT_VISIBLE_DEVICES"

supported_quantization: list[str] = [
"ascend"
]

@classmethod
def get_device_capability(cls, device_id: int = 0):
return None
Expand Down Expand Up @@ -86,6 +94,16 @@ def synchronize(cls):
def mem_get_info(cls) -> Tuple[int, int]:
return torch.npu.mem_get_info()

# Relies on this pull request https://github.com/vllm-project/vllm/pull/12432.
@classmethod
def pre_register_and_update(cls,
parser: Optional[FlexibleArgumentParser] = None
) -> None:
"""
Do some pre-registeration or update action for ascend platform.
"""
from vllm_ascend.quantization.quant_config import AscendQuantConfig # noqa: F401
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vllm-project/vllm#12432 is not merged. Maybe you can move this import to register function in __init__.py, but I'm not sure if it will lead the circle import error or not. You can have a try first.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems that there are circular import if moving this import. But the codes still work and whole inference process can generate correct texts, so I move this codes to register function


@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
# Register ops when setup.
Expand Down
Empty file.
199 changes: 199 additions & 0 deletions vllm_ascend/quantization/quant_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from typing import Any, Dict, List, Optional

import torch

from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
RowParallelLinear, UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import (register_quantization_config)
from vllm.model_executor.layers.quantization.base_config import (QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
ModelWeightParameter)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
FUSED_LAYER_NAME_MAPPING)
from vllm.distributed import get_tensor_model_parallel_rank
from .quantizer import AscendQuantizer

logger = init_logger(__name__)

@register_quantization_config("ascend")
class AscendQuantConfig(QuantizationConfig):
"""Config class for Ascend"""

def __init__(self, quant_config: Dict[str, Any]):
self.quant_description = quant_config.pop("quant_description")
self.quant_config = quant_config

def __repr__(self) -> str:
return "AscendQuantConfig:\n" + super().__repr__()

@classmethod
def get_name(cls) -> str:
return "ascend"

@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.int8, torch.float16, torch.bfloat16]

@classmethod
def get_min_capability(cls) -> int:
raise NotImplementedError("Ascend hardware dose not support \"get_min_capability\" feature.")

@classmethod
def get_config_filenames(cls) -> List[str]:
return []

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "AscendQuantConfig":
return cls(config)

@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
dev_type = hf_quant_cfg.get("dev_type", None)
if dev_type == "npu":
return "ascend"
return None

def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
if self.is_layer_skipped_ascend(prefix):
return UnquantizedLinearMethod()
return AscendLinearMethod(self)
return None

def is_layer_skipped_ascend(self, prefix: str):
# adapted from vllm.model_executor.layers.quantization.utils.quant_utils.is_layer_skipped
proj_name = prefix.split(".")[-1]
if proj_name in FUSED_LAYER_NAME_MAPPING:
shard_prefixes = [
prefix.replace(proj_name, shard_proj_name)
for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name]
]

is_skipped = None
for shard_prefix in shard_prefixes:
is_shard_skipped = self.quant_description[shard_prefix + '.weight'] == "FLOAT"

if is_skipped is None:
is_skipped = is_shard_skipped
elif is_shard_skipped != is_skipped:
raise ValueError(
f"Detected some but not all shards of {prefix} "
"are quantized. All shards of fused layers "
"to have the same precision.")
else:
is_skipped = self.quant_description[prefix + '.weight'] == "FLOAT"

assert is_skipped is not None
return is_skipped

def get_scaled_act_names(self) -> List[str]:
return []


class AscendLinearMethod(LinearMethodBase):
"""Linear method for Ascend quantization.

Args:
quant_config: The Ascend quantization config.
"""

def __init__(self, quant_config: AscendQuantConfig) -> None:
self.quantizer = AscendQuantizer.get_quantizer(quant_config.quant_config)
self.quant_method = self.quantizer.build_linear_method()

def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
del output_size
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")

weights = self.quant_method.create_weights(
input_size_per_partition,
output_size_per_partition,
params_dtype
)

weight_name = self.quant_method.get_weight()
if weight_name in weights.keys():
layer.register_parameter(
weight_name,
ModelWeightParameter(
data=weights[weight_name].transpose(0, 1),
input_dim=1,
output_dim=0,
weight_loader=weight_loader
)
)
else:
raise ValueError(f"{weight_name} is nor registered. Please check your linear quant method implementation.")

pertensor_names = self.quant_method.get_pertensor_param()
for pertensor_name in pertensor_names:
if pertensor_name in weights.keys():
layer.register_parameter(
pertensor_name,
BasevLLMParameter(
data=weights[pertensor_name],
weight_loader=weight_loader
)
)
else:
raise ValueError(f"{pertensor_name} is nor registered. Please check your linear quant method implementation.")

perchannel_names = self.quant_method.get_perchannel_param()
for perchannel_name in perchannel_names:
if perchannel_name in weights.keys():
layer.register_parameter(
perchannel_name,
ChannelQuantScaleParameter(
data=weights[perchannel_name],
output_dim=0,
weight_loader=weight_loader
)
)
else:
raise ValueError(f"{perchannel_name} is nor registered. Please check your linear quant method implementation.")

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if hasattr(self.quant_method, 'transpose_weight') and self.quant_method.transpose_weight:
layer.weight.data = layer.weight.data.transpose(1, 0)

def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if isinstance(layer, RowParallelLinear):
tp_rank = get_tensor_model_parallel_rank()
return self.quant_method.apply(layer, x, bias, tp_rank)
return self.quant_method.apply(layer, x, bias)
Loading