29
29
from torchao .quantization .subclass import (
30
30
to_laq ,
31
31
LinearActQuantizedTensor ,
32
+ Int8WeightOnlyQuantizedLinearWeight ,
33
+ Int4WeightOnlyQuantizedLinearWeight ,
32
34
)
33
35
from torchao .quantization .quant_api import (
34
36
_replace_with_custom_fn_if_matches_filter ,
@@ -138,6 +140,28 @@ def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs
138
140
model , _get_subclass_inserter (Int8DynamicallyQuantizedLinearWeight , enable_parametrization = False , ** kwargs ), filter_fn
139
141
)
140
142
143
+ def _get_ref_change_linear_weights_to_woqtensors (deprecated_tenosr_subclass ):
144
+ def _ref_change_linear_weights_to_woqtensors (model , filter_fn = None , ** kwargs ):
145
+ """
146
+ The deprecated implementation for weight only quant API, used as a reference for
147
+ numerics and performance
148
+ """
149
+ from torchao .quantization .quant_api import _is_linear
150
+ from torchao .quantization .quant_api import _get_subclass_inserter
151
+
152
+ filter_fn = kwargs .pop ("filter_fn" , _is_linear )
153
+
154
+ _replace_with_custom_fn_if_matches_filter (
155
+ model ,
156
+ _get_subclass_inserter (deprecated_tenosr_subclass , enable_parametrization = True , ** kwargs ),
157
+ filter_fn ,
158
+ )
159
+
160
+ return _ref_change_linear_weights_to_woqtensors
161
+
162
+ _ref_change_linear_weights_to_int8_woqtensors = _get_ref_change_linear_weights_to_woqtensors (Int8WeightOnlyQuantizedLinearWeight )
163
+ _ref_change_linear_weights_to_int4_woqtensors = _get_ref_change_linear_weights_to_woqtensors (Int4WeightOnlyQuantizedLinearWeight )
164
+
141
165
class TestQuantFlow (unittest .TestCase ):
142
166
def test_dynamic_quant_gpu_singleline (self ):
143
167
m = ToyLinearModel ().eval ()
@@ -478,8 +502,7 @@ def test_quantized_tensor_subclass_int4(self):
478
502
assert isinstance (m .linear2 .weight , AffineQuantizedTensor )
479
503
480
504
# reference
481
- from torchao .quantization .quant_api import change_linear_weights_to_int4_woqtensors
482
- change_linear_weights_to_int4_woqtensors (m_copy , groupsize = groupsize )
505
+ _ref_change_linear_weights_to_int4_woqtensors (m_copy , groupsize = groupsize )
483
506
484
507
res = m (* example_inputs )
485
508
ref = m_copy (* example_inputs )
@@ -489,7 +512,7 @@ def test_quantized_tensor_subclass_int4(self):
489
512
490
513
@unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "Test only enabled for 2.4+" )
491
514
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
492
- def test_quantized_tensor_subclass_int8 (self ):
515
+ def test_quantized_tensor_subclass_int8_wo (self ):
493
516
m = ToyLinearModel ().eval ().to (torch .bfloat16 )
494
517
m_copy = copy .deepcopy (m )
495
518
example_inputs = tuple (map (lambda x : x .to (torch .bfloat16 ), m .example_inputs ()))
@@ -500,13 +523,13 @@ def test_quantized_tensor_subclass_int8(self):
500
523
assert isinstance (m .linear2 .weight , AffineQuantizedTensor )
501
524
502
525
# reference
503
- from torchao . quantization . quant_api import change_linear_weights_to_int8_woqtensors
504
- change_linear_weights_to_int8_woqtensors ( m_copy )
526
+ _ref_change_linear_weights_to_int8_woqtensors ( m_copy )
527
+
505
528
506
529
res = m (* example_inputs )
507
530
ref = m_copy (* example_inputs )
508
531
509
- torch .testing . assert_close (res , ref , rtol = 0.00001 , atol = 1e-2 )
532
+ self . assertTrue ( torch .equal (res , ref ) )
510
533
511
534
512
535
@unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "Test only enabled for 2.4+" )
@@ -525,8 +548,7 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
525
548
assert isinstance (m .linear2 .weight .original_weight_tensor , AffineQuantizedTensor )
526
549
527
550
# reference
528
- from torchao .quantization .quant_api import change_linear_weights_to_int8_dqtensors
529
- change_linear_weights_to_int8_dqtensors (m_copy )
551
+ _ref_change_linear_weights_to_int8_dqtensors (m_copy )
530
552
531
553
res = m (* example_inputs )
532
554
ref = m_copy (* example_inputs )
@@ -545,45 +567,5 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
545
567
# make sure it compiles
546
568
torch ._export .aot_compile (m_unwrapped , example_inputs )
547
569
548
- @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "Test only enabled for 2.4+" )
549
- @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
550
- @unittest .skip ("This perf test is supposed to be run locally for sanity check performance when there is a change of int8 dynamic quant implementation" )
551
- def test_quantized_tensor_subclass_int8_dyn_quant_perf (self ):
552
- m = ToyLinearModel (1024 , 1024 , 1024 ).eval ().to (torch .bfloat16 ).to ("cuda" )
553
- m_ref = copy .deepcopy (m )
554
- # setting batch_size to 20 to be compatible with the kernel
555
- example_inputs = m .example_inputs (batch_size = 20 , dtype = torch .bfloat16 , device = "cuda" )
556
-
557
- from torchao .quantization .quant_api import change_linear_weights_to_int8_dqtensors
558
- change_linear_weights_to_int8_dqtensors (m )
559
-
560
- # reference
561
- _ref_change_linear_weights_to_int8_dqtensors (m_ref )
562
-
563
- res = m (* example_inputs )
564
- ref = m_ref (* example_inputs )
565
-
566
- self .assertTrue (torch .equal (res , ref ))
567
-
568
- # perf comparison
569
- from torchao .utils import benchmark_model
570
- # warmup
571
- WARMUP = 5
572
- RUNS = 100
573
- input_tensor = example_inputs [0 ]
574
- m = torch .compile (m , mode = 'max-autotune' , fullgraph = True )
575
-
576
- benchmark_model (m , WARMUP , input_tensor )
577
- elapsed_time = benchmark_model (m , RUNS , input_tensor )
578
-
579
- m_ref = torch .compile (m_ref , mode = 'max-autotune' , fullgraph = True )
580
- benchmark_model (m_ref , WARMUP , input_tensor )
581
- ref_elapsed_time = benchmark_model (m_ref , RUNS , input_tensor )
582
-
583
- print (f"elapsed time: { elapsed_time } , ref elapsed time: { ref_elapsed_time } " )
584
- self .assertTrue (elapsed_time < 1.05 * ref_elapsed_time )
585
-
586
-
587
-
588
570
if __name__ == "__main__" :
589
571
unittest .main ()
0 commit comments