Skip to content

Commit 95e67ea

Browse files
authored
refine load API for 3.x ipex backend (#1755)
Signed-off-by: Cheng, Zixuan <[email protected]>
1 parent 0b2080b commit 95e67ea

File tree

6 files changed

+24
-15
lines changed

6 files changed

+24
-15
lines changed

examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def calib_func(model):
143143

144144
if args.load:
145145
from neural_compressor.torch.quantization import load
146-
user_model = load(user_model, "saved_results")
146+
user_model = load("saved_results", user_model)
147147

148148

149149
if args.approach in ["dynamic", "static"] or args.load:

neural_compressor/common/utils/save_load.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import os
2020

2121

22-
def save_config_mapping(config_mapping, qconfig_file_path):
22+
def save_config_mapping(config_mapping, qconfig_file_path): # pragma: no cover
2323
"""Save config mapping to json file.
2424
2525
Args:
@@ -36,7 +36,7 @@ def save_config_mapping(config_mapping, qconfig_file_path):
3636
json.dump(per_op_qconfig, f, indent=4)
3737

3838

39-
def load_config_mapping(qconfig_file_path, config_name_mapping):
39+
def load_config_mapping(qconfig_file_path, config_name_mapping): # pragma: no cover
4040
"""Reload config mapping from json file.
4141
4242
Args:

neural_compressor/torch/quantization/load_entry.py

+17-9
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,23 @@
2424
}
2525

2626

27-
def load(model, output_dir="./saved_results"):
27+
def load(output_dir="./saved_results", model=None):
2828
from neural_compressor.common.base_config import ConfigRegistry
2929

3030
qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), "qconfig.json")
31-
config_mapping = load_config_mapping(qconfig_file_path, ConfigRegistry.get_all_configs()["torch"])
32-
model.qconfig = config_mapping
33-
# select load function
34-
config_object = config_mapping[next(iter(config_mapping))]
35-
if isinstance(config_object, FP8Config):
36-
from neural_compressor.torch.algorithms.habana_fp8 import load
37-
38-
return load(model, output_dir)
31+
with open(qconfig_file_path, "r") as f:
32+
per_op_qconfig = json.load(f)
33+
if " " in per_op_qconfig.keys(): # ipex qconfig format: {' ': {'q_op_infos': {'0': {'op_type': ...
34+
from neural_compressor.torch.algorithms.static_quant import load
35+
36+
return load(output_dir)
37+
38+
else: # FP8
39+
config_mapping = load_config_mapping(qconfig_file_path, ConfigRegistry.get_all_configs()["torch"])
40+
model.qconfig = config_mapping
41+
# select load function
42+
config_object = config_mapping[next(iter(config_mapping))]
43+
if isinstance(config_object, FP8Config):
44+
from neural_compressor.torch.algorithms.habana_fp8 import load
45+
46+
return load(model, output_dir)

test/3x/torch/quantization/habana_fp8/test_fp8.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def calib_func(model):
153153
from neural_compressor.torch.quantization import load
154154

155155
m = copy.deepcopy(self.model)
156-
m = load(m, "saved_results")
156+
m = load("saved_results", m)
157157
recovered_out = m(inp)
158158
assert (recovered_out == fp8_out).all(), "Unexpected result. Please double check."
159159
assert isinstance(m.fc1, FP8Linear), "Unexpected result. Please double check."

test/3x/torch/quantization/test_smooth_quant.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,8 @@ def test_sq_save_load(self):
133133
q_model.save("saved_results")
134134
inc_out = q_model(example_inputs)
135135

136-
from neural_compressor.torch.algorithms.smooth_quant import load, recover_model_from_json
136+
from neural_compressor.torch.algorithms.smooth_quant import recover_model_from_json
137+
from neural_compressor.torch.quantization import load
137138

138139
# load using saved model
139140
loaded_model = load("saved_results")

test/3x/torch/quantization/test_static_quant.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def run_fn(model):
153153
assert torch.allclose(inc_out, ipex_out, atol=2e-02), "Unexpected result. Please double check."
154154
q_model.save("saved_results")
155155

156-
from neural_compressor.torch.algorithms.static_quant import load
156+
from neural_compressor.torch.quantization import load
157157

158158
# load
159159
loaded_model = load("saved_results")

0 commit comments

Comments
 (0)