77
88from vllm import _custom_ops as ops
99from vllm .model_executor .custom_op import CustomOp
10- from vllm .model_executor .layers .quantization .utils .fp8_quant_ops import (
11- quantize_fp8_per_group , quantize_fp8_per_tensor , quantize_fp8_per_token )
1210from vllm .model_executor .layers .quantization .utils .quant_utils import (
1311 GroupShape )
1412from vllm .platforms import current_platform
1513
1614# Using the default value (240.0) from pytorch will cause accuracy
1715# issue on dynamic quantization models. Here use 224.0 for fnuz on ROCm.
1816_FP8_DTYPE = current_platform .fp8_dtype ()
17+ _FP8_FINFO = torch .finfo (_FP8_DTYPE )
18+ _FP8_MAX = 224.0 if current_platform .is_fp8_fnuz () else _FP8_FINFO .max
19+ _FP8_MIN = - 224.0 if current_platform .is_fp8_fnuz () else _FP8_FINFO .min
20+ _FP8_MIN_SCALING_FACTOR = 1.0 / (_FP8_MAX * 512.0 )
1921
2022
2123@CustomOp .register ("quant_fp8" )
@@ -92,9 +94,25 @@ def forward_native(
9294 and scale_ub .numel () == 1 )
9395
9496 if self .use_per_token_if_dynamic and scale is None :
95- out , scale = quantize_fp8_per_token (x , scale , scale_ub )
97+ # Per-token quantization logic
98+ x_max , _ = x .abs ().max (dim = - 1 )
99+ x_max = x_max .unsqueeze (- 1 ).to (torch .float32 )
100+ if scale_ub is not None :
101+ x_max = x_max .clamp (max = scale_ub )
102+ scale = (x_max / _FP8_MAX ).clamp (min = _FP8_MIN_SCALING_FACTOR )
103+
104+ out = x .to (torch .float32 ) * scale .reciprocal ()
105+ out = out .clamp (_FP8_MIN , _FP8_MAX ).to (_FP8_DTYPE )
96106 else :
97- out , scale = quantize_fp8_per_tensor (x , scale )
107+ # Per-tensor quantization logic
108+ if scale is None :
109+ x_max = x .abs ().max ().unsqueeze (- 1 ).to (torch .float32 )
110+ scale = (x_max / _FP8_MAX ).clamp (min = _FP8_MIN_SCALING_FACTOR )
111+
112+ # Even for dynamic per-token scales,
113+ # reciprocal performs slightly better than division
114+ out = x .to (torch .float32 ) * scale .reciprocal ()
115+ out = out .clamp (_FP8_MIN , _FP8_MAX ).to (_FP8_DTYPE )
98116
99117 # This currently generates an extra Triton kernel in compilation.
100118 # Fortunately, we don't use padding if compiling.
@@ -118,5 +136,31 @@ def _quantize_group_cuda(
118136
119137 def _quantize_group_native (
120138 self , x : torch .Tensor ) -> tuple [torch .Tensor , torch .Tensor ]:
121- return quantize_fp8_per_group (x , self .group_size ,
122- self .column_major_scales )
139+ orig_shape = x .shape
140+ hidden_dim = x .shape [- 1 ]
141+ num_groups = (hidden_dim + self .group_size - 1 ) // self .group_size
142+ padded_dim = num_groups * self .group_size
143+
144+ if padded_dim != hidden_dim :
145+ padding = padded_dim - hidden_dim
146+ x = F .pad (x , (0 , padding ), mode = 'constant' , value = 0.0 )
147+
148+ x_grouped = x .view (- 1 , num_groups , self .group_size )
149+ absmax = x_grouped .abs ().max (dim = - 1 , keepdim = True )[0 ].float ()
150+ scales = (absmax / _FP8_MAX ).clamp (min = _FP8_MIN_SCALING_FACTOR )
151+
152+ x_scaled = x_grouped / scales
153+ x_quant = x_scaled .clamp (_FP8_MIN , _FP8_MAX ).to (_FP8_DTYPE )
154+
155+ x_quant = x_quant .view (- 1 , padded_dim )
156+ if padded_dim != hidden_dim :
157+ x_quant = x_quant [..., :hidden_dim ]
158+ x_quant = x_quant .view (orig_shape )
159+
160+ scales = scales .squeeze (- 1 )
161+ scales = scales .reshape (orig_shape [:- 1 ] + (num_groups , ))
162+
163+ if self .column_major_scales :
164+ scales = scales .transpose (- 2 , - 1 ).contiguous ()
165+
166+ return x_quant , scales
0 commit comments