Skip to content

Commit f502252

Browse files
authored
Revert "Support dynamically quantized 2D convolutions (#10347)"
This reverts commit cfd1be3.
1 parent 80fc3fc commit f502252

File tree

9 files changed

+14
-235
lines changed

9 files changed

+14
-235
lines changed

backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import torch
1010
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
11-
from executorch.backends.xnnpack.utils.quant_utils import is_dynamic_qdq
1211
from executorch.backends.xnnpack.utils.utils import is_param_node
1312
from executorch.exir.dialects._ops import ops as exir_ops
1413
from executorch.exir.pass_base import PassResult
@@ -284,26 +283,14 @@ def input_to_nhwc(
284283
]
285284
else:
286285
# Need to create NHWC node
287-
# Check if input uses dynamic quantization
288-
is_dynamic_input = is_dynamic_qdq(input_node)
289-
290-
if is_dynamic_input:
291-
# Trace back to original source node
292-
while getattr(input_node, "args", None):
293-
input_node = input_node.args[0]
294-
295286
with graph_module.graph.inserting_after(input_node):
296287
input_node_nhwc = self.create_call_function_node(
297288
graph_module=graph_module,
298289
target=exir_ops.edge.aten._to_copy.default,
299290
args=(input_node,),
300291
memory_format=torch.channels_last,
301292
)
302-
303-
if is_dynamic_input:
304-
# Replace downstream input_nodes with NHWC node
305-
input_node.replace_all_uses_with(input_node_nhwc)
306-
input_node_nhwc.args = (input_node,)
293+
self.mark_as_nhwc_node(input_node_nhwc)
307294

308295
self.insert_copy_and_assign_partner_nodes_quantization_sensitive(
309296
graph_module=graph_module,

backends/xnnpack/operators/quant_params.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -141,27 +141,12 @@ def quantize_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
141141
tensor, self.scale, self.zp, self.qmin, self.qmax, self.dtype
142142
)
143143

144-
# Temporary helper until non-batch dimensions can be inferred
145-
# Detects if a node feeds into a conv op by checking all downstream users
146-
@staticmethod
147-
def _feeds_into_conv(node: torch.fx.Node) -> bool:
148-
users_list = [node]
149-
150-
while users_list:
151-
current_user = users_list.pop()
152-
if "convolution" in str(current_user.target):
153-
return True
154-
users_list.extend(current_user.users)
155-
156-
return False
157-
158144
@classmethod
159145
def _from_dynamic_input_node(cls, quant_node: torch.fx.Node) -> QuantParams:
160146
q_input = quant_node.args[0] # fp32 input
161147
assert isinstance(q_input, torch.fx.Node)
162148
# TODO - materialize this from the quant_node scale count and val shape
163-
# Set non-batch dims to 3 if node feeds into conv (only 2D is supported), otherwise set to 1 for linear
164-
num_nonbatch_dims = 3 if cls._feeds_into_conv(quant_node) else 1
149+
num_nonbatch_dims = 1
165150

166151
return cls(
167152
per_channel=False, # True is not valid

backends/xnnpack/partition/config/gemm_configs.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from typing import cast, List, Optional, Tuple
1010

1111
import torch
12-
from executorch.backends.transforms import get_shape
1312
from executorch.backends.xnnpack.operators.quant_params import QuantParams
1413
from executorch.backends.xnnpack.partition.config.xnnpack_config import (
1514
ConfigPrecisionType,
@@ -28,7 +27,6 @@
2827
)
2928
from executorch.backends.xnnpack.utils.utils import (
3029
get_input_node,
31-
is_depthwise_conv,
3230
is_getitem,
3331
is_node,
3432
is_param_node,
@@ -361,23 +359,12 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
361359
return False # Only support 1D + 2D Conv
362360

363361
kernel_node = get_input_node(node, 1)
364-
kernel_shape = get_shape(kernel_node)
365362
weight_quant_params = QuantParams.from_weights(kernel_node, ep)
366-
groups = cast(int, node.args[8])
367-
is_transpose = node.args[6]
368363

369-
# XNNPACK does not support dynamic quantization convs that are not 2D or are depthwise
370-
if self._detect_precision(node) == ConfigPrecisionType.DYNAMIC_QUANT and (
371-
len(conv_stride) != 2
372-
or is_depthwise_conv(kernel_shape, groups, is_transpose)
373-
):
374-
why(
375-
node,
376-
"XNNPACK only supports standard 2D convolutions for dynamic quantization",
377-
)
378-
return False
364+
is_transpose = node.args[6]
365+
groups = cast(int, node.args[8])
379366

380-
# XNNPACK does not support non-zero output padding in transposed
367+
# XNNPack does not support non-zero output padding in transposed
381368
# convolutions.
382369
if is_transpose and any(
383370
out_pad != 0 for out_pad in cast(List[int], node.args[7])
@@ -407,7 +394,6 @@ def supported_precision_types(self):
407394
return [
408395
ConfigPrecisionType.FP32,
409396
ConfigPrecisionType.STATIC_QUANT,
410-
ConfigPrecisionType.DYNAMIC_QUANT,
411397
]
412398

413399

backends/xnnpack/quantizer/xnnpack_quantizer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,6 @@ class XNNPACKQuantizer(Quantizer):
265265

266266
DYNAMIC_OPS = [
267267
"linear",
268-
"conv",
269268
]
270269

271270
def __init__(self) -> None:

backends/xnnpack/quantizer/xnnpack_quantizer_utils.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
import torch
88
import torch.nn.functional as F
9-
from executorch.backends.xnnpack.utils.utils import is_depthwise_conv
109
from torch._subclasses import FakeTensor
1110
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
1211
from torch.ao.quantization.pt2e.export_utils import _WrapperModule
@@ -30,6 +29,7 @@
3029
)
3130
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
3231

32+
3333
__all__ = [
3434
"OperatorConfig",
3535
"OperatorPatternType",
@@ -323,23 +323,6 @@ def _do_annotate_conv(
323323
assert isinstance(weight, Node)
324324
input_qspec_map[weight] = get_weight_qspec(quantization_config)
325325

326-
# Only annotate dynamically quantized conv if it's 2D and not depthwise
327-
if (
328-
quantization_config
329-
and quantization_config.input_activation
330-
and quantization_config.input_activation.is_dynamic
331-
):
332-
weight_val = weight.meta.get("val", None)
333-
weight_shape = getattr(weight_val, "shape", None)
334-
335-
# Skip if not a 4D weight tensor (i.e. not conv2d)
336-
if weight_shape is not None and len(weight_shape) != 4:
337-
continue
338-
339-
# Skip if depthwise (default to groups=1 since it's not an arg)
340-
if is_depthwise_conv(weight_shape, 1, is_conv_transpose):
341-
continue
342-
343326
# adding weight node to the partition as well
344327
partition = [conv_node, conv_node.args[1]]
345328

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,11 @@ Error defineTensor(
512512
buffer_ptr == nullptr,
513513
Internal,
514514
"Dynamically quantized tensor should not have constant data but found non-nullptr");
515+
// TODO(T179441835): Dynamic Quantization with num_nonbatch_dims > 1
516+
ET_CHECK_OR_RETURN_ERROR(
517+
qparams->num_nonbatch_dims() == 1,
518+
Internal,
519+
"Dynamically Quantized Tensors currently only support per token quantization");
515520
status = xnn_define_dynamically_quantized_tensor_value(
516521
/*subgraph=*/subgraph_ptr,
517522
/*datatype=*/getDataType(tensor_value->datatype()),
@@ -1167,7 +1172,7 @@ Error defineStaticTransposeNode(
11671172
ET_CHECK_OR_RETURN_ERROR(
11681173
status == xnn_status_success,
11691174
Internal,
1170-
"Failed to create static transpose node %i with code: %s",
1175+
"Failed to create sigmoid node %i with code: %s",
11711176
node->debug_handle(),
11721177
xnn_status_to_string(status));
11731178

backends/xnnpack/test/ops/test_conv2d.py

Lines changed: 1 addition & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,6 @@
1818
except:
1919
has_quantized_ops = False
2020

21-
from executorch.backends.xnnpack.partition.config.xnnpack_config import (
22-
ConfigPrecisionType,
23-
)
24-
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
2521
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
2622
get_symmetric_quantization_config,
2723
)
@@ -30,7 +26,7 @@
3026
)
3127
from executorch.backends.xnnpack.test.test_xnnpack_utils import randomize_bn
3228
from executorch.backends.xnnpack.test.tester import Quantize, Tester
33-
from executorch.backends.xnnpack.test.tester.tester import ToEdgeTransformAndLower
29+
3430
from executorch.exir.dialects._ops import ops as exir_ops
3531

3632

@@ -173,43 +169,6 @@ def get_inputs(self):
173169
return (torch.randn(2, 2, 4, 4),)
174170

175171

176-
class Conv2dDQSeq(torch.nn.Module):
177-
def __init__(self):
178-
super().__init__()
179-
self.first = torch.nn.Conv2d(
180-
in_channels=3, out_channels=8, kernel_size=3, padding=1
181-
)
182-
self.second = torch.nn.Conv2d(
183-
in_channels=8, out_channels=10, kernel_size=3, padding=1
184-
)
185-
186-
def forward(self, x):
187-
y = self.first(x)
188-
return self.second(y)
189-
190-
def get_inputs(self):
191-
return (torch.randn(1, 3, 8, 8),)
192-
193-
194-
class Conv2dDQParallel(torch.nn.Module):
195-
def __init__(self):
196-
super().__init__()
197-
self.first = torch.nn.Conv2d(
198-
in_channels=3, out_channels=8, kernel_size=3, padding=1
199-
)
200-
self.second = torch.nn.Conv2d(
201-
in_channels=3, out_channels=8, kernel_size=3, padding=1
202-
)
203-
204-
def forward(self, x):
205-
first = self.first(x)
206-
second = self.second(x)
207-
return first, second
208-
209-
def get_inputs(self):
210-
return (torch.randn(1, 3, 8, 8),)
211-
212-
213172
class TestConv2d(unittest.TestCase):
214173
def setUp(self):
215174
torch._dynamo.reset()
@@ -264,37 +223,6 @@ def _test(
264223
.run_method_and_compare_outputs(qtol=1)
265224
)
266225

267-
def _test_dq(
268-
self,
269-
m: torch.nn.Module,
270-
conv_count=1,
271-
dynamic_shapes=None,
272-
):
273-
quant_config = get_symmetric_quantization_config(
274-
is_per_channel=True,
275-
is_dynamic=True,
276-
)
277-
278-
DynamicallyQuantizedPartitioner = XnnpackPartitioner(
279-
config_precisions=ConfigPrecisionType.DYNAMIC_QUANT,
280-
per_op_mode=True,
281-
)
282-
283-
tester = Tester(m, m.get_inputs(), dynamic_shapes=dynamic_shapes)
284-
tester.quantize(Quantize(quantization_config=quant_config))
285-
tester.export()
286-
tester.check(["torch.ops.quantized_decomposed.choose_qparams"])
287-
tester.to_edge_transform_and_lower(
288-
ToEdgeTransformAndLower([DynamicallyQuantizedPartitioner])
289-
)
290-
tester.check_count(
291-
{"torch.ops.higher_order.executorch_call_delegate": conv_count}
292-
)
293-
tester.check_not(["executorch_exir_dialects_edge__ops_aten_conv2d_default"])
294-
tester.to_executorch()
295-
tester.serialize()
296-
tester.run_method_and_compare_outputs(qtol=1)
297-
298226
def test_fp16_conv2d(self) -> None:
299227
for transpose in (True, False):
300228
for has_bias in (True, False):
@@ -771,26 +699,3 @@ def forward(self, x):
771699
.serialize()
772700
.run_method_and_compare_outputs(qtol=1)
773701
)
774-
775-
def test_dq_conv2d(self) -> None:
776-
model = Conv2d(
777-
in_channels=3,
778-
out_channels=10,
779-
kernel_size=(3, 3),
780-
stride=(1, 1),
781-
padding=(0, 0),
782-
batches=1,
783-
width=8,
784-
height=8,
785-
)
786-
self._test_dq(model)
787-
788-
def test_dq_conv2d_seq(self) -> None:
789-
model = Conv2dDQSeq()
790-
conv_count = sum(1 for m in model.modules() if type(m) is torch.nn.Conv2d)
791-
self._test_dq(model, conv_count)
792-
793-
def test_dq_conv2d_parallel(self) -> None:
794-
model = Conv2dDQParallel()
795-
conv_count = sum(1 for m in model.modules() if type(m) is torch.nn.Conv2d)
796-
self._test_dq(model, conv_count)

backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py

Lines changed: 1 addition & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,10 @@
1010
from executorch.backends.xnnpack._passes.channels_last_tagged_reshape_pass import (
1111
ChannelsLastTaggedReshapePass,
1212
)
13-
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
14-
get_symmetric_quantization_config,
15-
)
1613
from executorch.backends.xnnpack.test.test_xnnpack_utils_classes import (
1714
OpSequencesAddConv2d,
1815
)
19-
from executorch.backends.xnnpack.test.tester import Quantize, RunPasses, Tester
16+
from executorch.backends.xnnpack.test.tester import RunPasses, Tester
2017

2118

2219
class TestChannelsLastTaggedReshapePass(unittest.TestCase):
@@ -38,10 +35,6 @@ def setUp(self):
3835
dequant_name = "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default"
3936
conv_name = "executorch_exir_dialects_edge__ops_aten_convolution_default"
4037
relu_name = "executorch_exir_dialects_edge__ops_aten_relu_default"
41-
choose_qparams_name = (
42-
"executorch_exir_dialects_edge__ops_quantized_decomposed_choose_qparams_tensor"
43-
)
44-
dynamic_quant_name = "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_tensor"
4538

4639
def test_fp32_channels_last_tagged_reshape_pass(self):
4740
for module, num_reshape in self.modules.items():
@@ -186,37 +179,3 @@ def test_fp32_channels_last_tagged_reshape_pass_conv_bn_hardtanh_mean_seq(self):
186179
)
187180
.run_method_and_compare_outputs()
188181
)
189-
190-
class Conv2dDynamicQuant(torch.nn.Module):
191-
def __init__(self):
192-
super().__init__()
193-
self.conv = torch.nn.Conv2d(3, 10, 3)
194-
195-
def forward(self, x):
196-
return self.conv(x)
197-
198-
def test_dq_conv2d_channels_last_tagged_reshape_pass(self) -> None:
199-
(
200-
Tester(self.Conv2dDynamicQuant().eval(), (torch.randn(1, 3, 8, 8),))
201-
.quantize(
202-
Quantize(
203-
quantization_config=get_symmetric_quantization_config(
204-
is_dynamic=True
205-
)
206-
)
207-
)
208-
.export()
209-
.to_edge()
210-
.run_passes(self.PassStage)
211-
.check(
212-
[
213-
self.to_copy_name,
214-
self.choose_qparams_name,
215-
self.dynamic_quant_name,
216-
self.dequant_name,
217-
self.conv_name,
218-
self.to_copy_name,
219-
]
220-
)
221-
.run_method_and_compare_outputs()
222-
)

0 commit comments

Comments
 (0)