From 3cf0e089896352b13d721e470c7c5d180dd92288 Mon Sep 17 00:00:00 2001 From: gta Date: Fri, 12 Jul 2024 02:20:40 +0000 Subject: [PATCH 1/9] Support xpu for ipex static quant Signed-off-by: gta --- .../algorithms/static_quant/static_quant.py | 110 +++++++++++------- .../torch/algorithms/static_quant/utility.py | 86 +++++++++++++- .../torch/quantization/config.py | 74 ++++++------ .../torch/quantization/test_static_quant.py | 57 +++++++-- 4 files changed, 243 insertions(+), 84 deletions(-) diff --git a/neural_compressor/torch/algorithms/static_quant/static_quant.py b/neural_compressor/torch/algorithms/static_quant/static_quant.py index efd1880666c..490acacd23a 100644 --- a/neural_compressor/torch/algorithms/static_quant/static_quant.py +++ b/neural_compressor/torch/algorithms/static_quant/static_quant.py @@ -33,11 +33,13 @@ from neural_compressor.torch.algorithms import Quantizer from neural_compressor.torch.utils import logger +from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator from .utility import ( CpuInfo, cfg_to_qconfig, dump_model_op_stats, + generate_xpu_qconfig, get_ipex_version, get_quantizable_ops_recursively, ipex_config_path, @@ -68,45 +70,64 @@ def prepare(self, model, example_inputs, inplace=True, *args, **kwargs): Returns: A prepared model. """ + device = auto_detect_accelerator().current_device() assert example_inputs is not None, "Please provide example_inputs for static quantization." - _, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, _ = get_quantizable_ops_recursively( - model, example_inputs - ) - # update json file in ipex_config_path; map ipex op_name to pt op_name - self.user_cfg = cfg_to_qconfig(self.quant_config, cfgs, op_infos_from_cfgs, output_tensor_id_op_name) - model.eval() + if device == "cpu": + _, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, _ = get_quantizable_ops_recursively( + model, example_inputs + ) + # update json file in ipex_config_path; map ipex op_name to pt op_name + self.user_cfg = cfg_to_qconfig(self.quant_config, cfgs, op_infos_from_cfgs, output_tensor_id_op_name) + else: + model = model.to("xpu") - use_bf16 = self.quant_config.get("use_bf16", None) + model.eval() # Check save_qconf_summary part is a workaround for IPEX bug. - # Sometimes the prepared model from get_op_capablitiy loss this attribute - if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"): - from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig - - if ipex_ver.release >= Version("2.1").release: - # HistogramObserver will cause a performance issue. - # static_qconfig = ipex.quantization.default_static_qconfig_mapping - qconfig = QConfig( - activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8), - weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric), - ) - from torch.ao.quantization import QConfigMapping - - static_qconfig = QConfigMapping().set_global(qconfig) - else: - static_qconfig = QConfig( - activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8), - weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric), - ) - if isinstance(example_inputs, dict): - model = ipex.quantization.prepare( - model, static_qconfig, example_kwarg_inputs=example_inputs, inplace=inplace - ) + # Sometimes the prepared model from get_op_capablitiy loss this attributes + if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"): # pragma: no cover + from torch.ao.quantization import HistogramObserver, MinMaxObserver, PerChannelMinMaxObserver, QConfig + + if device != "cpu": # pragma: no cover + from torch.quantization.quantize_jit import prepare_jit + + with torch.no_grad(): + modelJit = torch.jit.trace(model, example_inputs) + qconfig = generate_xpu_qconfig(self.quant_config) + model = prepare_jit(modelJit, qconfig, inplace) else: - model = ipex.quantization.prepare(model, static_qconfig, example_inputs=example_inputs, inplace=inplace) + if ipex_ver.release >= Version("2.1").release: + # HistogramObserver will cause a performance issue. + # static_qconfig = ipex.quantization.default_static_qconfig_mapping + qconfig = QConfig( + activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8), + weight=PerChannelMinMaxObserver.with_args( + dtype=torch.qint8, qscheme=torch.per_channel_symmetric + ), + ) + from torch.ao.quantization import QConfigMapping + + static_qconfig = QConfigMapping().set_global(qconfig) + else: + static_qconfig = QConfig( + activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8), + weight=PerChannelMinMaxObserver.with_args( + dtype=torch.qint8, qscheme=torch.per_channel_symmetric + ), + ) + if isinstance(example_inputs, dict): + model = ipex.quantization.prepare( + model, static_qconfig, example_kwarg_inputs=example_inputs, inplace=inplace + ) + else: + model = ipex.quantization.prepare( + model, static_qconfig, example_inputs=example_inputs, inplace=inplace + ) + + if device == "cpu": + model.load_qconf_summary(qconf_summary=ipex_config_path) - model.load_qconf_summary(qconf_summary=ipex_config_path) return model def convert(self, model, example_inputs, inplace=True, *args, **kwargs): @@ -120,22 +141,31 @@ def convert(self, model, example_inputs, inplace=True, *args, **kwargs): Returns: A quantized model. """ + device = auto_detect_accelerator().current_device() use_bf16 = self.quant_config.get("use_bf16", None) from neural_compressor.torch.algorithms.static_quant import save - model.save_qconf_summary(qconf_summary=ipex_config_path) - model = _ipex_post_quant_process(model, example_inputs, use_bf16, inplace=inplace) + if device != "cpu": # pragma: no cover + from torch.quantization.quantize_jit import convert_jit + + model = convert_jit(model, inplace) + simple_inference(model, example_inputs, iterations=2) + dump_model_op_stats(self.quant_config["op"]) + else: + model.save_qconf_summary(qconf_summary=ipex_config_path) + model = _ipex_post_quant_process(model, example_inputs, use_bf16, inplace=inplace) + + with open(ipex_config_path, "r") as f: + model.tune_cfg = json.load(f) + model.ipex_config_path = ipex_config_path - with open(ipex_config_path, "r") as f: - model.tune_cfg = json.load(f) - model.ipex_config_path = ipex_config_path + dump_model_op_stats(self.user_cfg) - dump_model_op_stats(self.user_cfg) + model.ori_save = model.save + model.save = MethodType(save, model) logger.info("Static quantization done.") - model.ori_save = model.save - model.save = MethodType(save, model) return model diff --git a/neural_compressor/torch/algorithms/static_quant/utility.py b/neural_compressor/torch/algorithms/static_quant/utility.py index 23ac16630a4..a0eb6ad67f7 100644 --- a/neural_compressor/torch/algorithms/static_quant/utility.py +++ b/neural_compressor/torch/algorithms/static_quant/utility.py @@ -24,11 +24,12 @@ try: import intel_extension_for_pytorch as ipex + import prettytable as pt except: # pragma: no cover pass from neural_compressor.common.utils import DEFAULT_WORKSPACE, CpuInfo -from neural_compressor.torch.utils import Statistics, get_ipex_version, get_torch_version, logger +from neural_compressor.torch.utils import get_ipex_version, get_torch_version, logger version = get_torch_version() ipex_ver = get_ipex_version() @@ -163,6 +164,47 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_ return cfgs, ori_user_cfg +def generate_xpu_qconfig(tune_cfg): + # qconfig observer & config constants for ipex-xpu + from torch.ao.quantization import HistogramObserver, MinMaxObserver, QConfig + + act_observer_minmax_asym = MinMaxObserver.with_args(quant_min=0, quant_max=127) + act_observer_minmax_sym = MinMaxObserver.with_args( + dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, quant_min=-128, quant_max=127 + ) + act_observer_kl_asym = HistogramObserver.with_args(quant_min=0, quant_max=127) + act_observer_kl_sym = HistogramObserver.with_args( + dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, quant_min=-128, quant_max=127 + ) + # no tuning for granularity due to tuning space + weight_observer_minmax_sym = MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric) + + qconfig = {} + user_cfg = copy.deepcopy(tune_cfg["op"]) + for _, cfg in user_cfg.items(): + act_algo = cfg["activation"]["algorithm"] + act_sym = cfg["activation"]["scheme"] + break + + if act_algo == "minmax": + if act_sym == "sym": + activation = act_observer_minmax_sym + else: + activation = act_observer_minmax_asym + else: + if act_sym == "sym": + activation = act_observer_kl_sym + else: + activation = act_observer_kl_asym + + qconfig[""] = QConfig(activation=activation, weight=weight_observer_minmax_sym) + + for (op_name, op_type), cfg in user_cfg.items(): + if cfg["weight"]["dtype"] == "fp32": + qconfig[op_name] = None + return qconfig + + def generate_activation_observer( scheme, algorithm, smooth_quant=False, smooth_quant_enable=False, alpha=0.5 ): # pragma: no cover @@ -566,6 +608,48 @@ def get_quantizable_ops_from_cfgs(ops_name, op_infos_from_cfgs, input_tensor_ids return quantizable_ops +class Statistics: # pragma: no cover + """The statistics printer.""" + + def __init__(self, data, header, field_names, output_handle=logger.info): + """Init a Statistics object. + + Args: + data: The statistics data + header: The table header + field_names: The field names + output_handle: The output logging method + """ + self.field_names = field_names + self.header = header + self.data = data + self.output_handle = output_handle + self.tb = pt.PrettyTable(min_table_width=40) + + def print_stat(self): + """Print the statistics.""" + valid_field_names = [] + for index, value in enumerate(self.field_names): + if index < 2: + valid_field_names.append(value) + continue + + if any(i[index] for i in self.data): + valid_field_names.append(value) + self.tb.field_names = valid_field_names + for i in self.data: + tmp_data = [] + for index, value in enumerate(i): + if self.field_names[index] in valid_field_names: + tmp_data.append(value) + if any(tmp_data[1:]): + self.tb.add_row(tmp_data) + lines = self.tb.get_string().split("\n") + self.output_handle("|" + self.header.center(len(lines[0]) - 2, "*") + "|") + for i in lines: + self.output_handle(i) + + class TransformerBasedModelBlockPatternDetector: # pragma: no cover """Detect the attention block and FFN block in transformer-based model.""" diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index 9014f1576a3..08b1cd45489 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -23,7 +23,6 @@ import torch -import neural_compressor.torch.utils as torch_utils from neural_compressor.common.base_config import ( BaseConfig, config_registry, @@ -220,17 +219,14 @@ def get_config_set_for_tuning(cls) -> Union[None, "RTNConfig", List["RTNConfig"] dtype=["int4", "nf4"], use_sym=[True, False], group_size=[32, 128], use_mse_search=[False, True] ) - @classmethod - def get_predefined_configs(cls) -> Dict[torch_utils.ProcessorType, "RTNConfig"]: - pre_defined_configs: Dict[torch_utils.ProcessorType, RTNConfig] = {} - pre_defined_configs[torch_utils.ProcessorType.Client] = cls(use_layer_wise=True) - pre_defined_configs[torch_utils.ProcessorType.Server] = cls() - return pre_defined_configs +def get_default_rtn_config() -> RTNConfig: + """Generate the default rtn config. -def get_default_rtn_config(processor_type: Optional[Union[str, torch_utils.ProcessorType]] = None) -> RTNConfig: - process_type = torch_utils.get_processor_type_from_user_config(processor_type) - return RTNConfig.get_predefined_configs()[process_type] + Returns: + the default rtn config. + """ + return RTNConfig() def get_default_double_quant_config(type="BNB_NF4"): @@ -382,17 +378,14 @@ def get_config_set_for_tuning(cls) -> Union[None, "GPTQConfig", List["GPTQConfig # TODO fwk owner needs to update it. return GPTQConfig(act_order=[True, False], use_sym=[False, True]) - @classmethod - def get_predefined_configs(cls) -> Dict[torch_utils.ProcessorType, "GPTQConfig"]: - pre_defined_configs: Dict[torch_utils.ProcessorType, GPTQConfig] = {} - pre_defined_configs[torch_utils.ProcessorType.Client] = cls(use_layer_wise=True) - pre_defined_configs[torch_utils.ProcessorType.Server] = cls() - return pre_defined_configs +def get_default_gptq_config() -> GPTQConfig: + """Generate the default gptq config. -def get_default_gptq_config(processor_type: Optional[Union[str, torch_utils.ProcessorType]] = None) -> RTNConfig: - process_type = torch_utils.get_processor_type_from_user_config(processor_type) - return GPTQConfig.get_predefined_configs()[process_type] + Returns: + the default gptq config. + """ + return GPTQConfig() ######################## AWQ Config ############################### @@ -732,7 +725,6 @@ def __init__( not_use_best_mse: bool = False, dynamic_max_gap: int = -1, scale_dtype: str = "fp16", - use_layer_wise: bool = False, white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST, ): """Init AUTOROUND weight-only quantization config. @@ -785,7 +777,6 @@ def __init__( self.not_use_best_mse = not_use_best_mse self.dynamic_max_gap = dynamic_max_gap self.scale_dtype = scale_dtype - self.use_layer_wise = use_layer_wise self._post_init() @classmethod @@ -812,17 +803,14 @@ def get_config_set_for_tuning(cls) -> Union[None, "AutoRoundConfig", List["AutoR # TODO fwk owner needs to update it. return AutoRoundConfig(bits=[4, 6]) - @classmethod - def get_predefined_configs(cls) -> Dict[torch_utils.ProcessorType, "AutoRoundConfig"]: - pre_defined_configs: Dict[torch_utils.ProcessorType, AutoRoundConfig] = {} - pre_defined_configs[torch_utils.ProcessorType.Client] = cls(use_layer_wise=True) - pre_defined_configs[torch_utils.ProcessorType.Server] = cls() - return pre_defined_configs +def get_default_AutoRound_config() -> AutoRoundConfig: + """Generate the default AUTOROUND config. -def get_default_AutoRound_config(processor_type: Optional[Union[str, torch_utils.ProcessorType]] = None) -> RTNConfig: - process_type = torch_utils.get_processor_type_from_user_config(processor_type) - return AutoRoundConfig.get_predefined_configs()[process_type] + Returns: + the default AUTOROUND config. + """ + return AutoRoundConfig() ######################## MX Config ############################### @@ -1043,6 +1031,7 @@ def __init__( act_algo: str = "minmax", excluded_precisions: list = [], white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST, + model_info: Optional[List[Tuple[str, Callable]]] = None, ): """Init Static Quant Configs.""" super().__init__(white_list=white_list) @@ -1055,6 +1044,7 @@ def __init__( self.act_granularity = act_granularity self.act_algo = act_algo self.excluded_precisions = excluded_precisions + self.model_info = model_info self._post_init() @classmethod @@ -1072,10 +1062,28 @@ def get_model_info_for_ipex(model: torch.nn.Module, example_inputs) -> List[Tupl _, _, _, _, model_info = get_quantizable_ops_recursively(model, example_inputs=example_inputs) return model_info - @staticmethod - def get_model_info(model: torch.nn.Module, example_inputs=None) -> List[Tuple[str, Callable]]: + def get_model_info_for_ipex_xpu(self, model: torch.nn.Module) -> List[Tuple[str, Callable]]: + if self.model_info: + return self.model_info + else: + white_list = torch.quantization.quantization_mappings.get_default_qconfig_propagation_list() + filter_result = [] + for op_name, module in model.named_modules(): + if type(module) in white_list: + pair = (op_name, type(module).__name__) + filter_result.append(pair) + logger.debug(f"Get model info: {filter_result}") + self.model_info = filter_result + return filter_result + + def get_model_info(self, model: torch.nn.Module, example_inputs=None) -> List[Tuple[str, Callable]]: + from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator + if is_ipex_imported(): - return StaticQuantConfig.get_model_info_for_ipex(model, example_inputs) + if auto_detect_accelerator().current_device() == "cpu": + return StaticQuantConfig.get_model_info_for_ipex(model, example_inputs) + else: + return StaticQuantConfig.get_model_info_for_ipex_xpu(self, model) def to_config_mapping( self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None diff --git a/test/3x/torch/quantization/test_static_quant.py b/test/3x/torch/quantization/test_static_quant.py index 46e791aa52f..23efde53b58 100644 --- a/test/3x/torch/quantization/test_static_quant.py +++ b/test/3x/torch/quantization/test_static_quant.py @@ -4,6 +4,14 @@ import pytest import torch +try: + import intel_extension_for_pytorch as ipex + + is_ipex_available = True +except: # pragma: no cover + is_ipex_available = False + assert False, "Please install IPEX for static quantization." + from neural_compressor.torch.quantization import ( StaticQuantConfig, convert, @@ -11,10 +19,9 @@ prepare, quantize, ) -from neural_compressor.torch.utils import is_ipex_available +from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator -if is_ipex_available(): - import intel_extension_for_pytorch as ipex +device = auto_detect_accelerator().current_device() def build_simple_torch_model(): @@ -53,7 +60,7 @@ def setup_class(self): def teardown_class(self): shutil.rmtree("saved_results", ignore_errors=True) - @pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX") + @pytest.mark.skipif(not is_ipex_available or device != "cpu", reason="Requires IPEX on CPU device") def test_static_quant_default(self): fp32_model = copy.deepcopy(self.fp32_model) quant_config = get_default_static_config() @@ -70,7 +77,7 @@ def test_static_quant_default(self): q_model = convert(prepared_model) assert q_model is not None, "Quantization failed!" - @pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX") + @pytest.mark.skipif(not is_ipex_available or device != "cpu", reason="Requires IPEX on CPU device") def test_static_quant_fallback(self): fp32_model = copy.deepcopy(self.fp32_model) quant_config = get_default_static_config() @@ -100,7 +107,7 @@ def test_static_quant_fallback(self): dtype = q_model.tune_cfg[" "]["q_op_infos"][op]["input_tensor_infos"][0]["force_dtype"] assert dtype == "torch.float32", "Failed to fallback fc2 layer, please check!" - @pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX") + @pytest.mark.skipif(not is_ipex_available or device != "cpu", reason="Requires IPEX on CPU device") @pytest.mark.parametrize( "act_sym, act_algo", [ @@ -119,7 +126,7 @@ def test_static_quant_params(self, act_sym, act_algo): q_model = convert(prepared_model) assert q_model is not None, "Quantization failed!" - @pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX") + @pytest.mark.skipif(not is_ipex_available or device != "cpu", reason="Requires IPEX on CPU device") def test_static_quant_accuracy(self): class M(torch.nn.Module): def __init__(self): @@ -148,7 +155,7 @@ def run_fn(model): # set a big atol to avoid random issue assert torch.allclose(output1, output2, atol=2e-2), "Accuracy gap atol > 0.02 is unexpected. Please check." - @pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX") + @pytest.mark.skipif(not is_ipex_available or device != "cpu", reason="Requires IPEX on CPU device") def test_static_quant_save_load(self): from intel_extension_for_pytorch.quantization import convert as ipex_convert from intel_extension_for_pytorch.quantization import prepare as ipex_prepare @@ -196,7 +203,7 @@ def run_fn(model): loaded_model = load("saved_results") assert isinstance(loaded_model, torch.jit.ScriptModule) - @pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX") + @pytest.mark.skipif(not is_ipex_available or device != "cpu", reason="Requires IPEX on CPU device") def test_static_quant_with_quantize_API(self): # quantize API fp32_model = copy.deepcopy(self.fp32_model) @@ -205,7 +212,7 @@ def test_static_quant_with_quantize_API(self): q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs) assert q_model is not None, "Quantization failed!" - @pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX") + @pytest.mark.skipif(not is_ipex_available or device != "cpu", reason="Requires IPEX on CPU device") def test_static_quant_mixed_precision(self): fp32_model = copy.deepcopy(self.fp32_model) example_inputs = self.input @@ -227,3 +234,33 @@ def test_static_quant_mixed_precision(self): run_fn(prepared_model) q_model = convert(prepared_model) assert q_model is not None, "Quantization failed!" + + @pytest.mark.skipif(not is_ipex_available or device == "cpu", reason="Requires IPEX on XPU device") + @pytest.mark.parametrize( + "act_sym, act_algo", + [ + (True, "kl"), + (True, "minmax"), + (False, "kl"), + (False, "minmax"), + ], + ) + def test_static_quant_params_xpu(self, act_sym, act_algo): + import torchvision.models as models + + model = models.resnet50(pretrained=True) + fp32_model = copy.deepcopy(model) + data = torch.rand(1, 3, 224, 224) + example_inputs = data.to("xpu") + + def run_fn(model): + model(example_inputs) + + quant_config = StaticQuantConfig(act_sym=act_sym, act_algo=act_algo, excluded_precisions=["bf16"]) + # fallback by op_name + quant_config.set_local("conv1", StaticQuantConfig(w_dtype="fp32", act_dtype="fp32")) + prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs) + run_fn(prepared_model) + q_model = convert(prepared_model) + run_fn(q_model) + assert q_model is not None, "Quantization failed!" From fcd8c77386b9eeb3ddbed8ee259b79feeb202678 Mon Sep 17 00:00:00 2001 From: Zixuan Cheng <110808245+violetch24@users.noreply.github.com> Date: Fri, 12 Jul 2024 11:05:49 +0800 Subject: [PATCH 2/9] Update utility.py --- .../torch/algorithms/static_quant/utility.py | 45 +------------------ 1 file changed, 1 insertion(+), 44 deletions(-) diff --git a/neural_compressor/torch/algorithms/static_quant/utility.py b/neural_compressor/torch/algorithms/static_quant/utility.py index a0eb6ad67f7..5e95dae2aac 100644 --- a/neural_compressor/torch/algorithms/static_quant/utility.py +++ b/neural_compressor/torch/algorithms/static_quant/utility.py @@ -24,12 +24,11 @@ try: import intel_extension_for_pytorch as ipex - import prettytable as pt except: # pragma: no cover pass from neural_compressor.common.utils import DEFAULT_WORKSPACE, CpuInfo -from neural_compressor.torch.utils import get_ipex_version, get_torch_version, logger +from neural_compressor.torch.utils import Statistics, get_ipex_version, get_torch_version, logger version = get_torch_version() ipex_ver = get_ipex_version() @@ -608,48 +607,6 @@ def get_quantizable_ops_from_cfgs(ops_name, op_infos_from_cfgs, input_tensor_ids return quantizable_ops -class Statistics: # pragma: no cover - """The statistics printer.""" - - def __init__(self, data, header, field_names, output_handle=logger.info): - """Init a Statistics object. - - Args: - data: The statistics data - header: The table header - field_names: The field names - output_handle: The output logging method - """ - self.field_names = field_names - self.header = header - self.data = data - self.output_handle = output_handle - self.tb = pt.PrettyTable(min_table_width=40) - - def print_stat(self): - """Print the statistics.""" - valid_field_names = [] - for index, value in enumerate(self.field_names): - if index < 2: - valid_field_names.append(value) - continue - - if any(i[index] for i in self.data): - valid_field_names.append(value) - self.tb.field_names = valid_field_names - for i in self.data: - tmp_data = [] - for index, value in enumerate(i): - if self.field_names[index] in valid_field_names: - tmp_data.append(value) - if any(tmp_data[1:]): - self.tb.add_row(tmp_data) - lines = self.tb.get_string().split("\n") - self.output_handle("|" + self.header.center(len(lines[0]) - 2, "*") + "|") - for i in lines: - self.output_handle(i) - - class TransformerBasedModelBlockPatternDetector: # pragma: no cover """Detect the attention block and FFN block in transformer-based model.""" From d2285264ebd789ea355e0e8d1640ad3aa201e99a Mon Sep 17 00:00:00 2001 From: gta Date: Fri, 12 Jul 2024 03:10:50 +0000 Subject: [PATCH 3/9] minor fix Signed-off-by: gta --- .../torch/quantization/config.py | 48 ++++++++++++------- 1 file changed, 30 insertions(+), 18 deletions(-) diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index 08b1cd45489..22bfe3268c9 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -23,6 +23,7 @@ import torch +import neural_compressor.torch.utils as torch_utils from neural_compressor.common.base_config import ( BaseConfig, config_registry, @@ -219,14 +220,17 @@ def get_config_set_for_tuning(cls) -> Union[None, "RTNConfig", List["RTNConfig"] dtype=["int4", "nf4"], use_sym=[True, False], group_size=[32, 128], use_mse_search=[False, True] ) + @classmethod + def get_predefined_configs(cls) -> Dict[torch_utils.ProcessorType, "RTNConfig"]: + pre_defined_configs: Dict[torch_utils.ProcessorType, RTNConfig] = {} + pre_defined_configs[torch_utils.ProcessorType.Client] = cls(use_layer_wise=True) + pre_defined_configs[torch_utils.ProcessorType.Server] = cls() + return pre_defined_configs -def get_default_rtn_config() -> RTNConfig: - """Generate the default rtn config. - Returns: - the default rtn config. - """ - return RTNConfig() +def get_default_rtn_config(processor_type: Optional[Union[str, torch_utils.ProcessorType]] = None) -> RTNConfig: + process_type = torch_utils.get_processor_type_from_user_config(processor_type) + return RTNConfig.get_predefined_configs()[process_type] def get_default_double_quant_config(type="BNB_NF4"): @@ -378,14 +382,17 @@ def get_config_set_for_tuning(cls) -> Union[None, "GPTQConfig", List["GPTQConfig # TODO fwk owner needs to update it. return GPTQConfig(act_order=[True, False], use_sym=[False, True]) + @classmethod + def get_predefined_configs(cls) -> Dict[torch_utils.ProcessorType, "GPTQConfig"]: + pre_defined_configs: Dict[torch_utils.ProcessorType, GPTQConfig] = {} + pre_defined_configs[torch_utils.ProcessorType.Client] = cls(use_layer_wise=True) + pre_defined_configs[torch_utils.ProcessorType.Server] = cls() + return pre_defined_configs -def get_default_gptq_config() -> GPTQConfig: - """Generate the default gptq config. - Returns: - the default gptq config. - """ - return GPTQConfig() +def get_default_gptq_config(processor_type: Optional[Union[str, torch_utils.ProcessorType]] = None) -> RTNConfig: + process_type = torch_utils.get_processor_type_from_user_config(processor_type) + return GPTQConfig.get_predefined_configs()[process_type] ######################## AWQ Config ############################### @@ -725,6 +732,7 @@ def __init__( not_use_best_mse: bool = False, dynamic_max_gap: int = -1, scale_dtype: str = "fp16", + use_layer_wise: bool = False, white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST, ): """Init AUTOROUND weight-only quantization config. @@ -777,6 +785,7 @@ def __init__( self.not_use_best_mse = not_use_best_mse self.dynamic_max_gap = dynamic_max_gap self.scale_dtype = scale_dtype + self.use_layer_wise = use_layer_wise self._post_init() @classmethod @@ -803,14 +812,17 @@ def get_config_set_for_tuning(cls) -> Union[None, "AutoRoundConfig", List["AutoR # TODO fwk owner needs to update it. return AutoRoundConfig(bits=[4, 6]) + @classmethod + def get_predefined_configs(cls) -> Dict[torch_utils.ProcessorType, "AutoRoundConfig"]: + pre_defined_configs: Dict[torch_utils.ProcessorType, AutoRoundConfig] = {} + pre_defined_configs[torch_utils.ProcessorType.Client] = cls(use_layer_wise=True) + pre_defined_configs[torch_utils.ProcessorType.Server] = cls() + return pre_defined_configs -def get_default_AutoRound_config() -> AutoRoundConfig: - """Generate the default AUTOROUND config. - Returns: - the default AUTOROUND config. - """ - return AutoRoundConfig() +def get_default_AutoRound_config(processor_type: Optional[Union[str, torch_utils.ProcessorType]] = None) -> RTNConfig: + process_type = torch_utils.get_processor_type_from_user_config(processor_type) + return AutoRoundConfig.get_predefined_configs()[process_type] ######################## MX Config ############################### From 66506623934ab6349729a99a3ede664aad2a7061 Mon Sep 17 00:00:00 2001 From: gta Date: Fri, 12 Jul 2024 06:14:51 +0000 Subject: [PATCH 4/9] add save and load Signed-off-by: gta --- .../torch/algorithms/static_quant/save_load.py | 14 +++++++++++--- .../torch/algorithms/static_quant/static_quant.py | 11 +++++------ 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/neural_compressor/torch/algorithms/static_quant/save_load.py b/neural_compressor/torch/algorithms/static_quant/save_load.py index 557c1577728..e20db2b85b1 100644 --- a/neural_compressor/torch/algorithms/static_quant/save_load.py +++ b/neural_compressor/torch/algorithms/static_quant/save_load.py @@ -32,9 +32,17 @@ def save(model, output_dir="./saved_results"): qmodel_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME) qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), QCONFIG_NAME) - model.ori_save(qmodel_file_path) - with open(qconfig_file_path, "w") as f: - json.dump(model.tune_cfg, f, indent=4) + if next(model.parameters()).device.type == "cpu": + model.ori_save(qmodel_file_path) + with open(qconfig_file_path, "w") as f: + json.dump(model.tune_cfg, f, indent=4) + else: + from neural_compressor.common.utils import save_config_mapping + + save_config_mapping(model.qconfig, qconfig_file_path) + # MethodType 'save' not in state_dict + del model.save + torch.save(model.state_dict(), qmodel_file_path) logger.info("Save quantized model to {}.".format(qmodel_file_path)) logger.info("Save configuration of quantized model to {}.".format(qconfig_file_path)) diff --git a/neural_compressor/torch/algorithms/static_quant/static_quant.py b/neural_compressor/torch/algorithms/static_quant/static_quant.py index 490acacd23a..ce5fe1d1a97 100644 --- a/neural_compressor/torch/algorithms/static_quant/static_quant.py +++ b/neural_compressor/torch/algorithms/static_quant/static_quant.py @@ -58,6 +58,7 @@ def __init__(self, quant_config: OrderedDict = {}): """ super().__init__(quant_config) self.user_cfg = OrderedDict() + self.device = auto_detect_accelerator().current_device() def prepare(self, model, example_inputs, inplace=True, *args, **kwargs): """Prepares a given model for quantization. @@ -70,10 +71,9 @@ def prepare(self, model, example_inputs, inplace=True, *args, **kwargs): Returns: A prepared model. """ - device = auto_detect_accelerator().current_device() assert example_inputs is not None, "Please provide example_inputs for static quantization." - if device == "cpu": + if self.device == "cpu": _, cfgs, op_infos_from_cfgs, output_tensor_id_op_name, _ = get_quantizable_ops_recursively( model, example_inputs ) @@ -89,7 +89,7 @@ def prepare(self, model, example_inputs, inplace=True, *args, **kwargs): if not hasattr(model, "save_qconf_summary") or not hasattr(model, "load_qconf_summary"): # pragma: no cover from torch.ao.quantization import HistogramObserver, MinMaxObserver, PerChannelMinMaxObserver, QConfig - if device != "cpu": # pragma: no cover + if self.device != "cpu": # pragma: no cover from torch.quantization.quantize_jit import prepare_jit with torch.no_grad(): @@ -125,7 +125,7 @@ def prepare(self, model, example_inputs, inplace=True, *args, **kwargs): model, static_qconfig, example_inputs=example_inputs, inplace=inplace ) - if device == "cpu": + if self.device == "cpu": model.load_qconf_summary(qconf_summary=ipex_config_path) return model @@ -141,12 +141,11 @@ def convert(self, model, example_inputs, inplace=True, *args, **kwargs): Returns: A quantized model. """ - device = auto_detect_accelerator().current_device() use_bf16 = self.quant_config.get("use_bf16", None) from neural_compressor.torch.algorithms.static_quant import save - if device != "cpu": # pragma: no cover + if self.device != "cpu": # pragma: no cover from torch.quantization.quantize_jit import convert_jit model = convert_jit(model, inplace) From facde1875297643b8d6cae0d5e7ba75d3390c728 Mon Sep 17 00:00:00 2001 From: gta Date: Mon, 15 Jul 2024 15:20:01 +0000 Subject: [PATCH 5/9] fix save Signed-off-by: gta --- .../torch/algorithms/static_quant/save_load.py | 6 ++---- .../torch/algorithms/static_quant/static_quant.py | 4 ++-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/neural_compressor/torch/algorithms/static_quant/save_load.py b/neural_compressor/torch/algorithms/static_quant/save_load.py index e20db2b85b1..dedac50c5e9 100644 --- a/neural_compressor/torch/algorithms/static_quant/save_load.py +++ b/neural_compressor/torch/algorithms/static_quant/save_load.py @@ -32,17 +32,15 @@ def save(model, output_dir="./saved_results"): qmodel_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME) qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), QCONFIG_NAME) - if next(model.parameters()).device.type == "cpu": + if next(model.parameters()).device.type == "cpu": # pragma: no cover model.ori_save(qmodel_file_path) with open(qconfig_file_path, "w") as f: json.dump(model.tune_cfg, f, indent=4) else: from neural_compressor.common.utils import save_config_mapping + model.ori_save(qmodel_file_path) save_config_mapping(model.qconfig, qconfig_file_path) - # MethodType 'save' not in state_dict - del model.save - torch.save(model.state_dict(), qmodel_file_path) logger.info("Save quantized model to {}.".format(qmodel_file_path)) logger.info("Save configuration of quantized model to {}.".format(qconfig_file_path)) diff --git a/neural_compressor/torch/algorithms/static_quant/static_quant.py b/neural_compressor/torch/algorithms/static_quant/static_quant.py index ce5fe1d1a97..f55eea467c6 100644 --- a/neural_compressor/torch/algorithms/static_quant/static_quant.py +++ b/neural_compressor/torch/algorithms/static_quant/static_quant.py @@ -161,8 +161,8 @@ def convert(self, model, example_inputs, inplace=True, *args, **kwargs): dump_model_op_stats(self.user_cfg) - model.ori_save = model.save - model.save = MethodType(save, model) + model.ori_save = model.save + model.save = MethodType(save, model) logger.info("Static quantization done.") return model From 888f7514622d009b35fce6608d47f52d6a209775 Mon Sep 17 00:00:00 2001 From: gta Date: Tue, 16 Jul 2024 06:37:42 +0000 Subject: [PATCH 6/9] add save load ut Signed-off-by: gta --- .../torch/algorithms/static_quant/save_load.py | 6 +++--- .../algorithms/static_quant/static_quant.py | 3 ++- .../3x/torch/quantization/test_static_quant.py | 18 +++++++++++++++++- 3 files changed, 22 insertions(+), 5 deletions(-) diff --git a/neural_compressor/torch/algorithms/static_quant/save_load.py b/neural_compressor/torch/algorithms/static_quant/save_load.py index dedac50c5e9..2d40182661a 100644 --- a/neural_compressor/torch/algorithms/static_quant/save_load.py +++ b/neural_compressor/torch/algorithms/static_quant/save_load.py @@ -32,14 +32,14 @@ def save(model, output_dir="./saved_results"): qmodel_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME) qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), QCONFIG_NAME) - if next(model.parameters()).device.type == "cpu": # pragma: no cover + if next(model.parameters()).device.type == "cpu": model.ori_save(qmodel_file_path) with open(qconfig_file_path, "w") as f: json.dump(model.tune_cfg, f, indent=4) - else: + else: # pragma: no cover from neural_compressor.common.utils import save_config_mapping - model.ori_save(qmodel_file_path) + torch.jit.save(model, qmodel_file_path) save_config_mapping(model.qconfig, qconfig_file_path) logger.info("Save quantized model to {}.".format(qmodel_file_path)) diff --git a/neural_compressor/torch/algorithms/static_quant/static_quant.py b/neural_compressor/torch/algorithms/static_quant/static_quant.py index f55eea467c6..95d740c8691 100644 --- a/neural_compressor/torch/algorithms/static_quant/static_quant.py +++ b/neural_compressor/torch/algorithms/static_quant/static_quant.py @@ -150,7 +150,8 @@ def convert(self, model, example_inputs, inplace=True, *args, **kwargs): model = convert_jit(model, inplace) simple_inference(model, example_inputs, iterations=2) - dump_model_op_stats(self.quant_config["op"]) + model.qconfig = self.quant_config["op"] + dump_model_op_stats(model.qconfig) else: model.save_qconf_summary(qconf_summary=ipex_config_path) model = _ipex_post_quant_process(model, example_inputs, use_bf16, inplace=inplace) diff --git a/test/3x/torch/quantization/test_static_quant.py b/test/3x/torch/quantization/test_static_quant.py index facd0999d86..4aecd29eecf 100644 --- a/test/3x/torch/quantization/test_static_quant.py +++ b/test/3x/torch/quantization/test_static_quant.py @@ -245,7 +245,7 @@ def test_static_quant_mixed_precision(self): (False, "minmax"), ], ) - def test_static_quant_params_xpu(self, act_sym, act_algo): + def test_static_quant_xpu(self, act_sym, act_algo): import torchvision.models as models model = models.resnet50(pretrained=True) @@ -264,3 +264,19 @@ def run_fn(model): q_model = convert(prepared_model) run_fn(q_model) assert q_model is not None, "Quantization failed!" + + quant_config = StaticQuantConfig(act_sym=act_sym, act_algo=act_algo, excluded_precisions=["bf16"]) + # fallback by op_type + quant_config.set_local("Conv2d", StaticQuantConfig(w_dtype="fp32", act_dtype="fp32")) + prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs) + run_fn(prepared_model) + q_model = convert(prepared_model) + run_fn(q_model) + assert q_model is not None, "Quantization failed!" + + q_model.save("saved_results") + from neural_compressor.torch.quantization import load + + # load + loaded_model = load("saved_results") + assert isinstance(loaded_model, torch.jit.ScriptModule), "Loading failed!" From 110f6c46a1c66c77e47e8d179b9c9bda0b81c47d Mon Sep 17 00:00:00 2001 From: violetch24 Date: Tue, 16 Jul 2024 23:15:59 -0700 Subject: [PATCH 7/9] fix ci error Signed-off-by: violetch24 --- neural_compressor/torch/algorithms/static_quant/save_load.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/neural_compressor/torch/algorithms/static_quant/save_load.py b/neural_compressor/torch/algorithms/static_quant/save_load.py index 2d40182661a..9a7808c17eb 100644 --- a/neural_compressor/torch/algorithms/static_quant/save_load.py +++ b/neural_compressor/torch/algorithms/static_quant/save_load.py @@ -32,7 +32,8 @@ def save(model, output_dir="./saved_results"): qmodel_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME) qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), QCONFIG_NAME) - if next(model.parameters()).device.type == "cpu": + device = next(model.parameters(), None).device.type if next(model.parameters(), None) else "cpu" + if device == "cpu": model.ori_save(qmodel_file_path) with open(qconfig_file_path, "w") as f: json.dump(model.tune_cfg, f, indent=4) From 841c22a18dfac9a7c30e27d73f916fdc03ea3af4 Mon Sep 17 00:00:00 2001 From: violetch24 Date: Wed, 17 Jul 2024 00:47:39 -0700 Subject: [PATCH 8/9] ut coverage Signed-off-by: violetch24 --- .../torch/algorithms/static_quant/static_quant.py | 4 ++-- neural_compressor/torch/algorithms/static_quant/utility.py | 2 +- neural_compressor/torch/quantization/config.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/neural_compressor/torch/algorithms/static_quant/static_quant.py b/neural_compressor/torch/algorithms/static_quant/static_quant.py index 95d740c8691..35eb76596b3 100644 --- a/neural_compressor/torch/algorithms/static_quant/static_quant.py +++ b/neural_compressor/torch/algorithms/static_quant/static_quant.py @@ -79,7 +79,7 @@ def prepare(self, model, example_inputs, inplace=True, *args, **kwargs): ) # update json file in ipex_config_path; map ipex op_name to pt op_name self.user_cfg = cfg_to_qconfig(self.quant_config, cfgs, op_infos_from_cfgs, output_tensor_id_op_name) - else: + else: # pragma: no cover model = model.to("xpu") model.eval() @@ -109,7 +109,7 @@ def prepare(self, model, example_inputs, inplace=True, *args, **kwargs): from torch.ao.quantization import QConfigMapping static_qconfig = QConfigMapping().set_global(qconfig) - else: + else: # pragma: no cover static_qconfig = QConfig( activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8), weight=PerChannelMinMaxObserver.with_args( diff --git a/neural_compressor/torch/algorithms/static_quant/utility.py b/neural_compressor/torch/algorithms/static_quant/utility.py index 5e95dae2aac..4d02513ad03 100644 --- a/neural_compressor/torch/algorithms/static_quant/utility.py +++ b/neural_compressor/torch/algorithms/static_quant/utility.py @@ -163,7 +163,7 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_ return cfgs, ori_user_cfg -def generate_xpu_qconfig(tune_cfg): +def generate_xpu_qconfig(tune_cfg): # pragma: no cover # qconfig observer & config constants for ipex-xpu from torch.ao.quantization import HistogramObserver, MinMaxObserver, QConfig diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index 1addbd15b8d..14ed5139062 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -1126,7 +1126,7 @@ def get_model_info_for_ipex(model: torch.nn.Module, example_inputs) -> List[Tupl _, _, _, _, model_info = get_quantizable_ops_recursively(model, example_inputs=example_inputs) return model_info - def get_model_info_for_ipex_xpu(self, model: torch.nn.Module) -> List[Tuple[str, Callable]]: + def get_model_info_for_ipex_xpu(self, model: torch.nn.Module) -> List[Tuple[str, Callable]]: # pragma: no cover if self.model_info: return self.model_info else: From 57b59a312fff471a2c31181c334df3d3b0f5829a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 17 Jul 2024 07:49:28 +0000 Subject: [PATCH 9/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../torch/algorithms/static_quant/static_quant.py | 4 ++-- neural_compressor/torch/algorithms/static_quant/utility.py | 2 +- neural_compressor/torch/quantization/config.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/neural_compressor/torch/algorithms/static_quant/static_quant.py b/neural_compressor/torch/algorithms/static_quant/static_quant.py index 35eb76596b3..08dc5a1035f 100644 --- a/neural_compressor/torch/algorithms/static_quant/static_quant.py +++ b/neural_compressor/torch/algorithms/static_quant/static_quant.py @@ -79,7 +79,7 @@ def prepare(self, model, example_inputs, inplace=True, *args, **kwargs): ) # update json file in ipex_config_path; map ipex op_name to pt op_name self.user_cfg = cfg_to_qconfig(self.quant_config, cfgs, op_infos_from_cfgs, output_tensor_id_op_name) - else: # pragma: no cover + else: # pragma: no cover model = model.to("xpu") model.eval() @@ -109,7 +109,7 @@ def prepare(self, model, example_inputs, inplace=True, *args, **kwargs): from torch.ao.quantization import QConfigMapping static_qconfig = QConfigMapping().set_global(qconfig) - else: # pragma: no cover + else: # pragma: no cover static_qconfig = QConfig( activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8), weight=PerChannelMinMaxObserver.with_args( diff --git a/neural_compressor/torch/algorithms/static_quant/utility.py b/neural_compressor/torch/algorithms/static_quant/utility.py index 4d02513ad03..f4930a22ddd 100644 --- a/neural_compressor/torch/algorithms/static_quant/utility.py +++ b/neural_compressor/torch/algorithms/static_quant/utility.py @@ -163,7 +163,7 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_ return cfgs, ori_user_cfg -def generate_xpu_qconfig(tune_cfg): # pragma: no cover +def generate_xpu_qconfig(tune_cfg): # pragma: no cover # qconfig observer & config constants for ipex-xpu from torch.ao.quantization import HistogramObserver, MinMaxObserver, QConfig diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index 14ed5139062..2c43f1e59c1 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -1126,7 +1126,7 @@ def get_model_info_for_ipex(model: torch.nn.Module, example_inputs) -> List[Tupl _, _, _, _, model_info = get_quantizable_ops_recursively(model, example_inputs=example_inputs) return model_info - def get_model_info_for_ipex_xpu(self, model: torch.nn.Module) -> List[Tuple[str, Callable]]: # pragma: no cover + def get_model_info_for_ipex_xpu(self, model: torch.nn.Module) -> List[Tuple[str, Callable]]: # pragma: no cover if self.model_info: return self.model_info else: