|
24 | 24 | }
|
25 | 25 |
|
26 | 26 |
|
27 |
| -def load(model, output_dir="./saved_results"): |
| 27 | +def load(output_dir="./saved_results", model=None): |
28 | 28 | from neural_compressor.common.base_config import ConfigRegistry
|
29 | 29 |
|
30 | 30 | qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), "qconfig.json")
|
31 |
| - config_mapping = load_config_mapping(qconfig_file_path, ConfigRegistry.get_all_configs()["torch"]) |
32 |
| - model.qconfig = config_mapping |
33 |
| - # select load function |
34 |
| - config_object = config_mapping[next(iter(config_mapping))] |
35 |
| - if isinstance(config_object, FP8Config): |
36 |
| - from neural_compressor.torch.algorithms.habana_fp8 import load |
37 |
| - |
38 |
| - return load(model, output_dir) |
| 31 | + with open(qconfig_file_path, "r") as f: |
| 32 | + per_op_qconfig = json.load(f) |
| 33 | + if " " in per_op_qconfig.keys(): # ipex qconfig format: {' ': {'q_op_infos': {'0': {'op_type': ... |
| 34 | + from neural_compressor.torch.algorithms.static_quant import load |
| 35 | + |
| 36 | + return load(output_dir) |
| 37 | + |
| 38 | + else: # FP8 |
| 39 | + config_mapping = load_config_mapping(qconfig_file_path, ConfigRegistry.get_all_configs()["torch"]) |
| 40 | + model.qconfig = config_mapping |
| 41 | + # select load function |
| 42 | + config_object = config_mapping[next(iter(config_mapping))] |
| 43 | + if isinstance(config_object, FP8Config): |
| 44 | + from neural_compressor.torch.algorithms.habana_fp8 import load |
| 45 | + |
| 46 | + return load(model, output_dir) |
0 commit comments