Skip to content

Revert "Support dynamically quantized 2D convolutions" #10397

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 1 addition & 14 deletions backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import torch
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
from executorch.backends.xnnpack.utils.quant_utils import is_dynamic_qdq
from executorch.backends.xnnpack.utils.utils import is_param_node
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import PassResult
Expand Down Expand Up @@ -284,26 +283,14 @@ def input_to_nhwc(
]
else:
# Need to create NHWC node
# Check if input uses dynamic quantization
is_dynamic_input = is_dynamic_qdq(input_node)

if is_dynamic_input:
# Trace back to original source node
while getattr(input_node, "args", None):
input_node = input_node.args[0]

with graph_module.graph.inserting_after(input_node):
input_node_nhwc = self.create_call_function_node(
graph_module=graph_module,
target=exir_ops.edge.aten._to_copy.default,
args=(input_node,),
memory_format=torch.channels_last,
)

if is_dynamic_input:
# Replace downstream input_nodes with NHWC node
input_node.replace_all_uses_with(input_node_nhwc)
input_node_nhwc.args = (input_node,)
self.mark_as_nhwc_node(input_node_nhwc)

self.insert_copy_and_assign_partner_nodes_quantization_sensitive(
graph_module=graph_module,
Expand Down
17 changes: 1 addition & 16 deletions backends/xnnpack/operators/quant_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,27 +141,12 @@ def quantize_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
tensor, self.scale, self.zp, self.qmin, self.qmax, self.dtype
)

# Temporary helper until non-batch dimensions can be inferred
# Detects if a node feeds into a conv op by checking all downstream users
@staticmethod
def _feeds_into_conv(node: torch.fx.Node) -> bool:
users_list = [node]

while users_list:
current_user = users_list.pop()
if "convolution" in str(current_user.target):
return True
users_list.extend(current_user.users)

return False

@classmethod
def _from_dynamic_input_node(cls, quant_node: torch.fx.Node) -> QuantParams:
q_input = quant_node.args[0] # fp32 input
assert isinstance(q_input, torch.fx.Node)
# TODO - materialize this from the quant_node scale count and val shape
# Set non-batch dims to 3 if node feeds into conv (only 2D is supported), otherwise set to 1 for linear
num_nonbatch_dims = 3 if cls._feeds_into_conv(quant_node) else 1
num_nonbatch_dims = 1

return cls(
per_channel=False, # True is not valid
Expand Down
20 changes: 3 additions & 17 deletions backends/xnnpack/partition/config/gemm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from typing import cast, List, Optional, Tuple

import torch
from executorch.backends.transforms import get_shape
from executorch.backends.xnnpack.operators.quant_params import QuantParams
from executorch.backends.xnnpack.partition.config.xnnpack_config import (
ConfigPrecisionType,
Expand All @@ -28,7 +27,6 @@
)
from executorch.backends.xnnpack.utils.utils import (
get_input_node,
is_depthwise_conv,
is_getitem,
is_node,
is_param_node,
Expand Down Expand Up @@ -361,23 +359,12 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
return False # Only support 1D + 2D Conv

kernel_node = get_input_node(node, 1)
kernel_shape = get_shape(kernel_node)
weight_quant_params = QuantParams.from_weights(kernel_node, ep)
groups = cast(int, node.args[8])
is_transpose = node.args[6]

# XNNPACK does not support dynamic quantization convs that are not 2D or are depthwise
if self._detect_precision(node) == ConfigPrecisionType.DYNAMIC_QUANT and (
len(conv_stride) != 2
or is_depthwise_conv(kernel_shape, groups, is_transpose)
):
why(
node,
"XNNPACK only supports standard 2D convolutions for dynamic quantization",
)
return False
is_transpose = node.args[6]
groups = cast(int, node.args[8])

# XNNPACK does not support non-zero output padding in transposed
# XNNPack does not support non-zero output padding in transposed
# convolutions.
if is_transpose and any(
out_pad != 0 for out_pad in cast(List[int], node.args[7])
Expand Down Expand Up @@ -407,7 +394,6 @@ def supported_precision_types(self):
return [
ConfigPrecisionType.FP32,
ConfigPrecisionType.STATIC_QUANT,
ConfigPrecisionType.DYNAMIC_QUANT,
]


Expand Down
1 change: 0 additions & 1 deletion backends/xnnpack/quantizer/xnnpack_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,6 @@ class XNNPACKQuantizer(Quantizer):

DYNAMIC_OPS = [
"linear",
"conv",
]

def __init__(self) -> None:
Expand Down
19 changes: 1 addition & 18 deletions backends/xnnpack/quantizer/xnnpack_quantizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import torch
import torch.nn.functional as F
from executorch.backends.xnnpack.utils.utils import is_depthwise_conv
from torch._subclasses import FakeTensor
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
from torch.ao.quantization.pt2e.export_utils import _WrapperModule
Expand All @@ -30,6 +29,7 @@
)
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions


__all__ = [
"OperatorConfig",
"OperatorPatternType",
Expand Down Expand Up @@ -323,23 +323,6 @@ def _do_annotate_conv(
assert isinstance(weight, Node)
input_qspec_map[weight] = get_weight_qspec(quantization_config)

# Only annotate dynamically quantized conv if it's 2D and not depthwise
if (
quantization_config
and quantization_config.input_activation
and quantization_config.input_activation.is_dynamic
):
weight_val = weight.meta.get("val", None)
weight_shape = getattr(weight_val, "shape", None)

# Skip if not a 4D weight tensor (i.e. not conv2d)
if weight_shape is not None and len(weight_shape) != 4:
continue

# Skip if depthwise (default to groups=1 since it's not an arg)
if is_depthwise_conv(weight_shape, 1, is_conv_transpose):
continue

# adding weight node to the partition as well
partition = [conv_node, conv_node.args[1]]

Expand Down
7 changes: 6 additions & 1 deletion backends/xnnpack/runtime/XNNCompiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,11 @@ Error defineTensor(
buffer_ptr == nullptr,
Internal,
"Dynamically quantized tensor should not have constant data but found non-nullptr");
// TODO(T179441835): Dynamic Quantization with num_nonbatch_dims > 1
ET_CHECK_OR_RETURN_ERROR(
qparams->num_nonbatch_dims() == 1,
Internal,
"Dynamically Quantized Tensors currently only support per token quantization");
status = xnn_define_dynamically_quantized_tensor_value(
/*subgraph=*/subgraph_ptr,
/*datatype=*/getDataType(tensor_value->datatype()),
Expand Down Expand Up @@ -1167,7 +1172,7 @@ Error defineStaticTransposeNode(
ET_CHECK_OR_RETURN_ERROR(
status == xnn_status_success,
Internal,
"Failed to create static transpose node %i with code: %s",
"Failed to create sigmoid node %i with code: %s",
node->debug_handle(),
xnn_status_to_string(status));

Expand Down
97 changes: 1 addition & 96 deletions backends/xnnpack/test/ops/test_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@
except:
has_quantized_ops = False

from executorch.backends.xnnpack.partition.config.xnnpack_config import (
ConfigPrecisionType,
)
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
)
Expand All @@ -30,7 +26,7 @@
)
from executorch.backends.xnnpack.test.test_xnnpack_utils import randomize_bn
from executorch.backends.xnnpack.test.tester import Quantize, Tester
from executorch.backends.xnnpack.test.tester.tester import ToEdgeTransformAndLower

from executorch.exir.dialects._ops import ops as exir_ops


Expand Down Expand Up @@ -173,43 +169,6 @@ def get_inputs(self):
return (torch.randn(2, 2, 4, 4),)


class Conv2dDQSeq(torch.nn.Module):
def __init__(self):
super().__init__()
self.first = torch.nn.Conv2d(
in_channels=3, out_channels=8, kernel_size=3, padding=1
)
self.second = torch.nn.Conv2d(
in_channels=8, out_channels=10, kernel_size=3, padding=1
)

def forward(self, x):
y = self.first(x)
return self.second(y)

def get_inputs(self):
return (torch.randn(1, 3, 8, 8),)


class Conv2dDQParallel(torch.nn.Module):
def __init__(self):
super().__init__()
self.first = torch.nn.Conv2d(
in_channels=3, out_channels=8, kernel_size=3, padding=1
)
self.second = torch.nn.Conv2d(
in_channels=3, out_channels=8, kernel_size=3, padding=1
)

def forward(self, x):
first = self.first(x)
second = self.second(x)
return first, second

def get_inputs(self):
return (torch.randn(1, 3, 8, 8),)


class TestConv2d(unittest.TestCase):
def setUp(self):
torch._dynamo.reset()
Expand Down Expand Up @@ -264,37 +223,6 @@ def _test(
.run_method_and_compare_outputs(qtol=1)
)

def _test_dq(
self,
m: torch.nn.Module,
conv_count=1,
dynamic_shapes=None,
):
quant_config = get_symmetric_quantization_config(
is_per_channel=True,
is_dynamic=True,
)

DynamicallyQuantizedPartitioner = XnnpackPartitioner(
config_precisions=ConfigPrecisionType.DYNAMIC_QUANT,
per_op_mode=True,
)

tester = Tester(m, m.get_inputs(), dynamic_shapes=dynamic_shapes)
tester.quantize(Quantize(quantization_config=quant_config))
tester.export()
tester.check(["torch.ops.quantized_decomposed.choose_qparams"])
tester.to_edge_transform_and_lower(
ToEdgeTransformAndLower([DynamicallyQuantizedPartitioner])
)
tester.check_count(
{"torch.ops.higher_order.executorch_call_delegate": conv_count}
)
tester.check_not(["executorch_exir_dialects_edge__ops_aten_conv2d_default"])
tester.to_executorch()
tester.serialize()
tester.run_method_and_compare_outputs(qtol=1)

def test_fp16_conv2d(self) -> None:
for transpose in (True, False):
for has_bias in (True, False):
Expand Down Expand Up @@ -771,26 +699,3 @@ def forward(self, x):
.serialize()
.run_method_and_compare_outputs(qtol=1)
)

def test_dq_conv2d(self) -> None:
model = Conv2d(
in_channels=3,
out_channels=10,
kernel_size=(3, 3),
stride=(1, 1),
padding=(0, 0),
batches=1,
width=8,
height=8,
)
self._test_dq(model)

def test_dq_conv2d_seq(self) -> None:
model = Conv2dDQSeq()
conv_count = sum(1 for m in model.modules() if type(m) is torch.nn.Conv2d)
self._test_dq(model, conv_count)

def test_dq_conv2d_parallel(self) -> None:
model = Conv2dDQParallel()
conv_count = sum(1 for m in model.modules() if type(m) is torch.nn.Conv2d)
self._test_dq(model, conv_count)
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,10 @@
from executorch.backends.xnnpack._passes.channels_last_tagged_reshape_pass import (
ChannelsLastTaggedReshapePass,
)
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
)
from executorch.backends.xnnpack.test.test_xnnpack_utils_classes import (
OpSequencesAddConv2d,
)
from executorch.backends.xnnpack.test.tester import Quantize, RunPasses, Tester
from executorch.backends.xnnpack.test.tester import RunPasses, Tester


class TestChannelsLastTaggedReshapePass(unittest.TestCase):
Expand All @@ -38,10 +35,6 @@ def setUp(self):
dequant_name = "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default"
conv_name = "executorch_exir_dialects_edge__ops_aten_convolution_default"
relu_name = "executorch_exir_dialects_edge__ops_aten_relu_default"
choose_qparams_name = (
"executorch_exir_dialects_edge__ops_quantized_decomposed_choose_qparams_tensor"
)
dynamic_quant_name = "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_tensor"

def test_fp32_channels_last_tagged_reshape_pass(self):
for module, num_reshape in self.modules.items():
Expand Down Expand Up @@ -186,37 +179,3 @@ def test_fp32_channels_last_tagged_reshape_pass_conv_bn_hardtanh_mean_seq(self):
)
.run_method_and_compare_outputs()
)

class Conv2dDynamicQuant(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 10, 3)

def forward(self, x):
return self.conv(x)

def test_dq_conv2d_channels_last_tagged_reshape_pass(self) -> None:
(
Tester(self.Conv2dDynamicQuant().eval(), (torch.randn(1, 3, 8, 8),))
.quantize(
Quantize(
quantization_config=get_symmetric_quantization_config(
is_dynamic=True
)
)
)
.export()
.to_edge()
.run_passes(self.PassStage)
.check(
[
self.to_copy_name,
self.choose_qparams_name,
self.dynamic_quant_name,
self.dequant_name,
self.conv_name,
self.to_copy_name,
]
)
.run_method_and_compare_outputs()
)
Loading
Loading