Skip to content

Commit c3214c9

Browse files
authored
support for restoring ipex model from json (#1405)
Signed-off-by: Kaihui-intel <[email protected]>
1 parent 130c215 commit c3214c9

File tree

2 files changed

+67
-0
lines changed

2 files changed

+67
-0
lines changed

neural_compressor/utils/pytorch.py

+47
Original file line numberDiff line numberDiff line change
@@ -464,3 +464,50 @@ def load(checkpoint_dir=None, model=None, layer_wise=False, history_cfg=None, **
464464
assert len(mismatch_log.unexpected_keys) == 0, "Loading state_dict failed: {}".format(mismatch_log)
465465
util.get_embedding_contiguous(model)
466466
return model
467+
468+
469+
def recover_model_from_json(model, json_file_path, example_inputs):
470+
"""Recover ipex model from JSON file.
471+
472+
Args:
473+
model (object): fp32 model need to do quantization.
474+
json_file_path (json): configuration JSON file for ipex.
475+
example_inputs (tuple or torch.Tensor or dict): example inputs that will be passed to the ipex function.
476+
477+
Returns:
478+
(object): quantized model
479+
"""
480+
from ..utils.utility import LazyImport
481+
482+
ipex = LazyImport("intel_extension_for_pytorch")
483+
from torch.ao.quantization.observer import MinMaxObserver
484+
485+
qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5, act_observer=MinMaxObserver())
486+
if isinstance(example_inputs, dict):
487+
model = ipex.quantization.prepare(model, qconfig, example_kwarg_inputs=example_inputs, inplace=True)
488+
else:
489+
model = ipex.quantization.prepare(model, qconfig, example_inputs=example_inputs, inplace=True)
490+
model.load_qconf_summary(qconf_summary=json_file_path)
491+
model = ipex.quantization.convert(model, inplace=True)
492+
with torch.no_grad():
493+
try:
494+
if isinstance(example_inputs, dict):
495+
# pylint: disable=E1120,E1123
496+
model = torch.jit.trace(model, example_kwarg_inputs=example_inputs)
497+
else:
498+
model = torch.jit.trace(model, example_inputs)
499+
model = torch.jit.freeze(model.eval())
500+
except:
501+
if isinstance(example_inputs, dict):
502+
# pylint: disable=E1120,E1123
503+
model = torch.jit.trace(model, example_kwarg_inputs=example_inputs, strict=False, check_trace=False)
504+
else:
505+
model = torch.jit.trace(model, example_inputs, strict=False)
506+
model = torch.jit.freeze(model.eval())
507+
if isinstance(example_inputs, dict):
508+
model(**example_inputs)
509+
model(**example_inputs)
510+
else:
511+
model(example_inputs)
512+
model(example_inputs)
513+
return model

test/algorithm/test_smooth_quant.py

+20
Original file line numberDiff line numberDiff line change
@@ -880,6 +880,26 @@ def calib_func(model):
880880
calib_func=calib_func,
881881
)
882882
q_model.save("saved")
883+
# test recover_model_from_json
884+
from neural_compressor.utils.pytorch import recover_model_from_json
885+
886+
tmp_model = copy.deepcopy(fp32_model)
887+
888+
ipex_model = recover_model_from_json(tmp_model, "./saved/best_configure.json", example_inputs=input_ids)
889+
inc_output = q_model.model(input_ids)
890+
ipex_output = ipex_model(input_ids)
891+
self.assertTrue(torch.allclose(inc_output, ipex_output, atol=1e-05))
892+
893+
example_tuple = (input_ids,)
894+
ipex_model = recover_model_from_json(tmp_model, "./saved/best_configure.json", example_inputs=example_tuple)
895+
ipex_output = ipex_model(input_ids)
896+
self.assertTrue(torch.allclose(inc_output, ipex_output, atol=1e-05))
897+
898+
example_dict = {"x": input_ids}
899+
ipex_model = recover_model_from_json(tmp_model, "./saved/best_configure.json", example_inputs=example_dict)
900+
ipex_output = ipex_model(input_ids)
901+
self.assertTrue(torch.allclose(inc_output, ipex_output, atol=1e-05))
902+
883903
# compare ipex and inc quantization
884904
with open("saved/best_configure.json", "r") as f:
885905
inc_config_json = json.load(f)

0 commit comments

Comments
 (0)