diff --git a/vllm/envs.py b/vllm/envs.py index ce719cd3d2d9..6d76e041dff7 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -173,6 +173,7 @@ VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT: bool = True + VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT: bool = True VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD: bool = True VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE: bool = True VLLM_ROCM_USE_AITER_TRITON_FP8_BMM: bool = True @@ -1241,6 +1242,8 @@ def get_vllm_port() -> Optional[int]: # Use AITER Triton fused RMSNORM + Quantization "VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT": lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT", "1"))), + "VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT": + lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT", "1"))), # Use AITER Triton fused elementwise multiply + elementwise addtion "VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD": diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 52126604b363..c83608f2a0bb 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -26,22 +26,52 @@ if VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP4_QUANT: from aiter.ops.triton.activation import act_mul_and_mxfp4_quant + rocm_aiter_fp4_quant_group_size = 32 + + def rocm_aiter_act_mul_and_fp4_group_quant_impl( + x: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + M = x.shape[0] + shuffle = VLLM_TRITON_FP4_GEMM_USE_ASM and (M >= 32) + x_fp4, out_bs = act_mul_and_mxfp4_quant(x, activation="silu", shuffle=shuffle, scale_shuffle_padding=True) + return x_fp4, out_bs + + def rocm_aiter_act_mul_and_fp4_group_quant_fake( + x: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + M, N = x.shape + assert N % 4 == 0 + N_half = N // 2 + x_fp4 = torch.empty((M, N_half // 2), dtype=torch.uint8, device=x.device) + scaleN_valid = (N_half + rocm_aiter_fp4_quant_group_size - 1) // rocm_aiter_fp4_quant_group_size + scaleM = (M + 255) // 256 * 256 + scaleN = (scaleN_valid + 7) // 8 * 8 + out_bs = torch.empty((scaleM, scaleN), dtype=torch.uint8, device=x.device) + return x_fp4, out_bs + + direct_register_custom_op( + op_name="rocm_aiter_act_mul_and_fp4_group_quant", + op_func=rocm_aiter_act_mul_and_fp4_group_quant_impl, + mutates_args=[], + fake_impl=rocm_aiter_act_mul_and_fp4_group_quant_fake, + dispatch_key=current_platform.dispatch_key, + ) + VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT = envs.VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT if VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT: - logger.info("[Aiter] VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT=1") from aiter.ops.triton.activation import act_mul_and_fp8_group_quant import aiter as rocm_aiter rocm_aiter_fp8_dtype = rocm_aiter.dtypes.fp8 rocm_aiter_fp8_quant_group_size = 128 - def act_mul_and_fp8_group_quant_impl( + def rocm_aiter_act_mul_and_fp8_group_quant_impl( x: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - return act_mul_and_fp8_group_quant(x, activation="silu", group_size=rocm_aiter_fp8_quant_group_size, dtype_quant=rocm_aiter_fp8_dtype) + return rocm_aiter_act_mul_and_fp8_group_quant(x, activation="silu", group_size=rocm_aiter_fp8_quant_group_size, dtype_quant=rocm_aiter_fp8_dtype) - def act_mul_and_fp8_group_quant_fake( + def rocm_aiter_act_mul_and_fp8_group_quant_fake( x: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: M, N = x.shape @@ -52,10 +82,10 @@ def act_mul_and_fp8_group_quant_fake( return x_fp8, out_bs direct_register_custom_op( - op_name="act_mul_and_fp8_group_quant", - op_func=act_mul_and_fp8_group_quant_impl, + op_name="rocm_aiter_act_mul_and_fp8_group_quant", + op_func=rocm_aiter_act_mul_and_fp8_group_quant_impl, mutates_args=[], - fake_impl=act_mul_and_fp8_group_quant_fake, + fake_impl=rocm_aiter_act_mul_and_fp8_group_quant_fake, dispatch_key=current_platform.dispatch_key, ) @@ -118,9 +148,7 @@ class SiluAndMul(CustomOp): def __init__(self): super().__init__() - if VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP4_QUANT: - self.op = lambda x, shuffle: act_mul_and_mxfp4_quant(x, "silu", shuffle=shuffle) - elif current_platform.is_cuda_alike(): + if current_platform.is_cuda_alike(): self.op = torch.ops._C.silu_and_mul elif current_platform.is_xpu(): from vllm._ipex_ops import ipex_ops @@ -136,16 +164,11 @@ def forward_native(self, x: torch.Tensor) -> torch.Tensor: def forward_cuda(self, x: torch.Tensor, scale: Optional[torch.Tensor] = None) -> torch.Tensor: - if VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP4_QUANT: - shuffle = VLLM_TRITON_FP4_GEMM_USE_ASM and x.shape[0] >= 32 - out, out_scales = self.op(x, shuffle) - return out, out_scales - else: - d = x.shape[-1] // 2 - output_shape = (x.shape[:-1] + (d, )) - out = torch.empty(output_shape, dtype=x.dtype, device=x.device) - self.op(out, x) - return out + d = x.shape[-1] // 2 + output_shape = (x.shape[:-1] + (d, )) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + self.op(out, x) + return out def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index b2d753018b4e..c7b52f5f2924 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -11,6 +11,71 @@ from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op +from vllm.logger import init_logger +logger = init_logger(__name__) + +if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER: + VLLM_TRITON_FP4_GEMM_USE_ASM = envs.VLLM_ROCM_USE_AITER and envs.VLLM_TRITON_FP4_GEMM_USE_ASM + VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT = envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT + + if VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT: + from aiter import per_1x32_f4_quant_hip + from aiter.ops.triton.fused_mxfp4_quant import fused_rms_mxfp4_quant + rocm_aiter_fp4_quant_group_size = 32 + + def rocm_aiter_fused_rms_and_fp4_group_quant_impl( + x: torch.Tensor, + weight: torch.Tensor, + eps: float, + residual: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + M = x.shape[0] + res = None + if M <= 64: + shuffle = VLLM_TRITON_FP4_GEMM_USE_ASM and (M >= 32) + (x_fp4, out_bs), _, res = fused_rms_mxfp4_quant(x, weight, eps, + None, None, eps, + res1=residual, + shuffle=shuffle, + scale_shuffle_padding=True) + else: + shuffle = VLLM_TRITON_FP4_GEMM_USE_ASM + if residual is not None: + x_rms, res = torch.ops.vllm.rocm_aiter_fused_add_rms_norm(x, residual, weight, eps) + else: + x_rms = torch.ops.vllm.rocm_aiter_rms_norm(x, weight, eps) + x_fp4, out_bs = per_1x32_f4_quant_hip(x_rms, shuffle=shuffle) + + if res is None: + return x_fp4, out_bs, x + return x_fp4, out_bs, res + + def rocm_aiter_fused_rms_and_fp4_group_quant_fake( + x: torch.Tensor, + weight: torch.Tensor, + eps: float, + residual: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + M, N = x.shape + x_fp4 = torch.empty((M, N // 2), dtype=torch.uint8, device=x.device) + scaleN_valid = (N + rocm_aiter_fp4_quant_group_size - 1) // rocm_aiter_fp4_quant_group_size + scaleM = (M + 255) // 256 * 256 + scaleN = (scaleN_valid + 7) // 8 * 8 + out_bs = torch.empty((scaleM, scaleN), dtype=torch.uint8, device=x.device) + res = torch.empty((M, N), dtype=x.dtype, device=x.device) + return x_fp4, out_bs, res + + direct_register_custom_op( + op_name="rocm_aiter_fused_rms_and_fp4_group_quant", + op_func=rocm_aiter_fused_rms_and_fp4_group_quant_impl, + mutates_args=[], + fake_impl=rocm_aiter_fused_rms_and_fp4_group_quant_fake, + dispatch_key=current_platform.dispatch_key, + ) +else: + VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT = False + +logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT=}") def is_rocm_aiter_rmsnorm_enabled() -> bool: return current_platform.is_rocm() \ diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 5f331765e87c..328cf69023f2 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -597,8 +597,8 @@ def forward( # Matrix multiply. assert self.quant_method is not None from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod - from vllm.model_executor.layers.quantization.quark.schemes.quark_w4a4_mxfp4 import QuarkW4A4MXFP4 - if isinstance(self.quant_method, Fp8LinearMethod) or isinstance(self.quant_method, QuarkW4A4MXFP4): + from vllm.model_executor.layers.quantization.quark.quark import QuarkLinearMethod + if isinstance(self.quant_method, Fp8LinearMethod) or isinstance(self.quant_method, QuarkLinearMethod): output_parallel = self.quant_method.apply(self, input_, bias, x_quant_scales=x_quant_scales) else: assert x_quant_scales is None, f"x_quant_scales input is not supported for {self.quant_method.__class__}" diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py index ce7713be18a2..b8f234a6c2ee 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py @@ -25,8 +25,10 @@ def aiter_triton_gemm_check(m, n, k): if m <= 64: - return ((n == 8192 and k == 8192) or (n == 10240 and k == 8192) - or (n == 57344 and k == 8192) or (n == 8192 and k == 28672)) + return ( + (n == 10240 and k == 8192) or (n == 8192 and k == 8192) or (n == 57344 and k == 8192) or (n == 8192 and k == 28672) or + (n == 1280 and k == 8192) or (n == 8192 and k == 1024) or (n == 7168 and k == 8192) or (n == 8192 and k == 3584) + ) return False def gemm_with_dynamic_quant( @@ -67,14 +69,15 @@ def gemm_with_dynamic_quant( gemm_afp4wfp4_preshuffled_weight_scales(x_q.view(torch.uint8), weight.view(torch.uint8).view(weight.shape[0] // 16, -1), x_s, weight_scale.view(torch.uint8).view(weight_scale.shape[0] // 32, -1), out_dtype, y) + else: if x_scales is None: # use hip quant kernel for performance x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=True) else: - x_q = x - x_s = x_scales + x_q = x.view(torch.float4_e2m1fn_x2) + x_s = x_scales.view(torch.float8_e8m0fnu) # 32 alignment is enough for dim0 padding of output for # gemm_a4w4 kernel @@ -82,7 +85,17 @@ def gemm_with_dynamic_quant( weight.shape[0], device=x_q.device, dtype=out_dtype) - + + # weight = weight.view(x_q.dtype) + # weight_scale = weight_scale.view(x_s.dtype) + # print("fp4dtype", x_q.dtype, weight.dtype, x_s.dtype, weight_scale.dtype) + + # gemm_a4w4(x_q, + # weight, + # x_s, + # weight_scale, + # y, + # bpreshuffle=True) gemm_a4w4(x_q, weight.view(x_q.dtype), x_s, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 1dc51fcc00ea..9095e4e3ce99 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -56,16 +56,22 @@ is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) +from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod +from vllm.model_executor.layers.quantization.quark.quark import QuarkLinearMethod +from vllm.model_executor.layers.quantization.quark.schemes.quark_w4a4_mxfp4 import QuarkW4A4MXFP4 from vllm.platforms import current_platform from vllm.logger import init_logger logger = init_logger(__name__) if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER: - from vllm.model_executor.layers.activation import VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT + from vllm.model_executor.layers.activation import VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT, VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP4_QUANT + from vllm.model_executor.layers.layernorm import VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE else: VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT = False + VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP4_QUANT = False + VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT = False VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = False VLLM_ROCM_USE_AITER_MHA = envs.VLLM_ROCM_USE_AITER_MHA @@ -104,15 +110,23 @@ def __init__( if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") - self.block_quant = hasattr(quant_config, "weight_block_size") and quant_config.weight_block_size is not None + self.block_quant = isinstance(self.down_proj.quant_method, Fp8LinearMethod) and hasattr(quant_config, "weight_block_size") and quant_config.weight_block_size is not None if VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT and not self.block_quant: - logger.info("[Aiter] [WARNING] VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT will not be activated because this model is not using blocked quantization") + logger.info(f"[Aiter] [WARNING] VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT will not be activated because {self.__class__.__name__} is not using FP8 blockscale GEMM") + self.fp4_block_quant_gemm = (isinstance(self.down_proj.quant_method, QuarkLinearMethod) and hasattr(self.down_proj, "scheme") and isinstance(self.down_proj.scheme, QuarkW4A4MXFP4)) + if VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP4_QUANT and not self.fp4_block_quant_gemm: + logger.info(f"[Aiter] [WARNING] VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP4_QUANT will not be activated because {self.__class__.__name__} is not using FP4 blockscale GEMM") self.act_fn = SiluAndMul() def forward(self, x): - x, _ = self.gate_up_proj(x) - if VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT and self.block_quant: - x = torch.ops.vllm.act_mul_and_fp8_group_quant(x) + x_quant_scales = None + if isinstance(x, tuple): + x, x_quant_scales = x + x, _ = self.gate_up_proj(x, x_quant_scales=x_quant_scales) + if VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP4_QUANT and self.fp4_block_quant_gemm: + x = torch.ops.vllm.rocm_aiter_act_mul_and_fp4_group_quant(x) + elif VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT and self.block_quant: + x = torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant(x) else: x = self.act_fn(x) x_quant_scales = None @@ -220,7 +234,11 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) + hidden_states_quant = None + if isinstance(hidden_states, tuple): + hidden_states, hidden_states_quant = hidden_states + + qkv, _ = self.qkv_proj(hidden_states, x_quant_scales = hidden_states_quant) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE: @@ -316,6 +334,12 @@ def __init__( eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm_with_fp4_block_quant_gemm = (isinstance(self.self_attn.qkv_proj.quant_method, QuarkLinearMethod) and hasattr(self.self_attn.qkv_proj, "scheme") and isinstance(self.self_attn.qkv_proj.scheme, QuarkW4A4MXFP4)) + if VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP4_QUANT and not self.input_layernorm_with_fp4_block_quant_gemm: + logger.info(f"[Aiter] [WARNING] VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT will not be activated because {self.self_attn.__class__.__name__} is not using FP4 blockscale GEMM") + self.post_attention_layernorm_with_fp4_block_quant_gemm = (isinstance(self.mlp.gate_up_proj.quant_method, QuarkLinearMethod) and hasattr(self.mlp.gate_up_proj, "scheme") and isinstance(self.mlp.gate_up_proj.scheme, QuarkW4A4MXFP4)) + if VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP4_QUANT and not self.post_attention_layernorm_with_fp4_block_quant_gemm: + logger.info(f"[Aiter] [WARNING] VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT will not be activated because {self.mlp.__class__.__name__} is not using FP4 blockscale GEMM") def forward( self, @@ -324,18 +348,34 @@ def forward( residual: Optional[torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) + if VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT and self.input_layernorm_with_fp4_block_quant_gemm: + weight = self.input_layernorm.weight + eps = self.input_layernorm.variance_epsilon + if residual is None: + residual = hidden_states + hidden_states_quant, hidden_states_quant_scales, _ = torch.ops.vllm.rocm_aiter_fused_rms_and_fp4_group_quant(hidden_states, weight, eps, None) + else: + hidden_states_quant, hidden_states_quant_scales, residual = torch.ops.vllm.rocm_aiter_fused_rms_and_fp4_group_quant(hidden_states, weight, eps, residual) + hidden_states = (hidden_states_quant, hidden_states_quant_scales) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) hidden_states = self.self_attn(positions=positions, - hidden_states=hidden_states) + hidden_states=hidden_states) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + if VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT and self.post_attention_layernorm_with_fp4_block_quant_gemm: + weight = self.post_attention_layernorm.weight + eps = self.post_attention_layernorm.variance_epsilon + hidden_states_quant, hidden_states_quant_scales, residual = torch.ops.vllm.rocm_aiter_fused_rms_and_fp4_group_quant(hidden_states, weight, eps, residual) + hidden_states = (hidden_states_quant, hidden_states_quant_scales) + else: + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual