@@ -292,6 +292,9 @@ def __init__(self) -> None:
292
292
] = {}
293
293
self .module_type_config : dict [Callable , Optional [QuantizationConfig ]] = {}
294
294
self .module_name_config : dict [str , Optional [QuantizationConfig ]] = {}
295
+ # If specified, only quantize nodes that return true for the filter
296
+ # function.
297
+ self .filter_fn : Optional [Callable [[Node ], bool ]] = None
295
298
296
299
@classmethod
297
300
def get_supported_quantization_configs (cls ) -> list [QuantizationConfig ]:
@@ -355,6 +358,14 @@ def set_module_name(
355
358
self .module_name_config [module_name ] = quantization_config
356
359
return self
357
360
361
+ def set_filter_function (self , filter_fn : Callable [[Node ], bool ]):
362
+ """
363
+ Set the filter function. We only quantize nodes that return True for
364
+ the filter function.
365
+ """
366
+ self .filter_fn = filter_fn
367
+ return self
368
+
358
369
def transform_for_annotation (
359
370
self , model : torch .fx .GraphModule
360
371
) -> torch .fx .GraphModule :
@@ -378,17 +389,29 @@ def _annotate_all_patterns(
378
389
if quantization_config is None :
379
390
return model
380
391
392
+ # Create a combined filter function, which returns True only when
393
+ # both filter_fn and self.filter_fn return True.
394
+ def combined_filter_fn (n : Node ) -> bool :
395
+ combined_filter = [self .filter_fn , filter_fn ]
396
+ return all (f (n ) for f in combined_filter if f is not None )
397
+
381
398
for pattern in self .SUPPORTED_PATTERNS :
382
399
if operator_target and operator_target not in pattern .op_overloads :
383
400
# if operator_target is specified, skip patterns that aren't
384
401
# associated with that target
385
402
continue
386
403
if quantization_config .input_activation .is_dynamic and pattern .is_dynamic :
387
- OP_TO_ANNOTATOR [pattern .name ](model , quantization_config , filter_fn )
404
+ OP_TO_ANNOTATOR [pattern .name ](
405
+ model , quantization_config , combined_filter_fn
406
+ )
388
407
elif quantization_config .is_qat and pattern .is_qat :
389
- OP_TO_ANNOTATOR [pattern .name ](model , quantization_config , filter_fn )
408
+ OP_TO_ANNOTATOR [pattern .name ](
409
+ model , quantization_config , combined_filter_fn
410
+ )
390
411
elif not quantization_config .input_activation .is_dynamic :
391
- OP_TO_ANNOTATOR [pattern .name ](model , quantization_config , filter_fn )
412
+ OP_TO_ANNOTATOR [pattern .name ](
413
+ model , quantization_config , combined_filter_fn
414
+ )
392
415
393
416
return model
394
417
0 commit comments