Skip to content

Commit ac47d9b

Browse files
yiliu30Kaihui-intelchensuyue
authored
Enhance auto-tune module (#1608)
Signed-off-by: Kaihui-intel <[email protected]> Signed-off-by: yiliu30 <[email protected]> Signed-off-by: chensuyue <[email protected]> Co-authored-by: Kaihui-intel <[email protected]> Co-authored-by: chensuyue <[email protected]>
1 parent 191383e commit ac47d9b

File tree

15 files changed

+345
-98
lines changed

15 files changed

+345
-98
lines changed

.azure-pipelines/scripts/codeScan/pylint/pylint.sh

+2-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ pip install torch \
3939
prettytable \
4040
psutil \
4141
py-cpuinfo \
42-
pyyaml
42+
pyyaml \
43+
pydantic \
4344

4445
if [ "${scan_module}" = "neural_solution" ]; then
4546
cd /neural-compressor

neural_compressor/common/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
set_resume_from,
2121
set_workspace,
2222
set_tensorboard,
23+
dump_elapsed_time,
2324
)
2425
from neural_compressor.common.base_config import options
2526

@@ -33,4 +34,5 @@
3334
"set_random_seed",
3435
"set_resume_from",
3536
"set_tensorboard",
37+
"dump_elapsed_time",
3638
]

neural_compressor/common/base_config.py

+44-10
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from __future__ import annotations
1919

20+
import inspect
2021
import json
2122
import re
2223
from abc import ABC, abstractmethod
@@ -25,6 +26,7 @@
2526
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
2627

2728
from neural_compressor.common import Logger
29+
from neural_compressor.common.tuning_param import TuningParam
2830
from neural_compressor.common.utils import (
2931
BASE_CONFIG,
3032
COMPOSABLE_CONFIG,
@@ -295,6 +297,15 @@ def __add__(self, other: BaseConfig) -> BaseConfig:
295297
else:
296298
return ComposableConfig(configs=[self, other])
297299

300+
@staticmethod
301+
def get_the_default_value_of_param(config: BaseConfig, param: str) -> Any:
302+
# Get the signature of the __init__ method
303+
signature = inspect.signature(config.__init__)
304+
305+
# Get the parameters and their default values
306+
parameters = signature.parameters
307+
return parameters.get(param).default
308+
298309
def expand(self) -> List[BaseConfig]:
299310
"""Expand the config.
300311
@@ -331,19 +342,42 @@ def expand(self) -> List[BaseConfig]:
331342
"""
332343
config_list: List[BaseConfig] = []
333344
params_list = self.params_list
334-
params_dict = OrderedDict()
335345
config = self
346+
tuning_param_list = []
347+
not_tuning_param_pair = {} # key is the param name, value is the user specified value
336348
for param in params_list:
337-
param_val = getattr(config, param)
338-
# TODO (Yi) to handle param_val itself is a list
339-
if isinstance(param_val, list):
340-
params_dict[param] = param_val
349+
# Create `TuningParam` for each param
350+
# There are two cases:
351+
# 1. The param is a string.
352+
# 2. The param is a `TuningParam` instance.
353+
if isinstance(param, str):
354+
default_param = self.get_the_default_value_of_param(config, param)
355+
tuning_param = TuningParam(name=param, tunable_type=List[type(default_param)])
356+
elif isinstance(param, TuningParam):
357+
tuning_param = param
341358
else:
342-
params_dict[param] = [param_val]
343-
for params_values in product(*params_dict.values()):
344-
new_config = self.__class__(**dict(zip(params_list, params_values)))
345-
config_list.append(new_config)
346-
logger.info(f"Expanded the {self.__class__.name} and got {len(config_list)} configs.")
359+
raise ValueError(f"Unsupported param type: {param}")
360+
# Assign the options to the `TuningParam` instance
361+
param_val = getattr(config, tuning_param.name)
362+
if param_val is not None:
363+
if tuning_param.is_tunable(param_val):
364+
tuning_param.options = param_val
365+
tuning_param_list.append(tuning_param)
366+
else:
367+
not_tuning_param_pair[tuning_param.name] = param_val
368+
logger.debug("Tuning param list: %s", tuning_param_list)
369+
logger.debug("Not tuning param pair: %s", not_tuning_param_pair)
370+
if len(tuning_param_list) == 0:
371+
config_list = [config]
372+
else:
373+
tuning_param_name_lst = [tuning_param.name for tuning_param in tuning_param_list]
374+
for params_values in product(*[tuning_param.options for tuning_param in tuning_param_list]):
375+
tuning_param_pair = dict(zip(tuning_param_name_lst, params_values))
376+
tmp_params_dict = {**not_tuning_param_pair, **tuning_param_pair}
377+
new_config = self.__class__(**tmp_params_dict)
378+
logger.info(new_config.to_dict())
379+
config_list.append(new_config)
380+
logger.info("Expanded the %s and got %d configs.", self.__class__.name, len(config_list))
347381
return config_list
348382

349383
def _get_op_name_op_type_config(self):

neural_compressor/common/base_tuning.py

+99-41
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import copy
1717
import inspect
1818
import uuid
19-
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
19+
from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Sized, Tuple, Union
2020

2121
from neural_compressor.common import Logger
2222
from neural_compressor.common.base_config import BaseConfig, ComposableConfig
@@ -31,6 +31,10 @@
3131
"TuningMonitor",
3232
"TuningLogger",
3333
"init_tuning",
34+
"Sampler",
35+
"SequentialSampler",
36+
"default_sampler",
37+
"ConfigSet",
3438
]
3539

3640

@@ -123,36 +127,103 @@ def self_check(self) -> None:
123127
evaluator = Evaluator()
124128

125129

126-
class Sampler:
127-
# TODO Separate sorting functionality of `ConfigLoader` into `Sampler` in the follow-up PR.
128-
pass
130+
class ConfigSet:
129131

132+
def __init__(self, config_list: List[BaseConfig]) -> None:
133+
self.config_list = config_list
130134

131-
class ConfigLoader:
132-
def __init__(self, config_set, sampler: Sampler) -> None:
133-
self.config_set = config_set
134-
self.sampler = sampler
135+
def __getitem__(self, index) -> BaseConfig:
136+
assert 0 <= index < len(self.config_list), f"Index {index} out of range."
137+
return self.config_list[index]
135138

136-
@staticmethod
137-
def parse_quant_config(quant_config: BaseConfig) -> List[BaseConfig]:
138-
if isinstance(quant_config, ComposableConfig):
139-
result = []
140-
for q_config in quant_config.config_list:
141-
result += q_config.expand()
142-
return result
139+
def __len__(self) -> int:
140+
return len(self.config_list)
141+
142+
@classmethod
143+
def _from_single_config(cls, config: BaseConfig) -> List[BaseConfig]:
144+
config_list = []
145+
config_list = config.expand()
146+
return config_list
147+
148+
@classmethod
149+
def _from_list_of_configs(cls, fwk_configs: List[BaseConfig]) -> List[BaseConfig]:
150+
config_list = []
151+
for config in fwk_configs:
152+
config_list += cls._from_single_config(config)
153+
return config_list
154+
155+
@classmethod
156+
def generate_config_list(cls, fwk_configs: Union[BaseConfig, List[BaseConfig]]):
157+
# There are several cases for the input `fwk_configs`:
158+
# 1. fwk_configs is a single config
159+
# 2. fwk_configs is a list of configs
160+
# For a single config, we need to check if it can be expanded or not.
161+
config_list = []
162+
if isinstance(fwk_configs, BaseConfig):
163+
config_list = cls._from_single_config(fwk_configs)
164+
elif isinstance(fwk_configs, List):
165+
config_list = cls._from_list_of_configs(fwk_configs)
143166
else:
144-
return quant_config.expand()
167+
raise NotImplementedError(f"Unsupported type {type(fwk_configs)} for fwk_configs.")
168+
return config_list
169+
170+
@classmethod
171+
def from_fwk_configs(cls, fwk_configs: Union[BaseConfig, List[BaseConfig]]) -> "ConfigSet":
172+
"""Create a ConfigSet object from a single config or a list of configs.
173+
174+
Args:
175+
fwk_configs: A single config or a list of configs.
176+
Examples:
177+
1) single config: RTNConfig(weight_group_size=32)
178+
2) single expandable config: RTNConfig(weight_group_size=[32, 64])
179+
3) mixed 1) and 2): [RTNConfig(weight_group_size=32), RTNConfig(weight_group_size=[32, 64])]
180+
181+
Returns:
182+
ConfigSet: A ConfigSet object.
183+
"""
184+
config_list = cls.generate_config_list(fwk_configs)
185+
return cls(config_list)
186+
187+
188+
class Sampler:
189+
def __init__(self, config_source: Optional[ConfigSet]) -> None:
190+
pass
191+
192+
def __iter__(self) -> Iterator[BaseConfig]:
193+
"""Iterate over indices of config set elements."""
194+
raise NotImplementedError
145195

146-
def parse_quant_configs(self) -> List[BaseConfig]:
147-
# TODO (Yi) separate this functionality into `Sampler` in the next PR
148-
quant_config_list = []
149-
for quant_config in self.config_set:
150-
quant_config_list.extend(ConfigLoader.parse_quant_config(quant_config))
151-
return quant_config_list
196+
197+
class SequentialSampler(Sampler):
198+
"""Samples elements sequentially, always in the same order.
199+
200+
Args:
201+
config_source (_ConfigSet): config set to sample from
202+
"""
203+
204+
config_source: Sized
205+
206+
def __init__(self, config_source: Sized) -> None:
207+
self.config_source = config_source
208+
209+
def __iter__(self) -> Iterator[int]:
210+
return iter(range(len(self.config_source)))
211+
212+
def __len__(self) -> int:
213+
return len(self.config_source)
214+
215+
216+
default_sampler = SequentialSampler
217+
218+
219+
class ConfigLoader:
220+
def __init__(self, config_set: ConfigSet, sampler: Sampler = default_sampler) -> None:
221+
self.config_set = ConfigSet.from_fwk_configs(config_set)
222+
self._sampler = sampler(self.config_set)
152223

153224
def __iter__(self) -> Generator[BaseConfig, Any, None]:
154-
for config in self.parse_quant_configs():
155-
yield config
225+
for index in self._sampler:
226+
yield self.config_set[index]
156227

157228

158229
class TuningLogger:
@@ -211,12 +282,14 @@ class TuningConfig:
211282
212283
Args:
213284
config_set: quantization configs. Default value is empty.
214-
timeout: Tuning timeout (seconds). Default value is 0 which means early stop.
285+
A single config or a list of configs. More details can
286+
be found in the `from_fwk_configs`of `ConfigSet` class.
215287
max_trials: Max tuning times. Default value is 100. Combine with timeout field to decide when to exit.
216288
tolerable_loss: This float indicates how much metric loss we can accept. \
217289
The metric loss is relative, it can be both positive and negative. Default is 0.01.
218290
219291
Examples:
292+
# TODO: to refine it
220293
from neural_compressor import TuningConfig
221294
tune_config = TuningConfig(
222295
config_set=[config1, config2, ...],
@@ -239,28 +312,13 @@ class TuningConfig:
239312
# The best tuning config is config2, because of the following:
240313
# 1. Not achieving the set goal. (config_metric < fp32_baseline * (1 - tolerable_loss))
241314
# 2. Reached maximum tuning times.
242-
243-
# Case 3: Timeout
244-
tune_config = TuningConfig(
245-
config_set=[config1, config2, ...],
246-
timeout=10, # seconds
247-
max_trials=3,
248-
tolerable_loss=0.01
249-
)
250-
config1_tuning_time, config2_tuning_time, config3_tuning_time, ... = 4, 5, 6, ... # seconds
251-
fp32_baseline = 100
252-
config1_metric, config2_metric, config3_metric, ... = 98, 98, 97, ...
253-
254-
# Tuning result of case 3:
255-
# The best tuning config is config2, due to timeout, the third trial was forced to exit.
256315
"""
257316

258317
def __init__(
259-
self, config_set=None, timeout=0, max_trials=100, sampler: Sampler = None, tolerable_loss=0.01
318+
self, config_set=None, max_trials=100, sampler: Sampler = default_sampler, tolerable_loss=0.01
260319
) -> None:
261320
"""Init a TuneCriterion object."""
262321
self.config_set = config_set
263-
self.timeout = timeout
264322
self.max_trials = max_trials
265323
self.sampler = sampler
266324
self.tolerable_loss = tolerable_loss

0 commit comments

Comments
 (0)