Skip to content

Commit d4662ad

Browse files
authored
Add transformers-like api doc (#2018)
Signed-off-by: Kaihui-intel <[email protected]>
1 parent 72398b6 commit d4662ad

File tree

1 file changed

+214
-0
lines changed

1 file changed

+214
-0
lines changed
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
Transformers-like API
2+
=====
3+
4+
1. [Introduction](#introduction)
5+
6+
2. [Supported Algorithms](#supported-algorithms)
7+
8+
3. [Usage For Intel CPU](#usage-for-cpu)
9+
10+
4. [Usage For Intel GPU](#usage-for-intel-gpu)
11+
12+
5. [Examples](#examples)
13+
14+
## Introduction
15+
16+
Transformers-like API provides a seamless user experience of model compressions on Transformer-based models by extending Hugging Face transformers APIs, leveraging neural compressor existing weight-only quantization capability and replacing Linear operator with Intel® Extension for PyTorch.
17+
18+
## Supported Algorithms
19+
20+
| Support Device | RTN | AWQ | TEQ | GPTQ | AutoRound |
21+
|:--------------:|:----------:|:----------:|:----------:|:----:|:----:|
22+
| Intel CPU | &#10004; | &#10004; | &#10004; | &#10004; | &#10004; |
23+
| Intel GPU | &#10004; | stay tuned | stay tuned | &#10004; | &#10004; |
24+
25+
> Please refer to [weight-only quantization document](./PT_WeightOnlyQuant.md) for more details.
26+
27+
28+
## Usage For CPU
29+
30+
Our motivation is to improve CPU support for weight only quantization. We have extended the `from_pretrained` function so that `quantization_config` can accept [`RtnConfig`](https://github.com/intel/neural-compressor/blob/master/neural_compressor/transformers/utils/quantization_config.py#L243), [`AwqConfig`](https://github.com/intel/neural-compressor/blob/72398b69334d90cdd7664ac12a025cd36695b55c/neural_compressor/transformers/utils/quantization_config.py#L394), [`TeqConfig`](https://github.com/intel/neural-compressor/blob/72398b69334d90cdd7664ac12a025cd36695b55c/neural_compressor/transformers/utils/quantization_config.py#L464), [`GPTQConfig`](https://github.com/intel/neural-compressor/blob/72398b69334d90cdd7664ac12a025cd36695b55c/neural_compressor/transformers/utils/quantization_config.py#L298), [`AutoroundConfig`](https://github.com/intel/neural-compressor/blob/72398b69334d90cdd7664ac12a025cd36695b55c/neural_compressor/transformers/utils/quantization_config.py#L527) to implements conversion on the CPU.
31+
32+
### Usage examples for CPU device
33+
quantization and inference with `RtnConfig`, `AwqConfig`, `TeqConfig`, `GPTQConfig`, `AutoRoundConfig` on CPU device.
34+
```python
35+
# RTN
36+
from neural_compressor.transformers import AutoModelForCausalLM, RtnConfig
37+
38+
model_name_or_path = "MODEL_NAME_OR_PATH"
39+
woq_config = RtnConfig(bits=4)
40+
q_model = AutoModelForCausalLM.from_pretrained(
41+
model_name_or_path,
42+
quantization_config=woq_config,
43+
)
44+
45+
# AWQ
46+
from neural_compressor.transformers import AutoModelForCausalLM, AwqConfig
47+
48+
model_name_or_path = "MODEL_NAME_OR_PATH"
49+
woq_config = AwqConfig(bits=4)
50+
q_model = AutoModelForCausalLM.from_pretrained(
51+
model_name_or_path,
52+
quantization_config=woq_config,
53+
)
54+
55+
# TEQ
56+
from transformers import AutoTokenizer
57+
from neural_compressor.transformers import AutoModelForCausalLM, TeqConfig
58+
59+
model_name_or_path = "MODEL_NAME_OR_PATH"
60+
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
61+
woq_config = TeqConfig(bits=4, tokenizer=tokenizer)
62+
q_model = AutoModelForCausalLM.from_pretrained(
63+
model_name_or_path,
64+
quantization_config=woq_config,
65+
)
66+
67+
# GPTQ
68+
from transformers import AutoTokenizer
69+
from neural_compressor.transformers import AutoModelForCausalLM, GPTQConfig
70+
71+
model_name_or_path = "MODEL_NAME_OR_PATH"
72+
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
73+
woq_config = GPTQConfig(bits=4, tokenizer=tokenizer)
74+
woq_model = AutoModelForCausalLM.from_pretrained(
75+
model_name_or_path,
76+
quantization_config=woq_config,
77+
)
78+
79+
# AutoRound
80+
from transformers import AutoTokenizer
81+
from neural_compressor.transformers import AutoModelForCausalLM, AutoRoundConfig
82+
83+
model_name_or_path = "MODEL_NAME_OR_PATH"
84+
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
85+
woq_config = AutoRoundConfig(bits=4, tokenizer=tokenizer)
86+
woq_model = AutoModelForCausalLM.from_pretrained(
87+
model_name_or_path,
88+
quantization_config=woq_config,
89+
)
90+
91+
# inference
92+
from transformers import AutoTokenizer
93+
94+
prompt = "Once upon a time, a little girl"
95+
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
96+
97+
input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"]
98+
generate_kwargs = dict(do_sample=False, temperature=0.9, num_beams=4)
99+
gen_ids = q_model.generate(input_ids, **generate_kwargs)
100+
gen_text = tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
101+
print(gen_text)
102+
```
103+
104+
You can also save and load your quantized low bit model by the below code.
105+
106+
```python
107+
# quant
108+
from neural_compressor.transformers import AutoModelForCausalLM, RtnConfig
109+
110+
model_name_or_path = "MODEL_NAME_OR_PATH"
111+
woq_config = RtnConfig(bits=4)
112+
q_model = AutoModelForCausalLM.from_pretrained(
113+
model_name_or_path,
114+
quantization_config=woq_config,
115+
)
116+
117+
# save quant model
118+
saved_dir = "SAVE_DIR"
119+
q_model.save_pretrained(saved_dir)
120+
121+
# load quant model
122+
loaded_model = AutoModelForCausalLM.from_pretrained(saved_dir)
123+
```
124+
125+
## Usage For Intel GPU
126+
Intel® Neural Compressor implement weight-only quantization for Intel GPU,(PVC/ARC/MTL/LNL) with [intel-extension-for-pytorch](https://github.com/intel/intel-extension-for-pytorch).
127+
128+
Now 4-bit/8-bit inference with `RtnConfig`, `GPTQConfig`, `AutoRoundConfig` are support on Intel GPU device.
129+
130+
We support experimental woq inference on Intel GPU,(PVC/ARC/MTL/LNL) with replacing Linear op in PyTorch. Validated models: meta-llama/Meta-Llama-3-8B, meta/llama-Llama-2-7b-hf, Qwen/Qwen-7B-Chat, microsoft/Phi-3-mini-4k-instruct.
131+
132+
Here are the example codes.
133+
134+
#### Prepare Dependency Packages
135+
1. Install Oneapi Package
136+
The Oneapi DPCPP compiler is required to compile intel-extension-for-pytorch. Please follow [the link](https://www.intel.com/content/www/us/en/developer/articles/guide/installation-guide-for-oneapi-toolkits.html) to install the OneAPI to "/opt/intel folder".
137+
138+
2. Build and Install PyTorch and intel-extension-for-pytorch. Please follow [the link](https://intel.github.io/intel-extension-for-pytorch/index.html#installation).
139+
140+
3. Quantization Model and Inference
141+
```python
142+
import intel_extension_for_pytorch as ipex
143+
from neural_compressor.transformers import AutoModelForCausalLM
144+
from transformers import AutoTokenizer
145+
import torch
146+
147+
model_name_or_path = "Qwen/Qwen-7B-Chat" # MODEL_NAME_OR_PATH
148+
prompt = "Once upon a time, a little girl"
149+
input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"]
150+
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
151+
152+
q_model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map="xpu", trust_remote_code=True)
153+
154+
# optimize the model with ipex, it will improve performance.
155+
quantization_config = q_model.quantization_config if hasattr(q_model, "quantization_config") else None
156+
q_model = ipex.optimize_transformers(
157+
q_model, inplace=True, dtype=torch.float16, quantization_config=quantizaiton_config, device="xpu"
158+
)
159+
160+
output = q_model.generate(input_ids, max_new_tokens=100, do_sample=True)
161+
print(tokenizer.batch_decode(output, skip_special_tokens=True))
162+
```
163+
164+
> Note: If your device memory is not enough, please quantize and save the model first, then rerun the example with loading the model as below, If your device memory is enough, skip below instruction, just quantization and inference.
165+
166+
4. Saving and Loading quantized model
167+
* First step: Quantize and save model
168+
```python
169+
from neural_compressor.transformers import AutoModelForCausalLM, RtnConfig
170+
171+
model_name_or_path = "MODEL_NAME_OR_PATH"
172+
woq_config = RtnConfig(bits=4)
173+
q_model = AutoModelForCausalLM.from_pretrained(
174+
model_name_or_path, quantization_config=woq_config, device_map="xpu", trust_remote_code=True
175+
)
176+
177+
# Please note, saving model should be executed before ipex.optimize_transformers function is called.
178+
q_model.save_pretrained("saved_dir")
179+
```
180+
* Second step: Load model and inference(In order to reduce memory usage, you may need to end the quantize process and rerun the script to load the model.)
181+
```python
182+
# Load model
183+
loaded_model = AutoModelForCausalLM.from_pretrained("saved_dir", trust_remote_code=True)
184+
185+
# Before executed the loaded model, you can call ipex.optimize_transformers function.
186+
quantization_config = q_model.quantization_config if hasattr(q_model, "quantization_config") else None
187+
loaded_model = ipex.optimize_transformers(
188+
loaded_model, inplace=True, dtype=torch.float16, quantization_config=quantization_config, device="xpu"
189+
)
190+
191+
# inference
192+
from transformers import AutoTokenizer
193+
194+
prompt = "Once upon a time, a little girl"
195+
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
196+
input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"]
197+
generate_kwargs = dict(do_sample=False, temperature=0.9, num_beams=4)
198+
gen_ids = loaded_model.generate(input_ids, **generate_kwargs)
199+
gen_text = tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
200+
print(gen_text)
201+
```
202+
203+
5. You can directly use [example script](https://github.com/intel/neural-compressor/blob/master/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/transformers/weight_only/text-generation/run_generation_gpu_woq.py)
204+
```bash
205+
python run_generation_gpu_woq.py --woq --benchmark --model save_dir
206+
```
207+
208+
>Note:
209+
> * Saving quantized model should be executed before the optimize_transformers function is called.
210+
> * The optimize_transformers function is designed to optimize transformer-based models within frontend Python modules, with a particular focus on Large Language Models (LLMs). It provides optimizations for both model-wise and content-generation-wise. The detail of `optimize_transformers`, please refer to [the link](https://github.com/intel/intel-extension-for-pytorch/blob/xpu-main/docs/tutorials/llm/llm_optimize_transformers.md).
211+
212+
## Examples
213+
214+
Users can also refer to [examples](https://github.com/intel/neural-compressor/blob/master/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/transformers/weight_only/text-generation) on how to quantize a model with transformers-like api.

0 commit comments

Comments
 (0)