Skip to content

Commit 75b55c2

Browse files
committed
fixed test for int4wo and add __init__.py
1 parent c08ab33 commit 75b55c2

File tree

3 files changed

+7
-10
lines changed

3 files changed

+7
-10
lines changed

test/quantization/test_mixed_precision.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
11
import torch
22
import torch.nn as nn
33

4-
import os
5-
import sys
6-
# append the path to the naive_intNwo.py file
7-
sys.path.append(os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "torchao/quantization/prototype/mixed_precision/scripts"))
8-
from naive_intNwo import intN_weight_only
4+
from torchao.quantization.prototype.mixed_precision.scripts.naive_intNwo import intN_weight_only
95

106
from torchao.quantization import quantize_, int8_weight_only, int4_weight_only
117

@@ -18,11 +14,11 @@
1814
)
1915

2016
def test_weight_only_quant(quantization_bit=2, symmetric=False):
21-
for x_shape in [[32, 64], [80, 80, 80, 64], [16, 64, 64]]:
22-
x = torch.randn(*x_shape)
23-
m = nn.Sequential(nn.Linear(64, 80))
17+
for x_shape in [[64, 32], [80, 80, 80, 32], [16, 64, 32]]:
18+
x = torch.randn(*x_shape, dtype=torch.bfloat16)
19+
m = nn.Sequential(nn.Linear(32, 80)).bfloat16()
2420
y_ref = m(x)
25-
quantize_(m, intN_weight_only(n=quantization_bit, group_size=16, symmetric=symmetric))
21+
quantize_(m, intN_weight_only(n=quantization_bit, group_size=32, symmetric=symmetric))
2622
y_wo = m(x)
2723
sqnr = compute_error(y_ref, y_wo)
2824
#SQNR_dB can be approximated by 6.02n, where n is the bit width of the quantization
@@ -31,7 +27,7 @@ def test_weight_only_quant(quantization_bit=2, symmetric=False):
3127

3228

3329
# test if the asymmetric and symmetric quantization API works with different bit widths
34-
for i in [2,3,5,6,8]:
30+
for i in [2, 3, 4, 5, 6, 8]:
3531
#test for asymmetric quantization
3632
try:
3733
test_weight_only_quant(i, False)

torchao/quantization/prototype/mixed_precision/__init__.py

Whitespace-only changes.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .naive_intNwo import intN_weight_only

0 commit comments

Comments
 (0)