15
15
from vllm .model_executor .layers .quantization .utils .quant_utils import (
16
16
_normalize_quant_group_shape , scaled_dequantize )
17
17
from vllm .model_executor .layers .quantization .utils .w8a8_utils import (
18
- CUTLASS_BLOCK_FP8_SUPPORTED , CUTLASS_FP8_SUPPORTED , apply_fp8_linear )
18
+ CUTLASS_BLOCK_FP8_SUPPORTED , Fp8LinearOp , cutlass_block_fp8_supported ,
19
+ cutlass_fp8_supported )
19
20
from vllm .platforms import current_platform
20
21
from vllm .utils import direct_register_custom_op
21
22
@@ -32,6 +33,7 @@ def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
32
33
return x == torch .float8_e4m3fn or x == torch .float8_e4m3fnuz
33
34
34
35
36
+ # TODO fix ROCm->Triton custom path
35
37
def apply_w8a8_block_fp8_linear (
36
38
input : torch .Tensor ,
37
39
weight : torch .Tensor ,
@@ -104,43 +106,54 @@ def apply_w8a8_block_fp8_linear_fake(
104
106
# Unify the interface between `apply_w8a8_block_fp8_linear` and
105
107
# `apply_fp8_linear`
106
108
# NOTE(lucas): this is quite messy, we should think through this more formally
107
- def apply_fp8_linear_generic (
108
- input : torch .Tensor ,
109
- weight : torch .Tensor ,
110
- weight_scale : torch .Tensor ,
111
- input_group_shape : Tuple [int , int ],
112
- weight_group_shape : Tuple [int , int ],
113
- input_scale : Optional [torch .Tensor ] = None , # static scale if one
114
- cutlass_fp8_supported : bool = CUTLASS_FP8_SUPPORTED ,
115
- cutlass_block_fp8_supported : bool = CUTLASS_BLOCK_FP8_SUPPORTED ,
116
- ) -> torch .Tensor :
117
- # View input as 2D matrix for fp8 methods
118
- input = input .view (- 1 , input .shape [- 1 ])
119
-
120
- weight_group_shape = _normalize_quant_group_shape (\
121
- weight , weight_group_shape )
122
- input_group_shape = _normalize_quant_group_shape (input , input_group_shape )
123
-
124
- def is_dim_blocked (dim , shape , group_shape ):
125
- return group_shape < shape [dim ] and group_shape > 1
126
-
127
- if is_dim_blocked (0 , weight .shape , weight_group_shape [0 ])\
128
- and is_dim_blocked (1 , weight .shape , weight_group_shape [1 ]) and \
129
- input_group_shape == (1 , weight_group_shape [1 ]):
130
- return apply_w8a8_block_fp8_linear (
131
- input ,
132
- weight ,
133
- list (weight_group_shape ),
134
- weight_scale ,
135
- cutlass_block_fp8_supported = cutlass_block_fp8_supported )
136
- else :
137
- # Despite having linear in the it doesn't conform to
138
- # `torch.nn.functional.linear` which is defined as `input @ weight.T`
139
- # so we explicitly transpose the weight matrix here
140
- return apply_fp8_linear (input , weight .T , weight_scale .T ,
141
- cutlass_fp8_supported = cutlass_fp8_supported ,
142
- use_per_token_if_dynamic = \
143
- (input_group_shape == (1 , input .shape [1 ])))
109
+ # TODO(luka): unify this better
110
+ class Fp8LinearGenericOp :
111
+
112
+ def __init__ (
113
+ self ,
114
+ cutlass_fp8_supported : bool = cutlass_fp8_supported (),
115
+ cutlass_block_fp8_supported : bool = cutlass_block_fp8_supported (),
116
+ ):
117
+ self .cutlass_block_fp8_supported = cutlass_block_fp8_supported
118
+ self .fp8_linear = Fp8LinearOp (
119
+ cutlass_fp8_supported = cutlass_fp8_supported )
120
+
121
+ def apply (
122
+ self ,
123
+ input : torch .Tensor ,
124
+ weight : torch .Tensor ,
125
+ weight_scale : torch .Tensor ,
126
+ input_group_shape : Tuple [int , int ],
127
+ weight_group_shape : Tuple [int , int ],
128
+ input_scale : Optional [torch .Tensor ] = None , # static scale if one
129
+ ) -> torch .Tensor :
130
+ # View input as 2D matrix for fp8 methods
131
+ input = input .view (- 1 , input .shape [- 1 ])
132
+
133
+ weight_group_shape = _normalize_quant_group_shape ( \
134
+ weight , weight_group_shape )
135
+ input_group_shape = _normalize_quant_group_shape (
136
+ input , input_group_shape )
137
+
138
+ def is_dim_blocked (dim , shape , group_shape ):
139
+ return group_shape < shape [dim ] and group_shape > 1
140
+
141
+ if is_dim_blocked (0 , weight .shape , weight_group_shape [0 ])\
142
+ and is_dim_blocked (1 , weight .shape , weight_group_shape [1 ]) and \
143
+ input_group_shape == (1 , weight_group_shape [1 ]):
144
+ return apply_w8a8_block_fp8_linear (
145
+ input ,
146
+ weight ,
147
+ list (weight_group_shape ),
148
+ weight_scale ,
149
+ cutlass_block_fp8_supported = self .cutlass_block_fp8_supported )
150
+ else :
151
+ # Despite having linear in the name it doesn't conform to
152
+ # `torch.nn.functional.linear` which is defined as
153
+ # `input @ weight.T` so we explicitly transpose the weight matrix
154
+ return self .fp8_linear .apply (input , weight .T , weight_scale .T ,
155
+ use_per_token_if_dynamic = \
156
+ (input_group_shape == (1 , input .shape [1 ])))
144
157
145
158
146
159
def input_to_float8 (
0 commit comments