Skip to content

Commit 76b8b3f

Browse files
yiliu30chensuyue
andauthored
Add white_list to config (#1430)
* add white list --------- Signed-off-by: yiliu30 <[email protected]> Signed-off-by: chensuyue <[email protected]> Co-authored-by: chensuyue <[email protected]>
1 parent 464af67 commit 76b8b3f

File tree

11 files changed

+109
-90
lines changed

11 files changed

+109
-90
lines changed

.azure-pipelines/model-test.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@ pr:
1111
- neural_compressor
1212
- setup.py
1313
- requirements.txt
14-
- .azure-pipelines/model-test.yml
1514
- .azure-pipelines/scripts/models
1615
- examples/tensorflow/oob_models/quantization/ptq
1716
exclude:
1817
- test
1918
- neural_compressor/common
2019
- neural_compressor/torch
20+
- neural_compressor/tensorflow
2121

2222
pool: MODEL_PERF_TEST_TF
2323

.azure-pipelines/ut-basic-no-cover.yml

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ pr:
1717
- test/3x
1818
- neural_compressor/common
1919
- neural_compressor/torch
20+
- neural_compressor/tensorflow
2021

2122
pool: ICX-16C
2223

.azure-pipelines/ut-basic.yml

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ pr:
1717
- test/3x
1818
- neural_compressor/common
1919
- neural_compressor/torch
20+
- neural_compressor/tensorflow
2021

2122
pool: ICX-16C
2223

.azure-pipelines/ut-itrex.yml

+4-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@ pr:
1212
- setup.py
1313
- requirements.txt
1414
- .azure-pipelines/scripts/ut/run_itrex.sh
15-
- .azure-pipelines/ut-itrex.yml
15+
exclude:
16+
- neural_compressor/common
17+
- neural_compressor/torch
18+
- neural_compressor/tensorflow
1619

1720
pool: MODEL_PERF_TEST
1821

neural_compressor/common/base_config.py

+48-11
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,15 @@
2323
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2424

2525
from neural_compressor.common.logger import Logger
26-
from neural_compressor.common.utility import BASE_CONFIG, COMPOSABLE_CONFIG, GLOBAL, LOCAL
26+
from neural_compressor.common.utility import (
27+
BASE_CONFIG,
28+
COMPOSABLE_CONFIG,
29+
DEFAULT_WHITE_LIST,
30+
EMPTY_WHITE_LIST,
31+
GLOBAL,
32+
LOCAL,
33+
OP_NAME_OR_MODULE_TYPE,
34+
)
2735

2836
logger = Logger().get_logger()
2937

@@ -59,18 +67,43 @@ class BaseConfig(ABC):
5967
"""The base config for all algorithm configs."""
6068

6169
name = BASE_CONFIG
70+
params_list = []
6271

63-
def __init__(self) -> None:
72+
def __init__(self, white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST) -> None:
6473
self._global_config: Optional[BaseConfig] = None
6574
# For PyTorch, operator_type is the collective name for module type and functional operation type,
6675
# for example, `torch.nn.Linear`, and `torch.nn.functional.linear`.
6776
# local config is the collections of operator_type configs and operator configs
6877
self._local_config: Dict[str, Optional[BaseConfig]] = {}
78+
self._white_list = white_list
79+
80+
def _post_init(self):
81+
if self.white_list == DEFAULT_WHITE_LIST:
82+
global_config = self.get_params_dict()
83+
self._global_config = self.__class__(**global_config, white_list=None)
84+
elif isinstance(self.white_list, list) and len(self.white_list) > 0:
85+
for op_name_or_type in self.white_list:
86+
global_config = self.get_params_dict()
87+
tmp_config = self.__class__(**global_config, white_list=None)
88+
self.set_local(op_name_or_type, tmp_config)
89+
elif self.white_list == EMPTY_WHITE_LIST:
90+
return
91+
else:
92+
raise NotImplementedError(
93+
f"The white list should be one of {DEFAULT_WHITE_LIST}, {EMPTY_WHITE_LIST},"
94+
" a not empty list, but got {self.white_list}"
95+
)
96+
97+
@property
98+
def white_list(self):
99+
return self._white_list
100+
101+
@white_list.setter
102+
def white_list(self, op_name_or_type_list: Optional[List[OP_NAME_OR_MODULE_TYPE]]):
103+
self._white_list = op_name_or_type_list
69104

70105
@property
71106
def global_config(self):
72-
if self._global_config is None:
73-
self._global_config = self.__class__(**self.to_dict())
74107
return self._global_config
75108

76109
@global_config.setter
@@ -88,25 +121,28 @@ def local_config(self, config):
88121
def set_local(self, operator_name: str, config: BaseConfig) -> BaseConfig:
89122
if operator_name in self.local_config:
90123
logger.warning("The configuration for %s has already been set, update it.", operator_name)
91-
if self.global_config is None:
92-
self.global_config = self.__class__(**self.to_dict())
93124
self.local_config[operator_name] = config
94125
return self
95126

96127
def to_dict(self, params_list=[], operator2str=None):
97128
result = {}
98-
global_config = {}
99-
for param in params_list:
100-
global_config[param] = getattr(self, param)
129+
global_config = self.get_params_dict()
101130
if bool(self.local_config):
102131
result[LOCAL] = {}
103132
for op_name, config in self.local_config.items():
104133
result[LOCAL][op_name] = config.to_dict()
105-
result[GLOBAL] = global_config
134+
if self.global_config:
135+
result[GLOBAL] = global_config
106136
else:
107137
result = global_config
108138
return result
109139

140+
def get_params_dict(self):
141+
result = dict()
142+
for param in self.params_list:
143+
result[param] = getattr(self, param)
144+
return result
145+
110146
@classmethod
111147
def from_dict(cls, config_dict, str2operator=None):
112148
"""Construct config from a dict.
@@ -205,7 +241,8 @@ def to_config_mapping(
205241
global_config = config.global_config
206242
op_type_config_dict, op_name_config_dict = config._get_op_name_op_type_config()
207243
for op_name, op_type in model_info:
208-
config_mapping[(op_type, op_name)] = global_config
244+
if self.global_config is not None:
245+
config_mapping[(op_type, op_name)] = global_config
209246
if op_type in op_type_config_dict:
210247
config_mapping[(op_type, op_name)] = op_name_config_dict[op_type]
211248
if op_name in op_name_config_dict:

neural_compressor/common/utility.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,17 @@
2121
# constants for configs
2222
GLOBAL = "global"
2323
LOCAL = "local"
24+
DEFAULT_WHITE_LIST = "*"
25+
EMPTY_WHITE_LIST = None
2426

2527
# config name
2628
BASE_CONFIG = "base_config"
2729
COMPOSABLE_CONFIG = "composable_config"
2830
RTN_WEIGHT_ONLY_QUANT = "rtn_weight_only_quant"
2931
STATIC_QUANT = "static_quant"
3032
GPTQ = "gptq"
31-
DUMMY_CONFIG = "dummy_config"
33+
34+
35+
from typing import Callable, Union
36+
37+
OP_NAME_OR_MODULE_TYPE = Union[str, Callable]

neural_compressor/tensorflow/quantization/config.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@
1818
from __future__ import annotations
1919

2020
from enum import Enum
21-
from typing import Callable, Dict, List, NamedTuple, Union
21+
from typing import Callable, Dict, List, NamedTuple, Optional, Union
2222

2323
import tensorflow as tf
2424

2525
from neural_compressor.common.base_config import BaseConfig, register_config, registered_configs
26-
from neural_compressor.common.utility import STATIC_QUANT
26+
from neural_compressor.common.utility import DEFAULT_WHITE_LIST, OP_NAME_OR_MODULE_TYPE, STATIC_QUANT
2727

2828
FRAMEWORK_NAME = "keras"
2929

@@ -89,6 +89,7 @@ def __init__(
8989
act_dtype: str = "int8",
9090
act_sym: bool = True,
9191
act_granularity: str = "per_tensor",
92+
white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST,
9293
):
9394
"""Init static quantization config.
9495
@@ -100,13 +101,14 @@ def __init__(
100101
act_sym (bool): Indicates whether activations are symmetric, default is True.
101102
act_granularity (str): Calculate tensor-wise scales or channel-wise scales for activations.
102103
"""
103-
super().__init__()
104+
super().__init__(white_list=white_list)
104105
self.weight_dtype = weight_dtype
105106
self.weight_sym = weight_sym
106107
self.weight_granularity = weight_granularity
107108
self.act_dtype = act_dtype
108109
self.act_sym = act_sym
109110
self.act_granularity = act_granularity
111+
self._post_init()
110112

111113
def to_dict(self):
112114
return super().to_dict(params_list=self.params_list, operator2str=operator2str)

neural_compressor/torch/__init__.py

-2
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
quantize,
2020
RTNWeightQuantConfig,
2121
get_default_rtn_config,
22-
DummyConfig,
23-
get_default_dummy_config,
2422
GPTQConfig,
2523
get_default_gptq_config,
2624
)

neural_compressor/torch/quantization/__init__.py

-2
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
from neural_compressor.torch.quantization.config import (
1717
RTNWeightQuantConfig,
1818
get_default_rtn_config,
19-
DummyConfig,
20-
get_default_dummy_config,
2119
GPTQConfig,
2220
get_default_gptq_config,
2321
)

neural_compressor/torch/quantization/config.py

+8-63
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@
1818
from __future__ import annotations
1919

2020
from enum import Enum
21-
from typing import Callable, Dict, List, NamedTuple, Union
21+
from typing import Callable, Dict, List, NamedTuple, Optional, Union
2222

2323
import torch
2424

2525
from neural_compressor.common.base_config import BaseConfig, register_config, registered_configs
26-
from neural_compressor.common.utility import DUMMY_CONFIG, GPTQ, RTN_WEIGHT_ONLY_QUANT
26+
from neural_compressor.common.utility import DEFAULT_WHITE_LIST, GPTQ, OP_NAME_OR_MODULE_TYPE, RTN_WEIGHT_ONLY_QUANT
2727

2828
FRAMEWORK_NAME = "torch"
2929

@@ -87,6 +87,7 @@ def __init__(
8787
double_quant_bits: int = 8,
8888
double_quant_sym: bool = True,
8989
double_quant_group_size: int = 256,
90+
white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST,
9091
):
9192
"""Init RTN weight-only quantization config.
9293
@@ -105,7 +106,7 @@ def __init__(
105106
double_quant_sym (bool): Indicates whether double_quant scale are symmetric, default is True.
106107
double_quant_group_size (int): Size of double_quant groups, default is 32.
107108
"""
108-
super().__init__()
109+
super().__init__(white_list=white_list)
109110
self.weight_bits = weight_bits
110111
self.weight_dtype = weight_dtype
111112
self.weight_group_size = weight_group_size
@@ -119,6 +120,7 @@ def __init__(
119120
self.double_quant_dtype = double_quant_dtype
120121
self.double_quant_sym = double_quant_sym
121122
self.double_quant_group_size = double_quant_group_size
123+
self._post_init()
122124

123125
def to_dict(self):
124126
return super().to_dict(params_list=self.params_list, operator2str=operator2str)
@@ -220,12 +222,13 @@ def __init__(
220222
double_quant_bits: int = 8,
221223
double_quant_sym: bool = True,
222224
double_quant_group_size: int = 256,
225+
white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST,
223226
):
224227
"""Init GPTQ config.
225228
226229
Args:
227230
"""
228-
super().__init__()
231+
super().__init__(white_list=white_list)
229232
self.weight_dtype = weight_dtype
230233
self.weight_bits = weight_bits
231234
self.weight_group_size = weight_group_size
@@ -248,6 +251,7 @@ def __init__(
248251
self.double_quant_dtype = double_quant_dtype
249252
self.double_quant_sym = double_quant_sym
250253
self.double_quant_group_size = double_quant_group_size
254+
self._post_init()
251255

252256
def to_dict(self):
253257
return super().to_dict(params_list=self.params_list, operator2str=operator2str)
@@ -281,65 +285,6 @@ def get_default_gptq_config() -> GPTQConfig:
281285
return GPTQConfig()
282286

283287

284-
######################## Dummy Config ###############################
285-
# TODO (Yi) remove it after finishing the GPTQ config
286-
@register_config(framework_name=FRAMEWORK_NAME, algo_name=DUMMY_CONFIG)
287-
class DummyConfig(BaseConfig):
288-
"""Config class for round-to-nearest weight-only quantization."""
289-
290-
supported_configs: List[OperatorConfig] = []
291-
params_list = ["act_dtype", "weight_dtype", "dummy_attr"]
292-
name = DUMMY_CONFIG
293-
294-
def __init__(
295-
self,
296-
weight_dtype: str = "int",
297-
act_dtype: str = "fp32",
298-
dummy_attr: int = 0,
299-
):
300-
"""Init RTN weight-only quantization config.
301-
302-
Args:
303-
act_dtype (str): Data type for activations, default is "fp32".
304-
weight_dtype (str): Data type for weights, default is "int".
305-
dummy_attr (int): Dummy attribute, default is 0.
306-
"""
307-
super().__init__()
308-
self.act_dtype = act_dtype
309-
self.weight_dtype = weight_dtype
310-
self.dummy_attr = dummy_attr
311-
312-
def to_dict(self):
313-
return super().to_dict(params_list=self.params_list, operator2str=operator2str)
314-
315-
@classmethod
316-
def from_dict(cls, config_dict):
317-
return super(DummyConfig, cls).from_dict(config_dict=config_dict, str2operator=str2operator)
318-
319-
@classmethod
320-
def register_supported_configs(cls) -> List[OperatorConfig]:
321-
supported_configs = []
322-
linear_dummy_config = DummyConfig(
323-
act_dtype=["fp32"],
324-
weight_dtype=["int4", "int8"],
325-
dummy_attr=[1, 2, 3],
326-
)
327-
operators = [torch.nn.Linear, torch.nn.functional.linear]
328-
supported_configs.append(
329-
OperatorConfig(config=linear_dummy_config, operators=operators, backend=Backend.DEFAULT)
330-
)
331-
cls.supported_configs = supported_configs
332-
333-
334-
def get_default_dummy_config() -> DummyConfig:
335-
"""Generate the default dummy config.
336-
337-
Returns:
338-
the default dummy config.
339-
"""
340-
return DummyConfig()
341-
342-
343288
##################### Algo Configs End ###################################
344289

345290

0 commit comments

Comments
 (0)