diff --git a/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py b/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py index 0dcfb4484ed..2ebf69da4f5 100644 --- a/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py +++ b/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py @@ -238,7 +238,19 @@ def _do_annotate_conv( weight = conv_node.args[1] assert isinstance(weight, Node) - input_qspec_map[weight] = get_weight_qspec(quantization_config) + weight_qspec = get_weight_qspec(quantization_config) + if is_conv_transpose: + # transposed convs per output channel quantization + weight_qspec = QuantizationSpec( + dtype=weight_qspec.dtype, + quant_min=weight_qspec.quant_min, + quant_max=weight_qspec.quant_max, + qscheme=weight_qspec.qscheme, + ch_axis=1, + is_dynamic=False, + observer_or_fake_quant_ctr=weight_qspec.observer_or_fake_quant_ctr, + ) + input_qspec_map[weight] = weight_qspec # Only annotate dynamically quantized conv if it's 2D and not depthwise if ( @@ -311,7 +323,19 @@ def _do_annotate_conv_relu( weight = conv_node.args[1] assert isinstance(weight, Node) - input_qspec_map[weight] = get_weight_qspec(quantization_config) + weight_qspec = get_weight_qspec(quantization_config) + if is_conv_transpose: + # transposed convs per output channel quantization + weight_qspec = QuantizationSpec( + dtype=weight_qspec.dtype, + quant_min=weight_qspec.quant_min, + quant_max=weight_qspec.quant_max, + qscheme=weight_qspec.qscheme, + ch_axis=1, + is_dynamic=False, + observer_or_fake_quant_ctr=weight_qspec.observer_or_fake_quant_ctr, + ) + input_qspec_map[weight] = weight_qspec # adding weight node to the partition as well partition = [relu_node, conv_node, conv_node.args[1]] diff --git a/backends/xnnpack/test/ops/test_conv2d.py b/backends/xnnpack/test/ops/test_conv2d.py index 92bb03c907a..d838ef0ffe9 100644 --- a/backends/xnnpack/test/ops/test_conv2d.py +++ b/backends/xnnpack/test/ops/test_conv2d.py @@ -221,7 +221,6 @@ def _test( conv_count=1, dtype: torch.dtype = torch.float, check_quantized=True, - delegated=True, ): # pyre-fixme[29]: `Union[torch._tensor.Tensor, # torch.nn.modules.module.Module]` is not a function. @@ -240,29 +239,20 @@ def _test( (tester.export().check_count({op: conv_count}).to_edge_transform_and_lower()) - if delegated: - ( - tester.check_not( - ["executorch_exir_dialects_edge__ops_aten_convolution_default"] - ) - .check_not( - [ - "executorch_exir_dialects_edge__ops__native_batch_norm_legit_no_training_default" - ] - ) - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .to_executorch() - .serialize() - .run_method_and_compare_outputs(qtol=1) + ( + tester.check_not( + ["executorch_exir_dialects_edge__ops_aten_convolution_default"] ) - else: - # need quantize ops when ops are not delegated to xnnpack - if has_quantized_ops: - ( - tester.to_executorch() - .serialize() - .run_method_and_compare_outputs(qtol=1) - ) + .check_not( + [ + "executorch_exir_dialects_edge__ops__native_batch_norm_legit_no_training_default" + ] + ) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .serialize() + .run_method_and_compare_outputs(qtol=1) + ) def _test_dq( self, @@ -325,7 +315,6 @@ def test_qs8_conv2d_per_channel(self) -> None: self._test( Conv2d(transpose=transpose), quant_config=get_symmetric_quantization_config(is_per_channel=True), - delegated=not transpose, # XNNPACK does not support per input channel quantization for transpose convolutions with groups > 1 ) def test_fp32_conv2d_seq(self) -> None: @@ -485,7 +474,6 @@ def get_inputs(self): self._test( ConvReLU(transpose=transpose), quant_config=get_symmetric_quantization_config(is_per_channel=True), - delegated=not transpose, # XNNPACK does not support per input channel quantization for transpose convolutions with groups > 1 ) def test_qs8_conv2d_dw_relu(self): @@ -537,8 +525,6 @@ def get_inputs(self): quant_config=get_symmetric_quantization_config( is_per_channel=per_channel_quant ), - # XNNPACK does not support per input channel quantization for transpose convolutions with groups > 1 - delegated=not (transpose and per_channel_quant), ) def test_qs8_conv2d_relu_seq(self): @@ -593,7 +579,7 @@ def get_inputs(self): conv_count=2, ) - def test_qs8_conv_transpose_2d_quantize_per_channel(self): + def test_qs8_conv_transpose_2d_quantize_per_channel_multi_axis(self): class PerChannelConvTranspose2d(torch.nn.Module): def __init__(self, input_channels, output_channels, groups, axis): super().__init__() @@ -662,76 +648,24 @@ def get_inputs(self): ) for groups in (1, 2): - for axis in (0, 1): - self._test( - PerChannelConvTranspose2d(3 * groups, 5 * groups, groups, axis), - quant_config=None, - conv_count=1, - delegated=axis == 1 - and groups - == 1, # xnnpack only support output channel axis quantization with groups == 1 - ) - - def test_qs8_conv_transpose_2d_dqd_f32_weights(self): - class TransposeConv2dDQDf32weights(torch.nn.Module): - def __init__(self, input_channels, output_channels, groups, axis): - super().__init__() - self.input_channels = input_channels - self.output_channels = output_channels - self.axis = axis - self.groups = groups - self.transpose = True - self.weights = torch.nn.Parameter( - torch.randn((input_channels, output_channels // groups, 4, 4)), - requires_grad=False, - ) - - axis_size = self.weights.shape[axis] - self.scale = torch.nn.Parameter(torch.ones(axis_size) * 0.12345) - self.zero_point = torch.nn.Parameter( - torch.zeros((axis_size,), dtype=torch.int64), requires_grad=False - ) - - def forward(self, x): - dequantize_input = ( - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default( - x, 0.12345, 0, -127, 127, torch.int8 + for ch_axis in (1, 2): + if ch_axis == 1 and groups == 1: + self._test( + PerChannelConvTranspose2d( + 3 * groups, 5 * groups, groups, ch_axis + ), # ch_axis=0 + quant_config=None, + conv_count=1, ) - ) - x = torch.nn.functional.conv_transpose2d( - dequantize_input, self.weights, groups=self.groups - ) - - return exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default( - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default( - x, - 0.12345, - 0, - -127, - 127, - torch.int8, - ), - 0.12345, - 0, - -127, - 127, - torch.int8, - ) - - def get_inputs(self): - return ( - torch.randint( - low=-127, high=127, size=(3, self.input_channels, 4, 4) - ).type(dtype=torch.int8), - ) - - for groups in (1, 2): - for axis in (0, 1): - self._test( - TransposeConv2dDQDf32weights(3 * groups, 5 * groups, groups, axis), - quant_config=None, - conv_count=1, - ) + else: + with self.assertRaises(RuntimeError): + self._test( + PerChannelConvTranspose2d( + 3 * groups, 5 * groups, groups, ch_axis + ), # ch_axis=0 + quant_config=None, + conv_count=1, + ) def test_padded_output_tconv(self): class TConv2d(torch.nn.Module): @@ -761,7 +695,7 @@ def forward(self, x): (tester.export().check_count({op: conv_count}).to_edge_transform_and_lower()) - # tconv should not be offloaded to XNNPack, since output padding is not + # tconv should not be offloaded to XNNPack, since output padding is not supported ( tester.check( ["executorch_exir_dialects_edge__ops_aten_convolution_default"]