diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index ee431c9148b8..da69bdbbd09b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -12,7 +12,6 @@ ) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, - check_aiter_fp8_linear_support, create_fp8_input_scale, create_fp8_scale_parameter, create_fp8_weight_parameter, @@ -20,6 +19,7 @@ process_fp8_weight_block_strategy, process_fp8_weight_channel_strategy, process_fp8_weight_tensor_strategy, + use_aiter_fp8_linear, validate_fp8_block_shape, ) from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape @@ -61,7 +61,7 @@ def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool) ) self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() - self.use_aiter_and_is_supported = check_aiter_fp8_linear_support() + self.use_aiter_and_is_supported = use_aiter_fp8_linear() if self.weight_block_size is not None: assert not self.is_static_input_scheme diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index e5681cb85625..c59da2963663 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -56,7 +56,6 @@ ) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, - check_aiter_fp8_linear_support, create_fp8_input_scale, create_fp8_scale_parameter, create_fp8_weight_parameter, @@ -65,6 +64,7 @@ process_fp8_weight_block_strategy, process_fp8_weight_tensor_strategy, requant_weight_ue8m0_inplace, + use_aiter_fp8_linear, validate_fp8_block_shape, ) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( @@ -362,7 +362,7 @@ def __init__(self, quant_config: Fp8Config): if vllm_is_batch_invariant(): self.use_marlin = False - self.use_aiter_and_is_supported = check_aiter_fp8_linear_support() + self.use_aiter_and_is_supported = use_aiter_fp8_linear() self.weight_block_size = self.quant_config.weight_block_size self.block_quant = self.weight_block_size is not None diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index f25148abb619..a2b0d1368c17 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -104,12 +104,38 @@ def rocm_aiter_gemm_w8a8_blockscale_fake( if ( envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR - and current_platform.is_fp8_fnuz() + and current_platform.supports_aiter_w8a8_block_fp8_linear() ): import aiter as rocm_aiter - from aiter import get_hip_quant + from aiter import per_group_quant_hip + + def rocm_aiter_per1x128_quant_impl( + x: torch.Tensor, + scale: torch.Tensor | None = None, + quant_dtype: torch.dtype = torch.float8_e4m3fn, + ) -> tuple[torch.Tensor, torch.Tensor]: + return per_group_quant_hip(x, scale, quant_dtype, group_size=128) + + def rocm_aiter_per1x128_quant_fake( + x: torch.Tensor, + scale: torch.Tensor | None = None, + quant_dtype: torch.dtype = torch.float8_e4m3fn, + ) -> tuple[torch.Tensor, torch.Tensor]: + group_size = 128 + y = torch.empty(x.shape, dtype=quant_dtype, device=x.device) + scale = torch.empty( + (*x.shape[:-1], x.shape[-1] // group_size), + dtype=torch.float32, + device=x.device, + ) + return y, scale - aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128) + direct_register_custom_op( + op_name="rocm_aiter_per1x128_quant", + op_func=rocm_aiter_per1x128_quant_impl, + mutates_args=[], + fake_impl=rocm_aiter_per1x128_quant_fake, + ) # TODO we should be able to change the type of block_size to GroupShape @@ -352,7 +378,7 @@ def _run_aiter( weight_scale: torch.Tensor, ) -> torch.Tensor: assert self.act_quant_group_shape == GroupShape(1, 128) - q_input, input_scale = aiter_per1x128_quant( + q_input, input_scale = torch.ops.vllm.rocm_aiter_per1x128_quant( input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8 ) return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale( @@ -938,14 +964,13 @@ def requant_weight_ue8m0_inplace( s_old.copy_(s_requant) -def check_aiter_fp8_linear_support() -> bool: - """AITER is only supported on ROCm and only for FP8_FNUZ - and at the moment are MI300 series""" +def use_aiter_fp8_linear() -> bool: + """Check whether the ROCm AITER FP8 linear op is supported and activated""" return ( current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR - and current_platform.is_fp8_fnuz() + and current_platform.supports_aiter_w8a8_block_fp8_linear() ) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index d3535c9781c4..f857957643bc 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -434,19 +434,22 @@ def get_device_communicator_cls(cls) -> str: ) @classmethod - def supports_mx(cls) -> bool: + def arch_is_in(cls, arch_list: list[str]) -> bool: + # only device 0 is checked, this assumes MI300 platforms are homogeneous gcn_arch = torch.cuda.get_device_properties(0).gcnArchName - return any(gfx in gcn_arch for gfx in ["gfx95"]) + return any(gfx in gcn_arch for gfx in arch_list) + + @classmethod + def supports_mx(cls) -> bool: + return cls.arch_is_in(["gfx95"]) @classmethod def supports_fp8(cls) -> bool: - gcn_arch = torch.cuda.get_device_properties(0).gcnArchName - return any(gfx in gcn_arch for gfx in ["gfx94", "gfx95", "gfx12"]) + return cls.arch_is_in(["gfx94", "gfx95", "gfx12"]) @classmethod def is_fp8_fnuz(cls) -> bool: - # only device 0 is checked, this assumes MI300 platforms are homogeneous - return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName + return cls.arch_is_in(["gfx94"]) @classmethod def fp8_dtype(cls) -> torch.dtype: @@ -457,10 +460,13 @@ def fp8_dtype(cls) -> torch.dtype: @classmethod def use_custom_allreduce(cls) -> bool: - # We only enable custom allreduce for MI300 series - gcn_arch = torch.cuda.get_device_properties(0).gcnArchName - supported_archs = ["gfx94", "gfx95"] - return any(gfx in gcn_arch for gfx in supported_archs) + # Enable for MI300 and MI350 series + return cls.arch_is_in(["gfx94", "gfx95"]) + + @classmethod + def supports_aiter_w8a8_block_fp8_linear(cls) -> bool: + # Enable for MI300 and MI350 series + return cls.arch_is_in(["gfx94", "gfx95"]) @classmethod def opaque_attention_op(cls) -> bool: