From 186bedec1d512f51ebaf24f8072925b4520a9a10 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Tue, 9 Jul 2024 11:37:40 +0800 Subject: [PATCH 1/6] detect processor type Signed-off-by: yiliu30 --- .../torch/quantization/config.py | 31 +++++++++------- neural_compressor/torch/utils/utility.py | 35 ++++++++++++++++++- test/3x/torch/test_config.py | 18 +++++++--- 3 files changed, 67 insertions(+), 17 deletions(-) diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index 71b01353d5a..c5e6c181e5b 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -23,6 +23,7 @@ import torch +import neural_compressor.torch.utils as torch_utils from neural_compressor.common.base_config import ( BaseConfig, config_registry, @@ -219,14 +220,17 @@ def get_config_set_for_tuning(cls) -> Union[None, "RTNConfig", List["RTNConfig"] dtype=["int4", "nf4"], use_sym=[True, False], group_size=[32, 128], use_mse_search=[False, True] ) + @classmethod + def get_predefined_configs(cls) -> Dict[torch_utils.ProcessorType, "RTNConfig"]: + pre_defined_configs: Dict[torch_utils.ProcessorType, RTNConfig] = {} + pre_defined_configs[torch_utils.ProcessorType.Client] = cls(use_layer_wise=True) + pre_defined_configs[torch_utils.ProcessorType.Server] = cls() + return pre_defined_configs -def get_default_rtn_config() -> RTNConfig: - """Generate the default rtn config. - Returns: - the default rtn config. - """ - return RTNConfig() +def get_default_rtn_config(processor_type: Optional[Union[str, torch_utils.ProcessorType]] = None) -> RTNConfig: + process_type = torch_utils.get_processor_type_from_user_config(processor_type) + return RTNConfig.get_predefined_configs()[process_type] def get_default_double_quant_config(type="BNB_NF4"): @@ -378,14 +382,17 @@ def get_config_set_for_tuning(cls) -> Union[None, "GPTQConfig", List["GPTQConfig # TODO fwk owner needs to update it. return GPTQConfig(act_order=[True, False], use_sym=[False, True]) + @classmethod + def get_predefined_configs(cls) -> Dict[torch_utils.ProcessorType, "GPTQConfig"]: + pre_defined_configs: Dict[torch_utils.ProcessorType, GPTQConfig] = {} + pre_defined_configs[torch_utils.ProcessorType.Client] = cls(use_layer_wise=True) + pre_defined_configs[torch_utils.ProcessorType.Server] = cls() + return pre_defined_configs -def get_default_gptq_config() -> GPTQConfig: - """Generate the default gptq config. - Returns: - the default gptq config. - """ - return GPTQConfig() +def get_default_gptq_config(processor_type: Optional[Union[str, torch_utils.ProcessorType]] = None) -> RTNConfig: + process_type = torch_utils.get_processor_type_from_user_config(processor_type) + return GPTQConfig.get_predefined_configs()[process_type] ######################## AWQ Config ############################### diff --git a/neural_compressor/torch/utils/utility.py b/neural_compressor/torch/utils/utility.py index e312a9c388b..4914b40cacd 100644 --- a/neural_compressor/torch/utils/utility.py +++ b/neural_compressor/torch/utils/utility.py @@ -13,7 +13,7 @@ # limitations under the License. -from typing import Callable, Dict, List, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import torch from prettytable import PrettyTable @@ -278,3 +278,36 @@ def get_model_device(model: torch.nn.Module): """ for n, p in model.named_parameters(): return p.data.device.type # p.data.device == device(type='cpu') + + +import enum + +import psutil + + +class ProcessorType(enum.Enum): + Client = "Client" + Server = "Server" + + +def detect_processor_type_based_on_hw(): + # TODO: refine the logic + ram_size = psutil.virtual_memory().total / (1024**3) + if ram_size > 32: + return ProcessorType.Server + else: + return ProcessorType.Client + + +def get_processor_type_from_user_config(user_processor_type: Optional[Union[str, ProcessorType]] = None): + if user_processor_type is None: + processor_type = detect_processor_type_based_on_hw() + elif isinstance(user_processor_type, ProcessorType): + processor_type = user_processor_type + elif isinstance(user_processor_type, str): + user_processor_type = user_processor_type.lower().capitalize() + assert user_processor_type in ProcessorType.__members__, f"Unsupported processor type: {user_processor_type}" + processor_type = ProcessorType(user_processor_type) + else: + raise NotImplementedError(f"Unsupported processor type: {user_processor_type}") + return processor_type diff --git a/test/3x/torch/test_config.py b/test/3x/torch/test_config.py index c5bdc5261cf..4894145f834 100644 --- a/test/3x/torch/test_config.py +++ b/test/3x/torch/test_config.py @@ -1,9 +1,11 @@ import copy import unittest +import pytest import torch import transformers +import neural_compressor.torch.utils as torch_utils from neural_compressor.torch.quantization import ( AutoRoundConfig, AWQConfig, @@ -56,6 +58,18 @@ def setUp(self): # print the test name logger.info(f"Running TestQuantizationConfig test: {self.id()}") + @pytest.mark.parametrize("config_cls", [RTNConfig, GPTQConfig]) + def test_get_config_based_on_processor_type(self, config_cls): + config_for_client = config_cls.get_predefined_configs()[torch_utils.ProcessorType.Client] + assert ( + config_for_client.use_layer_wise + ), f"Expect use_layer_wise to be True, got {config_for_client.use_layer_wise}" + + config_for_server = config_cls.get_predefined_configs()[torch_utils.ProcessorType.Server] + assert ( + config_for_server.use_layer_wise is False + ), f"Expect use_layer_wise to be False, got {config_for_server.use_layer_wise}" + def test_quantize_rtn_from_dict_default(self): logger.info("test_quantize_rtn_from_dict_default") @@ -339,7 +353,3 @@ def test_expand_config(self): expand_config_list = RTNConfig.expand(tune_config) self.assertEqual(expand_config_list[0].bits, 4) self.assertEqual(expand_config_list[1].bits, 6) - - -if __name__ == "__main__": - unittest.main() From 4ad81989de0f1221bc5bc196ddac4bd916a20b60 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Tue, 9 Jul 2024 11:42:44 +0800 Subject: [PATCH 2/6] add more uts Signed-off-by: yiliu30 --- test/3x/torch/test_config.py | 31 ++++++++++++------------------- 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/test/3x/torch/test_config.py b/test/3x/torch/test_config.py index 4894145f834..d0cf2c1835f 100644 --- a/test/3x/torch/test_config.py +++ b/test/3x/torch/test_config.py @@ -58,18 +58,6 @@ def setUp(self): # print the test name logger.info(f"Running TestQuantizationConfig test: {self.id()}") - @pytest.mark.parametrize("config_cls", [RTNConfig, GPTQConfig]) - def test_get_config_based_on_processor_type(self, config_cls): - config_for_client = config_cls.get_predefined_configs()[torch_utils.ProcessorType.Client] - assert ( - config_for_client.use_layer_wise - ), f"Expect use_layer_wise to be True, got {config_for_client.use_layer_wise}" - - config_for_server = config_cls.get_predefined_configs()[torch_utils.ProcessorType.Server] - assert ( - config_for_server.use_layer_wise is False - ), f"Expect use_layer_wise to be False, got {config_for_server.use_layer_wise}" - def test_quantize_rtn_from_dict_default(self): logger.info("test_quantize_rtn_from_dict_default") @@ -345,11 +333,16 @@ def test_hqq_config(self): self.assertEqual(hqq_config.to_dict(), hqq_config2.to_dict()) -class TestQuantConfigForAutotune(unittest.TestCase): - def test_expand_config(self): - # test the expand functionalities, the user is not aware it +class TestQuantConfigBasedonProcessorType: - tune_config = RTNConfig(bits=[4, 6]) - expand_config_list = RTNConfig.expand(tune_config) - self.assertEqual(expand_config_list[0].bits, 4) - self.assertEqual(expand_config_list[1].bits, 6) + @pytest.mark.parametrize("config_cls", [RTNConfig, GPTQConfig]) + def test_get_config_based_on_processor_type(self, config_cls): + config_for_client = config_cls.get_predefined_configs()[torch_utils.ProcessorType.Client] + assert ( + config_for_client.use_layer_wise + ), f"Expect use_layer_wise to be True, got {config_for_client.use_layer_wise}" + + config_for_server = config_cls.get_predefined_configs()[torch_utils.ProcessorType.Server] + assert ( + config_for_server.use_layer_wise is False + ), f"Expect use_layer_wise to be False, got {config_for_server.use_layer_wise}" From a15438c3798c3bc95c4f513c288e7376c5919a33 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 10 Jul 2024 14:16:39 +0800 Subject: [PATCH 3/6] enhance the detect logic Signed-off-by: yiliu30 --- neural_compressor/common/utils/utility.py | 91 +++++++++++++++++------ neural_compressor/torch/utils/utility.py | 45 ++++++++--- test/3x/torch/test_config.py | 40 ++++++++++ 3 files changed, 143 insertions(+), 33 deletions(-) diff --git a/neural_compressor/common/utils/utility.py b/neural_compressor/common/utils/utility.py index 82f24243a9b..beb62e7f8cb 100644 --- a/neural_compressor/common/utils/utility.py +++ b/neural_compressor/common/utils/utility.py @@ -38,6 +38,7 @@ "CpuInfo", "default_tuning_logger", "call_counter", + "cpu_info", ] @@ -89,7 +90,7 @@ def __call__(self, *args, **kwargs): @singleton class CpuInfo(object): - """CPU info collection.""" + """Get CPU Info.""" def __init__(self): """Get whether the cpu numerical format is bf16, the number of sockets, cores and cores per socket.""" @@ -110,14 +111,28 @@ def __init__(self): b"\xB8\x07\x00\x00\x00" b"\x0f\xa2" b"\xC3", # mov eax, 7 # cpuid # ret ) self._bf16 = bool(eax & (1 << 5)) - # TODO: The implementation will be refined in the future. - # https://github.com/intel/neural-compressor/tree/detect_sockets - if "arch" in info and "ARM" in info["arch"]: # pragma: no cover - self._sockets = 1 - else: - self._sockets = self.get_number_of_sockets() - self._cores = psutil.cpu_count(logical=False) - self._cores_per_socket = int(self._cores / self._sockets) + self._info = info + # detect the below info when needed + self._cores = None + self._sockets = None + self._cores_per_socket = None + + @staticmethod + def _detect_cores(): + physical_cores = psutil.cpu_count(logical=False) + return physical_cores + + @property + def cores(self): + """Get the number of cores in platform.""" + if self._cores is None: + self._cores = self._detect_cores() + return self._cores + + @cores.setter + def cores(self, num_of_cores): + """Set the number of cores in platform.""" + self._cores = num_of_cores @property def bf16(self): @@ -130,30 +145,58 @@ def vnni(self): return self._vnni @property - def cores_per_socket(self): + def cores_per_socket(self) -> int: """Get the cores per socket.""" + if self._cores_per_socket is None: + self._cores_per_socket = self.cores // self.sockets return self._cores_per_socket - def get_number_of_sockets(self) -> int: - """Get number of sockets in platform.""" + @property + def sockets(self): + """Get the number of sockets in platform.""" + if self._sockets is None: + self._sockets = self._get_number_of_sockets() + return self._sockets + + @sockets.setter + def sockets(self, num_of_sockets): + """Set the number of sockets in platform.""" + self._sockets = num_of_sockets + + def _get_number_of_sockets(self) -> int: + if "arch" in self._info and "ARM" in self._info["arch"]: # pragma: no cover + return 1 + + num_sockets = None cmd = "cat /proc/cpuinfo | grep 'physical id' | sort -u | wc -l" if psutil.WINDOWS: cmd = r'wmic cpu get DeviceID | C:\Windows\System32\find.exe /C "CPU"' elif psutil.MACOS: # pragma: no cover cmd = "sysctl -n machdep.cpu.core_count" - with subprocess.Popen( - args=cmd, - shell=True, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - universal_newlines=False, - ) as proc: - proc.wait() - if proc.stdout: - for line in proc.stdout: - return int(line.decode("utf-8", errors="ignore").strip()) - return 0 + num_sockets = None + try: + with subprocess.Popen( + args=cmd, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + universal_newlines=False, + ) as proc: + proc.wait() + if proc.stdout: + for line in proc.stdout: + num_sockets = int(line.decode("utf-8", errors="ignore").strip()) + except Exception as e: + logger.error("Failed to get number of sockets: %s" % e) + if isinstance(num_sockets, int) and num_sockets >= 1: + return num_sockets + else: + logger.warning("Failed to get number of sockets, return 1 as default.") + return 1 + + +cpu_info = CpuInfo() def dump_elapsed_time(customized_msg=""): diff --git a/neural_compressor/torch/utils/utility.py b/neural_compressor/torch/utils/utility.py index 4914b40cacd..dee660d0de7 100644 --- a/neural_compressor/torch/utils/utility.py +++ b/neural_compressor/torch/utils/utility.py @@ -13,13 +13,15 @@ # limitations under the License. +import enum from typing import Callable, Dict, List, Optional, Tuple, Union +import psutil import torch from prettytable import PrettyTable from typing_extensions import TypeAlias -from neural_compressor.common.utils import LazyImport, Mode, logger +from neural_compressor.common.utils import LazyImport, Mode, cpu_info, logger OP_NAME_AND_TYPE_TUPLE_TYPE: TypeAlias = Tuple[str, Union[torch.nn.Module, Callable]] @@ -280,26 +282,51 @@ def get_model_device(model: torch.nn.Module): return p.data.device.type # p.data.device == device(type='cpu') -import enum - -import psutil - - class ProcessorType(enum.Enum): Client = "Client" Server = "Server" def detect_processor_type_based_on_hw(): - # TODO: refine the logic - ram_size = psutil.virtual_memory().total / (1024**3) - if ram_size > 32: + """Detects the processor type based on the hardware configuration. + + Returns: + ProcessorType: The detected processor type (Server or Client). + """ + # Detect the processor type based on below conditions: + # 1. If there are more than one sockets, it is a server. + # 2. If the memory size is greater than 64GB, it is a server. + log_mgs = "Processor type detected as {processor_type} due to {reason}." + if cpu_info.sockets > 1: + logger.info(log_mgs.format(processor_type=ProcessorType.Server.value, reason="there are more than one sockets")) + return ProcessorType.Server + elif psutil.virtual_memory().total / (1024**3) > 64: + logger.info( + log_mgs.format(processor_type=ProcessorType.Server.value, reason="the memory size is greater than 64GB") + ) return ProcessorType.Server else: + logger.info( + f"Processor type detected as {ProcessorType.Client.value}, pass `processor_type='server'` to override it if needed." + ) return ProcessorType.Client def get_processor_type_from_user_config(user_processor_type: Optional[Union[str, ProcessorType]] = None): + """Get the processor type. + + Get the processor type based on the user configuration or automatically detect it based on the hardware. + + Args: + user_processor_type (Optional[Union[str, ProcessorType]]): The user-specified processor type. Defaults to None. + + Returns: + ProcessorType: The detected or user-specified processor type. + + Raises: + AssertionError: If the user-specified processor type is not supported. + NotImplementedError: If the processor type is not recognized. + """ if user_processor_type is None: processor_type = detect_processor_type_based_on_hw() elif isinstance(user_processor_type, ProcessorType): diff --git a/test/3x/torch/test_config.py b/test/3x/torch/test_config.py index d0cf2c1835f..87ce1a27fb0 100644 --- a/test/3x/torch/test_config.py +++ b/test/3x/torch/test_config.py @@ -15,6 +15,7 @@ SmoothQuantConfig, StaticQuantConfig, TEQConfig, + get_default_gptq_config, get_default_hqq_config, get_default_rtn_config, quantize, @@ -346,3 +347,42 @@ def test_get_config_based_on_processor_type(self, config_cls): assert ( config_for_server.use_layer_wise is False ), f"Expect use_layer_wise to be False, got {config_for_server.use_layer_wise}" + + @pytest.fixture + def force_client(self, monkeypatch): + monkeypatch.setattr(torch_utils.utility.cpu_info, "sockets", 1) + + # force the ram size detected by psutil <= 64GB + class MockMemory: + def __init__(self, total): + self.total = total + + # Patch the psutil.virtual_memory() method + monkeypatch.setattr(torch_utils.utility.psutil, "virtual_memory", lambda: MockMemory(16 * 1024**3)) + + def test_auto_detect_processor_type(self, force_client): + p_type = torch_utils.detect_processor_type_based_on_hw() + assert ( + p_type == torch_utils.ProcessorType.Client + ), f"Expect processor type to be {torch_utils.ProcessorType.Client}, got {p_type}" + + @pytest.fixture + def force_server(self, monkeypatch): + monkeypatch.setattr(torch_utils.utility.cpu_info, "sockets", 2) + + def test_get_default_config_force_server(self, force_server): + rtn_config = get_default_rtn_config() + assert not rtn_config.use_layer_wise, f"Expect use_layer_wise to be `False`, got {rtn_config.use_layer_wise}" + gptq_config = get_default_gptq_config() + assert not gptq_config.use_layer_wise, f"Expect use_layer_wise to be `False`, got {gptq_config.use_layer_wise}" + + @pytest.mark.parametrize("p_type", [None, torch_utils.ProcessorType.Client, torch_utils.ProcessorType.Server]) + def test_get_default_config(self, p_type): + rtn_config = get_default_rtn_config(processor_type=p_type) + assert rtn_config.use_layer_wise == ( + p_type == torch_utils.ProcessorType.Client + ), f"Expect use_layer_wise to be {p_type == torch_utils.ProcessorType.Client}, got {rtn_config.use_layer_wise}" + gptq_config = get_default_gptq_config(processor_type=p_type) + assert gptq_config.use_layer_wise == ( + p_type == torch_utils.ProcessorType.Client + ), f"Expect use_layer_wise to be {p_type == torch_utils.ProcessorType.Client}, got {gptq_config.use_layer_wise}" From ab1a7bd78fd14d4a065151c1be6bbfa99362927a Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Thu, 11 Jul 2024 11:20:53 +0800 Subject: [PATCH 4/6] add autoround Signed-off-by: yiliu30 --- neural_compressor/common/utils/constants.py | 3 ++ neural_compressor/common/utils/utility.py | 53 ++++++++++++++++++- .../torch/quantization/config.py | 17 +++--- neural_compressor/torch/utils/utility.py | 32 +---------- test/3x/common/test_utility.py | 23 +++++++- test/3x/torch/test_config.py | 25 +++------ 6 files changed, 94 insertions(+), 59 deletions(-) diff --git a/neural_compressor/common/utils/constants.py b/neural_compressor/common/utils/constants.py index adf7755003b..76846682fd4 100644 --- a/neural_compressor/common/utils/constants.py +++ b/neural_compressor/common/utils/constants.py @@ -56,3 +56,6 @@ class Mode(Enum): PREPARE = "prepare" CONVERT = "convert" QUANTIZE = "quantize" + + +SERVER_PROCESSOR_BRAND_KEY_WORLD_LST = ["Xeon"] diff --git a/neural_compressor/common/utils/utility.py b/neural_compressor/common/utils/utility.py index 3c2dc18a009..6e525fdae5e 100644 --- a/neural_compressor/common/utils/utility.py +++ b/neural_compressor/common/utils/utility.py @@ -17,6 +17,7 @@ """The utility of common module.""" import collections +import enum import importlib import subprocess import time @@ -25,7 +26,7 @@ import cpuinfo import psutil -from neural_compressor.common.utils import Mode, TuningLogger, logger +from neural_compressor.common.utils import Mode, TuningLogger, constants, logger __all__ = [ "set_workspace", @@ -40,6 +41,8 @@ "default_tuning_logger", "call_counter", "cpu_info", + "ProcessorType", + "detect_processor_type_based_on_hw", ] @@ -112,11 +115,22 @@ def __init__(self): ) self._bf16 = bool(eax & (1 << 5)) self._info = info + self._brand_raw = info.get("brand_raw", "") # detect the below info when needed self._cores = None self._sockets = None self._cores_per_socket = None + @property + def brand_raw(self): + """Get the brand name of the CPU.""" + return self._brand_raw + + @brand_raw.setter + def brand_raw(self, brand_name): + """Set the brand name of the CPU.""" + self._brand_raw = brand_name + @staticmethod def _detect_cores(): physical_cores = psutil.cpu_count(logical=False) @@ -301,3 +315,40 @@ def wrapper(*args, **kwargs): return func(*args, **kwargs) return wrapper + + +class ProcessorType(enum.Enum): + Client = "Client" + Server = "Server" + + +def detect_processor_type_based_on_hw(): + """Detects the processor type based on the hardware configuration. + + Returns: + ProcessorType: The detected processor type (Server or Client). + """ + # Detect the processor type based on below conditions: + # If there are more than one sockets, it is a server. + # If the brand name includes key word in `SERVER_PROCESSOR_BRAND_KEY_WORLD_LST`, it is a server. + # If the memory size is greater than 32GB, it is a server. + log_mgs = "Processor type detected as {processor_type} due to {reason}." + if cpu_info.sockets > 1: + logger.info(log_mgs.format(processor_type=ProcessorType.Server.value, reason="there are more than one sockets")) + return ProcessorType.Server + elif any(brand in cpu_info.brand_raw for brand in constants.SERVER_PROCESSOR_BRAND_KEY_WORLD_LST): + logger.info( + log_mgs.format(processor_type=ProcessorType.Server.value, reason=f"the brand name is {cpu_info.brand_raw}.") + ) + return ProcessorType.Server + elif psutil.virtual_memory().total / (1024**3) > 32: + logger.info( + log_mgs.format(processor_type=ProcessorType.Server.value, reason="the memory size is greater than 32GB") + ) + return ProcessorType.Server + else: + logger.info( + "Processor type detected as %s, pass `processor_type='server'` to override it if needed.", + ProcessorType.Client.value, + ) + return ProcessorType.Client diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index c5e6c181e5b..9014f1576a3 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -732,6 +732,7 @@ def __init__( not_use_best_mse: bool = False, dynamic_max_gap: int = -1, scale_dtype: str = "fp16", + use_layer_wise: bool = False, white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST, ): """Init AUTOROUND weight-only quantization config. @@ -784,6 +785,7 @@ def __init__( self.not_use_best_mse = not_use_best_mse self.dynamic_max_gap = dynamic_max_gap self.scale_dtype = scale_dtype + self.use_layer_wise = use_layer_wise self._post_init() @classmethod @@ -810,14 +812,17 @@ def get_config_set_for_tuning(cls) -> Union[None, "AutoRoundConfig", List["AutoR # TODO fwk owner needs to update it. return AutoRoundConfig(bits=[4, 6]) + @classmethod + def get_predefined_configs(cls) -> Dict[torch_utils.ProcessorType, "AutoRoundConfig"]: + pre_defined_configs: Dict[torch_utils.ProcessorType, AutoRoundConfig] = {} + pre_defined_configs[torch_utils.ProcessorType.Client] = cls(use_layer_wise=True) + pre_defined_configs[torch_utils.ProcessorType.Server] = cls() + return pre_defined_configs -def get_default_AutoRound_config() -> AutoRoundConfig: - """Generate the default AUTOROUND config. - Returns: - the default AUTOROUND config. - """ - return AutoRoundConfig() +def get_default_AutoRound_config(processor_type: Optional[Union[str, torch_utils.ProcessorType]] = None) -> RTNConfig: + process_type = torch_utils.get_processor_type_from_user_config(processor_type) + return AutoRoundConfig.get_predefined_configs()[process_type] ######################## MX Config ############################### diff --git a/neural_compressor/torch/utils/utility.py b/neural_compressor/torch/utils/utility.py index dee660d0de7..912ded3801c 100644 --- a/neural_compressor/torch/utils/utility.py +++ b/neural_compressor/torch/utils/utility.py @@ -21,7 +21,7 @@ from prettytable import PrettyTable from typing_extensions import TypeAlias -from neural_compressor.common.utils import LazyImport, Mode, cpu_info, logger +from neural_compressor.common.utils import Mode, ProcessorType, cpu_info, detect_processor_type_based_on_hw, logger OP_NAME_AND_TYPE_TUPLE_TYPE: TypeAlias = Tuple[str, Union[torch.nn.Module, Callable]] @@ -282,36 +282,6 @@ def get_model_device(model: torch.nn.Module): return p.data.device.type # p.data.device == device(type='cpu') -class ProcessorType(enum.Enum): - Client = "Client" - Server = "Server" - - -def detect_processor_type_based_on_hw(): - """Detects the processor type based on the hardware configuration. - - Returns: - ProcessorType: The detected processor type (Server or Client). - """ - # Detect the processor type based on below conditions: - # 1. If there are more than one sockets, it is a server. - # 2. If the memory size is greater than 64GB, it is a server. - log_mgs = "Processor type detected as {processor_type} due to {reason}." - if cpu_info.sockets > 1: - logger.info(log_mgs.format(processor_type=ProcessorType.Server.value, reason="there are more than one sockets")) - return ProcessorType.Server - elif psutil.virtual_memory().total / (1024**3) > 64: - logger.info( - log_mgs.format(processor_type=ProcessorType.Server.value, reason="the memory size is greater than 64GB") - ) - return ProcessorType.Server - else: - logger.info( - f"Processor type detected as {ProcessorType.Client.value}, pass `processor_type='server'` to override it if needed." - ) - return ProcessorType.Client - - def get_processor_type_from_user_config(user_processor_type: Optional[Union[str, ProcessorType]] = None): """Get the processor type. diff --git a/test/3x/common/test_utility.py b/test/3x/common/test_utility.py index 527f74a4a13..db8b2e749ad 100644 --- a/test/3x/common/test_utility.py +++ b/test/3x/common/test_utility.py @@ -11,6 +11,8 @@ import unittest from unittest.mock import MagicMock, patch +import pytest + import neural_compressor.common.utils.utility as inc_utils from neural_compressor.common import options from neural_compressor.common.utils import ( @@ -188,5 +190,22 @@ def add(a, b): self.assertEqual(inc_utils.FUNC_CALL_COUNTS["add"], 3) -if __name__ == "__main__": - unittest.main() +class TestAutoDetectProcessorType: + @pytest.fixture + def force_client(self, monkeypatch): + monkeypatch.setattr(inc_utils.cpu_info, "sockets", 1) + monkeypatch.setattr(inc_utils.cpu_info, "brand_raw", "") + + # force the ram size detected by psutil <= 64GB + class MockMemory: + def __init__(self, total): + self.total = total + + # Patch the psutil.virtual_memory() method + monkeypatch.setattr(inc_utils.psutil, "virtual_memory", lambda: MockMemory(16 * 1024**3)) + + def test_auto_detect_processor_type(self, force_client): + p_type = inc_utils.detect_processor_type_based_on_hw() + assert ( + p_type == inc_utils.ProcessorType.Client + ), f"Expect processor type to be {inc_utils.ProcessorType.Client}, got {p_type}" diff --git a/test/3x/torch/test_config.py b/test/3x/torch/test_config.py index 87ce1a27fb0..68e7d5975cc 100644 --- a/test/3x/torch/test_config.py +++ b/test/3x/torch/test_config.py @@ -15,6 +15,7 @@ SmoothQuantConfig, StaticQuantConfig, TEQConfig, + get_default_AutoRound_config, get_default_gptq_config, get_default_hqq_config, get_default_rtn_config, @@ -336,7 +337,7 @@ def test_hqq_config(self): class TestQuantConfigBasedonProcessorType: - @pytest.mark.parametrize("config_cls", [RTNConfig, GPTQConfig]) + @pytest.mark.parametrize("config_cls", [RTNConfig, GPTQConfig, AutoRoundConfig]) def test_get_config_based_on_processor_type(self, config_cls): config_for_client = config_cls.get_predefined_configs()[torch_utils.ProcessorType.Client] assert ( @@ -348,24 +349,6 @@ def test_get_config_based_on_processor_type(self, config_cls): config_for_server.use_layer_wise is False ), f"Expect use_layer_wise to be False, got {config_for_server.use_layer_wise}" - @pytest.fixture - def force_client(self, monkeypatch): - monkeypatch.setattr(torch_utils.utility.cpu_info, "sockets", 1) - - # force the ram size detected by psutil <= 64GB - class MockMemory: - def __init__(self, total): - self.total = total - - # Patch the psutil.virtual_memory() method - monkeypatch.setattr(torch_utils.utility.psutil, "virtual_memory", lambda: MockMemory(16 * 1024**3)) - - def test_auto_detect_processor_type(self, force_client): - p_type = torch_utils.detect_processor_type_based_on_hw() - assert ( - p_type == torch_utils.ProcessorType.Client - ), f"Expect processor type to be {torch_utils.ProcessorType.Client}, got {p_type}" - @pytest.fixture def force_server(self, monkeypatch): monkeypatch.setattr(torch_utils.utility.cpu_info, "sockets", 2) @@ -386,3 +369,7 @@ def test_get_default_config(self, p_type): assert gptq_config.use_layer_wise == ( p_type == torch_utils.ProcessorType.Client ), f"Expect use_layer_wise to be {p_type == torch_utils.ProcessorType.Client}, got {gptq_config.use_layer_wise}" + autoround_config = get_default_AutoRound_config(processor_type=p_type) + assert autoround_config.use_layer_wise == ( + p_type == torch_utils.ProcessorType.Client + ), f"Expect use_layer_wise to be {p_type == torch_utils.ProcessorType.Client}, got {autoround_config.use_layer_wise}" From ccd3ca7c52b8f50afe127e4718a4a00a5b1bfda1 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Thu, 11 Jul 2024 15:06:12 +0800 Subject: [PATCH 5/6] fix ut Signed-off-by: yiliu30 --- test/3x/common/test_utility.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/3x/common/test_utility.py b/test/3x/common/test_utility.py index 4f08fbd77e5..cb4596889f3 100644 --- a/test/3x/common/test_utility.py +++ b/test/3x/common/test_utility.py @@ -43,7 +43,7 @@ def test_set_random_seed(self): set_random_seed(seed) def test_set_workspace(self): - workspace = "/path/to/workspace" + workspace = "/tmp/inc_workspace" set_workspace(workspace) self.assertEqual(options.workspace, workspace) returned_workspace = get_workspace() From b20dace63badb85c745d0719c1d810e7518033bd Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Thu, 11 Jul 2024 16:16:33 +0800 Subject: [PATCH 6/6] add more UTs Signed-off-by: yiliu30 --- test/3x/common/test_utility.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/test/3x/common/test_utility.py b/test/3x/common/test_utility.py index cb4596889f3..fd349ce1706 100644 --- a/test/3x/common/test_utility.py +++ b/test/3x/common/test_utility.py @@ -80,6 +80,9 @@ def test_cpu_info(self): cpu_info = CpuInfo() assert isinstance(cpu_info.bf16, bool), "bf16 should be a boolean" assert isinstance(cpu_info.vnni, bool), "avx512 should be a boolean" + assert cpu_info.cores >= 1 + assert cpu_info.sockets >= 1 + assert cpu_info.cores_per_socket >= 1 class TestLazyImport(unittest.TestCase): @@ -115,6 +118,11 @@ def test_lazy_import_access_attr(self): self.assertIsNotNone(lazy_import.module) + def test_call_method_module_not_found(self): + with self.assertRaises(ImportError): + lazy_import = LazyImport("non_existent_module") + lazy_import(3, 4) + class TestUtils(unittest.TestCase): def test_dump_elapsed_time(self): @@ -211,3 +219,20 @@ def test_auto_detect_processor_type(self, force_client): assert ( p_type == inc_utils.ProcessorType.Client ), f"Expect processor type to be {inc_utils.ProcessorType.Client}, got {p_type}" + + def test_detect_processor_type_based_on_hw(self): + # Test when the brand name includes a server keyword + inc_utils.cpu_info.brand_raw = "Intel Xeon Server" + assert inc_utils.detect_processor_type_based_on_hw() == inc_utils.ProcessorType.Server + + # Test when the memory size is greater than 32GB + with patch("psutil.virtual_memory") as mock_virtual_memory: + mock_virtual_memory.return_value.total = 64 * 1024**3 + assert inc_utils.detect_processor_type_based_on_hw() == inc_utils.ProcessorType.Server + + # Test when none of the conditions are met + inc_utils.cpu_info.sockets = 1 + inc_utils.cpu_info.brand_raw = "Intel Core i7" + with patch("psutil.virtual_memory") as mock_virtual_memory: + mock_virtual_memory.return_value.total = 16 * 1024**3 + assert inc_utils.detect_processor_type_based_on_hw() == inc_utils.ProcessorType.Client