Skip to content

Commit 3add25e

Browse files
tjtanaaliuyumoye
authored andcommitted
[FEAT] [ROCm] [AITER]: Add AITER HIP block quant kernel (vllm-project#21242)
1 parent a417335 commit 3add25e

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

vllm/model_executor/layers/quantization/utils/fp8_utils.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,13 @@ def rocm_aiter_gemm_w8a8_blockscale_fake(
8282
fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake,
8383
dispatch_key=current_platform.dispatch_key,
8484
)
85+
if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR
86+
and current_platform.is_fp8_fnuz()):
87+
88+
import aiter as rocm_aiter
89+
from aiter import get_hip_quant
90+
91+
aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128)
8592

8693

8794
def dispatch_w8a8_blockscale_func(
@@ -178,8 +185,12 @@ def apply_w8a8_block_fp8_linear(
178185
block_size, input.dtype)
179186

180187
else:
181-
q_input, x_scale = per_token_group_quant_fp8(
182-
input_2d, block_size[1], column_major_scales=use_cutlass)
188+
if use_aiter_and_is_supported:
189+
q_input, x_scale = aiter_per1x128_quant(
190+
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8)
191+
else:
192+
q_input, x_scale = per_token_group_quant_fp8(
193+
input_2d, block_size[1], column_major_scales=use_cutlass)
183194

184195
output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale,
185196
block_size, input.dtype)

0 commit comments

Comments
 (0)