Skip to content

Commit ee27d08

Browse files
ProExpertProglulmer
authored andcommitted
[FP8] Refactor apply_fp8_linear and apply_fp8_linear_generic into an object (vllm-project#14390)
Signed-off-by: luka <[email protected]> Signed-off-by: Louis Ulmer <[email protected]>
1 parent c1c2455 commit ee27d08

File tree

11 files changed

+268
-242
lines changed

11 files changed

+268
-242
lines changed

tests/compile/test_fusion.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
1414
from vllm.model_executor.layers.layernorm import RMSNorm
1515
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
16-
CUTLASS_FP8_SUPPORTED, apply_fp8_linear, maybe_create_device_identity)
16+
CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity)
1717

1818
from .backend import TestBackend
1919

@@ -34,26 +34,20 @@ def __init__(self, hidden_size: int, eps: float, static: bool,
3434
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
3535
for _ in range(2)
3636
]
37+
self.fp8_linear = Fp8LinearOp(
38+
cutlass_fp8_supported=cutlass_fp8_enabled,
39+
use_per_token_if_dynamic=True)
3740

3841
def forward(self, x):
3942
resid = torch.sqrt(x)
4043
y = self.norm[0](x)
4144

42-
x2 = apply_fp8_linear(y,
43-
self.w[0],
44-
self.wscale[0],
45-
self.scale[0],
46-
use_per_token_if_dynamic=True,
47-
cutlass_fp8_supported=self.cutlass_fp8_enabled)
45+
x2 = self.fp8_linear.apply(y, self.w[0], self.wscale[0], self.scale[0])
4846
# make sure resid is used for replacement to work
4947
y2, resid = self.norm[1](x2, resid)
5048

51-
x3 = apply_fp8_linear(y2,
52-
self.w[1],
53-
self.wscale[1],
54-
self.scale[1],
55-
use_per_token_if_dynamic=True,
56-
cutlass_fp8_supported=self.cutlass_fp8_enabled)
49+
x3 = self.fp8_linear.apply(y2, self.w[1], self.wscale[1],
50+
self.scale[1])
5751
y3, resid = self.norm[2](x3, resid) # use resid here
5852
return y3
5953

vllm/attention/backends/mla/common.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@
226226
CompressedTensorsW8A8Fp8)
227227
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
228228
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
229-
apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8)
229+
Fp8LinearGenericOp, current_platform_fp8_dtype, is_fp8)
230230
from vllm.model_executor.layers.quantization.utils.quant_utils import (
231231
scaled_quantize)
232232
from vllm.model_executor.layers.rotary_embedding import (
@@ -1057,6 +1057,7 @@ def __init__(
10571057
self.kv_b_proj = kv_b_proj
10581058
self.o_proj = o_proj
10591059
self.triton_fa_func = triton_attention
1060+
self.fp8_linear_generic = Fp8LinearGenericOp()
10601061

10611062
# Handle the differences between the flash_attn_varlen from flash_attn
10621063
# and the one from vllm_flash_attn. The former is used on RoCM and the
@@ -1071,7 +1072,7 @@ def __init__(
10711072
def _v_up_proj_and_o_proj(self, x):
10721073
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
10731074
if is_fp8(self.W_UV_O):
1074-
output_parallel = apply_fp8_linear_generic(
1075+
output_parallel = self.fp8_linear_generic.apply(
10751076
x.flatten(start_dim=1), self.W_UV_O, self.W_UV_O_scales,
10761077
self.reqaunt_input_group_shape,
10771078
self.reqaunt_weight_group_shape)
@@ -1091,7 +1092,7 @@ def _v_up_proj_and_o_proj(self, x):
10911092
def _q_proj_and_k_up_proj(self, x):
10921093
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
10931094
if is_fp8(self.W_Q_UK):
1094-
return apply_fp8_linear_generic(
1095+
return self.fp8_linear_generic.apply(
10951096
x, self.W_Q_UK, self.W_Q_UK_scales,
10961097
self.reqaunt_input_group_shape,
10971098
self.reqaunt_weight_group_shape).view(

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
1010
CompressedTensorsScheme)
1111
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
12-
apply_fp8_linear, cutlass_fp8_supported, maybe_create_device_identity,
13-
normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale)
12+
Fp8LinearOp, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz,
13+
requantize_with_max_scale)
1414
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
1515
ModelWeightParameter,
1616
PerTensorScaleParameter)
@@ -24,7 +24,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
2424
def __init__(self, strategy: str, is_static_input_scheme: bool):
2525
self.strategy = strategy
2626
self.is_static_input_scheme = is_static_input_scheme
27-
self.cutlass_fp8_supported = cutlass_fp8_supported()
27+
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
2828

2929
@classmethod
3030
def get_min_capability(cls) -> int:
@@ -140,11 +140,8 @@ def apply_weights(self,
140140
x: torch.Tensor,
141141
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
142142

143-
return apply_fp8_linear(
144-
input=x,
145-
weight=layer.weight,
146-
weight_scale=layer.weight_scale,
147-
input_scale=layer.input_scale,
148-
bias=bias,
149-
cutlass_fp8_supported=self.cutlass_fp8_supported,
150-
use_per_token_if_dynamic=True)
143+
return self.fp8_linear.apply(input=x,
144+
weight=layer.weight,
145+
weight_scale=layer.weight_scale,
146+
input_scale=layer.input_scale,
147+
bias=bias)

vllm/model_executor/layers/quantization/fbgemm_fp8.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,12 @@
1111
UnquantizedLinearMethod)
1212
from vllm.model_executor.layers.quantization.base_config import (
1313
QuantizationConfig, QuantizeMethodBase)
14-
from vllm.model_executor.layers.quantization.fp8 import cutlass_fp8_supported
1514
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
1615
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
1716
from vllm.model_executor.layers.quantization.utils.quant_utils import (
1817
is_layer_skipped)
1918
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
20-
apply_fp8_linear, maybe_create_device_identity,
21-
normalize_e4m3fn_to_e4m3fnuz)
19+
Fp8LinearOp, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz)
2220
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
2321
ModelWeightParameter)
2422
from vllm.platforms import current_platform
@@ -37,6 +35,7 @@ def __init__(self, ignore_list: List[str], input_scale_ub: float):
3735
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
3836
# kernel for fast weight-only FP8 quantization
3937
self.use_marlin = not current_platform.has_device_capability(89)
38+
self.fp8_linear = Fp8LinearOp()
4039

4140
@classmethod
4241
def get_name(cls) -> str:
@@ -73,7 +72,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
7372

7473
def __init__(self, quant_config: FBGEMMFp8Config):
7574
self.quant_config = quant_config
76-
self.cutlass_fp8_supported = cutlass_fp8_supported()
75+
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
7776

7877
def create_weights(
7978
self,
@@ -159,12 +158,9 @@ def apply(self,
159158
size_k=layer.input_size_per_partition,
160159
bias=bias)
161160

162-
return apply_fp8_linear(
163-
input=x,
164-
weight=layer.weight,
165-
weight_scale=layer.weight_scale,
166-
input_scale=None,
167-
input_scale_ub=layer.input_scale_ub,
168-
bias=bias,
169-
cutlass_fp8_supported=self.cutlass_fp8_supported,
170-
use_per_token_if_dynamic=True)
161+
return self.fp8_linear.apply(input=x,
162+
weight=layer.weight,
163+
weight_scale=layer.weight_scale,
164+
input_scale=None,
165+
input_scale_ub=layer.input_scale_ub,
166+
bias=bias)

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from vllm.model_executor.layers.quantization.utils.quant_utils import (
2424
is_layer_skipped)
2525
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
26-
all_close_1d, apply_fp8_linear, convert_to_channelwise,
26+
Fp8LinearOp, all_close_1d, convert_to_channelwise,
2727
cutlass_block_fp8_supported, cutlass_fp8_supported,
2828
maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz,
2929
per_tensor_dequantize, requantize_with_max_scale)
@@ -137,7 +137,6 @@ class Fp8LinearMethod(LinearMethodBase):
137137

138138
def __init__(self, quant_config: Fp8Config):
139139
self.quant_config = quant_config
140-
self.cutlass_fp8_supported = cutlass_fp8_supported()
141140
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
142141

143142
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
@@ -153,6 +152,10 @@ def __init__(self, quant_config: Fp8Config):
153152
# Marlin doesn't support block-wise fp8
154153
self.use_marlin = False
155154

155+
self.fp8_linear = Fp8LinearOp(
156+
# Default to using per_token quantization if cutlass is supported
157+
use_per_token_if_dynamic=cutlass_fp8_supported())
158+
156159
def create_weights(
157160
self,
158161
layer: torch.nn.Module,
@@ -381,15 +384,11 @@ def apply(self,
381384
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
382385
)
383386

384-
return apply_fp8_linear(
385-
input=x,
386-
weight=layer.weight,
387-
weight_scale=layer.weight_scale,
388-
input_scale=layer.input_scale,
389-
bias=bias,
390-
cutlass_fp8_supported=self.cutlass_fp8_supported,
391-
# Default to using per_token quantization if cutlass is supported
392-
use_per_token_if_dynamic=self.cutlass_fp8_supported)
387+
return self.fp8_linear.apply(input=x,
388+
weight=layer.weight,
389+
weight_scale=layer.weight_scale,
390+
input_scale=layer.input_scale,
391+
bias=bias)
393392

394393

395394
class Fp8MoEMethod(FusedMoEMethodBase):

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
QuantizationConfig, QuantizeMethodBase)
1313
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
1414
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
15-
apply_fp8_linear, cutlass_fp8_supported, requantize_with_max_scale)
15+
Fp8LinearOp, requantize_with_max_scale)
1616
from vllm.model_executor.parameter import (ModelWeightParameter,
1717
PerTensorScaleParameter)
1818

@@ -95,7 +95,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
9595

9696
def __init__(self, quant_config: ModelOptFp8Config):
9797
self.quant_config = quant_config
98-
self.cutlass_fp8_supported = cutlass_fp8_supported()
98+
self.fp8_linear = Fp8LinearOp()
9999

100100
def create_weights(
101101
self,
@@ -157,10 +157,8 @@ def apply(
157157
x: torch.Tensor,
158158
bias: Optional[torch.Tensor] = None,
159159
) -> torch.Tensor:
160-
return apply_fp8_linear(
161-
input=x,
162-
weight=layer.weight,
163-
weight_scale=layer.weight_scale,
164-
input_scale=layer.input_scale,
165-
bias=bias,
166-
cutlass_fp8_supported=self.cutlass_fp8_supported)
160+
return self.fp8_linear.apply(input=x,
161+
weight=layer.weight,
162+
weight_scale=layer.weight_scale,
163+
input_scale=layer.input_scale,
164+
bias=bias)

vllm/model_executor/layers/quantization/ptpc_fp8.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from vllm.model_executor.layers.quantization.utils.quant_utils import (
1818
is_layer_skipped)
1919
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
20-
apply_fp8_linear)
20+
Fp8LinearOp)
2121
from vllm.platforms import current_platform
2222

2323
ACTIVATION_SCHEMES = ["static", "dynamic"]
@@ -93,6 +93,8 @@ def __init__(self, quant_config: PTPCFp8Config):
9393
super().__init__(quant_config=quant_config)
9494
# Force weight quantization
9595
self.quant_config.is_checkpoint_fp8_serialized = False
96+
self.fp8_linear = Fp8LinearOp(cutlass_fp8_supported=False,
97+
use_per_token_if_dynamic=True)
9698

9799
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
98100
layer.weight = torch.nn.Parameter(layer.weight.data,
@@ -115,11 +117,9 @@ def apply(self,
115117
x: torch.Tensor,
116118
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
117119

118-
return apply_fp8_linear(input=x,
119-
weight=layer.weight,
120-
weight_scale=layer.weight_scale,
121-
input_scale=None,
122-
input_scale_ub=None,
123-
bias=bias,
124-
cutlass_fp8_supported=False,
125-
use_per_token_if_dynamic=True)
120+
return self.fp8_linear.apply(input=x,
121+
weight=layer.weight,
122+
weight_scale=layer.weight_scale,
123+
input_scale=None,
124+
input_scale_ub=None,
125+
bias=bias)

vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77

88
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
99
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
10-
apply_fp8_linear, cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz,
11-
requantize_with_max_scale)
10+
Fp8LinearOp, normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale)
1211
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
1312
ModelWeightParameter,
1413
PerTensorScaleParameter)
@@ -22,7 +21,7 @@ class QuarkW8A8Fp8(QuarkScheme):
2221
def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool]):
2322
self.qscheme = qscheme
2423
self.is_static_input_scheme = is_static_input_scheme
25-
self.cutlass_fp8_supported = cutlass_fp8_supported()
24+
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
2625

2726
@classmethod
2827
def get_min_capability(cls) -> int:
@@ -132,11 +131,8 @@ def apply_weights(self,
132131
x: torch.Tensor,
133132
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
134133

135-
return apply_fp8_linear(
136-
input=x,
137-
weight=layer.weight,
138-
weight_scale=layer.weight_scale,
139-
input_scale=layer.input_scale,
140-
bias=bias,
141-
cutlass_fp8_supported=self.cutlass_fp8_supported,
142-
use_per_token_if_dynamic=True)
134+
return self.fp8_linear.apply(input=x,
135+
weight=layer.weight,
136+
weight_scale=layer.weight_scale,
137+
input_scale=layer.input_scale,
138+
bias=bias)

0 commit comments

Comments
 (0)