Skip to content

Commit ff3c3b6

Browse files
pytorchbotmcr229GregoryComer
authored
[Quantized DeConv Support] Enable Quantized Transposed Convs with groups==1 (#11774)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #11730 by @mcr229 ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/mcr229/31/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/mcr229/31/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/mcr229/31/orig @diff-train-skip-merge --------- Co-authored-by: Max Ren <[email protected]> Co-authored-by: Gregory Comer <[email protected]>
1 parent 222d9e3 commit ff3c3b6

File tree

4 files changed

+172
-144
lines changed

4 files changed

+172
-144
lines changed

backends/xnnpack/quantizer/xnnpack_quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ class XNNPACKQuantizer(Quantizer):
274274
QuantPattern("linear_relu", False, False, LINEAR_TARGETS),
275275
QuantPattern("linear", True, False, LINEAR_TARGETS),
276276
QuantPattern("conv", True, False, CONV_TARGETS),
277-
QuantPattern("conv_transpose", False, False, CONV_TARGETS),
277+
QuantPattern("conv_transpose", True, False, CONV_TARGETS),
278278
QuantPattern("conv_relu", False, False, CONV_TARGETS),
279279
QuantPattern("conv_transpose_relu", False, False, CONV_TARGETS),
280280
QuantPattern("adaptive_avg_pool2d", False, False, ADAPTIVE_AVG_POOL2D_TARGETS),

backends/xnnpack/quantizer/xnnpack_quantizer_utils.py

Lines changed: 59 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44

55
import torch
66
import torch.nn.functional as F
7-
from executorch.backends.xnnpack.utils.utils import is_depthwise_conv
7+
from executorch.backends.xnnpack.utils.utils import (
8+
get_groups_from_conv,
9+
is_depthwise_conv,
10+
)
811
from torch._subclasses import FakeTensor
912
from torch.fx import Node
1013
from torch.fx.passes.utils.matcher_with_name_node_map_utils import (
@@ -65,6 +68,28 @@ def decorator(annotator: AnnotatorType) -> None:
6568
return decorator
6669

6770

71+
def change_quantization_config(
72+
original_qspec,
73+
dtype=None,
74+
quant_min=None,
75+
quant_max=None,
76+
qscheme=None,
77+
ch_axis=None,
78+
is_dynamic=None,
79+
observer_or_fake_quant_ctr=None,
80+
):
81+
return QuantizationSpec(
82+
dtype=dtype or original_qspec.dtype,
83+
quant_min=quant_min or original_qspec.quant_min,
84+
quant_max=quant_max or original_qspec.quant_max,
85+
qscheme=qscheme or original_qspec.qscheme,
86+
ch_axis=ch_axis or original_qspec.ch_axis,
87+
is_dynamic=is_dynamic or original_qspec.is_dynamic,
88+
observer_or_fake_quant_ctr=observer_or_fake_quant_ctr
89+
or original_qspec.observer_or_fake_quant_ctr,
90+
)
91+
92+
6893
def is_relu_node(node: Node) -> bool:
6994
"""
7095
Check if a given node is a relu node
@@ -231,31 +256,44 @@ def _do_annotate_conv(
231256
if is_relu_node(user):
232257
continue
233258

259+
# Tracks conditions for whether or not to skip
260+
skip = False
261+
234262
input_qspec_map = {}
235263
input_act = conv_node.args[0]
236264
assert isinstance(input_act, Node)
237265
input_qspec_map[input_act] = get_input_act_qspec(quantization_config)
238266

239267
weight = conv_node.args[1]
240268
assert isinstance(weight, Node)
241-
input_qspec_map[weight] = get_weight_qspec(quantization_config)
269+
weight_qspec = get_weight_qspec(quantization_config)
270+
num_groups = get_groups_from_conv(conv_node)
242271

243-
# Only annotate dynamically quantized conv if it's 2D and not depthwise
244-
if (
272+
# skip if transposed conv has more than 1 group
273+
skip = skip or (is_conv_transpose and num_groups != 1)
274+
print(f"{skip} conv transpose and num_groups")
275+
276+
if is_conv_transpose:
277+
# transposed convs per output channel quantization
278+
weight_qspec = change_quantization_config(weight_qspec, ch_axis=1)
279+
280+
input_qspec_map[weight] = weight_qspec
281+
is_dynamic = (
245282
quantization_config
246283
and quantization_config.input_activation
247284
and quantization_config.input_activation.is_dynamic
248-
):
285+
)
286+
287+
# Only annotate dynamically quantized conv if it's 2D and not depthwise
288+
if is_dynamic:
249289
weight_val = weight.meta.get("val", None)
250290
weight_shape = getattr(weight_val, "shape", None)
251-
252291
# Skip if not a 4D weight tensor (i.e. not conv2d)
253-
if weight_shape is not None and len(weight_shape) != 4:
254-
continue
255-
292+
skip = skip or (weight_shape is not None and len(weight_shape) != 4)
256293
# Skip if depthwise (default to groups=1 since it's not an arg)
257-
if is_depthwise_conv(weight_shape, 1, is_conv_transpose):
258-
continue
294+
skip = skip or (
295+
not is_conv_transpose and is_depthwise_conv(weight_shape, 1, False)
296+
)
259297

260298
# adding weight node to the partition as well
261299
partition = [conv_node, conv_node.args[1]]
@@ -265,7 +303,7 @@ def _do_annotate_conv(
265303
input_qspec_map[bias] = get_bias_qspec(quantization_config)
266304
partition.append(bias)
267305

268-
if _is_annotated(partition):
306+
if _is_annotated(partition) or skip:
269307
continue
270308

271309
if filter_fn and any(not filter_fn(n) for n in partition):
@@ -311,7 +349,12 @@ def _do_annotate_conv_relu(
311349

312350
weight = conv_node.args[1]
313351
assert isinstance(weight, Node)
314-
input_qspec_map[weight] = get_weight_qspec(quantization_config)
352+
weight_qspec = get_weight_qspec(quantization_config)
353+
groups = get_groups_from_conv(conv_node)
354+
if is_conv_transpose:
355+
# transposed convs per output channel quantization
356+
weight_qspec = change_quantization_config(weight_qspec, ch_axis=1)
357+
input_qspec_map[weight] = weight_qspec
315358

316359
# adding weight node to the partition as well
317360
partition = [relu_node, conv_node, conv_node.args[1]]
@@ -323,6 +366,9 @@ def _do_annotate_conv_relu(
323366
if _is_annotated(partition):
324367
continue
325368

369+
if is_conv_transpose and groups != 1:
370+
continue
371+
326372
if filter_fn and any(not filter_fn(n) for n in partition):
327373
continue
328374

0 commit comments

Comments
 (0)