6
6
7
7
import torch .nn .functional as F
8
8
from torch import Tensor
9
- from torchao .dtypes import to_affine_quantized_intx_static
9
+ from torchao .dtypes import (
10
+ to_affine_quantized_intx_static ,
11
+ to_affine_quantized_floatx_static ,
12
+ Float8LayoutType ,
13
+ )
10
14
from torchao .quantization .utils import compute_error
11
15
from torchao .quantization import quantize_
12
16
from torchao .quantization import to_linear_activation_quantized
18
22
)
19
23
from torchao .quantization .quant_primitives import (
20
24
MappingType ,
25
+ FP8_TYPES ,
21
26
)
22
27
23
28
@@ -51,53 +56,81 @@ def replacement_fn(m):
51
56
52
57
# converting observed linear module to linear module with quantzied weights (and quantized activations)
53
58
# with tensor subclasses
54
- def apply_static_quant (observed_linear ):
55
- target_dtype = torch .uint8
56
-
57
- # weight quantization
58
- weight_scale , weight_zero_point = observed_linear .weight_obs .calculate_qparams ()
59
- def weight_quant_func (weight ):
60
- block_size = (1 , weight .shape [1 ])
61
- return to_affine_quantized_intx_static (weight , weight_scale , weight_zero_point , block_size , target_dtype )
62
- linear = torch .nn .Linear (observed_linear .in_features , observed_linear .out_features , False , device = observed_linear .weight .device , dtype = observed_linear .weight .dtype )
63
- linear .weight = observed_linear .weight
64
- linear .bias = observed_linear .bias
65
-
66
- linear .weight = torch .nn .Parameter (weight_quant_func (linear .weight ), requires_grad = False )
67
-
68
- # activation quantization
69
- act_scale , act_zero_point = observed_linear .act_obs .calculate_qparams ()
70
- input_quant_func = lambda x : to_affine_quantized_intx_static (x , act_scale , act_zero_point , x .shape , target_dtype )
71
- linear .weight = torch .nn .Parameter (to_linear_activation_quantized (linear .weight , input_quant_func ), requires_grad = False )
72
-
73
- return linear
74
-
59
+ def apply_static_quant (target_dtype : torch .dtype ):
60
+ # target_dtype = torch.uint8
61
+ def _apply_static_quant_to_linear (observed_linear ):
62
+ # weight quantization
63
+ weight_scale , weight_zero_point = observed_linear .weight_obs .calculate_qparams ()
64
+ def weight_quant_func (weight ):
65
+ block_size = (1 , weight .shape [1 ])
66
+ if target_dtype == torch .uint8 :
67
+ return to_affine_quantized_intx_static (weight , weight_scale , weight_zero_point , block_size , target_dtype )
68
+ elif target_dtype == torch .float8_e4m3fn :
69
+ return to_affine_quantized_floatx_static (weight , weight_scale , block_size , target_dtype , Float8LayoutType (mm_config = None ))
70
+ else :
71
+ raise ValueError (f"Unsupported target dtype { target_dtype } " )
72
+ linear = torch .nn .Linear (observed_linear .in_features , observed_linear .out_features , False , device = observed_linear .weight .device , dtype = observed_linear .weight .dtype )
73
+ linear .weight = observed_linear .weight
74
+ linear .bias = observed_linear .bias
75
+
76
+ linear .weight = torch .nn .Parameter (weight_quant_func (linear .weight ), requires_grad = False )
77
+
78
+ # activation quantization
79
+ act_scale , act_zero_point = observed_linear .act_obs .calculate_qparams ()
80
+ if target_dtype == torch .uint8 :
81
+ input_quant_func = lambda x : to_affine_quantized_intx_static (x , act_scale , act_zero_point , x .shape , target_dtype )
82
+ elif target_dtype == torch .float8_e4m3fn :
83
+ input_quant_func = lambda x : to_affine_quantized_floatx_static (x , act_scale , x .shape , target_dtype , Float8LayoutType (mm_config = None ))
84
+ else :
85
+ raise ValueError (f"Unsupported target dtype { target_dtype } " )
86
+ linear .weight = torch .nn .Parameter (to_linear_activation_quantized (linear .weight , input_quant_func ), requires_grad = False )
87
+
88
+ return linear
89
+
90
+ return _apply_static_quant_to_linear
75
91
76
92
# alternative for converting observed linear module to quantized linear module
77
93
class QuantizedLinear (torch .nn .Module ):
78
- def __init__ (self , in_features : int , out_features : int , act_obs : torch .nn .Module , weight_obs : torch .nn .Module , weight : torch .Tensor , bias : torch .Tensor ):
94
+ def __init__ (self , in_features : int , out_features : int , act_obs : torch .nn .Module , weight_obs : torch .nn .Module , weight : torch .Tensor , bias : torch .Tensor , target_dtype : torch . dtype ):
79
95
super ().__init__ ()
80
96
self .act_scale , self .act_zero_point = act_obs .calculate_qparams ()
81
97
weight_scale , weight_zero_point = weight_obs .calculate_qparams ()
82
98
assert weight .dim () == 2
83
99
block_size = (1 , weight .shape [1 ])
84
- target_dtype = torch .uint8
85
- self .qweight = to_affine_quantized_intx_static (weight , weight_scale , weight_zero_point , block_size , target_dtype )
100
+ self .target_dtype = target_dtype
86
101
self .bias = bias
102
+ if self .target_dtype == torch .uint8 :
103
+ self .qweight = to_affine_quantized_intx_static (weight , weight_scale , weight_zero_point , block_size , self .target_dtype )
104
+ elif self .target_dtype == torch .float8_e4m3fn :
105
+ self .qweight = to_affine_quantized_floatx_static (weight , weight_scale , block_size , target_dtype , Float8LayoutType (mm_config = None ))
106
+ else :
107
+ raise ValueError (f"Unsupported target dtype { self .target_dtype } " )
87
108
88
109
def forward (self , input : Tensor ):
89
110
block_size = input .shape
90
- target_dtype = torch .uint8
91
- qinput = to_affine_quantized_intx_static (input , self .act_scale , self .act_zero_point , block_size , target_dtype )
111
+ if self .target_dtype == torch .uint8 :
112
+ qinput = to_affine_quantized_intx_static (input , self .act_scale , self .act_zero_point , block_size , self .target_dtype )
113
+ elif self .target_dtype == torch .float8_e4m3fn :
114
+ qinput = to_affine_quantized_floatx_static (input , self .act_scale , block_size , self .target_dtype , Float8LayoutType (mm_config = None ))
115
+ else :
116
+ raise ValueError (f"Unsupported target dtype { self .target_dtype } " )
92
117
return F .linear (qinput , self .qweight , self .bias )
93
118
94
119
@classmethod
95
- def from_observed (cls , observed_linear ):
96
- quantized_linear = cls (observed_linear .in_features , observed_linear .out_features , observed_linear .act_obs , observed_linear .weight_obs , observed_linear .weight , observed_linear .bias )
120
+ def from_observed (cls , observed_linear , target_dtype ):
121
+ quantized_linear = cls (observed_linear .in_features ,
122
+ observed_linear .out_features ,
123
+ observed_linear .act_obs ,
124
+ observed_linear .weight_obs ,
125
+ observed_linear .weight ,
126
+ observed_linear .bias ,
127
+ target_dtype )
97
128
return quantized_linear
98
129
99
- def apply_static_quant2 (observed_linear ):
100
- return QuantizedLinear .from_observed (observed_linear )
130
+ def apply_static_quant2 (target_dtype : torch .dtype ):
131
+ def _apply_static_quant2 (observed_linear ):
132
+ return QuantizedLinear .from_observed (observed_linear , target_dtype )
133
+ return _apply_static_quant2
101
134
102
135
class ToyLinearModel (torch .nn .Module ):
103
136
def __init__ (self , m = 64 , n = 32 , k = 64 ):
@@ -113,46 +146,54 @@ def forward(self, x):
113
146
x = self .linear2 (x )
114
147
return x
115
148
116
- torch .manual_seed (0 )
117
149
118
- dtype = torch .bfloat16
119
- m = ToyLinearModel ().eval ().to (dtype ).to ("cuda" )
150
+ def test_static_quant (target_dtype : torch .dtype , mapping_type : MappingType ):
151
+ print (f"Testing { target_dtype } static quantization:" )
152
+ torch .manual_seed (0 )
153
+
154
+ dtype = torch .bfloat16
155
+ m = ToyLinearModel ().eval ().to (dtype ).to ("cuda" )
156
+
157
+ m_for_test = copy .deepcopy (m )
158
+
159
+ m_bf16 = copy .deepcopy (m )
160
+ example_inputs = m .example_inputs (dtype = dtype , device = "cuda" )
161
+ print ("example inputs shape:" , example_inputs [0 ].shape )
120
162
121
- m_for_test = copy . deepcopy ( m )
163
+ m_bf16 = torch . compile ( m_bf16 , mode = 'max-autotune' )
122
164
123
- m_bf16 = copy .deepcopy (m )
124
- example_inputs = m .example_inputs (dtype = dtype , device = "cuda" )
125
- print ("example inputs shape:" , example_inputs [0 ].shape )
165
+ act_obs = AffineQuantizedMinMaxObserver (mapping_type , target_dtype , granularity_type = PerTensor (), eps = torch .finfo (torch .float32 ).eps , scale_dtype = torch .float32 , zero_point_dtype = torch .float32 )
166
+ weight_obs = AffineQuantizedMinMaxObserver (mapping_type , target_dtype , granularity_type = PerAxis (axis = 0 ), eps = torch .finfo (torch .float32 ).eps , scale_dtype = torch .float32 , zero_point_dtype = torch .float32 )
126
167
127
- m_bf16 = torch . compile ( m_bf16 , mode = 'max-autotune' )
168
+ before_quant = m ( * example_inputs )
128
169
129
- act_obs = AffineQuantizedMinMaxObserver (MappingType .ASYMMETRIC , torch .uint8 , granularity_type = PerTensor (), eps = torch .finfo (torch .float32 ).eps , scale_dtype = torch .float32 , zero_point_dtype = torch .int32 )
130
- weight_obs = AffineQuantizedMinMaxObserver (MappingType .ASYMMETRIC , torch .uint8 , granularity_type = PerAxis (axis = 0 ), eps = torch .finfo (torch .float32 ).eps , scale_dtype = torch .float32 , zero_point_dtype = torch .int32 )
170
+ insert_observers_ (m , act_obs , weight_obs )
171
+ # calibrating / training
172
+ for _ in range (10 ):
173
+ m (* example_inputs )
131
174
132
- before_quant = m (* example_inputs )
175
+ after_obs = m (* example_inputs )
133
176
134
- insert_observers_ (m , act_obs , weight_obs )
135
- # calibrating / training
136
- for _ in range (10 ):
137
- m (* example_inputs )
177
+ m2 = copy .deepcopy (m )
138
178
139
- after_obs = m ( * example_inputs )
179
+ is_observed_linear = lambda m , fqn : isinstance ( m , ObservedLinear )
140
180
141
- m2 = copy .deepcopy (m )
181
+ # quantized linear represented as an nn.Linear with modified tensor subclass weights
182
+ # for both activation and weight quantization
183
+ quantize_ (m , apply_static_quant (target_dtype ), is_observed_linear )
184
+ print ("quantized model (applying tensor subclass to weight):" , m )
185
+ after_quant = m (* example_inputs )
186
+ assert compute_error (before_quant , after_quant ) > 25
187
+ print ("test passed" )
142
188
143
- is_observed_linear = lambda m , fqn : isinstance (m , ObservedLinear )
189
+ # quantized linear as a standalone module
190
+ quantize_ (m2 , apply_static_quant2 (target_dtype ), is_observed_linear )
191
+ print ("quantized model (quantized module):" , m2 )
192
+ after_quant = m2 (* example_inputs )
193
+ assert compute_error (before_quant , after_quant ) > 25
194
+ print ("test passed" )
144
195
145
- # quantized linear represented as an nn.Linear with modified tensor subclass weights
146
- # for both activation and weight quantization
147
- quantize_ (m , apply_static_quant , is_observed_linear )
148
- print ("quantized model (applying tensor subclass to weight):" , m )
149
- after_quant = m (* example_inputs )
150
- assert compute_error (before_quant , after_quant ) > 30
151
- print ("test passed" )
152
196
153
- # quantized linear as a standalone module
154
- quantize_ (m2 , apply_static_quant2 , is_observed_linear )
155
- print ("quantized model (quantized module):" , m2 )
156
- after_quant = m2 (* example_inputs )
157
- assert compute_error (before_quant , after_quant ) > 30
158
- print ("test passed" )
197
+ if __name__ == "__main__" :
198
+ test_static_quant (torch .uint8 , MappingType .ASYMMETRIC )
199
+ test_static_quant (torch .float8_e4m3fn , MappingType .SYMMETRIC )
0 commit comments