Skip to content

Fix smoothquant minmax observer init #1421

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 28, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions neural_compressor/adaptor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3114,14 +3114,23 @@ def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops):
smooth_quant_args = self.recipes.get("smooth_quant_args", {})
folding = smooth_quant_args.get("folding", False)
if not folding:
if self.sq_minmax_init or self.version.release >= Version("2.1.1").release:
from torch.ao.quantization.observer import MinMaxObserver
from torch.ao.quantization.observer import MinMaxObserver

if self.version.release >= Version("2.1.1").release:
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(
alpha=0.5, act_observer=MinMaxObserver
)
else:
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5)
if self.sq_minmax_init:
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(
alpha=0.5, act_observer=MinMaxObserver()
)
logger.warning(
"The int8 model accuracy will be close to 0 with MinMaxobserver, "
+ "the suggested IPEX version is higher or equal than 2.1.100."
)
else:
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5)
if self.example_inputs is None:
self.example_inputs = get_example_inputs(model, self.q_dataloader)
from neural_compressor.adaptor.torch_utils.util import move_input_device
Expand Down Expand Up @@ -3304,14 +3313,23 @@ def qdq_quantize(self, model, q_model, tune_cfg, dataloader, q_func):
# Check save_qconf_summary part is a workaround for IPEX bug.
# Sometimes the prepared model from get_op_capablitiy loss this attribute
if not hasattr(model._model, "save_qconf_summary") or not hasattr(model._model, "load_qconf_summary"):
if self.sq_minmax_init or self.version.release >= Version("2.1.1").release:
from torch.ao.quantization.observer import MinMaxObserver
from torch.ao.quantization.observer import MinMaxObserver

if self.version.release >= Version("2.1.1").release:
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(
alpha=0.5, act_observer=MinMaxObserver
)
else:
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5)
if self.sq_minmax_init:
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(
alpha=0.5, act_observer=MinMaxObserver()
)
logger.warning(
"The int8 model accuracy will be close to 0 with MinMaxobserver, "
+ "the suggested IPEX version is higher or equal than 2.1.100+cpu."
)
else:
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5)
if isinstance(self.example_inputs, dict):
model._model = ipex.quantization.prepare(
model._model, static_qconfig, example_kwarg_inputs=self.example_inputs, inplace=inplace
Expand Down