Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
check_aiter_fp8_linear_support,
create_fp8_input_scale,
create_fp8_scale_parameter,
create_fp8_weight_parameter,
maybe_post_process_fp8_weight_block,
process_fp8_weight_block_strategy,
process_fp8_weight_channel_strategy,
process_fp8_weight_tensor_strategy,
use_aiter_fp8_linear,
validate_fp8_block_shape,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
Expand Down Expand Up @@ -61,7 +61,7 @@ def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool)
)

self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
self.use_aiter_and_is_supported = use_aiter_fp8_linear()

if self.weight_block_size is not None:
assert not self.is_static_input_scheme
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
check_aiter_fp8_linear_support,
create_fp8_input_scale,
create_fp8_scale_parameter,
create_fp8_weight_parameter,
Expand All @@ -65,6 +64,7 @@
process_fp8_weight_block_strategy,
process_fp8_weight_tensor_strategy,
requant_weight_ue8m0_inplace,
use_aiter_fp8_linear,
validate_fp8_block_shape,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
Expand Down Expand Up @@ -362,7 +362,7 @@ def __init__(self, quant_config: Fp8Config):
if vllm_is_batch_invariant():
self.use_marlin = False

self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
self.use_aiter_and_is_supported = use_aiter_fp8_linear()

self.weight_block_size = self.quant_config.weight_block_size
self.block_quant = self.weight_block_size is not None
Expand Down
41 changes: 33 additions & 8 deletions vllm/model_executor/layers/quantization/utils/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,38 @@ def rocm_aiter_gemm_w8a8_blockscale_fake(
if (
envs.VLLM_ROCM_USE_AITER
and envs.VLLM_ROCM_USE_AITER_LINEAR
and current_platform.is_fp8_fnuz()
and current_platform.supports_aiter_w8a8_block_fp8_linear()
):
import aiter as rocm_aiter
from aiter import get_hip_quant
from aiter import per_group_quant_hip

def rocm_aiter_per1x128_quant_impl(
x: torch.Tensor,
scale: torch.Tensor | None = None,
quant_dtype: torch.dtype = torch.float8_e4m3fn,
) -> tuple[torch.Tensor, torch.Tensor]:
return per_group_quant_hip(x, scale, quant_dtype, group_size=128)

def rocm_aiter_per1x128_quant_fake(
x: torch.Tensor,
scale: torch.Tensor | None = None,
quant_dtype: torch.dtype = torch.float8_e4m3fn,
) -> tuple[torch.Tensor, torch.Tensor]:
group_size = 128
y = torch.empty(x.shape, dtype=quant_dtype, device=x.device)
scale = torch.empty(
(*x.shape[:-1], x.shape[-1] // group_size),
dtype=torch.float32,
device=x.device,
)
return y, scale

aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128)
direct_register_custom_op(
op_name="rocm_aiter_per1x128_quant",
op_func=rocm_aiter_per1x128_quant_impl,
mutates_args=[],
fake_impl=rocm_aiter_per1x128_quant_fake,
)


# TODO we should be able to change the type of block_size to GroupShape
Expand Down Expand Up @@ -352,7 +378,7 @@ def _run_aiter(
weight_scale: torch.Tensor,
) -> torch.Tensor:
assert self.act_quant_group_shape == GroupShape(1, 128)
q_input, input_scale = aiter_per1x128_quant(
q_input, input_scale = torch.ops.vllm.rocm_aiter_per1x128_quant(
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8
)
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale(
Expand Down Expand Up @@ -938,14 +964,13 @@ def requant_weight_ue8m0_inplace(
s_old.copy_(s_requant)


def check_aiter_fp8_linear_support() -> bool:
"""AITER is only supported on ROCm and only for FP8_FNUZ
and at the moment are MI300 series"""
def use_aiter_fp8_linear() -> bool:
"""Check whether the ROCm AITER FP8 linear op is supported and activated"""
return (
current_platform.is_rocm()
and envs.VLLM_ROCM_USE_AITER
and envs.VLLM_ROCM_USE_AITER_LINEAR
and current_platform.is_fp8_fnuz()
and current_platform.supports_aiter_w8a8_block_fp8_linear()
)


Expand Down
26 changes: 16 additions & 10 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,19 +434,22 @@ def get_device_communicator_cls(cls) -> str:
)

@classmethod
def supports_mx(cls) -> bool:
def arch_is_in(cls, arch_list: list[str]) -> bool:
# only device 0 is checked, this assumes MI300 platforms are homogeneous
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
return any(gfx in gcn_arch for gfx in ["gfx95"])
return any(gfx in gcn_arch for gfx in arch_list)

@classmethod
def supports_mx(cls) -> bool:
return cls.arch_is_in(["gfx95"])

@classmethod
def supports_fp8(cls) -> bool:
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
return any(gfx in gcn_arch for gfx in ["gfx94", "gfx95", "gfx12"])
return cls.arch_is_in(["gfx94", "gfx95", "gfx12"])

@classmethod
def is_fp8_fnuz(cls) -> bool:
# only device 0 is checked, this assumes MI300 platforms are homogeneous
return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
return cls.arch_is_in(["gfx94"])

@classmethod
def fp8_dtype(cls) -> torch.dtype:
Expand All @@ -457,10 +460,13 @@ def fp8_dtype(cls) -> torch.dtype:

@classmethod
def use_custom_allreduce(cls) -> bool:
# We only enable custom allreduce for MI300 series
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
supported_archs = ["gfx94", "gfx95"]
return any(gfx in gcn_arch for gfx in supported_archs)
# Enable for MI300 and MI350 series
return cls.arch_is_in(["gfx94", "gfx95"])

@classmethod
def supports_aiter_w8a8_block_fp8_linear(cls) -> bool:
# Enable for MI300 and MI350 series
return cls.arch_is_in(["gfx94", "gfx95"])

@classmethod
def opaque_attention_op(cls) -> bool:
Expand Down
Loading