18
18
get_symmetric_quantization_config ,
19
19
)
20
20
21
- from torchao .quantization .subclass import (
22
- to_aqt ,
23
- to_laqt ,
21
+ from torchao .dtypes import (
22
+ to_aq ,
24
23
AffineQuantizedTensor ,
25
- LinearActQuantizedTensor ,
26
24
)
27
25
from torchao .quantization .quant_primitives import (
28
26
MappingType ,
29
27
ZeroPointDomain ,
30
28
)
31
-
29
+ from torchao .quantization .subclass import (
30
+ to_laq ,
31
+ LinearActQuantizedTensor ,
32
+ )
32
33
from torchao .quantization .quant_api import (
33
34
_replace_with_custom_fn_if_matches_filter ,
34
35
apply_dynamic_quant ,
@@ -429,17 +430,17 @@ def get_per_token_block_size(x):
429
430
# input settings
430
431
input_mapping_type = MappingType .ASYMMETRIC
431
432
input_target_dtype = torch .int8
432
- input_quant_func = lambda x : to_aqt (x , input_mapping_type , get_per_token_block_size (x ), input_target_dtype )
433
+ input_quant_func = lambda x : to_aq (x , input_mapping_type , get_per_token_block_size (x ), input_target_dtype )
433
434
434
435
m = ToyLinearModel ().eval ()
435
436
m_copy = copy .deepcopy (m )
436
437
example_inputs = m .example_inputs ()
437
438
438
439
def apply_weight_quant (weight ):
439
- return to_aqt (weight , mapping_type , block_size , target_dtype , quant_min , quant_max , eps )
440
+ return to_aq (weight , mapping_type , block_size , target_dtype , quant_min , quant_max , eps )
440
441
441
442
def apply_act_quant (weight ):
442
- return to_laqt (weight , input_quant_func )
443
+ return to_laq (weight , input_quant_func )
443
444
444
445
# note: order is important
445
446
m = quantize (m , apply_weight_quant )
@@ -484,7 +485,7 @@ def test_quantized_tensor_subclass_int4(self):
484
485
example_inputs = tuple (map (lambda x : x .to (torch .bfloat16 ).to ("cuda" ), m .example_inputs ()))
485
486
486
487
def apply_weight_quant (weight ):
487
- return to_aqt (weight , mapping_type , block_size , target_dtype , quant_min , quant_max , eps , zero_point_dtype = zero_point_dtype , preserve_zero = preserve_zero , zero_point_domain = zero_point_domain )
488
+ return to_aq (weight , mapping_type , block_size , target_dtype , quant_min , quant_max , eps , zero_point_dtype = zero_point_dtype , preserve_zero = preserve_zero , zero_point_domain = zero_point_domain )
488
489
489
490
m = quantize (m , apply_weight_quant )
490
491
assert isinstance (m .linear1 .weight , AffineQuantizedTensor )
@@ -515,7 +516,7 @@ def test_quantized_tensor_subclass_int8(self):
515
516
516
517
def apply_weight_quant (weight ):
517
518
block_size = (1 , weight .shape [1 ])
518
- return to_aqt (weight , mapping_type , block_size , target_dtype , eps = eps , zero_point_dtype = zero_point_dtype )
519
+ return to_aq (weight , mapping_type , block_size , target_dtype , eps = eps , zero_point_dtype = zero_point_dtype )
519
520
520
521
m = quantize (m , apply_weight_quant )
521
522
@@ -555,7 +556,7 @@ def get_per_token_block_size(x):
555
556
input_eps = 1e-5
556
557
input_quant_min = - 127
557
558
input_quant_max = 127
558
- input_quant_func = lambda x : to_aqt (x , input_mapping_type , get_per_token_block_size (x ), input_target_dtype , eps = input_eps , quant_min = input_quant_min , quant_max = input_quant_max , scale_dtype = torch .float32 if x .dtype == torch .float16 else None )
559
+ input_quant_func = lambda x : to_aq (x , input_mapping_type , get_per_token_block_size (x ), input_target_dtype , eps = input_eps , quant_min = input_quant_min , quant_max = input_quant_max , scale_dtype = torch .float32 if x .dtype == torch .float16 else None )
559
560
560
561
# use 1024 so that we don't need padding
561
562
m = ToyLinearModel (1024 , 1024 , 1024 ).eval ().to (torch .bfloat16 ).to ("cuda" )
@@ -565,10 +566,10 @@ def get_per_token_block_size(x):
565
566
566
567
def apply_weight_quant (weight ):
567
568
block_size = get_weight_block_size (weight )
568
- return to_aqt (weight , mapping_type , block_size , target_dtype , eps = eps , zero_point_dtype = zero_point_dtype )
569
+ return to_aq (weight , mapping_type , block_size , target_dtype , eps = eps , zero_point_dtype = zero_point_dtype )
569
570
570
571
def apply_act_quant (weight ):
571
- return to_laqt (weight , input_quant_func )
572
+ return to_laq (weight , input_quant_func )
572
573
573
574
m = quantize (m , apply_weight_quant )
574
575
m = quantize (m , apply_act_quant )
0 commit comments