Skip to content

Commit 4f2c35d

Browse files
authored
Add TF Example gpt-j-6B (#1446)
Signed-off-by: zehao-intel <[email protected]>
1 parent 5d33a53 commit 4f2c35d

File tree

9 files changed

+635
-0
lines changed

9 files changed

+635
-0
lines changed

examples/.config/model_params_tensorflow.json

+7
Original file line numberDiff line numberDiff line change
@@ -1787,6 +1787,13 @@
17871787
"main_script": "main.py",
17881788
"batch_size": 16
17891789
},
1790+
"gpt-j-6B": {
1791+
"model_src_dir": "nlp/large_language_models/quantization/ptq/gpt-j",
1792+
"dataset_location": "",
1793+
"input_model": "/tf_dataset2/models/tensorflow/gpt-j-6B",
1794+
"main_script": "main.py",
1795+
"batch_size": 1
1796+
},
17901797
"opt_125m_sq": {
17911798
"model_src_dir": "nlp/large_language_models/quantization/ptq/smoothquant",
17921799
"dataset_location": "",

examples/README.md

+6
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,12 @@ Intel® Neural Compressor validated examples with multiple compression technique
289289
<td>Post-Training Static Quantization</td>
290290
<td><a href="./tensorflow/graph_networks/graphsage/">pb</a></td>
291291
</tr>
292+
<tr>
293+
<td>EleutherAI/gpt-j-6B</td>
294+
<td>Natural Language Processing</td>
295+
<td>Post-Training Static Quantization</td>
296+
<td><a href="./tensorflow/nlp/large_language_models/quantization/ptq/gpt-j">saved_model (smooth quant)</a></td>
297+
</tr>
292298
</tbody>
293299
</table>
294300

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
Step-by-Step
2+
============
3+
4+
This document is used to list steps of reproducing TensorFlow Intel® Neural Compressor smooth quantization of language models gpt-j-6B.
5+
6+
# Prerequisite
7+
8+
## 1. Environment
9+
10+
### Installation
11+
```shell
12+
# Install Intel® Neural Compressor
13+
pip install neural-compressor
14+
pip install -r requirements
15+
```
16+
17+
## 2. Prepare Pretrained model
18+
Run the follow script to download gpt-j-6B saved_model to ```./gpt-j-6B```:
19+
```
20+
bash prepare_model.sh
21+
```
22+
23+
## 3. Install TensorFlow 2.11.dev202242
24+
Build a TensorFlow pip package from [intel-tensorflow spr_ww42 branch](https://github.com/Intel-tensorflow/tensorflow/tree/spr_ww42) and install it. How to build a TensorFlow pip package from source please refer to this [tutorial](https://www.tensorflow.org/install/source).
25+
26+
The performance of int8 gpt-j-6B would be better once intel-tensorflow for gnr is released.
27+
28+
## 4. Prepare Dataset
29+
The dataset will be automatically loaded.
30+
31+
# Run
32+
33+
## Smooth Quantization
34+
35+
```shell
36+
bash run_quant.sh --input_model=<FP32_MODEL_PATH> --output_model=<INT8_MODEL_PATH>
37+
```
38+
39+
## Benchmark
40+
41+
### Evaluate Performance
42+
43+
```shell
44+
bash run_benchmark.sh --input_model=<MODEL_PATH> --mode=benchmark
45+
```
46+
47+
### Evaluate Accuracy
48+
49+
```shell
50+
bash run_benchmark.sh --input_model=<MODEL_PATH> --mode=accuracy
51+
```
52+
53+
54+
Details of enabling Intel® Neural Compressor on gpt-j-6B for TensorFlow
55+
=========================
56+
57+
This is a tutorial of how to enable gpt-j-6B model with Intel® Neural Compressor.
58+
## User Code Analysis
59+
60+
User specifies fp32 *model*, calibration dataloader *q_dataloader* and a custom *eval_func* which encapsulates the evaluation dataloader and metric by itself.
61+
62+
### calib_dataloader Part Adaption
63+
Below dataloader class uses generator function to provide the model with input.
64+
65+
```python
66+
class MyDataloader:
67+
def __init__(self, dataset, batch_size=1):
68+
self.dataset = dataset
69+
self.batch_size = batch_size
70+
self.length = math.ceil(len(dataset) / self.batch_size)
71+
72+
def generate_data(self, data, pad_token_id=50256):
73+
input_ids = tf.convert_to_tensor([data[:-1]], dtype=tf.int32)
74+
cur_len = len(data)-1
75+
input_ids_padding = tf.ones((self.batch_size, 1), dtype=tf.int32) * (pad_token_id or 0)
76+
generated = tf.concat([input_ids, input_ids_padding], axis=-1)
77+
model_kwargs = {'attention_mask': prepare_attention_mask_for_generation(input_ids)}
78+
if model_kwargs.get("past_key_values") is None:
79+
input_ids = generated[:, :cur_len]
80+
else:
81+
input_ids = tf.expand_dims(generated[:, cur_len - 1], -1)
82+
return model_kwargs['attention_mask'], input_ids
83+
84+
def __iter__(self):
85+
labels = None
86+
for _, data in enumerate(self.dataset):
87+
cur_input = self.generate_data(data)
88+
yield (cur_input, labels)
89+
90+
def __len__(self):
91+
return self.length
92+
```
93+
94+
### Quantization Config
95+
The Quantization Config class has default parameters setting for running on Intel CPUs. If running this example on Intel GPUs, the 'backend' parameter should be set to 'itex' and the 'device' parameter should be set to 'gpu'.
96+
97+
```python
98+
config = PostTrainingQuantConfig(
99+
device="gpu",
100+
backend="itex",
101+
...
102+
)
103+
```
104+
105+
### Code Update
106+
After prepare step is done, we add the code for quantization tuning to generate quantized model.
107+
108+
Firstly, let's load a INC inner class model from the path of gpt-j-6B saved_model.
109+
```python
110+
from neural_compressor import Model
111+
model = Model(run_args.input_model, modelType='llm_saved_model')
112+
```
113+
114+
#### Tune
115+
116+
To apply quantization, the function that maps names from AutoTrackable variables to graph nodes must be defined to match names of nodes in different format.
117+
```python
118+
def weight_name_mapping(name):
119+
"""The function that maps name from AutoTrackable variables to graph nodes"""
120+
name = name.replace('tfgptj_for_causal_lm', 'StatefulPartitionedCall')
121+
name = name.replace('kernel:0', 'Tensordot/ReadVariableOp')
122+
return name
123+
```
124+
125+
Please use the recipe to set smooth quantization.
126+
```python
127+
from neural_compressor import quantization, PostTrainingQuantConfig
128+
calib_dataloader = MyDataloader(mydata, batch_size=run_args.batch_size)
129+
recipes = {"smooth_quant": True, "smooth_quant_args": {'alpha': 0.52705}}
130+
conf = PostTrainingQuantConfig(quant_level=1,
131+
excluded_precisions=["bf16"],##use basic tuning
132+
recipes=recipes,
133+
calibration_sampling_size=[1],
134+
)
135+
136+
model.weight_name_mapping = weight_name_mapping
137+
q_model = quantization.fit( model,
138+
conf,
139+
eval_func=evaluate,
140+
calib_dataloader=calib_dataloader)
141+
142+
q_model.save(run_args.output_model)
143+
```
144+
#### Benchmark
145+
```python
146+
if run_args.mode == "performance":
147+
from neural_compressor.benchmark import fit
148+
from neural_compressor.config import BenchmarkConfig
149+
conf = BenchmarkConfig(warmup=10, iteration=run_args.iteration, cores_per_instance=4, num_of_instance=1)
150+
fit(model, conf, b_func=evaluate)
151+
elif run_args.mode == "accuracy":
152+
acc_result = evaluate(model.model)
153+
print("Batch size = %d" % run_args.batch_size)
154+
print("Accuracy: %.5f" % acc_result)
155+
```
156+
157+
The Intel® Neural Compressor quantization.fit() function will return a best quantized model under time constraint.

0 commit comments

Comments
 (0)