37
37
Quantizer ,
38
38
TwoStepQuantizer ,
39
39
quantize ,
40
+ get_apply_8da4w_quant ,
41
+ get_apply_int4wo_quant ,
42
+ get_apply_int8wo_quant ,
43
+ get_apply_int8dyn_quant ,
40
44
)
41
45
from torchao .quantization .utils import (
42
46
TORCH_VERSION_AFTER_2_3 ,
@@ -416,42 +420,11 @@ def test_eval_wrapper(self):
416
420
# TODO: move to a separate test file
417
421
@unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "Test only enabled for 2.4+" )
418
422
def test_quantized_tensor_subclass_8da4w (self ):
419
- # weight settings
420
423
groupsize = 32
421
- mapping_type = MappingType .SYMMETRIC
422
- block_size = (1 , groupsize )
423
- target_dtype = torch .int8
424
- eps = torch .finfo (torch .float32 ).eps
425
- quant_min = - 8
426
- quant_max = 7
427
-
428
- # TODO: make a general helper function?
429
- # input settings
430
- def get_per_token_block_size (x ):
431
- block_size = []
432
- for i in range (len (x .shape )- 1 ):
433
- block_size .append (1 )
434
- block_size .append (x .shape [- 1 ])
435
- return block_size
436
-
437
- # input settings
438
- input_mapping_type = MappingType .ASYMMETRIC
439
- input_target_dtype = torch .int8
440
- input_quant_func = lambda x : to_aq (x , input_mapping_type , get_per_token_block_size (x ), input_target_dtype )
441
-
442
424
m = ToyLinearModel ().eval ()
443
425
m_copy = copy .deepcopy (m )
444
426
example_inputs = m .example_inputs ()
445
-
446
- def apply_weight_quant (weight ):
447
- return to_aq (weight , mapping_type , block_size , target_dtype , quant_min , quant_max , eps )
448
-
449
- def apply_act_quant (weight ):
450
- return to_laq (weight , input_quant_func )
451
-
452
- # note: order is important
453
- m = quantize (m , apply_weight_quant )
454
- m = quantize (m , apply_act_quant )
427
+ m = quantize (m , get_apply_8da4w_quant (groupsize = groupsize ))
455
428
456
429
assert isinstance (m .linear1 .weight , LinearActQuantizedTensor )
457
430
assert isinstance (m .linear2 .weight , LinearActQuantizedTensor )
@@ -474,27 +447,13 @@ def apply_act_quant(weight):
474
447
@unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "Test only enabled for 2.4+" )
475
448
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
476
449
def test_quantized_tensor_subclass_int4 (self ):
477
- # weight settings
478
- groupsize = 32
479
- mapping_type = MappingType .ASYMMETRIC
480
- block_size = (1 , groupsize )
481
- target_dtype = torch .int32
482
- quant_min = 0
483
- quant_max = 15
484
- eps = 1e-6
485
- preserve_zero = False
486
- zero_point_dtype = torch .bfloat16
487
- zero_point_domain = ZeroPointDomain .FLOAT
488
-
489
450
# use 1024 so that we don't need padding
490
451
m = ToyLinearModel (1024 , 1024 , 1024 ).eval ().to (torch .bfloat16 ).to ("cuda" )
491
452
m_copy = copy .deepcopy (m )
492
453
example_inputs = tuple (map (lambda x : x .to (torch .bfloat16 ).to ("cuda" ), m .example_inputs ()))
493
454
494
- def apply_weight_quant (weight ):
495
- return to_aq (weight , mapping_type , block_size , target_dtype , quant_min , quant_max , eps , zero_point_dtype = zero_point_dtype , preserve_zero = preserve_zero , zero_point_domain = zero_point_domain )
496
-
497
- m = quantize (m , apply_weight_quant )
455
+ groupsize = 32
456
+ m = quantize (m , get_apply_int4wo_quant (groupsize = groupsize ))
498
457
assert isinstance (m .linear1 .weight , AffineQuantizedTensor )
499
458
assert isinstance (m .linear2 .weight , AffineQuantizedTensor )
500
459
@@ -511,21 +470,11 @@ def apply_weight_quant(weight):
511
470
@unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "Test only enabled for 2.4+" )
512
471
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
513
472
def test_quantized_tensor_subclass_int8 (self ):
514
- # weight settings
515
- mapping_type = MappingType .SYMMETRIC
516
- target_dtype = torch .int8
517
- eps = torch .finfo (torch .float32 ).eps
518
- zero_point_dtype = torch .int64
519
-
520
473
m = ToyLinearModel ().eval ().to (torch .bfloat16 )
521
474
m_copy = copy .deepcopy (m )
522
475
example_inputs = tuple (map (lambda x : x .to (torch .bfloat16 ), m .example_inputs ()))
523
476
524
- def apply_weight_quant (weight ):
525
- block_size = (1 , weight .shape [1 ])
526
- return to_aq (weight , mapping_type , block_size , target_dtype , eps = eps , zero_point_dtype = zero_point_dtype )
527
-
528
- m = quantize (m , apply_weight_quant )
477
+ m = quantize (m , get_apply_int8wo_quant ())
529
478
530
479
assert isinstance (m .linear1 .weight , AffineQuantizedTensor )
531
480
assert isinstance (m .linear2 .weight , AffineQuantizedTensor )
@@ -543,43 +492,12 @@ def apply_weight_quant(weight):
543
492
@unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "Test only enabled for 2.4+" )
544
493
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
545
494
def test_quantized_tensor_subclass_int8_dyn_quant (self ):
546
- # weight settings
547
- mapping_type = MappingType .SYMMETRIC
548
- def get_weight_block_size (x ):
549
- return (1 , x .shape [1 ])
550
- target_dtype = torch .int8
551
- eps = torch .finfo (torch .float32 ).eps
552
- zero_point_dtype = torch .int64
553
-
554
- # input settings
555
- def get_per_token_block_size (x ):
556
- block_size = list (x .shape )
557
- for i in range (len (block_size )- 1 ):
558
- block_size [i ] = 1
559
- return block_size
560
-
561
- input_mapping_type = MappingType .SYMMETRIC
562
- input_target_dtype = torch .int8
563
- input_eps = 1e-5
564
- input_quant_min = - 127
565
- input_quant_max = 127
566
- input_quant_func = lambda x : to_aq (x , input_mapping_type , get_per_token_block_size (x ), input_target_dtype , eps = input_eps , quant_min = input_quant_min , quant_max = input_quant_max , scale_dtype = torch .float32 if x .dtype == torch .float16 else None )
567
-
568
495
# use 1024 so that we don't need padding
569
496
m = ToyLinearModel (1024 , 1024 , 1024 ).eval ().to (torch .bfloat16 ).to ("cuda" )
570
497
m_copy = copy .deepcopy (m )
571
498
# setting batch_size to 20 to be compatible with the kernel
572
499
example_inputs = tuple (map (lambda x : x .to (torch .bfloat16 ).to ("cuda" ), m .example_inputs (batch_size = 20 )))
573
-
574
- def apply_weight_quant (weight ):
575
- block_size = get_weight_block_size (weight )
576
- return to_aq (weight , mapping_type , block_size , target_dtype , eps = eps , zero_point_dtype = zero_point_dtype )
577
-
578
- def apply_act_quant (weight ):
579
- return to_laq (weight , input_quant_func )
580
-
581
- m = quantize (m , apply_weight_quant )
582
- m = quantize (m , apply_act_quant )
500
+ m = quantize (m , get_apply_int8dyn_quant ())
583
501
584
502
assert isinstance (m .linear1 .weight , LinearActQuantizedTensor )
585
503
assert isinstance (m .linear2 .weight , LinearActQuantizedTensor )
0 commit comments