Skip to content

Commit 16a7b11

Browse files
authored
Get default config based on the auto-detect CPU type (#1904)
Signed-off-by: yiliu30 <[email protected]>
1 parent 2fc7255 commit 16a7b11

File tree

6 files changed

+291
-37
lines changed

6 files changed

+291
-37
lines changed

neural_compressor/common/utils/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,6 @@ class Mode(Enum):
5656
PREPARE = "prepare"
5757
CONVERT = "convert"
5858
QUANTIZE = "quantize"
59+
60+
61+
SERVER_PROCESSOR_BRAND_KEY_WORLD_LST = ["Xeon"]

neural_compressor/common/utils/utility.py

Lines changed: 130 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"""The utility of common module."""
1818

1919
import collections
20+
import enum
2021
import importlib
2122
import subprocess
2223
import time
@@ -26,7 +27,7 @@
2627
import psutil
2728
from prettytable import PrettyTable
2829

29-
from neural_compressor.common.utils import Mode, TuningLogger, logger
30+
from neural_compressor.common.utils import Mode, TuningLogger, constants, logger
3031

3132
__all__ = [
3233
"set_workspace",
@@ -41,6 +42,9 @@
4142
"CpuInfo",
4243
"default_tuning_logger",
4344
"call_counter",
45+
"cpu_info",
46+
"ProcessorType",
47+
"detect_processor_type_based_on_hw",
4448
"Statistics",
4549
]
4650

@@ -92,7 +96,7 @@ def __call__(self, *args, **kwargs):
9296

9397
@singleton
9498
class CpuInfo(object):
95-
"""CPU info collection."""
99+
"""Get CPU Info."""
96100

97101
def __init__(self):
98102
"""Get whether the cpu numerical format is bf16, the number of sockets, cores and cores per socket."""
@@ -113,6 +117,39 @@ def __init__(self):
113117
b"\xB8\x07\x00\x00\x00" b"\x0f\xa2" b"\xC3", # mov eax, 7 # cpuid # ret
114118
)
115119
self._bf16 = bool(eax & (1 << 5))
120+
self._info = info
121+
self._brand_raw = info.get("brand_raw", "")
122+
# detect the below info when needed
123+
self._cores = None
124+
self._sockets = None
125+
self._cores_per_socket = None
126+
127+
@property
128+
def brand_raw(self):
129+
"""Get the brand name of the CPU."""
130+
return self._brand_raw
131+
132+
@brand_raw.setter
133+
def brand_raw(self, brand_name):
134+
"""Set the brand name of the CPU."""
135+
self._brand_raw = brand_name
136+
137+
@staticmethod
138+
def _detect_cores():
139+
physical_cores = psutil.cpu_count(logical=False)
140+
return physical_cores
141+
142+
@property
143+
def cores(self):
144+
"""Get the number of cores in platform."""
145+
if self._cores is None:
146+
self._cores = self._detect_cores()
147+
return self._cores
148+
149+
@cores.setter
150+
def cores(self, num_of_cores):
151+
"""Set the number of cores in platform."""
152+
self._cores = num_of_cores
116153

117154
@property
118155
def bf16(self):
@@ -124,6 +161,60 @@ def vnni(self):
124161
"""Get whether it is vnni."""
125162
return self._vnni
126163

164+
@property
165+
def cores_per_socket(self) -> int:
166+
"""Get the cores per socket."""
167+
if self._cores_per_socket is None:
168+
self._cores_per_socket = self.cores // self.sockets
169+
return self._cores_per_socket
170+
171+
@property
172+
def sockets(self):
173+
"""Get the number of sockets in platform."""
174+
if self._sockets is None:
175+
self._sockets = self._get_number_of_sockets()
176+
return self._sockets
177+
178+
@sockets.setter
179+
def sockets(self, num_of_sockets):
180+
"""Set the number of sockets in platform."""
181+
self._sockets = num_of_sockets
182+
183+
def _get_number_of_sockets(self) -> int:
184+
if "arch" in self._info and "ARM" in self._info["arch"]: # pragma: no cover
185+
return 1
186+
187+
num_sockets = None
188+
cmd = "cat /proc/cpuinfo | grep 'physical id' | sort -u | wc -l"
189+
if psutil.WINDOWS:
190+
cmd = r'wmic cpu get DeviceID | C:\Windows\System32\find.exe /C "CPU"'
191+
elif psutil.MACOS: # pragma: no cover
192+
cmd = "sysctl -n machdep.cpu.core_count"
193+
194+
num_sockets = None
195+
try:
196+
with subprocess.Popen(
197+
args=cmd,
198+
shell=True,
199+
stdout=subprocess.PIPE,
200+
stderr=subprocess.STDOUT,
201+
universal_newlines=False,
202+
) as proc:
203+
proc.wait()
204+
if proc.stdout:
205+
for line in proc.stdout:
206+
num_sockets = int(line.decode("utf-8", errors="ignore").strip())
207+
except Exception as e:
208+
logger.error("Failed to get number of sockets: %s" % e)
209+
if isinstance(num_sockets, int) and num_sockets >= 1:
210+
return num_sockets
211+
else:
212+
logger.warning("Failed to get number of sockets, return 1 as default.")
213+
return 1
214+
215+
216+
cpu_info = CpuInfo()
217+
127218

128219
def dump_elapsed_time(customized_msg=""):
129220
"""Get the elapsed time for decorated functions.
@@ -236,6 +327,43 @@ def wrapper(*args, **kwargs):
236327
return wrapper
237328

238329

330+
class ProcessorType(enum.Enum):
331+
Client = "Client"
332+
Server = "Server"
333+
334+
335+
def detect_processor_type_based_on_hw():
336+
"""Detects the processor type based on the hardware configuration.
337+
338+
Returns:
339+
ProcessorType: The detected processor type (Server or Client).
340+
"""
341+
# Detect the processor type based on below conditions:
342+
# If there are more than one sockets, it is a server.
343+
# If the brand name includes key word in `SERVER_PROCESSOR_BRAND_KEY_WORLD_LST`, it is a server.
344+
# If the memory size is greater than 32GB, it is a server.
345+
log_mgs = "Processor type detected as {processor_type} due to {reason}."
346+
if cpu_info.sockets > 1:
347+
logger.info(log_mgs.format(processor_type=ProcessorType.Server.value, reason="there are more than one sockets"))
348+
return ProcessorType.Server
349+
elif any(brand in cpu_info.brand_raw for brand in constants.SERVER_PROCESSOR_BRAND_KEY_WORLD_LST):
350+
logger.info(
351+
log_mgs.format(processor_type=ProcessorType.Server.value, reason=f"the brand name is {cpu_info.brand_raw}.")
352+
)
353+
return ProcessorType.Server
354+
elif psutil.virtual_memory().total / (1024**3) > 32:
355+
logger.info(
356+
log_mgs.format(processor_type=ProcessorType.Server.value, reason="the memory size is greater than 32GB")
357+
)
358+
return ProcessorType.Server
359+
else:
360+
logger.info(
361+
"Processor type detected as %s, pass `processor_type='server'` to override it if needed.",
362+
ProcessorType.Client.value,
363+
)
364+
return ProcessorType.Client
365+
366+
239367
class Statistics:
240368
"""The statistics printer."""
241369

neural_compressor/torch/quantization/config.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import torch
2525

26+
import neural_compressor.torch.utils as torch_utils
2627
from neural_compressor.common.base_config import (
2728
BaseConfig,
2829
config_registry,
@@ -219,14 +220,17 @@ def get_config_set_for_tuning(cls) -> Union[None, "RTNConfig", List["RTNConfig"]
219220
dtype=["int4", "nf4"], use_sym=[True, False], group_size=[32, 128], use_mse_search=[False, True]
220221
)
221222

223+
@classmethod
224+
def get_predefined_configs(cls) -> Dict[torch_utils.ProcessorType, "RTNConfig"]:
225+
pre_defined_configs: Dict[torch_utils.ProcessorType, RTNConfig] = {}
226+
pre_defined_configs[torch_utils.ProcessorType.Client] = cls(use_layer_wise=True)
227+
pre_defined_configs[torch_utils.ProcessorType.Server] = cls()
228+
return pre_defined_configs
222229

223-
def get_default_rtn_config() -> RTNConfig:
224-
"""Generate the default rtn config.
225230

226-
Returns:
227-
the default rtn config.
228-
"""
229-
return RTNConfig()
231+
def get_default_rtn_config(processor_type: Optional[Union[str, torch_utils.ProcessorType]] = None) -> RTNConfig:
232+
process_type = torch_utils.get_processor_type_from_user_config(processor_type)
233+
return RTNConfig.get_predefined_configs()[process_type]
230234

231235

232236
def get_default_double_quant_config(type="BNB_NF4"):
@@ -378,14 +382,17 @@ def get_config_set_for_tuning(cls) -> Union[None, "GPTQConfig", List["GPTQConfig
378382
# TODO fwk owner needs to update it.
379383
return GPTQConfig(act_order=[True, False], use_sym=[False, True])
380384

385+
@classmethod
386+
def get_predefined_configs(cls) -> Dict[torch_utils.ProcessorType, "GPTQConfig"]:
387+
pre_defined_configs: Dict[torch_utils.ProcessorType, GPTQConfig] = {}
388+
pre_defined_configs[torch_utils.ProcessorType.Client] = cls(use_layer_wise=True)
389+
pre_defined_configs[torch_utils.ProcessorType.Server] = cls()
390+
return pre_defined_configs
381391

382-
def get_default_gptq_config() -> GPTQConfig:
383-
"""Generate the default gptq config.
384392

385-
Returns:
386-
the default gptq config.
387-
"""
388-
return GPTQConfig()
393+
def get_default_gptq_config(processor_type: Optional[Union[str, torch_utils.ProcessorType]] = None) -> RTNConfig:
394+
process_type = torch_utils.get_processor_type_from_user_config(processor_type)
395+
return GPTQConfig.get_predefined_configs()[process_type]
389396

390397

391398
######################## AWQ Config ###############################
@@ -725,6 +732,7 @@ def __init__(
725732
not_use_best_mse: bool = False,
726733
dynamic_max_gap: int = -1,
727734
scale_dtype: str = "fp16",
735+
use_layer_wise: bool = False,
728736
white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST,
729737
):
730738
"""Init AUTOROUND weight-only quantization config.
@@ -777,6 +785,7 @@ def __init__(
777785
self.not_use_best_mse = not_use_best_mse
778786
self.dynamic_max_gap = dynamic_max_gap
779787
self.scale_dtype = scale_dtype
788+
self.use_layer_wise = use_layer_wise
780789
self._post_init()
781790

782791
@classmethod
@@ -803,14 +812,17 @@ def get_config_set_for_tuning(cls) -> Union[None, "AutoRoundConfig", List["AutoR
803812
# TODO fwk owner needs to update it.
804813
return AutoRoundConfig(bits=[4, 6])
805814

815+
@classmethod
816+
def get_predefined_configs(cls) -> Dict[torch_utils.ProcessorType, "AutoRoundConfig"]:
817+
pre_defined_configs: Dict[torch_utils.ProcessorType, AutoRoundConfig] = {}
818+
pre_defined_configs[torch_utils.ProcessorType.Client] = cls(use_layer_wise=True)
819+
pre_defined_configs[torch_utils.ProcessorType.Server] = cls()
820+
return pre_defined_configs
806821

807-
def get_default_AutoRound_config() -> AutoRoundConfig:
808-
"""Generate the default AUTOROUND config.
809822

810-
Returns:
811-
the default AUTOROUND config.
812-
"""
813-
return AutoRoundConfig()
823+
def get_default_AutoRound_config(processor_type: Optional[Union[str, torch_utils.ProcessorType]] = None) -> RTNConfig:
824+
process_type = torch_utils.get_processor_type_from_user_config(processor_type)
825+
return AutoRoundConfig.get_predefined_configs()[process_type]
814826

815827

816828
######################## MX Config ###############################

neural_compressor/torch/utils/utility.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,21 @@
1313
# limitations under the License.
1414

1515

16-
from typing import Callable, Dict, List, Tuple, Union
16+
import enum
17+
from typing import Callable, Dict, List, Optional, Tuple, Union
1718

19+
import psutil
1820
import torch
1921
from typing_extensions import TypeAlias
2022

21-
from neural_compressor.common.utils import Mode, Statistics, logger
23+
from neural_compressor.common.utils import (
24+
Mode,
25+
ProcessorType,
26+
Statistics,
27+
cpu_info,
28+
detect_processor_type_based_on_hw,
29+
logger,
30+
)
2231

2332
OP_NAME_AND_TYPE_TUPLE_TYPE: TypeAlias = Tuple[str, Union[torch.nn.Module, Callable]]
2433

@@ -235,3 +244,31 @@ def get_model_device(model: torch.nn.Module):
235244
"""
236245
for n, p in model.named_parameters():
237246
return p.data.device.type # p.data.device == device(type='cpu')
247+
248+
249+
def get_processor_type_from_user_config(user_processor_type: Optional[Union[str, ProcessorType]] = None):
250+
"""Get the processor type.
251+
252+
Get the processor type based on the user configuration or automatically detect it based on the hardware.
253+
254+
Args:
255+
user_processor_type (Optional[Union[str, ProcessorType]]): The user-specified processor type. Defaults to None.
256+
257+
Returns:
258+
ProcessorType: The detected or user-specified processor type.
259+
260+
Raises:
261+
AssertionError: If the user-specified processor type is not supported.
262+
NotImplementedError: If the processor type is not recognized.
263+
"""
264+
if user_processor_type is None:
265+
processor_type = detect_processor_type_based_on_hw()
266+
elif isinstance(user_processor_type, ProcessorType):
267+
processor_type = user_processor_type
268+
elif isinstance(user_processor_type, str):
269+
user_processor_type = user_processor_type.lower().capitalize()
270+
assert user_processor_type in ProcessorType.__members__, f"Unsupported processor type: {user_processor_type}"
271+
processor_type = ProcessorType(user_processor_type)
272+
else:
273+
raise NotImplementedError(f"Unsupported processor type: {user_processor_type}")
274+
return processor_type

0 commit comments

Comments
 (0)