4
4
# This source code is licensed under the license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
- from enum import Enum
7
+ from enum import Enum , auto
8
8
from typing import List , Optional , Tuple , Dict
9
9
import torch
10
10
11
11
from torchao .kernel .intmm import int_scaled_matmul
12
12
from torchao .kernel .intmm import safe_int_mm
13
- from torchao .utils import TORCH_VERSION_AFTER_2_3
13
+ from torchao .utils import (
14
+ TORCH_VERSION_AFTER_2_3 ,
15
+ TORCH_VERSION_AFTER_2_5 ,
16
+ )
14
17
15
18
16
19
__all__ = [
@@ -34,17 +37,17 @@ class MappingType(Enum):
34
37
based on this mapping
35
38
e.g. scale = (10.2 - (-3.5)) / (7 - (-8))
36
39
"""
37
- SYMMETRIC = 0
38
- ASYMMETRIC = 1
40
+ SYMMETRIC = auto ()
41
+ ASYMMETRIC = auto ()
39
42
40
43
class ZeroPointDomain (Enum ):
41
44
"""Enum that indicate whether zero_point is in integer domain or floating point domain
42
45
43
46
integer domain: quantized_val = (float_val / scale) (integer) + zero_point (integer)
44
47
float domain: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale
45
48
"""
46
- INT = 0
47
- FLOAT = 1
49
+ INT = auto ()
50
+ FLOAT = auto ()
48
51
49
52
"""
50
53
Map from dtype to the bound value of integers
@@ -69,6 +72,20 @@ class ZeroPointDomain(Enum):
69
72
})
70
73
71
74
75
+ def register_custom_op (name : str ):
76
+ from torch ._inductor .decomposition import register_decomposition
77
+
78
+ def decorator (fn ):
79
+ if TORCH_VERSION_AFTER_2_5 :
80
+ opdef = torch .library .custom_op (name , mutates_args = ())(fn )
81
+ opdef .register_fake (fn )
82
+ register_decomposition ([opdef ._opoverload ])(fn )
83
+ return opdef
84
+ else :
85
+ return fn
86
+
87
+ return decorator
88
+
72
89
# TODO: decide on if we want to allow custom quant_min/quant_max here
73
90
def _get_and_check_qmin_qmax (dtype , quant_min , quant_max ):
74
91
"""Get quant_min and quant_max args based on dtype and also
@@ -140,7 +157,7 @@ def quantize_affine(
140
157
quant_min : Optional [int ] = None ,
141
158
quant_max : Optional [int ] = None ,
142
159
zero_point_domain : ZeroPointDomain = ZeroPointDomain .INT ,
143
- ):
160
+ ) -> torch . Tensor :
144
161
"""
145
162
Args:
146
163
input (torch.Tensor): original float32, float16 or bfloat16 Tensor
@@ -174,6 +191,31 @@ def quantize_affine(
174
191
Output:
175
192
quantized tensor with requested dtype
176
193
"""
194
+ return _quantize_affine (
195
+ input ,
196
+ block_size ,
197
+ scale ,
198
+ zero_point ,
199
+ output_dtype ,
200
+ quant_min ,
201
+ quant_max ,
202
+ zero_point_domain .name ,
203
+ )
204
+
205
+
206
+ @register_custom_op ("quant::quantize_affine" )
207
+ def _quantize_affine (
208
+ input : torch .Tensor ,
209
+ block_size : List [int ],
210
+ scale : torch .Tensor ,
211
+ zero_point : Optional [torch .Tensor ],
212
+ output_dtype : torch .dtype ,
213
+ quant_min : Optional [int ] = None ,
214
+ quant_max : Optional [int ] = None ,
215
+ zero_point_domain : str = "INT" ,
216
+ ) -> torch .Tensor :
217
+ """op definition that has compatible signatures with custom op library
218
+ """
177
219
# TODO: validations
178
220
# TODO: validate scale/zero_point dimensions are compatible with block_size
179
221
assert input .dtype in [torch .float32 , torch .float16 , torch .bfloat16 ], f"Unsupported input dtype: { input .dtype } "
@@ -188,12 +230,12 @@ def quantize_affine(
188
230
if zero_point is not None :
189
231
zero_point = zero_point .view (shape_after_reduction )
190
232
191
- if zero_point_domain == ZeroPointDomain .INT :
233
+ if zero_point_domain == ZeroPointDomain .INT . name :
192
234
quant = torch .clamp (
193
235
torch .round (input * (1.0 / scale )) + zero_point , quant_min , quant_max
194
236
).to (output_dtype )
195
237
else :
196
- assert zero_point_domain == ZeroPointDomain .FLOAT
238
+ assert zero_point_domain == ZeroPointDomain .FLOAT . name
197
239
mid_point = (quant_max + quant_min + 1 ) / 2
198
240
min_val = zero_point - scale * mid_point
199
241
quant = (
@@ -216,7 +258,7 @@ def dequantize_affine(
216
258
zero_point_domain : ZeroPointDomain = ZeroPointDomain .INT ,
217
259
* ,
218
260
output_dtype : torch .dtype = torch .float32 ,
219
- ):
261
+ ) -> torch . Tensor :
220
262
"""
221
263
Args:
222
264
input (torch.Tensor): quantized tensor, should match the dtype `dtype` argument
@@ -238,6 +280,34 @@ def dequantize_affine(
238
280
Output:
239
281
dequantized Tensor, with requested dtype or fp32
240
282
"""
283
+ return _dequantize_affine (
284
+ input ,
285
+ block_size ,
286
+ scale ,
287
+ zero_point ,
288
+ input_dtype ,
289
+ quant_min ,
290
+ quant_max ,
291
+ zero_point_domain .name ,
292
+ output_dtype = output_dtype ,
293
+ )
294
+
295
+
296
+ @register_custom_op ("quant::dequantize_affine" )
297
+ def _dequantize_affine (
298
+ input : torch .Tensor ,
299
+ block_size : List [int ],
300
+ scale : torch .Tensor ,
301
+ zero_point : Optional [torch .Tensor ],
302
+ input_dtype : torch .dtype ,
303
+ quant_min : Optional [int ] = None ,
304
+ quant_max : Optional [int ] = None ,
305
+ zero_point_domain : str = "INT" ,
306
+ * ,
307
+ output_dtype : torch .dtype = torch .float32 ,
308
+ ) -> torch .Tensor :
309
+ """op definition that has compatible signatures with custom op library
310
+ """
241
311
242
312
# TODO: validations
243
313
# TODO: validate scale/zero_point dimensions are compatible with block_size
@@ -255,16 +325,16 @@ def dequantize_affine(
255
325
if zero_point is not None :
256
326
zero_point = zero_point .view (shape_after_reduction )
257
327
258
- if zero_point_domain == ZeroPointDomain .INT :
328
+ if zero_point_domain == ZeroPointDomain .INT . name :
259
329
# Force a copy to avoid input modification due
260
330
# to upcoming in-place operations.
261
331
dequant = input .to (torch .int32 , copy = True )
262
332
if zero_point is not None :
263
- dequant -= zero_point .to (torch .int32 )
333
+ dequant = dequant - zero_point .to (torch .int32 )
264
334
dequant = dequant .to (output_dtype )
265
- dequant *= scale
335
+ dequant = dequant * scale
266
336
else :
267
- assert zero_point_domain == ZeroPointDomain .FLOAT , f"Unexpected zero point domain: { zero_point_domain } "
337
+ assert zero_point_domain == ZeroPointDomain .FLOAT . name , f"Unexpected zero point domain: { zero_point_domain } "
268
338
mid_point = (quant_max + quant_min + 1 ) / 2
269
339
# This should allocate new memory and avoid input modification
270
340
dequant = input - mid_point
@@ -320,8 +390,38 @@ def choose_qparams_affine(
320
390
Output:
321
391
Tuple of scales and zero_points Tensor with requested dtype
322
392
"""
393
+ return _choose_qparams_affine (
394
+ input ,
395
+ mapping_type .name ,
396
+ block_size ,
397
+ target_dtype ,
398
+ quant_min ,
399
+ quant_max ,
400
+ eps ,
401
+ scale_dtype ,
402
+ zero_point_dtype ,
403
+ preserve_zero ,
404
+ zero_point_domain .name
405
+ )
406
+
407
+ @register_custom_op ("quant::choose_qparams_affine" )
408
+ def _choose_qparams_affine (
409
+ input : torch .Tensor ,
410
+ mapping_type : str ,
411
+ block_size : List [int ],
412
+ target_dtype : torch .dtype ,
413
+ quant_min : Optional [int ] = None ,
414
+ quant_max : Optional [int ] = None ,
415
+ eps : Optional [float ] = None ,
416
+ scale_dtype : Optional [torch .dtype ] = None ,
417
+ zero_point_dtype : Optional [torch .dtype ] = None ,
418
+ preserve_zero : bool = True ,
419
+ zero_point_domain : str = "INT" ,
420
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
421
+ """op definition that has compatible signatures with custom op library
422
+ """
323
423
quant_min , quant_max = _get_and_check_qmin_qmax (target_dtype , quant_min , quant_max )
324
- assert mapping_type in [MappingType .SYMMETRIC , MappingType .ASYMMETRIC ], f"Unsupported mapping type: { mapping_type } "
424
+ assert mapping_type in [MappingType .SYMMETRIC . name , MappingType .ASYMMETRIC . name ], f"Unsupported mapping type: { mapping_type } "
325
425
326
426
if scale_dtype is None :
327
427
scale_dtype = input .dtype
@@ -342,21 +442,22 @@ def choose_qparams_affine(
342
442
min_val_neg = min_val
343
443
max_val_pos = max_val
344
444
345
- if mapping_type == MappingType .SYMMETRIC :
445
+ if mapping_type == MappingType .SYMMETRIC . name :
346
446
max_val_pos = torch .max (- min_val_neg , max_val_pos )
347
447
scale = max_val_pos / (float (quant_max - quant_min ) / 2 )
348
448
if not preserve_zero :
349
449
raise ValueError ("preserve_zero == False is not supported for symmetric quantization" )
350
- if zero_point_domain != ZeroPointDomain .INT :
450
+ if zero_point_domain != ZeroPointDomain .INT . name :
351
451
raise ValueError ("zero_point_domain != ZeroPointDomain.INT is not supported for symmetric quantization" )
352
452
zero_point = torch .full_like (scale , int ((quant_max + quant_min + 1 ) / 2 ))
353
453
else :
454
+ assert mapping_type == MappingType .ASYMMETRIC .name
354
455
scale = (max_val_pos - min_val_neg ) / float (quant_max - quant_min )
355
456
if preserve_zero :
356
457
zero_point = quant_min - torch .round (min_val_neg / scale )
357
458
zero_point = torch .clamp (zero_point , quant_min , quant_max )
358
459
else :
359
- assert zero_point_domain == ZeroPointDomain .FLOAT , "if not preserve_zero, zero_point must be in FLOAT domain"
460
+ assert zero_point_domain == ZeroPointDomain .FLOAT . name , "if not preserve_zero, zero_point must be in FLOAT domain"
360
461
mid_point = (quant_max + quant_min + 1 ) / 2
361
462
zero_point = min_val_neg + scale * mid_point
362
463
0 commit comments