Skip to content

Commit d7030aa

Browse files
authored
Add filter function to XNNPack Quantizer
Differential Revision: D73677442 Pull Request resolved: #10626
1 parent 184fba5 commit d7030aa

File tree

2 files changed

+56
-3
lines changed

2 files changed

+56
-3
lines changed

backends/xnnpack/quantizer/xnnpack_quantizer.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,9 @@ def __init__(self) -> None:
292292
] = {}
293293
self.module_type_config: dict[Callable, Optional[QuantizationConfig]] = {}
294294
self.module_name_config: dict[str, Optional[QuantizationConfig]] = {}
295+
# If specified, only quantize nodes that return true for the filter
296+
# function.
297+
self.filter_fn: Optional[Callable[[Node], bool]] = None
295298

296299
@classmethod
297300
def get_supported_quantization_configs(cls) -> list[QuantizationConfig]:
@@ -355,6 +358,14 @@ def set_module_name(
355358
self.module_name_config[module_name] = quantization_config
356359
return self
357360

361+
def set_filter_function(self, filter_fn: Callable[[Node], bool]):
362+
"""
363+
Set the filter function. We only quantize nodes that return True for
364+
the filter function.
365+
"""
366+
self.filter_fn = filter_fn
367+
return self
368+
358369
def transform_for_annotation(
359370
self, model: torch.fx.GraphModule
360371
) -> torch.fx.GraphModule:
@@ -378,17 +389,29 @@ def _annotate_all_patterns(
378389
if quantization_config is None:
379390
return model
380391

392+
# Create a combined filter function, which returns True only when
393+
# both filter_fn and self.filter_fn return True.
394+
def combined_filter_fn(n: Node) -> bool:
395+
combined_filter = [self.filter_fn, filter_fn]
396+
return all(f(n) for f in combined_filter if f is not None)
397+
381398
for pattern in self.SUPPORTED_PATTERNS:
382399
if operator_target and operator_target not in pattern.op_overloads:
383400
# if operator_target is specified, skip patterns that aren't
384401
# associated with that target
385402
continue
386403
if quantization_config.input_activation.is_dynamic and pattern.is_dynamic:
387-
OP_TO_ANNOTATOR[pattern.name](model, quantization_config, filter_fn)
404+
OP_TO_ANNOTATOR[pattern.name](
405+
model, quantization_config, combined_filter_fn
406+
)
388407
elif quantization_config.is_qat and pattern.is_qat:
389-
OP_TO_ANNOTATOR[pattern.name](model, quantization_config, filter_fn)
408+
OP_TO_ANNOTATOR[pattern.name](
409+
model, quantization_config, combined_filter_fn
410+
)
390411
elif not quantization_config.input_activation.is_dynamic:
391-
OP_TO_ANNOTATOR[pattern.name](model, quantization_config, filter_fn)
412+
OP_TO_ANNOTATOR[pattern.name](
413+
model, quantization_config, combined_filter_fn
414+
)
392415

393416
return model
394417

backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,36 @@ def test_obs_sharing_ops(self):
297297
]
298298
self._test_quantizer(m, example_inputs, quantizer, node_occurrence, node_list)
299299

300+
def test_set_filter_fn(self):
301+
quantizer = XNNPACKQuantizer()
302+
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
303+
quantizer.set_global(quantization_config)
304+
m_eager = TestHelperModules.TwoLinearModule().eval()
305+
306+
# Set the filter function so that the second linear is not quantized
307+
def filter_fn(n):
308+
return n.name != "linear_1"
309+
310+
quantizer.set_filter_function(filter_fn)
311+
312+
# Test with 2d inputs
313+
example_inputs_2d = (torch.randn(9, 8),)
314+
node_occurrence = {
315+
# input and output of the first linear op will be (de)quantized
316+
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
317+
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
318+
# quantize_per_channel for weights are const propagated
319+
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
320+
# weight for the first linear will be dequantized
321+
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
322+
}
323+
self._test_quantizer(
324+
m_eager,
325+
example_inputs_2d,
326+
quantizer,
327+
node_occurrence,
328+
)
329+
300330
def test_set_module_name(self):
301331
class Sub(torch.nn.Module):
302332
def __init__(self) -> None:

0 commit comments

Comments
 (0)