Skip to content

[Quantized DeConv Support] Enable Quantized Transposed Convs with groups==1 #11730

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions backends/xnnpack/quantizer/xnnpack_quantizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,19 @@ def _do_annotate_conv(

weight = conv_node.args[1]
assert isinstance(weight, Node)
input_qspec_map[weight] = get_weight_qspec(quantization_config)
weight_qspec = get_weight_qspec(quantization_config)
if is_conv_transpose:
# transposed convs per output channel quantization
weight_qspec = QuantizationSpec(
dtype=weight_qspec.dtype,
quant_min=weight_qspec.quant_min,
quant_max=weight_qspec.quant_max,
qscheme=weight_qspec.qscheme,
ch_axis=1,
is_dynamic=False,
observer_or_fake_quant_ctr=weight_qspec.observer_or_fake_quant_ctr,
)
input_qspec_map[weight] = weight_qspec

# Only annotate dynamically quantized conv if it's 2D and not depthwise
if (
Expand Down Expand Up @@ -311,7 +323,19 @@ def _do_annotate_conv_relu(

weight = conv_node.args[1]
assert isinstance(weight, Node)
input_qspec_map[weight] = get_weight_qspec(quantization_config)
weight_qspec = get_weight_qspec(quantization_config)
if is_conv_transpose:
# transposed convs per output channel quantization
weight_qspec = QuantizationSpec(
dtype=weight_qspec.dtype,
quant_min=weight_qspec.quant_min,
quant_max=weight_qspec.quant_max,
qscheme=weight_qspec.qscheme,
ch_axis=1,
is_dynamic=False,
observer_or_fake_quant_ctr=weight_qspec.observer_or_fake_quant_ctr,
)
input_qspec_map[weight] = weight_qspec

# adding weight node to the partition as well
partition = [relu_node, conv_node, conv_node.args[1]]
Expand Down
130 changes: 32 additions & 98 deletions backends/xnnpack/test/ops/test_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,6 @@ def _test(
conv_count=1,
dtype: torch.dtype = torch.float,
check_quantized=True,
delegated=True,
):
# pyre-fixme[29]: `Union[torch._tensor.Tensor,
# torch.nn.modules.module.Module]` is not a function.
Expand All @@ -240,29 +239,20 @@ def _test(

(tester.export().check_count({op: conv_count}).to_edge_transform_and_lower())

if delegated:
(
tester.check_not(
["executorch_exir_dialects_edge__ops_aten_convolution_default"]
)
.check_not(
[
"executorch_exir_dialects_edge__ops__native_batch_norm_legit_no_training_default"
]
)
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.serialize()
.run_method_and_compare_outputs(qtol=1)
(
tester.check_not(
["executorch_exir_dialects_edge__ops_aten_convolution_default"]
)
else:
# need quantize ops when ops are not delegated to xnnpack
if has_quantized_ops:
(
tester.to_executorch()
.serialize()
.run_method_and_compare_outputs(qtol=1)
)
.check_not(
[
"executorch_exir_dialects_edge__ops__native_batch_norm_legit_no_training_default"
]
)
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.serialize()
.run_method_and_compare_outputs(qtol=1)
)

def _test_dq(
self,
Expand Down Expand Up @@ -325,7 +315,6 @@ def test_qs8_conv2d_per_channel(self) -> None:
self._test(
Conv2d(transpose=transpose),
quant_config=get_symmetric_quantization_config(is_per_channel=True),
delegated=not transpose, # XNNPACK does not support per input channel quantization for transpose convolutions with groups > 1
)

def test_fp32_conv2d_seq(self) -> None:
Expand Down Expand Up @@ -485,7 +474,6 @@ def get_inputs(self):
self._test(
ConvReLU(transpose=transpose),
quant_config=get_symmetric_quantization_config(is_per_channel=True),
delegated=not transpose, # XNNPACK does not support per input channel quantization for transpose convolutions with groups > 1
)

def test_qs8_conv2d_dw_relu(self):
Expand Down Expand Up @@ -537,8 +525,6 @@ def get_inputs(self):
quant_config=get_symmetric_quantization_config(
is_per_channel=per_channel_quant
),
# XNNPACK does not support per input channel quantization for transpose convolutions with groups > 1
delegated=not (transpose and per_channel_quant),
)

def test_qs8_conv2d_relu_seq(self):
Expand Down Expand Up @@ -593,7 +579,7 @@ def get_inputs(self):
conv_count=2,
)

def test_qs8_conv_transpose_2d_quantize_per_channel(self):
def test_qs8_conv_transpose_2d_quantize_per_channel_multi_axis(self):
class PerChannelConvTranspose2d(torch.nn.Module):
def __init__(self, input_channels, output_channels, groups, axis):
super().__init__()
Expand Down Expand Up @@ -662,76 +648,24 @@ def get_inputs(self):
)

for groups in (1, 2):
for axis in (0, 1):
self._test(
PerChannelConvTranspose2d(3 * groups, 5 * groups, groups, axis),
quant_config=None,
conv_count=1,
delegated=axis == 1
and groups
== 1, # xnnpack only support output channel axis quantization with groups == 1
)

def test_qs8_conv_transpose_2d_dqd_f32_weights(self):
class TransposeConv2dDQDf32weights(torch.nn.Module):
def __init__(self, input_channels, output_channels, groups, axis):
super().__init__()
self.input_channels = input_channels
self.output_channels = output_channels
self.axis = axis
self.groups = groups
self.transpose = True
self.weights = torch.nn.Parameter(
torch.randn((input_channels, output_channels // groups, 4, 4)),
requires_grad=False,
)

axis_size = self.weights.shape[axis]
self.scale = torch.nn.Parameter(torch.ones(axis_size) * 0.12345)
self.zero_point = torch.nn.Parameter(
torch.zeros((axis_size,), dtype=torch.int64), requires_grad=False
)

def forward(self, x):
dequantize_input = (
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default(
x, 0.12345, 0, -127, 127, torch.int8
for ch_axis in (1, 2):
if ch_axis == 1 and groups == 1:
self._test(
PerChannelConvTranspose2d(
3 * groups, 5 * groups, groups, ch_axis
), # ch_axis=0
quant_config=None,
conv_count=1,
)
)
x = torch.nn.functional.conv_transpose2d(
dequantize_input, self.weights, groups=self.groups
)

return exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default(
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default(
x,
0.12345,
0,
-127,
127,
torch.int8,
),
0.12345,
0,
-127,
127,
torch.int8,
)

def get_inputs(self):
return (
torch.randint(
low=-127, high=127, size=(3, self.input_channels, 4, 4)
).type(dtype=torch.int8),
)

for groups in (1, 2):
for axis in (0, 1):
self._test(
TransposeConv2dDQDf32weights(3 * groups, 5 * groups, groups, axis),
quant_config=None,
conv_count=1,
)
else:
with self.assertRaises(RuntimeError):
self._test(
PerChannelConvTranspose2d(
3 * groups, 5 * groups, groups, ch_axis
), # ch_axis=0
quant_config=None,
conv_count=1,
)

def test_padded_output_tconv(self):
class TConv2d(torch.nn.Module):
Expand Down Expand Up @@ -761,7 +695,7 @@ def forward(self, x):

(tester.export().check_count({op: conv_count}).to_edge_transform_and_lower())

# tconv should not be offloaded to XNNPack, since output padding is not
# tconv should not be offloaded to XNNPack, since output padding is not supported
(
tester.check(
["executorch_exir_dialects_edge__ops_aten_convolution_default"]
Expand Down
Loading