@@ -213,14 +213,57 @@ class TuningConfig:
213
213
config_set: quantization configs. Default value is empty.
214
214
timeout: Tuning timeout (seconds). Default value is 0 which means early stop.
215
215
max_trials: Max tuning times. Default value is 100. Combine with timeout field to decide when to exit.
216
+ tolerable_loss: This float indicates how much metric loss we can accept. \
217
+ The metric loss is relative, it can be both positive and negative. Default is 0.01.
218
+
219
+ Examples:
220
+ from neural_compressor import TuningConfig
221
+ tune_config = TuningConfig(
222
+ config_set=[config1, config2, ...],
223
+ max_trials=3,
224
+ tolerable_loss=0.01
225
+ )
226
+
227
+ # Case 1: Tolerable Loss
228
+ fp32_baseline = 100
229
+ config1_metric, config2_metric, ... = 98, 99, ...
230
+
231
+ # Tuning result of case 1:
232
+ # The best tuning config is config2, because config2_metric >= fp32_baseline * (1 - tolerable_loss)
233
+
234
+ # Case 2: Maximum Trials
235
+ fp32_baseline = 100
236
+ config1_metric, config2_metric, config3_metric, ... = 98, 98, 97, ...
237
+
238
+ # Tuning result of case 2:
239
+ # The best tuning config is config2, because of the following:
240
+ # 1. Not achieving the set goal. (config_metric < fp32_baseline * (1 - tolerable_loss))
241
+ # 2. Reached maximum tuning times.
242
+
243
+ # Case 3: Timeout
244
+ tune_config = TuningConfig(
245
+ config_set=[config1, config2, ...],
246
+ timeout=10, # seconds
247
+ max_trials=3,
248
+ tolerable_loss=0.01
249
+ )
250
+ config1_tuning_time, config2_tuning_time, config3_tuning_time, ... = 4, 5, 6, ... # seconds
251
+ fp32_baseline = 100
252
+ config1_metric, config2_metric, config3_metric, ... = 98, 98, 97, ...
253
+
254
+ # Tuning result of case 3:
255
+ # The best tuning config is config2, due to timeout, the third trial was forced to exit.
216
256
"""
217
257
218
- def __init__ (self , config_set = None , timeout = 0 , max_trials = 100 , sampler : Sampler = None ) -> None :
258
+ def __init__ (
259
+ self , config_set = None , timeout = 0 , max_trials = 100 , sampler : Sampler = None , tolerable_loss = 0.01
260
+ ) -> None :
219
261
"""Init a TuneCriterion object."""
220
262
self .config_set = config_set
221
263
self .timeout = timeout
222
264
self .max_trials = max_trials
223
265
self .sampler = sampler
266
+ self .tolerable_loss = tolerable_loss
224
267
225
268
226
269
class _TrialRecord :
@@ -242,12 +285,17 @@ def __init__(self, tuning_config: TuningConfig) -> None:
242
285
self .tuning_config = tuning_config
243
286
self .trial_cnt = 0
244
287
self .tuning_history : List [_TrialRecord ] = []
288
+ self .baseline = None
245
289
246
290
def add_trial_result (self , trial_index : int , trial_result : Union [int , float ], quant_config : BaseConfig ) -> None :
247
291
self .trial_cnt += 1
248
292
trial_record = _TrialRecord (trial_index , trial_result , quant_config )
249
293
self .tuning_history .append (trial_record )
250
294
295
+ def set_baseline (self , baseline : float ):
296
+ self .baseline = baseline
297
+ logger .info (f"Fp32 baseline is { self .baseline } " )
298
+
251
299
def get_number_of_trials (self ):
252
300
return len (self .tuning_history )
253
301
@@ -260,8 +308,23 @@ def get_best_quant_config(self) -> BaseConfig:
260
308
return sorted_trials_records [0 ].quant_config
261
309
262
310
def need_stop (self ) -> bool :
263
- # TODO Support more stop criteria in the next PR, such as `reach accuracy goal`, `timeout`, and so on.
264
- return self .trial_cnt >= self .tuning_config .max_trials
311
+ """Check if need to stop tuning. Either accuracy goal is met, max trials is reached or timeout is reached.
312
+
313
+ Returns:
314
+ bool: True if need to stop, otherwise False.
315
+ """
316
+
317
+ # TODO: Support more stop criteria in the next PR, such as `timeout`, and so on.
318
+ # reach max trials
319
+ reach_max_trials = self .trial_cnt >= self .tuning_config .max_trials
320
+ # reach accuracy goal
321
+ meet_accuracy_goal = (
322
+ False
323
+ if self .baseline is None
324
+ else self .tuning_history [- 1 ].trial_result >= (self .baseline * (1 - self .tuning_config .tolerable_loss ))
325
+ )
326
+ # [-1] is the last element representing the latest trail record.
327
+ return reach_max_trials or meet_accuracy_goal
265
328
266
329
267
330
def init_tuning (tuning_config : TuningConfig ) -> Tuple [ConfigLoader , TuningLogger , TuningMonitor ]:
0 commit comments