16
16
import copy
17
17
import inspect
18
18
import uuid
19
- from typing import Any , Callable , Dict , Generator , List , Optional , Tuple , Union
19
+ from typing import Any , Callable , Dict , Generator , Iterator , List , Optional , Sized , Tuple , Union
20
20
21
21
from neural_compressor .common import Logger
22
22
from neural_compressor .common .base_config import BaseConfig , ComposableConfig
31
31
"TuningMonitor" ,
32
32
"TuningLogger" ,
33
33
"init_tuning" ,
34
+ "Sampler" ,
35
+ "SequentialSampler" ,
36
+ "default_sampler" ,
37
+ "ConfigSet" ,
34
38
]
35
39
36
40
@@ -123,36 +127,103 @@ def self_check(self) -> None:
123
127
evaluator = Evaluator ()
124
128
125
129
126
- class Sampler :
127
- # TODO Separate sorting functionality of `ConfigLoader` into `Sampler` in the follow-up PR.
128
- pass
130
+ class ConfigSet :
129
131
132
+ def __init__ (self , config_list : List [BaseConfig ]) -> None :
133
+ self .config_list = config_list
130
134
131
- class ConfigLoader :
132
- def __init__ (self , config_set , sampler : Sampler ) -> None :
133
- self .config_set = config_set
134
- self .sampler = sampler
135
+ def __getitem__ (self , index ) -> BaseConfig :
136
+ assert 0 <= index < len (self .config_list ), f"Index { index } out of range."
137
+ return self .config_list [index ]
135
138
136
- @staticmethod
137
- def parse_quant_config (quant_config : BaseConfig ) -> List [BaseConfig ]:
138
- if isinstance (quant_config , ComposableConfig ):
139
- result = []
140
- for q_config in quant_config .config_list :
141
- result += q_config .expand ()
142
- return result
139
+ def __len__ (self ) -> int :
140
+ return len (self .config_list )
141
+
142
+ @classmethod
143
+ def _from_single_config (cls , config : BaseConfig ) -> List [BaseConfig ]:
144
+ config_list = []
145
+ config_list = config .expand ()
146
+ return config_list
147
+
148
+ @classmethod
149
+ def _from_list_of_configs (cls , fwk_configs : List [BaseConfig ]) -> List [BaseConfig ]:
150
+ config_list = []
151
+ for config in fwk_configs :
152
+ config_list += cls ._from_single_config (config )
153
+ return config_list
154
+
155
+ @classmethod
156
+ def generate_config_list (cls , fwk_configs : Union [BaseConfig , List [BaseConfig ]]):
157
+ # There are several cases for the input `fwk_configs`:
158
+ # 1. fwk_configs is a single config
159
+ # 2. fwk_configs is a list of configs
160
+ # For a single config, we need to check if it can be expanded or not.
161
+ config_list = []
162
+ if isinstance (fwk_configs , BaseConfig ):
163
+ config_list = cls ._from_single_config (fwk_configs )
164
+ elif isinstance (fwk_configs , List ):
165
+ config_list = cls ._from_list_of_configs (fwk_configs )
143
166
else :
144
- return quant_config .expand ()
167
+ raise NotImplementedError (f"Unsupported type { type (fwk_configs )} for fwk_configs." )
168
+ return config_list
169
+
170
+ @classmethod
171
+ def from_fwk_configs (cls , fwk_configs : Union [BaseConfig , List [BaseConfig ]]) -> "ConfigSet" :
172
+ """Create a ConfigSet object from a single config or a list of configs.
173
+
174
+ Args:
175
+ fwk_configs: A single config or a list of configs.
176
+ Examples:
177
+ 1) single config: RTNConfig(weight_group_size=32)
178
+ 2) single expandable config: RTNConfig(weight_group_size=[32, 64])
179
+ 3) mixed 1) and 2): [RTNConfig(weight_group_size=32), RTNConfig(weight_group_size=[32, 64])]
180
+
181
+ Returns:
182
+ ConfigSet: A ConfigSet object.
183
+ """
184
+ config_list = cls .generate_config_list (fwk_configs )
185
+ return cls (config_list )
186
+
187
+
188
+ class Sampler :
189
+ def __init__ (self , config_source : Optional [ConfigSet ]) -> None :
190
+ pass
191
+
192
+ def __iter__ (self ) -> Iterator [BaseConfig ]:
193
+ """Iterate over indices of config set elements."""
194
+ raise NotImplementedError
145
195
146
- def parse_quant_configs (self ) -> List [BaseConfig ]:
147
- # TODO (Yi) separate this functionality into `Sampler` in the next PR
148
- quant_config_list = []
149
- for quant_config in self .config_set :
150
- quant_config_list .extend (ConfigLoader .parse_quant_config (quant_config ))
151
- return quant_config_list
196
+
197
+ class SequentialSampler (Sampler ):
198
+ """Samples elements sequentially, always in the same order.
199
+
200
+ Args:
201
+ config_source (_ConfigSet): config set to sample from
202
+ """
203
+
204
+ config_source : Sized
205
+
206
+ def __init__ (self , config_source : Sized ) -> None :
207
+ self .config_source = config_source
208
+
209
+ def __iter__ (self ) -> Iterator [int ]:
210
+ return iter (range (len (self .config_source )))
211
+
212
+ def __len__ (self ) -> int :
213
+ return len (self .config_source )
214
+
215
+
216
+ default_sampler = SequentialSampler
217
+
218
+
219
+ class ConfigLoader :
220
+ def __init__ (self , config_set : ConfigSet , sampler : Sampler = default_sampler ) -> None :
221
+ self .config_set = ConfigSet .from_fwk_configs (config_set )
222
+ self ._sampler = sampler (self .config_set )
152
223
153
224
def __iter__ (self ) -> Generator [BaseConfig , Any , None ]:
154
- for config in self .parse_quant_configs () :
155
- yield config
225
+ for index in self ._sampler :
226
+ yield self . config_set [ index ]
156
227
157
228
158
229
class TuningLogger :
@@ -211,12 +282,14 @@ class TuningConfig:
211
282
212
283
Args:
213
284
config_set: quantization configs. Default value is empty.
214
- timeout: Tuning timeout (seconds). Default value is 0 which means early stop.
285
+ A single config or a list of configs. More details can
286
+ be found in the `from_fwk_configs`of `ConfigSet` class.
215
287
max_trials: Max tuning times. Default value is 100. Combine with timeout field to decide when to exit.
216
288
tolerable_loss: This float indicates how much metric loss we can accept. \
217
289
The metric loss is relative, it can be both positive and negative. Default is 0.01.
218
290
219
291
Examples:
292
+ # TODO: to refine it
220
293
from neural_compressor import TuningConfig
221
294
tune_config = TuningConfig(
222
295
config_set=[config1, config2, ...],
@@ -239,28 +312,13 @@ class TuningConfig:
239
312
# The best tuning config is config2, because of the following:
240
313
# 1. Not achieving the set goal. (config_metric < fp32_baseline * (1 - tolerable_loss))
241
314
# 2. Reached maximum tuning times.
242
-
243
- # Case 3: Timeout
244
- tune_config = TuningConfig(
245
- config_set=[config1, config2, ...],
246
- timeout=10, # seconds
247
- max_trials=3,
248
- tolerable_loss=0.01
249
- )
250
- config1_tuning_time, config2_tuning_time, config3_tuning_time, ... = 4, 5, 6, ... # seconds
251
- fp32_baseline = 100
252
- config1_metric, config2_metric, config3_metric, ... = 98, 98, 97, ...
253
-
254
- # Tuning result of case 3:
255
- # The best tuning config is config2, due to timeout, the third trial was forced to exit.
256
315
"""
257
316
258
317
def __init__ (
259
- self , config_set = None , timeout = 0 , max_trials = 100 , sampler : Sampler = None , tolerable_loss = 0.01
318
+ self , config_set = None , max_trials = 100 , sampler : Sampler = default_sampler , tolerable_loss = 0.01
260
319
) -> None :
261
320
"""Init a TuneCriterion object."""
262
321
self .config_set = config_set
263
- self .timeout = timeout
264
322
self .max_trials = max_trials
265
323
self .sampler = sampler
266
324
self .tolerable_loss = tolerable_loss
0 commit comments