1
1
import torch
2
2
import torch .nn as nn
3
3
4
- import os
5
- import sys
6
- # append the path to the naive_intNwo.py file
7
- sys .path .append (os .path .join (os .path .dirname (os .path .dirname (os .path .dirname (os .path .abspath (__file__ )))), "torchao/quantization/prototype/mixed_precision/scripts" ))
8
- from naive_intNwo import intN_weight_only
4
+ from torchao .quantization .prototype .mixed_precision .scripts .naive_intNwo import intN_weight_only
9
5
10
6
from torchao .quantization import quantize_ , int8_weight_only , int4_weight_only
11
7
18
14
)
19
15
20
16
def test_weight_only_quant (quantization_bit = 2 , symmetric = False ):
21
- for x_shape in [[32 , 64 ], [80 , 80 , 80 , 64 ], [16 , 64 , 64 ]]:
22
- x = torch .randn (* x_shape )
23
- m = nn .Sequential (nn .Linear (64 , 80 ))
17
+ for x_shape in [[64 , 32 ], [80 , 80 , 80 , 32 ], [16 , 64 , 32 ]]:
18
+ x = torch .randn (* x_shape , dtype = torch . bfloat16 )
19
+ m = nn .Sequential (nn .Linear (32 , 80 )). bfloat16 ( )
24
20
y_ref = m (x )
25
- quantize_ (m , intN_weight_only (n = quantization_bit , group_size = 16 , symmetric = symmetric ))
21
+ quantize_ (m , intN_weight_only (n = quantization_bit , group_size = 32 , symmetric = symmetric ))
26
22
y_wo = m (x )
27
23
sqnr = compute_error (y_ref , y_wo )
28
24
#SQNR_dB can be approximated by 6.02n, where n is the bit width of the quantization
@@ -31,7 +27,7 @@ def test_weight_only_quant(quantization_bit=2, symmetric=False):
31
27
32
28
33
29
# test if the asymmetric and symmetric quantization API works with different bit widths
34
- for i in [2 ,3 , 5 , 6 , 8 ]:
30
+ for i in [2 , 3 , 4 , 5 , 6 , 8 ]:
35
31
#test for asymmetric quantization
36
32
try :
37
33
test_weight_only_quant (i , False )
0 commit comments