Skip to content

Commit d84a93f

Browse files
authored
Complement UT of calibration function for TF 3x API (#1945)
Signed-off-by: zehao-intel <[email protected]>
1 parent fb85779 commit d84a93f

File tree

2 files changed

+162
-0
lines changed

2 files changed

+162
-0
lines changed

.azure-pipelines/scripts/ut/3x/run_3x_tf.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ ut_log_name=${LOG_DIR}/ut_3x_tf.log
2626
# test for tensorflow ut
2727
pytest --cov="${inc_path}" -vs --disable-warnings --html=report_tf_quant.html --self-contained-html ./tensorflow/quantization 2>&1 | tee -a ${ut_log_name}
2828
rm -rf tensorflow/quantization
29+
pytest --cov="${inc_path}" --cov-append -vs --disable-warnings --html=report_tf_test_quantize_model.html --self-contained-html ./tensorflow/test_quantize_model.py 2>&1 | tee -a ${ut_log_name}
30+
rm -rf tensorflow/test_quantize_model.py
2931
pytest --cov="${inc_path}" --cov-append -vs --disable-warnings --html=report_tf.html --self-contained-html . 2>&1 | tee -a ${ut_log_name}
3032

3133
# test for tensorflow new api ut
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (c) 2024 Intel Corporation
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
import math
19+
import shutil
20+
import time
21+
import unittest
22+
23+
import numpy as np
24+
import tensorflow as tf
25+
from tensorflow import keras
26+
27+
from neural_compressor.common import logger
28+
from neural_compressor.tensorflow.utils import version1_gte_version2
29+
30+
31+
def build_model():
32+
# Load MNIST dataset
33+
mnist = keras.datasets.mnist
34+
35+
# 60000 images in train and 10000 images in test, but we don't need so much for ut
36+
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
37+
train_images, train_labels = train_images[:1000], train_labels[:1000]
38+
test_images, test_labels = test_images[:200], test_labels[:200]
39+
40+
# Normalize the input image so that each pixel value is between 0 to 1.
41+
train_images = train_images / 255.0
42+
test_images = test_images / 255.0
43+
44+
# Define the model architecture.
45+
model = keras.Sequential(
46+
[
47+
keras.layers.InputLayer(input_shape=(28, 28)),
48+
keras.layers.Reshape(target_shape=(28, 28, 1)),
49+
keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation="relu", name="conv2d"),
50+
keras.layers.MaxPooling2D(pool_size=(2, 2)),
51+
keras.layers.Flatten(),
52+
keras.layers.Dense(10, name="dense"),
53+
]
54+
)
55+
# Train the digit classification model
56+
model.compile(
57+
optimizer="adam", loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=["accuracy"]
58+
)
59+
60+
model.fit(
61+
train_images,
62+
train_labels,
63+
epochs=1,
64+
validation_split=0.1,
65+
)
66+
67+
_, baseline_model_accuracy = model.evaluate(test_images, test_labels, verbose=0)
68+
69+
print("Baseline test accuracy:", baseline_model_accuracy)
70+
if version1_gte_version2(tf.__version__, "2.16.1"):
71+
model.export("baseline_model")
72+
else:
73+
model.save("baseline_model")
74+
75+
76+
class Dataset(object):
77+
def __init__(self, batch_size=1):
78+
self.batch_size = batch_size
79+
mnist = keras.datasets.mnist
80+
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
81+
train_images, train_labels = train_images[:1000], train_labels[:1000]
82+
test_images, test_labels = test_images[:200], test_labels[:200]
83+
# Normalize the input image so that each pixel value is between 0 to 1.
84+
self.train_images = train_images / 255.0
85+
self.test_images = test_images / 255.0
86+
self.train_labels = train_labels
87+
self.test_labels = test_labels
88+
89+
def __len__(self):
90+
return len(self.test_images)
91+
92+
def __getitem__(self, idx):
93+
return self.test_images[idx], self.test_labels[idx]
94+
95+
96+
class MyDataloader:
97+
def __init__(self, dataset, batch_size=1):
98+
self.dataset = dataset
99+
self.batch_size = batch_size
100+
self.length = math.ceil(len(dataset) / self.batch_size)
101+
102+
def __iter__(self):
103+
for _, (images, labels) in enumerate(self.dataset):
104+
images = np.expand_dims(images, axis=0)
105+
labels = np.expand_dims(labels, axis=0)
106+
yield (images, labels)
107+
108+
def __len__(self):
109+
return self.length
110+
111+
112+
def evaluate(model):
113+
input_tensor = model.input_tensor
114+
output_tensor = model.output_tensor if len(model.output_tensor) > 1 else model.output_tensor[0]
115+
116+
iteration = -1
117+
calib_dataloader = MyDataloader(dataset=Dataset())
118+
for idx, (inputs, labels) in enumerate(calib_dataloader):
119+
# dataloader should keep the order and len of inputs same with input_tensor
120+
inputs = np.array([inputs])
121+
feed_dict = dict(zip(input_tensor, inputs))
122+
123+
start = time.time()
124+
predictions = model.sess.run(output_tensor, feed_dict)
125+
end = time.time()
126+
127+
if idx + 1 == iteration:
128+
break
129+
130+
131+
class TestQuantizeModel(unittest.TestCase):
132+
@classmethod
133+
def setUpClass(self):
134+
build_model()
135+
self.fp32_model_path = "baseline_model"
136+
137+
@classmethod
138+
def tearDownClass(self):
139+
shutil.rmtree(self.fp32_model_path, ignore_errors=True)
140+
141+
def test_calib_func(self):
142+
logger.info("Run test_calib_func case...")
143+
144+
from neural_compressor.common import set_random_seed
145+
from neural_compressor.tensorflow import StaticQuantConfig, quantize_model
146+
147+
set_random_seed(9527)
148+
quant_config = StaticQuantConfig()
149+
q_model = quantize_model(self.fp32_model_path, quant_config, calib_func=evaluate)
150+
quantized = False
151+
for node in q_model.graph_def.node:
152+
if "Quantized" in node.op:
153+
quantized = True
154+
break
155+
156+
self.assertEqual(quantized, True)
157+
158+
159+
if __name__ == "__main__":
160+
unittest.main()

0 commit comments

Comments
 (0)