Skip to content

Commit b1db1cf

Browse files
authored
Fix smoothquant minmax observer init (#1421)
1 parent 9150181 commit b1db1cf

File tree

1 file changed

+24
-6
lines changed

1 file changed

+24
-6
lines changed

neural_compressor/adaptor/pytorch.py

+24-6
Original file line numberDiff line numberDiff line change
@@ -3114,14 +3114,23 @@ def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops):
31143114
smooth_quant_args = self.recipes.get("smooth_quant_args", {})
31153115
folding = smooth_quant_args.get("folding", False)
31163116
if not folding:
3117-
if self.sq_minmax_init or self.version.release >= Version("2.1.1").release:
3118-
from torch.ao.quantization.observer import MinMaxObserver
3117+
from torch.ao.quantization.observer import MinMaxObserver
31193118

3119+
if self.version.release >= Version("2.1.1").release:
31203120
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(
31213121
alpha=0.5, act_observer=MinMaxObserver
31223122
)
31233123
else:
3124-
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5)
3124+
if self.sq_minmax_init:
3125+
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(
3126+
alpha=0.5, act_observer=MinMaxObserver()
3127+
)
3128+
logger.warning(
3129+
"The int8 model accuracy will be close to 0 with MinMaxobserver, "
3130+
+ "the suggested IPEX version is higher or equal than 2.1.100."
3131+
)
3132+
else:
3133+
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5)
31253134
if self.example_inputs is None:
31263135
self.example_inputs = get_example_inputs(model, self.q_dataloader)
31273136
from neural_compressor.adaptor.torch_utils.util import move_input_device
@@ -3304,14 +3313,23 @@ def qdq_quantize(self, model, q_model, tune_cfg, dataloader, q_func):
33043313
# Check save_qconf_summary part is a workaround for IPEX bug.
33053314
# Sometimes the prepared model from get_op_capablitiy loss this attribute
33063315
if not hasattr(model._model, "save_qconf_summary") or not hasattr(model._model, "load_qconf_summary"):
3307-
if self.sq_minmax_init or self.version.release >= Version("2.1.1").release:
3308-
from torch.ao.quantization.observer import MinMaxObserver
3316+
from torch.ao.quantization.observer import MinMaxObserver
33093317

3318+
if self.version.release >= Version("2.1.1").release:
33103319
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(
33113320
alpha=0.5, act_observer=MinMaxObserver
33123321
)
33133322
else:
3314-
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5)
3323+
if self.sq_minmax_init:
3324+
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(
3325+
alpha=0.5, act_observer=MinMaxObserver()
3326+
)
3327+
logger.warning(
3328+
"The int8 model accuracy will be close to 0 with MinMaxobserver, "
3329+
+ "the suggested IPEX version is higher or equal than 2.1.100+cpu."
3330+
)
3331+
else:
3332+
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5)
33153333
if isinstance(self.example_inputs, dict):
33163334
model._model = ipex.quantization.prepare(
33173335
model._model, static_qconfig, example_kwarg_inputs=self.example_inputs, inplace=inplace

0 commit comments

Comments
 (0)