Skip to content

Commit aadda53

Browse files
committed
remove experiment scripts
1 parent 8e1d8d2 commit aadda53

File tree

3 files changed

+98
-62
lines changed

3 files changed

+98
-62
lines changed
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
import os
5+
import sys
6+
# append the path to the naive_intNwo.py file
7+
sys.path.append(os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "torchao/quantization/prototype/mixed_precision/scripts"))
8+
from naive_intNwo import intN_weight_only
9+
10+
from torchao.quantization import quantize_, int8_weight_only, int4_weight_only
11+
12+
from torchao.quantization.utils import (
13+
_apply_logging_hook,
14+
compute_error,
15+
compute_error as SQNR,
16+
_fqn_to_op_to_shape_to_count,
17+
LoggingTensorMode,
18+
)
19+
20+
def test_weight_only_quant(quantization_bit=2, symmetric=False):
21+
for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]:
22+
x = torch.randn(*x_shape)
23+
m = nn.Sequential(nn.Linear(4, 5))
24+
y_ref = m(x)
25+
quantize_(m, intN_weight_only(n=quantization_bit, group_size=2, symmetric=symmetric))
26+
y_wo = m(x)
27+
sqnr = compute_error(y_ref, y_wo)
28+
print(sqnr)
29+
assert sqnr > 44.0, "sqnr: {} is too low".format(sqnr)
30+
31+
32+
# test if the asymmetric and symmetric quantization API works with different bit widths
33+
for i in range(2, 9):
34+
#test for asymmetric quantization
35+
try:
36+
test_weight_only_quant(i, False)
37+
print(f"Test passed for {i}-bit using naive intNwo asymmetric quantization implementation")
38+
except Exception as e:
39+
print(f"Exception handled in test loop for {i}-bit asymmetric quantization. Details: {e}")
40+
41+
#test for symmetric quantization
42+
try:
43+
test_weight_only_quant(i, True)
44+
print(f"Test passed for {i}-bit using naive intNwo symmetric quantization implementation")
45+
except Exception as e:
46+
print(f"Exception handled in test loop for {i}-bit symmetric quantization. Details: {e}")
Lines changed: 18 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import torch.nn as nn
33

4-
from naive_intNwo import intN_weight_only_asym, intN_weight_only_sym
4+
from naive_intNwo import intN_weight_only
55
from transformers import AutoModelForCausalLM, AutoTokenizer
66

77
from lm_eval.models.huggingface import HFLM
@@ -28,63 +28,33 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compi
2828
tokenizer = AutoTokenizer.from_pretrained(repo_id)
2929
model = AutoModelForCausalLM.from_pretrained(repo_id).to(device="cpu", dtype=precision)
3030

31-
if quantization == "int8dq":
32-
quantize_(model.to(device=device), int8_dynamic_activation_int4_weight())
33-
34-
elif quantization == "int8wo":
35-
quantize_(model.to(device=device), int8_weight_only())
36-
37-
elif quantization == "int4wo":
38-
quantize_(model.to(device=device), int4_weight_only(group_size=group_size))
39-
40-
elif quantization == "autoquant":
31+
if quantization == "autoquant":
4132
model = autoquant(model.to(device=device))
4233

4334
# naive implementation of uniform precision quantization all layers
4435
elif quantization in ["2","3","4","5","6","8"]:
45-
if quant_sym == "asym":
46-
quantize_(model.to(device=device), intN_weight_only_asym(n=int(quantization), group_size=group_size))
47-
elif quant_sym == "sym":
48-
quantize_(model.to(device=device), intN_weight_only_sym(n=int(quantization), group_size=group_size))
49-
36+
quantize_(model.to(device=device), intN_weight_only(n=int(quantization), group_size=group_size, symmetric=quant_sym))
37+
38+
# mix precision quantization for Llama3
5039
elif quantization == "MP_llama3":
5140

52-
# filter for sensitive layers
41+
# filter for sensitive layers (the first 3 and last 2 layers for Llama3)
5342
def filter_fn_sen(child: torch.nn.Module, cur_fqn:str) -> bool:
5443
return isinstance(child, nn.Linear) and any(skiplayer in cur_fqn for skiplayer in ['.0.', '.1.', '.2.', '.30.', '.31.'])
5544

56-
# filter for non-sensitive layers
45+
# filter for non-sensitive layers (other 27 layers for Llama3)
5746
def filter_fn_nonsen(child: torch.nn.Module, cur_fqn:str) -> bool:
5847
return isinstance(child, nn.Linear) and not(any(skiplayer in cur_fqn for skiplayer in ['.0.', '.1.', '.2.', '.30.', '.31.']))
5948

49+
# quantize the sensitive layers
6050
if sensi_bit != 16:
61-
# quantize the sensitive layers
62-
if sensi_bit == 8:
63-
quantize_(model.to(device=device), int8_weight_only(), filter_fn_sen)
64-
elif sensi_bit == 4:
65-
quantize_(model.to(device=device), int4_weight_only(group_size=group_size), filter_fn_sen)
66-
elif sensi_bit in [6,5,3,2]:
67-
if quant_sym == "asym":
68-
quantize_(model.to(device=device), intN_weight_only_asym(n=sensi_bit, group_size=group_size), filter_fn_sen)
69-
elif quant_sym == "sym":
70-
quantize_(model.to(device=device), intN_weight_only_sym(n=sensi_bit, group_size=group_size), filter_fn_sen)
51+
quantize_(model.to(device=device), intN_weight_only(n=sensi_bit, group_size=group_size, symmetric=quant_sym), filter_fn_sen)
7152

7253
# quantize the less-sensitive layers
73-
if non_sensi_bit == 8:
74-
quantize_(model.to(device=device), int8_weight_only(), filter_fn_nonsen)
75-
elif non_sensi_bit == 4:
76-
quantize_(model.to(device=device), int4_weight_only(group_size=group_size), filter_fn_nonsen)
77-
elif non_sensi_bit in [6,5,3,2]:
78-
if sensi_bit == 4:
79-
if quant_sym == "asym":
80-
quantize_(model, intN_weight_only_asym(n=non_sensi_bit, group_size=group_size), filter_fn_nonsen)
81-
elif quant_sym == "sym":
82-
quantize_(model, intN_weight_only_sym(n=non_sensi_bit, group_size=group_size), filter_fn_nonsen)
83-
else:
84-
if quant_sym == "asym":
85-
quantize_(model.to(device=device), intN_weight_only_asym(n=non_sensi_bit, group_size=group_size), filter_fn_nonsen)
86-
elif quant_sym == "sym":
87-
quantize_(model.to(device=device), intN_weight_only_sym(n=non_sensi_bit, group_size=group_size), filter_fn_nonsen)
54+
if sensi_bit == 4:
55+
quantize_(model, intN_weight_only(n=non_sensi_bit, group_size=group_size, symmetric=quant_sym), filter_fn_nonsen)
56+
else:
57+
quantize_(model.to(device=device), intN_weight_only(n=non_sensi_bit, group_size=group_size, symmetric=quant_sym), filter_fn_nonsen)
8858

8959
if compile:
9060
model = torch.compile(model, mode="max-autotune", fullgraph=True)
@@ -113,13 +83,13 @@ def filter_fn_nonsen(child: torch.nn.Module, cur_fqn:str) -> bool:
11383
parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate')
11484
parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use')
11585
parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation')
116-
parser.add_argument('-q', '--quantization', default = "None", help='Which quantization technique to apply')
86+
parser.add_argument('-q', '--quantization', default = "None", choices = ["2", "3", "4", "5", "6", "8", "MP_llama3", "None"], help='Which quantization technique to apply, choose from ["2", "3", "4", "5", "6", "8"] for uniform quantizatoin, choose "MP_llama3" for mixed-precision for Llama3 and need to set corresponding sensi_bit and non_sensi_bit, choose "None" for no quantization')
11787
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
11888
parser.add_argument('--batch_size', type=int, default=1, help='Batch size to use for evaluation, note int8wo and int4wo work best with small batchsizes, int8dq works better with large batchsizes')
11989
parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time')
120-
parser.add_argument('--sensi_bit', type=int, default=16, help='Bit setting for sensitive layers')
121-
parser.add_argument('--non_sensi_bit', type=int, default=16, help='Bit setting for non-sensitive layers')
122-
parser.add_argument('--quant_sym', type=str, default="asym", help='symmetric or asymmetric quantization')
123-
parser.add_argument('--group_size', type=int, default=32, help='group size to perform quantization on')
90+
parser.add_argument('--sensi_bit', type=int, default=16, choices = [16, 8, 6, 5, 4, 3], help='Bit setting for sensitive layers')
91+
parser.add_argument('--non_sensi_bit', type=int, default=8, choices = [8, 6, 5, 4, 3, 2], help='Bit setting for non-sensitive layers')
92+
parser.add_argument('--quant_sym', type=bool, default=False, help='Symmetric or asymmetric quantization, asymmetric by default')
93+
parser.add_argument('--group_size', type=int, default=32, help='Group size to perform quantization on')
12494
args = parser.parse_args()
12595
run_evaluation(args.repo_id, args.tasks, args.limit, args.device, args.precision, args.quantization, args.compile, args.batch_size, args.max_length, args.sensi_bit, args.non_sensi_bit, args.quant_sym, args.group_size)

torchao/quantization/prototype/mixed_precision/scripts/naive_intNwo.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,26 @@
55
ZeroPointDomain,
66
)
77

8-
def intN_weight_only_asym(group_size=32, n=8):
8+
from torchao.quantization import int8_weight_only, int4_weight_only
9+
10+
11+
def intN_weight_only(group_size=32, n=8, symmetric=False):
12+
'''
13+
Apply int N-bit weight only quantization to a linear layer.
14+
Args:
15+
`groupsize`: parameter for quantization, controls the granularity of quantization, smaller size is more fine grained, choices are [512, 256, 128, 64, 32]
16+
`n`: number of bits to quantize to, choices are [8, 6, 5, 4, 3, 2]
17+
Usage:
18+
from torchao.quantization import quantize_
19+
quantize_(model, intN_weight_only(n=your_bit_choice, group_size=group_size), optional_filter_func_for_desired_layers_to_quantize)
20+
'''
21+
# for asymmetric quantization
922
def apply_intN_weight_only_quant_asym(weight):
10-
# avoid circular dep
23+
# avoid circular dependency
1124
from torchao.dtypes import to_affine_quantized
1225
mapping_type = MappingType.ASYMMETRIC
1326
block_size = (1, group_size)
14-
target_dtype = torch.int8
27+
target_dtype = torch.uint8
1528
quant_min = 0
1629
quant_max = 2**n-1
1730
eps = 1e-6
@@ -20,21 +33,28 @@ def apply_intN_weight_only_quant_asym(weight):
2033
zero_point_domain = ZeroPointDomain.FLOAT
2134
return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero,zero_point_domain=zero_point_domain)
2235

23-
return apply_intN_weight_only_quant_asym
24-
25-
def intN_weight_only_sym(group_size=32, n=8):
36+
# for symmetric quantization
2637
def apply_intN_weight_only_quant_sym(weight):
27-
# avoid circular dep
38+
# avoid circular dependency
2839
from torchao.dtypes import to_affine_quantized
2940
mapping_type = MappingType.SYMMETRIC
3041
block_size = (1, group_size)
3142
target_dtype = torch.int8
32-
quant_min = -2**(n-1)
33-
quant_max = 2**(n-1)-1
3443
eps = 1e-6
35-
preserve_zero = True
36-
zero_point_dtype = torch.bfloat16
37-
zero_point_domain = ZeroPointDomain.INT
38-
return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero,zero_point_domain=zero_point_domain)
44+
zero_point_dtype = torch.int64
45+
return to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
3946

40-
return apply_intN_weight_only_quant_sym
47+
try:
48+
assert n in [8, 6, 5, 4, 3, 2], "n must be one of [8, 6, 5, 4, 3, 2]"
49+
if n == 8:
50+
return int8_weight_only()
51+
elif n == 4:
52+
return int4_weight_only(group_size=group_size)
53+
else:
54+
if symmetric:
55+
return apply_intN_weight_only_quant_sym
56+
else:
57+
return apply_intN_weight_only_quant_asym
58+
except Exception as e:
59+
raise
60+

0 commit comments

Comments
 (0)