Skip to content

Commit f81fe11

Browse files
authored
Activation Aware Weight Quantization (AWQ) (#743)
Integrate AWQ within the TorchAO framework
1 parent 92dd5f5 commit f81fe11

File tree

15 files changed

+807
-9
lines changed

15 files changed

+807
-9
lines changed

scripts/create_weight_map.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import json
2+
import torch
3+
from transformers import AutoModel
4+
from pathlib import Path
5+
def create_weight_map(checkpoint_dir: Path):
6+
"""
7+
This function, create_weight_map, generates a mapping of a model's weights to a file (pytorch_model.bin)
8+
and saves this mapping, along with the model's total size, to a JSON file (pytorch_model.bin.index.json).
9+
The model is loaded from a pre-trained model specified by model_name.
10+
This weight map is used by the HF conversion script (convert_hf_checkpoint.py).
11+
"""
12+
# Load the model
13+
model_name = checkpoint_dir.parent.name +"/"+ checkpoint_dir.name
14+
print(model_name)
15+
model = AutoModel.from_pretrained(model_name)
16+
# Get the state dict
17+
state_dict = model.state_dict()
18+
# Create the weight map
19+
weight_map = {}
20+
for key, tensor in state_dict.items():
21+
# In this example, we're assuming all weights are in a single file
22+
# You may need to adjust this if your model uses sharded weights
23+
weight_map[key] = "pytorch_model.bin"
24+
# Create the index dictionary
25+
index_dict = {
26+
"metadata": {"total_size": sum(param.numel() * param.element_size() for param in model.parameters())},
27+
"weight_map": weight_map
28+
}
29+
# Save the index dictionary to a JSON file
30+
with open(f"{checkpoint_dir}/pytorch_model.bin.index.json", "w") as f:
31+
json.dump(index_dict, f, indent=2)
32+
print("Created pytorch_model.bin.index.json")
33+
34+
if __name__ == '__main__':
35+
import argparse
36+
parser = argparse.ArgumentParser(description='Create weight map for hf model')
37+
parser.add_argument('--checkpoint_dir', type=Path, default=Path("checkpoints/Xenova/llama2.c-stories15M"))
38+
39+
40+
args = parser.parse_args()
41+
create_weight_map(
42+
args.checkpoint_dir
43+
)

scripts/hf_eval.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def format_value(value):
4848
def run_evaluation(repo_id, tasks, limit, device, precision, quantization, sparsity, compile, save, batch_size, max_length):
4949

5050
tokenizer = AutoTokenizer.from_pretrained(repo_id)
51-
model = AutoModelForCausalLM.from_pretrained(repo_id).to(dtype=precision, device=device)
51+
model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=precision).to(device)
5252

5353
if quantization == "autoquant" and compile:
5454
model = torch.compile(model, mode="max-autotune", fullgraph=True)
@@ -64,9 +64,29 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, spars
6464
quantize_(model, fpx_weight_only(3, 2))
6565
elif quantization == "autoquant":
6666
model = autoquant(model.to(device=device))
67+
elif quantization == "awq":
68+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3
69+
from torchao.prototype.awq.example import get_calib_dataset
70+
if not TORCH_VERSION_AT_LEAST_2_3:
71+
print("AWQ quantization requires torch2.3+")
72+
exit()
73+
from torchao.prototype.awq import insert_awq_observer_, awq_uintx, AWQObservedLinear
74+
quant_dtype = torch.uint4
75+
group_size = 64
76+
calibration_limit = 10
77+
calibration_seq_length = 1024
78+
model=model.to(device)
79+
insert_awq_observer_(model,calibration_limit, calibration_seq_length, quant_dtype=quant_dtype, group_size=group_size)
80+
with torch.no_grad():
81+
calibration_data = get_calib_dataset(tokenizer=tokenizer, n_samples=calibration_limit, block_size=calibration_seq_length)
82+
for batch in calibration_data:
83+
model(batch.to(device))
84+
del batch
85+
is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear)
86+
quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size), is_observed_linear)
6787

6888
if quantization != "autoquant" and compile:
69-
model = torch.compile(model, mode="max-autotune", fullgraph=True)
89+
model = torch.compile(model, mode= "max-autotune", fullgraph=True)
7090

7191
if sparsity == "semi_sparse":
7292
def all_linear(mod, name):
@@ -114,7 +134,7 @@ def all_linear(mod, name):
114134
parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate')
115135
parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use')
116136
parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation')
117-
parser.add_argument('-q', '--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo", "autoquant", "None"], help='Which quantization technique to apply')
137+
parser.add_argument('-q', '--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo","autoquant", "awq", "None"], help='Which quantization technique to apply')
118138
parser.add_argument('-s', '--sparsity', default = "None", choices=["semi_sparse", "semi_sparse_mlp_only", "None"], help='Which sparsity technique to apply')
119139
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
120140
parser.add_argument('--save', action='store_true', help='Whether to save the model.')

test/dtypes/test_uintx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import torch
66

7-
from torchao.dtypes.uintx.uintx import to_uintx
7+
from torchao.dtypes.uintx import to_uintx
88
from torchao.quantization.quant_api import quantize_, uintx_weight_only
99
from torchao.utils import (
1010
TORCH_VERSION_AT_LEAST_2_3,

test/prototype/test_awq.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
from copy import deepcopy
2+
import os
3+
import pytest
4+
import torch
5+
from torchao.quantization import quantize_
6+
7+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5
8+
if TORCH_VERSION_AT_LEAST_2_3:
9+
from torchao.prototype.awq import insert_awq_observer_, awq_uintx, AWQObservedLinear
10+
11+
class ToyLinearModel(torch.nn.Module):
12+
def __init__(self, m=512, n=256, k=128):
13+
super().__init__()
14+
self.linear1 = torch.nn.Linear(m, n, bias=False)
15+
self.linear2 = torch.nn.Linear(n, k, bias=False)
16+
self.linear3 = torch.nn.Linear(k, 1, bias=False)
17+
18+
def example_inputs(self, batch_size, sequence_length=10, dtype=torch.bfloat16, device="cuda"):
19+
return [torch.randn(1, sequence_length, self.linear1.in_features, dtype=dtype, device=device) for j in range(batch_size)]
20+
21+
def forward(self, x):
22+
x = self.linear1(x)
23+
x = self.linear2(x)
24+
x = self.linear3(x)
25+
return x
26+
27+
devices = ["cpu", "cuda"]
28+
# torch.uintx dtypes are introduced in 2.3
29+
if TORCH_VERSION_AT_LEAST_2_3:
30+
qdtypes = (torch.uint4, torch.uint7)
31+
else:
32+
qdtypes = ()
33+
34+
@pytest.fixture(autouse=True)
35+
def run_before_and_after_tests():
36+
yield
37+
torch._dynamo.reset() # reset cache between tests
38+
39+
@pytest.mark.parametrize("device", devices)
40+
@pytest.mark.parametrize("qdtype", qdtypes)
41+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
42+
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5,reason="requires nightly pytorch")
43+
def test_awq_loading(device, qdtype):
44+
if qdtype == torch.uint4 and device == "cpu":
45+
pytest.skip("uint4 not supported on cpu")
46+
47+
dataset_size = 100
48+
l1,l2,l3 = 512,256,128
49+
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
50+
quant_dtype = qdtype
51+
group_size = 128
52+
n_calibration_examples = 10
53+
n_validation_examples = 10
54+
sequence_length = 5
55+
56+
m = ToyLinearModel(l1,l2,l3).eval().to(original_dtype).to(device)
57+
dataset = m.example_inputs(dataset_size, sequence_length=sequence_length, dtype=original_dtype, device=device)
58+
calibration_data = dataset[:n_calibration_examples]
59+
60+
# calibrate
61+
insert_awq_observer_(m, n_validation_examples, sequence_length, quant_dtype=quant_dtype, group_size=group_size)
62+
63+
for example in calibration_data:
64+
m(example.to(device))
65+
66+
67+
# quantize
68+
is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear)
69+
quantize_(m, awq_uintx(quant_dtype = quant_dtype, group_size = group_size), is_observed_linear)
70+
71+
model_save_path = "awq_model.pth"
72+
torch.save(m, model_save_path)
73+
loaded_model = torch.load(model_save_path)
74+
os.remove(model_save_path)
75+
76+
if torch.cuda.is_available():
77+
m = torch.compile(m, fullgraph=True)
78+
loaded_model = torch.compile(loaded_model, fullgraph=True)
79+
80+
awq_out = torch.cat([m(i.squeeze(0)) for i in dataset])
81+
awq_save_load_out = torch.cat([loaded_model(i.squeeze(0)) for i in dataset])
82+
83+
assert awq_out is not None
84+
assert awq_save_load_out is not None
85+
assert torch.allclose(awq_out, awq_save_load_out, atol = 1e-2)
86+
87+
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5,reason="requires nightly pytorch")
88+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
89+
def test_save_weights_only():
90+
dataset_size = 100
91+
l1,l2,l3 = 512,256,128
92+
original_dtype = torch.bfloat16
93+
quant_dtype = torch.uint4
94+
device = "cuda"
95+
group_size = 128
96+
n_calibration_examples = 10
97+
n_validation_examples = 10
98+
sequence_length = 5
99+
100+
m = ToyLinearModel(l1,l2,l3).eval().to(original_dtype).to(device)
101+
m2 = deepcopy(m)
102+
dataset = m.example_inputs(dataset_size, sequence_length=sequence_length, dtype=original_dtype, device=device)
103+
calibration_data = dataset[:n_calibration_examples]
104+
105+
# calibrate
106+
insert_awq_observer_(m, n_validation_examples, sequence_length, quant_dtype=quant_dtype, group_size=group_size)
107+
108+
for example in calibration_data:
109+
m(example.to(device))
110+
111+
112+
# quantize
113+
is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear)
114+
quantize_(m, awq_uintx(quant_dtype = quant_dtype, group_size = group_size), is_observed_linear)
115+
116+
model_save_path = "awq_model.pth"
117+
torch.save(m.state_dict(), model_save_path)
118+
m2.load_state_dict(torch.load(model_save_path), assign=True) # load weights only.torch.load(model_save_path)
119+
os.remove(model_save_path)
120+
121+
m = torch.compile(m, fullgraph=True)
122+
m2 = torch.compile(m2, fullgraph=True)
123+
124+
awq_out = torch.cat([m(i.squeeze(0)) for i in dataset])
125+
awq_save_load_out = torch.cat([m2(i.squeeze(0)) for i in dataset])
126+
127+
assert awq_out is not None
128+
assert awq_save_load_out is not None
129+
assert torch.allclose(awq_out, awq_save_load_out, atol = 1e-2)

torchao/_models/llama/eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,4 +255,4 @@ def run_evaluation(
255255
args.calibration_limit,
256256
args.calibration_seq_length,
257257
args.pad_calibration_inputs,
258-
)
258+
)

torchao/_models/llama/generate.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@ def main(
161161
temperature: float = 0.8,
162162
checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"),
163163
quantization: Optional[str] = None,
164+
calibration_limit: int = 10,
165+
calibration_seq_length: int = 256,
164166
kv_cache_quantization: bool = False,
165167
cache_size: Optional[int] = None,
166168
linear_causal_mask: bool=False,
@@ -232,6 +234,33 @@ def main(
232234
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
233235
if "fp6" in quantization:
234236
quantize_(model, fpx_weight_only(3, 2))
237+
if quantization.startswith("awq"):
238+
from torchao._models._eval import TransformerEvalWrapper
239+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3
240+
from torchao.prototype.awq.example import get_calib_dataset
241+
if not TORCH_VERSION_AT_LEAST_2_3:
242+
print("Awq requires torch2.3+")
243+
exit()
244+
from torchao.prototype.awq import insert_awq_observer_, awq_uintx, AWQObservedLinear
245+
quant_dtype = quantization.split("-")[1]
246+
group_size = int(quantization.split("-")[2])
247+
quant_dtype = getattr(torch, quant_dtype, torch.uint8)
248+
model=model.to(device)
249+
# get calibration data
250+
insert_awq_observer_(model, calibration_limit, calibration_seq_length, quant_dtype=quant_dtype, group_size=group_size)
251+
TransformerEvalWrapper(
252+
model=model.to(device),
253+
tokenizer=tokenizer,
254+
max_seq_length=calibration_seq_length,
255+
input_prep_func=prepare_inputs_for_model,
256+
device=device,
257+
).run_eval(
258+
tasks=['wikitext'],
259+
limit=calibration_limit,
260+
)
261+
is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear)
262+
use_hqq = "hqq" in quantization
263+
quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size, use_hqq=use_hqq), is_observed_linear)
235264
if "uintx" in quantization:
236265
# uintx-nbits-groupsize, e.g. "uintx-2-64"
237266
if "hqq" in quantization:
@@ -434,6 +463,8 @@ def callback(x):
434463
+'autoquant-int4, autoquant-float8, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin'
435464
)
436465
)
466+
parser.add_argument("--calibration_limit", type=int, default=10, help="Number of calibration examples")
467+
parser.add_argument("--calibration_seq_length", type=int, default=256, help="Sequence length for calibration")
437468
parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache')
438469
parser.add_argument('--cache_size', type=int, default=None, help='Force size of cache to be a certain number of tokens, if not set, will use max_new_tokens+prompt_size')
439470
parser.add_argument('--linear_causal_mask', action='store_true', help='Whether to use the memory efficient, but slightly less fast, linear causal mask (important for long context lengths)')
@@ -449,5 +480,5 @@ def callback(x):
449480
args = parser.parse_args()
450481
main(
451482
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
452-
args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result
483+
args.temperature, args.checkpoint_path, args.quantization, args.calibration_limit, args.calibration_seq_length, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result
453484
)

torchao/dtypes/uintx/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .uintx import UintxTensor, UintxLayoutType, UintxAQTLayout, to_uintx, _DTYPE_TO_BIT_WIDTH

torchao/dtypes/uintx/uintx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
_BIT_WIDTH_TO_DTYPE = {v: k for k, v in _DTYPE_TO_BIT_WIDTH.items()}
3232
else:
33-
print("uintx feature need torch 2.3+, please upgrade pytorch")
33+
print("uintx feature requires torch 2.3+, please upgrade pytorch")
3434

3535

3636
class UintxTensor(TorchAOBaseTensor):

torchao/prototype/awq/README.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# AWQ Quantization
2+
Adapted from https://github.com/mit-han-lab/llm-awq
3+
4+
## Benchmarks
5+
Evaluation perplexity numbers were calculated using the script in awq/example.py Group size of 64 was used for all quantization methods. For Llama-2-7b-chat-hf, performance benchmarks were calculated using the torchao/_models/llama/generate.py script and run on a 1xA100 80GB SXM4 instance. The awq-uint4 quantization method does not use an efficient fused kernel which is why performance is not great. awq-hqq uses tinygemm int4->bf16 kernel + hqq to provide better performance.
6+
7+
| Model | Quantization | Tokens/sec | Throughput (GB/sec) | Peak Mem (GB) | Model Size (GB) |
8+
|--------------------|--------------|------------|---------------------|---------------|-----------------|
9+
| Llama-2-7b-chat-hf | bfloat16 | 107.38 | 1418.93 | 13.88 | 13.21 |
10+
| | awq-hqq-int4 | 196.6 | 761.2 | 5.05 | 3.87 |
11+
| | awq-uint4 | 43.59 | 194.93 | 7.31 | 4.47 |
12+
| | int4wo-hqq | 209.19 | 804.32 | 4.89 | 3.84 |
13+
| | int4wo-64 | 201.14 | 751.42 | 4.87 | 3.74 |
14+
15+
16+
17+
The following tests were performed using LM eval and groupsize = 128
18+
| Model | Quantization | Perplexity | Truthful QA MC2 | WinoGrande | ARC challenge |
19+
| Llama-3-8B-Instruct| bfloat16 | 10.936 | 0.540 | 0.783 | 0.567 |
20+
| | awq-hqq-int4 | 11.383 | 0.522 | 0.772 | 0.543 |
21+
| | awq-uint4 | 11.409 | 0.519 | 0.756 | 0.577 |
22+
| | int4wo-hqq | 11.905 | 0.528 | 0.757 | 0.563 |
23+
| | int4wo-128 | 12.380 | 0.502 | 0.753 | 0.548 |
24+
25+
26+
27+
28+
29+

torchao/prototype/awq/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .api import insert_awq_observer_, awq_uintx
2+
from .core import AWQObservedLinear

0 commit comments

Comments
 (0)