Skip to content

Commit 3ad9c07

Browse files
committed
Move files from quantization/prototype -> prototype/quantization
1 parent 79ea660 commit 3ad9c07

File tree

17 files changed

+65
-63
lines changed

17 files changed

+65
-63
lines changed

test/quantization/test_mixed_precision.py renamed to test/prototype/test_mixed_precision.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch.nn as nn
55
from torchao.quantization import quantize_, int8_weight_only, int4_weight_only
66
from torchao.quantization.utils import compute_error
7-
from torchao.quantization.prototype.mixed_precision.scripts.naive_intNwo import intN_weight_only
7+
from torchao.prototype.quantization.mixed_precision.scripts.naive_intNwo import intN_weight_only
88

99
_CUDA_IS_AVAILABLE = torch.cuda.is_available()
1010

torchao/prototype/quantization/mixed_precision/__init__.py

Whitespace-only changes.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .naive_intNwo import intN_weight_only
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import torch
2+
3+
from torchao.quantization.quant_primitives import (
4+
MappingType,
5+
ZeroPointDomain,
6+
)
7+
8+
from torchao.quantization import int8_weight_only, int4_weight_only
9+
from torchao.quantization.quant_api import _get_linear_subclass_inserter
10+
11+
def intN_weight_only(group_size=32, n=8, symmetric=False):
12+
'''
13+
Apply int N-bit weight only quantization to a linear layer.
14+
Args:
15+
`group_size`: parameter for quantization, controls the granularity of quantization, smaller size is more fine grained, choices are [512, 256, 128, 64, 32]
16+
`n`: number of bits to quantize to, choices are [8, 6, 5, 4, 3, 2]
17+
Usage:
18+
from torchao.quantization import quantize_
19+
quantize_(model, intN_weight_only(n=your_bit_choice, group_size=group_size), optional_filter_func_for_desired_layers_to_quantize)
20+
'''
21+
# for asymmetric quantization
22+
def apply_intN_weight_only_quant_asym(weight):
23+
# avoid circular dependency
24+
from torchao.dtypes import to_affine_quantized_intx
25+
mapping_type = MappingType.ASYMMETRIC
26+
block_size = (1, group_size)
27+
target_dtype = torch.uint8
28+
quant_min = 0
29+
quant_max = 2**n-1
30+
eps = 1e-6
31+
preserve_zero = True
32+
zero_point_dtype = torch.int64
33+
zero_point_domain = ZeroPointDomain.INT
34+
return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype)#, preserve_zero=preserve_zero,zero_point_domain=zero_point_domain)
35+
36+
# for symmetric quantization
37+
def apply_intN_weight_only_quant_sym(weight):
38+
# avoid circular dependency
39+
from torchao.dtypes import to_affine_quantized_intx
40+
mapping_type = MappingType.SYMMETRIC
41+
block_size = (1, group_size)
42+
target_dtype = torch.int8
43+
quant_min = -2**(n-1)
44+
quant_max = 2**(n-1)-1
45+
eps = 1e-6
46+
zero_point_dtype = torch.int64
47+
return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps=eps, zero_point_dtype=zero_point_dtype)
48+
49+
try:
50+
assert n in [8, 6, 5, 4, 3, 2], "n must be one of [8, 6, 5, 4, 3, 2]"
51+
if n == 8:
52+
return int8_weight_only()
53+
elif n == 4:
54+
return int4_weight_only(group_size=group_size)
55+
else:
56+
if symmetric:
57+
return _get_linear_subclass_inserter(apply_intN_weight_only_quant_sym)
58+
else:
59+
return _get_linear_subclass_inserter(apply_intN_weight_only_quant_asym)
60+
except Exception as e:
61+
raise
62+
Lines changed: 1 addition & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1 @@
1-
import torch
2-
3-
from torchao.quantization.quant_primitives import (
4-
MappingType,
5-
ZeroPointDomain,
6-
)
7-
8-
from torchao.quantization import int8_weight_only, int4_weight_only
9-
from torchao.quantization.quant_api import _get_linear_subclass_inserter
10-
11-
def intN_weight_only(group_size=32, n=8, symmetric=False):
12-
'''
13-
Apply int N-bit weight only quantization to a linear layer.
14-
Args:
15-
`group_size`: parameter for quantization, controls the granularity of quantization, smaller size is more fine grained, choices are [512, 256, 128, 64, 32]
16-
`n`: number of bits to quantize to, choices are [8, 6, 5, 4, 3, 2]
17-
Usage:
18-
from torchao.quantization import quantize_
19-
quantize_(model, intN_weight_only(n=your_bit_choice, group_size=group_size), optional_filter_func_for_desired_layers_to_quantize)
20-
'''
21-
# for asymmetric quantization
22-
def apply_intN_weight_only_quant_asym(weight):
23-
# avoid circular dependency
24-
from torchao.dtypes import to_affine_quantized_intx
25-
mapping_type = MappingType.ASYMMETRIC
26-
block_size = (1, group_size)
27-
target_dtype = torch.uint8
28-
quant_min = 0
29-
quant_max = 2**n-1
30-
eps = 1e-6
31-
preserve_zero = True
32-
zero_point_dtype = torch.int64
33-
zero_point_domain = ZeroPointDomain.INT
34-
return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype)#, preserve_zero=preserve_zero,zero_point_domain=zero_point_domain)
35-
36-
# for symmetric quantization
37-
def apply_intN_weight_only_quant_sym(weight):
38-
# avoid circular dependency
39-
from torchao.dtypes import to_affine_quantized_intx
40-
mapping_type = MappingType.SYMMETRIC
41-
block_size = (1, group_size)
42-
target_dtype = torch.int8
43-
quant_min = -2**(n-1)
44-
quant_max = 2**(n-1)-1
45-
eps = 1e-6
46-
zero_point_dtype = torch.int64
47-
return to_affine_quantized_intx(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps=eps, zero_point_dtype=zero_point_dtype)
48-
49-
try:
50-
assert n in [8, 6, 5, 4, 3, 2], "n must be one of [8, 6, 5, 4, 3, 2]"
51-
if n == 8:
52-
return int8_weight_only()
53-
elif n == 4:
54-
return int4_weight_only(group_size=group_size)
55-
else:
56-
if symmetric:
57-
return _get_linear_subclass_inserter(apply_intN_weight_only_quant_sym)
58-
else:
59-
return _get_linear_subclass_inserter(apply_intN_weight_only_quant_asym)
60-
except Exception as e:
61-
raise
62-
1+
from torchao.prototype.quantization.mixed_precision.scripts.naive_intNwo import intN_weight_only

0 commit comments

Comments
 (0)