Skip to content

Commit 5dafe5f

Browse files
authored
3.x SQ autotune supports calib_func w/ capture input (#1821)
Signed-off-by: Cheng, Zixuan <[email protected]>
1 parent 7120dd4 commit 5dafe5f

File tree

3 files changed

+137
-22
lines changed

3 files changed

+137
-22
lines changed

neural_compressor/torch/algorithms/smooth_quant/utility.py

Lines changed: 57 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
import intel_extension_for_pytorch as ipex
2222
import numpy
2323
import torch
24-
import tqdm
2524
from packaging.version import Version
25+
from tqdm import tqdm
2626

2727
from neural_compressor.torch.algorithms.static_quant import (
2828
CpuInfo,
@@ -78,6 +78,9 @@ def get_quantizable_ops_recursively(model, example_inputs, alpha, act_algo, inpl
7878

7979
from torch.ao.quantization import MinMaxObserver
8080

81+
if alpha == "auto": # for quantize API
82+
alpha = 0.5
83+
8184
if ipex_ver.release >= Version("2.1.1").release:
8285
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(
8386
alpha=alpha, act_observer=MinMaxObserver
@@ -390,6 +393,9 @@ def forward_wrapper(model, input, device=torch.device("cpu")): # pragma: no cov
390393
output = model(*input)
391394
except:
392395
output = model(input)
396+
elif isinstance(input, zip):
397+
for args, kwargs in input:
398+
output = model(*args, **kwargs)
393399
else:
394400
output = model(input)
395401
return output
@@ -412,6 +418,43 @@ def model_forward(model, dataloader, iters, device): # pragma: no cover
412418
break
413419

414420

421+
def build_captured_dataloader(model, run_fn, calib_num=None):
422+
class CapturedDataloader:
423+
def __init__(self, args_list, kwargs_list) -> None:
424+
self.args_list = args_list
425+
self.kwargs_list = kwargs_list
426+
427+
def __iter__(self):
428+
for args, kwargs in zip(self.args_list, self.kwargs_list):
429+
if not args:
430+
yield kwargs
431+
elif not kwargs:
432+
yield args
433+
else:
434+
yield args, kwargs
435+
436+
class InputCaptureModule(torch.nn.Module):
437+
def __init__(self, model) -> None:
438+
super().__init__()
439+
self.args_list = []
440+
self.kwargs_list = []
441+
self.orig_model = model
442+
self.iters = 0
443+
self.calib_num = calib_num
444+
445+
def forward(self, *args, **kwargs):
446+
if self.iters < self.calib_num:
447+
self.args_list.append(args)
448+
self.kwargs_list.append(kwargs)
449+
self.iters += 1
450+
451+
captured_model = InputCaptureModule(model)
452+
run_fn(captured_model)
453+
dataloader = CapturedDataloader(captured_model.args_list, captured_model.kwargs_list)
454+
model = captured_model.orig_model
455+
return model, dataloader
456+
457+
415458
def cal_scale(input_max_abs, weights, alpha, weight_max_lb=1e-5): # pragma: no cover
416459
weights = torch.cat(weights, dim=0)
417460
weight_max = torch.max(torch.abs(weights), dim=0)[0]
@@ -1349,14 +1392,15 @@ def _auto_tune_alpha(self):
13491392
best_alphas = self.init_alpha
13501393

13511394
if not self.dataloader:
1352-
logger.info(f"Auto-tuning failed due to no dataloader, using {best_alphas} instead.")
1353-
self._qdq_model_unwrapper_for_auto()
1354-
return best_alphas
1395+
logger.info("No dataloader, performing auto-tuning with calibration function instead.")
1396+
self.model, self.dataloader = build_captured_dataloader(self.model, self.q_func, self.calib_sample_num)
1397+
13551398
bar = tqdm(self.dataloader, total=self.calib_sample_num, desc="auto tune alpha") # pylint: disable=E1102
13561399
for input in bar:
13571400
if isinstance(input, tuple) or isinstance(input, list):
13581401
if len(input) == 2:
13591402
input, _ = input # Extract input when both input and label are yielded by dataloader.
1403+
13601404
loss_alphas = {}
13611405
best_alphas_per_module = best_alphas
13621406
if isinstance(best_alphas, dict):
@@ -1374,8 +1418,9 @@ def _auto_tune_alpha(self):
13741418
cur_loss = loss_alphas[key]
13751419
for alpha_key in cur_loss.keys():
13761420
cur_loss[alpha_key] += loss_tmp[key][alpha_key]
1377-
total_cnt += self.dataloader.batch_size
1378-
tmp_cnt += self.dataloader.batch_size
1421+
1422+
total_cnt += 1
1423+
tmp_cnt += 1
13791424
if tmp_cnt // multiply_factor >= 1:
13801425
alpha_update_iter += 1
13811426
tmp_cnt = 0
@@ -1418,13 +1463,14 @@ def _auto_tune_alpha_blockwise(self):
14181463
best_alphas = self.init_alpha
14191464

14201465
if not self.dataloader:
1421-
logger.info(f"Auto-tuning failed due to no dataloader, using {best_alphas} instead.")
1422-
self._qdq_model_unwrapper_for_auto()
1423-
return best_alphas
1466+
logger.info("No dataloader, performing auto-tuning with calibration function instead.")
1467+
self.model, self.dataloader = build_captured_dataloader(self.model, self.q_func, self.calib_sample_num)
1468+
14241469
bar = tqdm(self.dataloader, total=self.calib_sample_num, desc="auto tune alpha") # pylint: disable=E1102
14251470
for input in bar:
14261471
if isinstance(input, tuple): # Extract input when both input and label are yielded by dataloader.
14271472
input = input[0]
1473+
14281474
loss_alphas = {}
14291475
best_alphas_per_module = best_alphas
14301476
if isinstance(best_alphas, dict):
@@ -1446,8 +1492,8 @@ def _auto_tune_alpha_blockwise(self):
14461492
for alpha_key in cur_loss.keys():
14471493
cur_loss[alpha_key] += loss_tmp[block_name][alpha_key]
14481494

1449-
total_cnt += self.dataloader.batch_size
1450-
tmp_cnt += self.dataloader.batch_size
1495+
total_cnt += 1
1496+
tmp_cnt += 1
14511497
if tmp_cnt // multiply_factor >= 1:
14521498
alpha_update_iter += 1
14531499
tmp_cnt = 0
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import copy
2+
3+
import pytest
4+
import torch
5+
6+
7+
class Model(torch.nn.Module):
8+
device = torch.device("cpu")
9+
10+
def __init__(self):
11+
super(Model, self).__init__()
12+
self.fc1 = torch.nn.Linear(3, 4)
13+
self.fc2 = torch.nn.Linear(4, 3)
14+
15+
def forward(self, x):
16+
out = self.fc1(x)
17+
out = self.fc2(out)
18+
return out
19+
20+
21+
model = Model()
22+
23+
24+
def test_captured_dataloader():
25+
from neural_compressor.torch.algorithms.smooth_quant import build_captured_dataloader
26+
27+
fp32_model = copy.deepcopy(model)
28+
29+
def run_fn(model):
30+
for i in range(10):
31+
example_inputs = torch.randn([1, 3])
32+
model(example_inputs)
33+
34+
tmp_model, dataloader = build_captured_dataloader(fp32_model, run_fn, calib_num=32)
35+
assert tmp_model == fp32_model, "Model should be same after building dataloader. Please check."
36+
assert isinstance(dataloader.args_list[0][0], torch.Tensor), "Args list should contain tensors. Please check."
37+
assert not dataloader.kwargs_list[0], "Kwargs list should be empty. Please check."

test/3x/torch/quantization/test_smooth_quant.py

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,12 @@ def forward(self, x):
2626

2727

2828
model = Model()
29+
example_inputs = torch.rand([1, 3])
2930

3031

3132
def run_fn(model):
32-
model(torch.randn([1, 3]))
33+
for i in range(10):
34+
model(example_inputs)
3335

3436

3537
class TestSmoothQuant:
@@ -40,7 +42,6 @@ def teardown_class(self):
4042
def test_smooth_quant_default(self):
4143
fp32_model = copy.deepcopy(model)
4244
quant_config = get_default_sq_config()
43-
example_inputs = torch.randn([1, 3])
4445
prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs)
4546
run_fn(prepared_model)
4647
q_model = convert(prepared_model)
@@ -57,7 +58,6 @@ def test_smooth_quant_default(self):
5758
def test_smooth_quant_fallback(self):
5859
fp32_model = copy.deepcopy(model)
5960
quant_config = get_default_sq_config()
60-
example_inputs = torch.randn([1, 3])
6161
# fallback by op_type
6262
quant_config.set_local(torch.nn.Linear, SmoothQuantConfig(w_dtype="fp32", act_dtype="fp32"))
6363
prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs)
@@ -87,10 +87,6 @@ def test_sq_linear_params(self, act_sym, act_algo, alpha, folding, scale_sharing
8787
quant_config = SmoothQuantConfig(
8888
act_sym=act_sym, act_algo=act_algo, alpha=alpha, folding=folding, scale_sharing=scale_sharing
8989
)
90-
example_inputs = torch.zeros([1, 3])
91-
92-
def run_fn(model):
93-
model(example_inputs)
9490

9591
prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs)
9692
run_fn(prepared_model)
@@ -102,7 +98,6 @@ def run_fn(model):
10298

10399
@pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX")
104100
def test_sq_ipex_accuracy(self):
105-
example_inputs = torch.zeros([1, 3])
106101
qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5)
107102
user_model = copy.deepcopy(model)
108103
user_model = ipex.quantization.prepare(user_model.eval(), qconfig, example_inputs=example_inputs, inplace=True)
@@ -144,7 +139,6 @@ def run_fn(model):
144139
def test_sq_save_load(self):
145140
fp32_model = copy.deepcopy(model)
146141
quant_config = get_default_sq_config()
147-
example_inputs = torch.zeros([1, 3])
148142
prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs)
149143
run_fn(prepared_model)
150144
q_model = convert(prepared_model)
@@ -171,7 +165,6 @@ def test_sq_save_load(self):
171165
def test_smooth_quant_with_quantize_API(self):
172166
fp32_model = copy.deepcopy(model)
173167
quant_config = get_default_sq_config()
174-
example_inputs = torch.randn([1, 3])
175168
q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs)
176169
assert q_model is not None, "Quantization failed!"
177170

@@ -184,7 +177,6 @@ def test_smooth_quant_with_quantize_API(self):
184177
def test_smooth_quant_mixed_precision(self):
185178
fp32_model = copy.deepcopy(model)
186179
quant_config = get_default_sq_config() # do mixed_precison by default.
187-
example_inputs = torch.randn([1, 3])
188180

189181
# prepare/convert API
190182
prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs)
@@ -203,3 +195,43 @@ def test_smooth_quant_mixed_precision(self):
203195
quant_config.folding = True
204196
q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs)
205197
assert q_model is not None, "Quantization failed!"
198+
199+
@pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX")
200+
def test_smooth_quant_auto(self):
201+
fp32_model = copy.deepcopy(model)
202+
example_inputs = torch.rand([1, 3])
203+
204+
def run_fn(model):
205+
for i in range(100):
206+
model(example_inputs)
207+
208+
# block-wise
209+
quant_config = SmoothQuantConfig(
210+
alpha="auto",
211+
alpha_min=0.45,
212+
alpha_max=0.55,
213+
alpha_step=0.01,
214+
shared_criterion="mean",
215+
do_blockwise=True,
216+
folding=False,
217+
)
218+
q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs)
219+
assert q_model is not None, "Quantization failed!"
220+
output1 = fp32_model(example_inputs)
221+
output2 = q_model(example_inputs)
222+
assert torch.allclose(output1, output2, atol=2e-2), "Accuracy gap atol > 0.02 is unexpected. Please check."
223+
224+
# layer-wise
225+
quant_config = SmoothQuantConfig(
226+
alpha="auto",
227+
alpha_min=0.45,
228+
alpha_max=0.55,
229+
alpha_step=0.01,
230+
shared_criterion="max",
231+
do_blockwise=False,
232+
folding=False,
233+
)
234+
q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs)
235+
assert q_model is not None, "Quantization failed!"
236+
output2 = q_model(example_inputs)
237+
assert torch.allclose(output1, output2, atol=2e-2), "Accuracy gap atol > 0.02 is unexpected. Please check."

0 commit comments

Comments
 (0)