Skip to content

Commit 5d33a53

Browse files
authored
Add layer wise quantization doc and ONNXRT example (#1434)
Signed-off-by: yuwenzho <[email protected]>
1 parent 789779b commit 5d33a53

File tree

15 files changed

+287
-90
lines changed

15 files changed

+287
-90
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ q_model = fit(
122122
</tr>
123123
<tr>
124124
<td colspan="4" align="center"><a href="./docs/source/quantization_weight_only.md">Weight-Only Quantization (INT8/INT4/FP4/NF4) </td>
125-
<td colspan="4" align="center"><a href="https://github.com/intel/neural-compressor/blob/fp8_adaptor/docs/source/fp8.md">FP8 Quantization </td>
125+
<td colspan="2" align="center"><a href="https://github.com/intel/neural-compressor/blob/fp8_adaptor/docs/source/fp8.md">FP8 Quantization </td>
126+
<td colspan="2" align="center"><a href="./docs/source/quantization_layer_wise.md">Layer-Wise Quantization </td>
126127
</tr>
127128
</tbody>
128129
<thead>

docs/source/imgs/lwq_ort.png

157 KB
Loading
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
Layer Wise Quantization (LWQ)
2+
=====
3+
4+
1. [Introduction](#introduction)
5+
6+
2. [Supported Framework Model Matrix](#supported-framework-model-matrix)
7+
8+
3. [Examples](#examples)
9+
10+
## Introduction
11+
12+
Large language models (LLMs) have shown exceptional performance across various tasks, meanwhile, the substantial parameter size poses significant challenges for deployment. Layer-wise quantization(LWQ) can greatly reduce the memory footprint of LLMs, usually 80-90% reduction, which means that users can quantize LLMs even on single node using GPU or CPU. We can quantize the model under memory-constrained devices, therefore making the huge-sized LLM quantization possible.
13+
14+
<img src="./imgs/lwq.png" width=780 height=429>
15+
16+
*Figure 1: The process of layer-wise quantization for PyTorch model. The color grey means empty parameters and the color blue represents parameters need to be quantized. Every rectangle inside model represents one layer.*
17+
18+
<img src="./imgs/lwq_ort.png" width=900 height=400>
19+
20+
*Figure 2: The process of layer-wise quantization for ONNX model. The graph of LLM is split into several parts, and each subgraph is quantized in turn.*
21+
22+
## Supported Framework Model Matrix
23+
24+
25+
<table class="tg">
26+
<thead>
27+
<tr>
28+
<th colspan="2" style="text-align:center;vertical-align:middle">Types/Framework</th>
29+
<th style="text-align:center;vertical-align:middle">PyTorch</th>
30+
<th style="text-align:center;vertical-align:middle">ONNX Runtime</th>
31+
</tr>
32+
</thead>
33+
<tbody>
34+
<tr>
35+
<td style="text-align:center;vertical-align:middle" colspan="2">W8A8 Post Training Static Quantization</td>
36+
<td style="text-align:center;vertical-align:middle">&#10004;</td>
37+
<td style="text-align:center;vertical-align:middle">&#10004;</td>
38+
</tr>
39+
<tr>
40+
<td style="text-align:center;vertical-align:middle" rowspan="4">Weight-only Quantization</td>
41+
<td style="text-align:center;vertical-align:middle">RTN</td>
42+
<td style="text-align:center;vertical-align:middle">&#10004;</td>
43+
<td style="text-align:center;vertical-align:middle" rowspan="4">&#10005;</td>
44+
</tr>
45+
<tr>
46+
<td style="text-align:center;vertical-align:middle">AWQ</td>
47+
<td style="text-align:center;vertical-align:middle">&#10005;</td>
48+
</tr>
49+
<tr>
50+
<td style="text-align:center;vertical-align:middle">GPTQ</td>
51+
<td style="text-align:center;vertical-align:middle">&#10004;</td>
52+
</tr>
53+
<tr>
54+
<td style="text-align:center;vertical-align:middle">TEQ</td>
55+
<td style="text-align:center;vertical-align:middle">&#10005;</td>
56+
</tr>
57+
</tbody>
58+
</table>
59+
60+
## Examples
61+
62+
#### PyTorch framework example
63+
64+
```python
65+
from neural_compressor import PostTrainingQuantConfig, quantization
66+
from neural_compressor.adaptor.torch_utils.layer_wise_quant import load_empty_model
67+
68+
fp32_model = load_empty_model(model_name_or_path, torchscript=True)
69+
conf = PostTrainingQuantConfig(
70+
approach="weight_only",
71+
recipes={
72+
"layer_wise_quant": True,
73+
"rtn_args": {"enable_full_range": True},
74+
},
75+
)
76+
77+
q_model = quantization.fit(
78+
fp32_model,
79+
conf,
80+
calib_dataloader=eval_dataloader,
81+
eval_func=lambda x: 0.1,
82+
)
83+
ouput_dir = "./saved_model"
84+
q_model.save(ouput_dir)
85+
q_model = load(ouput_dir, fp32_model, weight_only=True, layer_wise=True)
86+
```
87+
88+
#### ONNX Runtime framework example
89+
90+
```python
91+
from neural_compressor import quantization, PostTrainingQuantConfig
92+
93+
conf = PostTrainingQuantConfig(recipes={"layer_wise_quant": True})
94+
q_model = quantization.fit(fp32_model_path, conf, calib_dataloader=dataloader)
95+
q_model.save(int8_model_path)
96+
```
97+
98+
Refer to [ONNX Runtime llama-2 LWQ example](../../examples/onnxrt/nlp/huggingface_model/text_generation/llama/quantization/weight_only)

docs/source/quantization_weight_only.md

Lines changed: 1 addition & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@ Weight Only Quantization (WOQ)
77

88
3. [Examples](#examples)
99

10-
4. [Layer Wise Quantization](#layer-wise-quantization)
11-
12-
5. [WOQ Algorithms Tuning](#woq-algorithms-tuning)
10+
4. [WOQ Algorithms Tuning](#woq-algorithms-tuning)
1311

1412

1513
## Introduction
@@ -144,50 +142,6 @@ The saved_results folder contains two files: `best_model.pt` and `qconfig.json`,
144142

145143
To seek the performance of weight-only quantized models, Please go to [Intel Extension for Transformers](https://github.com/intel/intel-extension-for-transformers/tree/main/examples/huggingface/pytorch/text-generation/quantization#1-performance) to quantize and deploy the model.
146144

147-
148-
## Layer Wise Quantization
149-
150-
Large language models (LLMs) have shown exceptional performance across various tasks, meanwhile, the substantial parameter size poses significant challenges for deployment. Layer-wise quantization(LWQ) can greatly reduce the memory footprint of LLMs, usually 80-90% reduction, which means that users can quantize LLMs even on single node using GPU or CPU. We can quantize the model under memory-constrained devices, therefore making the huge-sized LLM quantization possible.
151-
152-
<img src="./imgs/lwq.png">
153-
154-
*Figure 1: The process of layer-wise quantization. The color grey means empty parameters and the color blue represents parameters need to be quantized. Every rectangle inside model represents one layer.*
155-
156-
### Supported Matrix
157-
158-
| Algorithms/Framework | PyTorch |
159-
|:--------------:|:----------:|
160-
| RTN | &#10004; |
161-
| AWQ | &#10005; |
162-
| GPTQ | &#10004; |
163-
| TEQ | &#10005; |
164-
165-
### Example
166-
```python
167-
from neural_compressor import PostTrainingQuantConfig, quantization
168-
from neural_compressor.adaptor.torch_utils.layer_wise_quant import load_empty_model
169-
170-
fp32_model = load_empty_model(model_name_or_path, torchscript=True)
171-
conf = PostTrainingQuantConfig(
172-
approach="weight_only",
173-
recipes={
174-
"layer_wise_quant": True,
175-
"rtn_args": {"enable_full_range": True},
176-
},
177-
)
178-
179-
q_model = quantization.fit(
180-
fp32_model,
181-
conf,
182-
calib_dataloader=eval_dataloader,
183-
eval_func=lambda x: 0.1,
184-
)
185-
ouput_dir = "./saved_model"
186-
q_model.save(ouput_dir)
187-
q_model = load(ouput_dir, fp32_model, weight_only=True, layer_wise=True)
188-
```
189-
190-
191145
## WOQ Algorithms Tuning
192146

193147
To find the best algorithm, users can omit specifying a particular algorithm. In comparison to setting a specific algorithm, this tuning process will traverse through a set of pre-defined WOQ configurations and identify the optimal one with the best result. For details usage, please refer to the [tuning strategy](./tuning_strategies.md#Basic).

docs/source/user_guide.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,10 @@ This part provides the advanced topics that help user dive deep into Intel® Neu
8181
<td colspan="4" align="center"><a href="add_new_adaptor.md">Add New Adaptor</a></td>
8282
</tr>
8383
<tr>
84-
<td colspan="4" align="center"><a href="distillation_quantization.md">Distillation for Quantization</a></td>
85-
<td colspan="4" align="center"><a href="smooth_quant.md">SmoothQuant</a></td>
86-
<td colspan="4" align="center"><a href="quantization_weight_only.md">Weight-Only Quantization</a></td>
84+
<td colspan="3" align="center"><a href="distillation_quantization.md">Distillation for Quantization</a></td>
85+
<td colspan="3" align="center"><a href="smooth_quant.md">SmoothQuant</a></td>
86+
<td colspan="3" align="center"><a href="quantization_weight_only.md">Weight-Only Quantization</a></td>
87+
<td colspan="3" align="center"><a href="quantization_layer_wise.md">Layer-Wise Quantization</a></td>
8788
</tr>
8889
</tbody>
8990
</table>

examples/.config/model_params_onnxrt.json

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -763,6 +763,13 @@
763763
"main_script": "main.py",
764764
"batch_size": 1
765765
},
766+
"llama-2-7b-lwq": {
767+
"model_src_dir": "nlp/huggingface_model/text_generation/llama/quantization/ptq_static",
768+
"dataset_location": "",
769+
"input_model": "/tf_dataset2/models/onnx/llama-2-7b",
770+
"main_script": "main.py",
771+
"batch_size": 1
772+
},
766773
"llama-2-7b-rtn": {
767774
"model_src_dir": "nlp/huggingface_model/text_generation/llama/quantization/weight_only",
768775
"dataset_location": "",

examples/onnxrt/nlp/huggingface_model/text_generation/llama/quantization/ptq_static/README.md

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,15 @@ Note that this README.md uses meta-llama/Llama-2-7b-hf as an example. There are
2727

2828
Export to ONNX model:
2929
```bash
30-
optimum-cli export onnx --model meta-llama/Llama-2-7b-hf --task text-generation-with-past ./Llama-2-7b-hf
30+
python prepare_model.py --input_model="meta-llama/Llama-2-7b-hf" --output_model="./llama-2-7b-hf"
3131
```
3232

3333
# Run
3434

3535
## 1. Quantization
3636

37+
### Run SmoothQuant
38+
3739
```bash
3840
bash run_quant.sh --input_model=/path/to/model \ # folder path of onnx model
3941
--output_model=/path/to/model_tune \ # folder path to save onnx model
@@ -44,6 +46,20 @@ bash run_quant.sh --input_model=/path/to/model \ # folder path of onnx model
4446
--quant_format="QOperator" # or QDQ, optional
4547
```
4648

49+
### Run layer-wise quantization
50+
Set `--layer-wise=True` to use layer-wise quantization to save your memory. Please note that layer-wise quantization for ONNX models is still under development and only support W8A8 quantization now. More details please refer to [layer wise quantiation](https://github.com/intel/neural-compressor/blob/master/docs/source/quantization_layer_wise.md).
51+
52+
```bash
53+
bash run_quant.sh --input_model=/path/to/model \ # folder path of onnx model
54+
--output_model=/path/to/model_tune \ # folder path to save onnx model
55+
--batch_size=batch_size # optional \
56+
--dataset NeelNanda/pile-10k \
57+
--tokenizer=meta-llama/Llama-2-7b-hf \ # model name or folder path containing all relevant files for model's tokenizer
58+
--quant_format="QOperator" \ # or QDQ, optional
59+
--layer_wise=True
60+
```
61+
62+
4763
## 2. Benchmark
4864

4965
Accuracy:

examples/onnxrt/nlp/huggingface_model/text_generation/llama/quantization/ptq_static/main.py

Lines changed: 89 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# pylint:disable=redefined-outer-name,logging-format-interpolation
1818
import os
1919
import onnx
20+
import json
2021
import torch
2122
import logging
2223
import argparse
@@ -116,6 +117,11 @@
116117
type=int,
117118
default=4
118119
)
120+
parser.add_argument(
121+
'--layer_wise',
122+
action='store_true', \
123+
default=False,
124+
)
119125
args = parser.parse_args()
120126

121127
# load model
@@ -131,16 +137,26 @@ def benchmark(model):
131137
config = LlamaConfig.from_pretrained(args.model_path)
132138
sess_options = ort.SessionOptions()
133139
sess_options.intra_op_num_threads = args.intra_op_num_threads
134-
sessions = ORTModelForCausalLM.load_model(
135-
os.path.join(model, 'decoder_model.onnx'),
136-
os.path.join(model, 'decoder_with_past_model.onnx'),
140+
141+
if os.path.exists(os.path.join(model, "decoder_with_past_model.onnx")):
142+
sessions = ORTModelForCausalLM.load_model( # pylint: disable=E1123
143+
os.path.join(model, "decoder_model.onnx"),
144+
os.path.join(model, "decoder_with_past_model.onnx"),
137145
session_options=sess_options)
138-
model = ORTModelForCausalLM(
139-
sessions[0],
140-
config,
141-
model,
142-
sessions[1],
143-
use_cache=True)
146+
model = ORTModelForCausalLM(sessions[0], # pylint: disable=E1121
147+
config,
148+
model,
149+
sessions[1],
150+
use_cache=True)
151+
else:
152+
sessions = ORTModelForCausalLM.load_model( # pylint: disable=E1123
153+
os.path.join(model, "decoder_model.onnx"),
154+
session_options=sess_options)
155+
model = ORTModelForCausalLM(sessions[0], # pylint: disable=E1121
156+
config,
157+
model,
158+
use_cache=False,
159+
use_io_binding=False)
144160

145161
input_tokens = '32'
146162
max_new_tokens = 32
@@ -173,23 +189,50 @@ def benchmark(model):
173189
total_time += toc - tic
174190

175191
print("\n", "-" * 10, "Summary:", "-" * 10)
176-
latency = total_time / (num_iter - num_warmup)
177192
print(args)
178-
print("Inference latency: %.3f sec." % latency)
193+
throughput = (num_iter - num_warmup) / total_time
194+
print("Throughput: {} samples/s".format(throughput))
195+
196+
197+
def replace_architectures(json_path):
198+
# replace 'LLaMATokenizer' to lowercase 'LlamaTokenizer'
199+
# to avoid bug 'Tokenizer class LLaMATokenizer does not exist or is not currently imported.'
200+
# refer to https://github.com/huggingface/transformers/issues/22222#issuecomment-1477171703
201+
with open(json_path, "r") as file:
202+
data = json.load(file)
203+
data["architectures"] = ["LlamaForCausalLM"]
204+
205+
with open(json_path, 'w') as file:
206+
json.dump(data, file, indent=4)
179207

180208
def eval_func(model):
209+
model_dir = model
210+
if isinstance(model, str) and model.endswith(".onnx"):
211+
model_dir = os.path.dirname(model)
212+
213+
replace_architectures(os.path.join(model_dir, "config.json"))
214+
181215
results = evaluate(
182216
model="hf-causal",
183-
model_args='pretrained=' + model + ',tokenizer='+ args.tokenizer,
217+
model_args="pretrained=" + model_dir + ",tokenizer="+ args.tokenizer,
184218
batch_size=args.batch_size,
185219
tasks=args.tasks,
186-
model_format="onnx"
220+
model_format="onnx",
187221
)
222+
223+
eval_acc = 0
188224
for task_name in args.tasks:
189225
if task_name == "wikitext":
190226
print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["word_perplexity"]))
227+
eval_acc += results["results"][task_name]["word_perplexity"]
191228
else:
192229
print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["acc"]))
230+
eval_acc += results["results"][task_name]["acc"]
231+
232+
if len(args.tasks) != 0:
233+
eval_acc /= len(args.tasks)
234+
235+
return eval_acc
193236

194237
class KVDataloader:
195238
def __init__(self, model_path, pad_max=196, batch_size=1, sub_folder='train'):
@@ -258,15 +301,36 @@ def __iter__(self):
258301

259302
if args.tune:
260303
from neural_compressor import quantization, PostTrainingQuantConfig
261-
config = PostTrainingQuantConfig(
262-
calibration_sampling_size=[8],
263-
recipes={'optypes_to_exclude_output_quant': ['MatMul'],
264-
'smooth_quant': True,
265-
'smooth_quant_args': {'alpha': args.smooth_quant_alpha}},
266-
op_type_dict={'^((?!(MatMul|Gather|Conv)).)*$': {'weight': {'dtype': ['fp32']}, 'activation': {'dtype': ['fp32']}}})
267-
for model in ['decoder_model.onnx', 'decoder_with_past_model.onnx']:
268-
q_model = quantization.fit(
269-
os.path.join(args.model_path, model),
270-
config,
271-
calib_dataloader=KVDataloader(os.path.join(args.model_path, model), pad_max=args.pad_max, batch_size=1))
272-
q_model.save(os.path.join(args.output_model, model))
304+
if args.layer_wise:
305+
# layer-wise quantization for ONNX models is still under development and only support W8A8 quantization now
306+
config = PostTrainingQuantConfig(
307+
calibration_sampling_size=[8],
308+
recipes={'optypes_to_exclude_output_quant': ['MatMul'],
309+
'layer_wise_quant': True},
310+
op_type_dict={'^((?!(MatMul|Gather|Conv)).)*$': {'weight': {'dtype': ['fp32']}, 'activation': {'dtype': ['fp32']}}})
311+
for model in ['decoder_model.onnx']:
312+
# only test decoder_model
313+
q_model = quantization.fit(
314+
os.path.join(args.model_path, model),
315+
config,
316+
calib_dataloader=KVDataloader(os.path.join(args.model_path, model), pad_max=args.pad_max, batch_size=1))
317+
q_model.save(os.path.join(args.output_model, model))
318+
319+
tokenizer.save_pretrained(args.output_model)
320+
321+
else:
322+
config = PostTrainingQuantConfig(
323+
calibration_sampling_size=[8],
324+
recipes={'optypes_to_exclude_output_quant': ['MatMul'],
325+
'smooth_quant': True,
326+
'smooth_quant_args': {'alpha': args.smooth_quant_alpha},
327+
},
328+
op_type_dict={'^((?!(MatMul|Gather|Conv)).)*$': {'weight': {'dtype': ['fp32']}, 'activation': {'dtype': ['fp32']}}})
329+
for model in ['decoder_model.onnx', 'decoder_with_past_model.onnx']:
330+
q_model = quantization.fit(
331+
os.path.join(args.model_path, model),
332+
config,
333+
calib_dataloader=KVDataloader(os.path.join(args.model_path, model), pad_max=args.pad_max, batch_size=1))
334+
q_model.save(os.path.join(args.output_model, model))
335+
336+
tokenizer.save_pretrained(args.output_model)

0 commit comments

Comments
 (0)