17
17
"""The utility of common module."""
18
18
19
19
import collections
20
+ import enum
20
21
import importlib
21
22
import subprocess
22
23
import time
26
27
import psutil
27
28
from prettytable import PrettyTable
28
29
29
- from neural_compressor .common .utils import Mode , TuningLogger , logger
30
+ from neural_compressor .common .utils import Mode , TuningLogger , constants , logger
30
31
31
32
__all__ = [
32
33
"set_workspace" ,
41
42
"CpuInfo" ,
42
43
"default_tuning_logger" ,
43
44
"call_counter" ,
45
+ "cpu_info" ,
46
+ "ProcessorType" ,
47
+ "detect_processor_type_based_on_hw" ,
44
48
"Statistics" ,
45
49
]
46
50
@@ -92,7 +96,7 @@ def __call__(self, *args, **kwargs):
92
96
93
97
@singleton
94
98
class CpuInfo (object ):
95
- """CPU info collection ."""
99
+ """Get CPU Info ."""
96
100
97
101
def __init__ (self ):
98
102
"""Get whether the cpu numerical format is bf16, the number of sockets, cores and cores per socket."""
@@ -113,6 +117,39 @@ def __init__(self):
113
117
b"\xB8 \x07 \x00 \x00 \x00 " b"\x0f \xa2 " b"\xC3 " , # mov eax, 7 # cpuid # ret
114
118
)
115
119
self ._bf16 = bool (eax & (1 << 5 ))
120
+ self ._info = info
121
+ self ._brand_raw = info .get ("brand_raw" , "" )
122
+ # detect the below info when needed
123
+ self ._cores = None
124
+ self ._sockets = None
125
+ self ._cores_per_socket = None
126
+
127
+ @property
128
+ def brand_raw (self ):
129
+ """Get the brand name of the CPU."""
130
+ return self ._brand_raw
131
+
132
+ @brand_raw .setter
133
+ def brand_raw (self , brand_name ):
134
+ """Set the brand name of the CPU."""
135
+ self ._brand_raw = brand_name
136
+
137
+ @staticmethod
138
+ def _detect_cores ():
139
+ physical_cores = psutil .cpu_count (logical = False )
140
+ return physical_cores
141
+
142
+ @property
143
+ def cores (self ):
144
+ """Get the number of cores in platform."""
145
+ if self ._cores is None :
146
+ self ._cores = self ._detect_cores ()
147
+ return self ._cores
148
+
149
+ @cores .setter
150
+ def cores (self , num_of_cores ):
151
+ """Set the number of cores in platform."""
152
+ self ._cores = num_of_cores
116
153
117
154
@property
118
155
def bf16 (self ):
@@ -124,6 +161,60 @@ def vnni(self):
124
161
"""Get whether it is vnni."""
125
162
return self ._vnni
126
163
164
+ @property
165
+ def cores_per_socket (self ) -> int :
166
+ """Get the cores per socket."""
167
+ if self ._cores_per_socket is None :
168
+ self ._cores_per_socket = self .cores // self .sockets
169
+ return self ._cores_per_socket
170
+
171
+ @property
172
+ def sockets (self ):
173
+ """Get the number of sockets in platform."""
174
+ if self ._sockets is None :
175
+ self ._sockets = self ._get_number_of_sockets ()
176
+ return self ._sockets
177
+
178
+ @sockets .setter
179
+ def sockets (self , num_of_sockets ):
180
+ """Set the number of sockets in platform."""
181
+ self ._sockets = num_of_sockets
182
+
183
+ def _get_number_of_sockets (self ) -> int :
184
+ if "arch" in self ._info and "ARM" in self ._info ["arch" ]: # pragma: no cover
185
+ return 1
186
+
187
+ num_sockets = None
188
+ cmd = "cat /proc/cpuinfo | grep 'physical id' | sort -u | wc -l"
189
+ if psutil .WINDOWS :
190
+ cmd = r'wmic cpu get DeviceID | C:\Windows\System32\find.exe /C "CPU"'
191
+ elif psutil .MACOS : # pragma: no cover
192
+ cmd = "sysctl -n machdep.cpu.core_count"
193
+
194
+ num_sockets = None
195
+ try :
196
+ with subprocess .Popen (
197
+ args = cmd ,
198
+ shell = True ,
199
+ stdout = subprocess .PIPE ,
200
+ stderr = subprocess .STDOUT ,
201
+ universal_newlines = False ,
202
+ ) as proc :
203
+ proc .wait ()
204
+ if proc .stdout :
205
+ for line in proc .stdout :
206
+ num_sockets = int (line .decode ("utf-8" , errors = "ignore" ).strip ())
207
+ except Exception as e :
208
+ logger .error ("Failed to get number of sockets: %s" % e )
209
+ if isinstance (num_sockets , int ) and num_sockets >= 1 :
210
+ return num_sockets
211
+ else :
212
+ logger .warning ("Failed to get number of sockets, return 1 as default." )
213
+ return 1
214
+
215
+
216
+ cpu_info = CpuInfo ()
217
+
127
218
128
219
def dump_elapsed_time (customized_msg = "" ):
129
220
"""Get the elapsed time for decorated functions.
@@ -236,6 +327,43 @@ def wrapper(*args, **kwargs):
236
327
return wrapper
237
328
238
329
330
+ class ProcessorType (enum .Enum ):
331
+ Client = "Client"
332
+ Server = "Server"
333
+
334
+
335
+ def detect_processor_type_based_on_hw ():
336
+ """Detects the processor type based on the hardware configuration.
337
+
338
+ Returns:
339
+ ProcessorType: The detected processor type (Server or Client).
340
+ """
341
+ # Detect the processor type based on below conditions:
342
+ # If there are more than one sockets, it is a server.
343
+ # If the brand name includes key word in `SERVER_PROCESSOR_BRAND_KEY_WORLD_LST`, it is a server.
344
+ # If the memory size is greater than 32GB, it is a server.
345
+ log_mgs = "Processor type detected as {processor_type} due to {reason}."
346
+ if cpu_info .sockets > 1 :
347
+ logger .info (log_mgs .format (processor_type = ProcessorType .Server .value , reason = "there are more than one sockets" ))
348
+ return ProcessorType .Server
349
+ elif any (brand in cpu_info .brand_raw for brand in constants .SERVER_PROCESSOR_BRAND_KEY_WORLD_LST ):
350
+ logger .info (
351
+ log_mgs .format (processor_type = ProcessorType .Server .value , reason = f"the brand name is { cpu_info .brand_raw } ." )
352
+ )
353
+ return ProcessorType .Server
354
+ elif psutil .virtual_memory ().total / (1024 ** 3 ) > 32 :
355
+ logger .info (
356
+ log_mgs .format (processor_type = ProcessorType .Server .value , reason = "the memory size is greater than 32GB" )
357
+ )
358
+ return ProcessorType .Server
359
+ else :
360
+ logger .info (
361
+ "Processor type detected as %s, pass `processor_type='server'` to override it if needed." ,
362
+ ProcessorType .Client .value ,
363
+ )
364
+ return ProcessorType .Client
365
+
366
+
239
367
class Statistics :
240
368
"""The statistics printer."""
241
369
0 commit comments