Skip to content

Support dynamically quantized 2D convolutions #10347

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
0b5b0e8
WIP: add initial support for dq 2D conv
keyprocedure Apr 8, 2025
8fcb117
Permute before quant
keyprocedure Apr 12, 2025
4d064da
Refactor permute code
keyprocedure Apr 12, 2025
2905b98
Corrects input to conv
keyprocedure Apr 12, 2025
0fef04a
Add is_dequant check for trace back when inserting permute
keyprocedure Apr 12, 2025
f8f998c
Fix node identity check
keyprocedure Apr 12, 2025
2efe9bb
Use existing is_dequant check and update atol
keyprocedure Apr 13, 2025
3762e0d
Implement replace_all_uses_with function
keyprocedure Apr 15, 2025
4112c6a
Remove cmake file
keyprocedure Apr 15, 2025
cdd6f2d
Restore original supported conv2d operators
keyprocedure Apr 16, 2025
7150872
Add dynamic quant check before NHWC permute
keyprocedure Apr 16, 2025
6b44c4b
Refactor dq conv2d test
keyprocedure Apr 16, 2025
7054f2e
Revert formatting
keyprocedure Apr 16, 2025
fc48e03
Add check to only annotate dq conv2d
keyprocedure Apr 16, 2025
84b3634
Remove unused import
keyprocedure Apr 16, 2025
62e30e5
Add computation for non-batch dims; remove non-batch dims check
keyprocedure Apr 16, 2025
3c7fe32
Refactor test and imports
keyprocedure Apr 16, 2025
064671b
Update comments
keyprocedure Apr 16, 2025
228dc0b
Merge branch 'main' into support-dynamically-quantized-convolutions
keyprocedure Apr 16, 2025
b29030e
Add unit tests for dynamic quant sequential and parallel convs
keyprocedure Apr 20, 2025
6da8b7d
Add unit test for dynamic quant conv2d with channels-last permute
keyprocedure Apr 20, 2025
7c53454
Add check to determine if node feeds into conv and set non-batch dims…
keyprocedure Apr 20, 2025
eaba819
Add depthwise conv checks for dynamic quant
keyprocedure Apr 20, 2025
5a01127
Merge branch 'main' into support-dynamically-quantized-convolutions
keyprocedure Apr 21, 2025
e336df6
Move depthwise conv check to helper function in utils
keyprocedure Apr 21, 2025
d82e080
Use existing Conv2d class; get conv count from model
keyprocedure Apr 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
from executorch.backends.xnnpack.utils.quant_utils import is_dynamic_qdq
from executorch.backends.xnnpack.utils.utils import is_param_node
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import PassResult
Expand Down Expand Up @@ -283,14 +284,26 @@ def input_to_nhwc(
]
else:
# Need to create NHWC node
# Check if input uses dynamic quantization
is_dynamic_input = is_dynamic_qdq(input_node)

if is_dynamic_input:
# Trace back to original source node
while getattr(input_node, "args", None):
input_node = input_node.args[0]

with graph_module.graph.inserting_after(input_node):
input_node_nhwc = self.create_call_function_node(
graph_module=graph_module,
target=exir_ops.edge.aten._to_copy.default,
args=(input_node,),
memory_format=torch.channels_last,
)
self.mark_as_nhwc_node(input_node_nhwc)

if is_dynamic_input:
# Replace downstream input_nodes with NHWC node
input_node.replace_all_uses_with(input_node_nhwc)
input_node_nhwc.args = (input_node,)

self.insert_copy_and_assign_partner_nodes_quantization_sensitive(
graph_module=graph_module,
Expand Down
17 changes: 16 additions & 1 deletion backends/xnnpack/operators/quant_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,27 @@ def quantize_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
tensor, self.scale, self.zp, self.qmin, self.qmax, self.dtype
)

# Temporary helper until non-batch dimensions can be inferred
# Detects if a node feeds into a conv op by checking all downstream users
@staticmethod
def _feeds_into_conv(node: torch.fx.Node) -> bool:
users_list = [node]

while users_list:
current_user = users_list.pop()
if "convolution" in str(current_user.target):
return True
users_list.extend(current_user.users)

return False

@classmethod
def _from_dynamic_input_node(cls, quant_node: torch.fx.Node) -> QuantParams:
q_input = quant_node.args[0] # fp32 input
assert isinstance(q_input, torch.fx.Node)
# TODO - materialize this from the quant_node scale count and val shape
num_nonbatch_dims = 1
# Set non-batch dims to 3 if node feeds into conv (only 2D is supported), otherwise set to 1 for linear
num_nonbatch_dims = 3 if cls._feeds_into_conv(quant_node) else 1

return cls(
per_channel=False, # True is not valid
Expand Down
20 changes: 17 additions & 3 deletions backends/xnnpack/partition/config/gemm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import cast, List, Optional, Tuple

import torch
from executorch.backends.transforms import get_shape
from executorch.backends.xnnpack.operators.quant_params import QuantParams
from executorch.backends.xnnpack.partition.config.xnnpack_config import (
ConfigPrecisionType,
Expand All @@ -27,6 +28,7 @@
)
from executorch.backends.xnnpack.utils.utils import (
get_input_node,
is_depthwise_conv,
is_getitem,
is_node,
is_param_node,
Expand Down Expand Up @@ -359,12 +361,23 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
return False # Only support 1D + 2D Conv

kernel_node = get_input_node(node, 1)
kernel_shape = get_shape(kernel_node)
weight_quant_params = QuantParams.from_weights(kernel_node, ep)

is_transpose = node.args[6]
groups = cast(int, node.args[8])
is_transpose = node.args[6]

# XNNPACK does not support dynamic quantization convs that are not 2D or are depthwise
if self._detect_precision(node) == ConfigPrecisionType.DYNAMIC_QUANT and (
len(conv_stride) != 2
or is_depthwise_conv(kernel_shape, groups, is_transpose)
):
why(
node,
"XNNPACK only supports standard 2D convolutions for dynamic quantization",
)
return False

# XNNPack does not support non-zero output padding in transposed
# XNNPACK does not support non-zero output padding in transposed
# convolutions.
if is_transpose and any(
out_pad != 0 for out_pad in cast(List[int], node.args[7])
Expand Down Expand Up @@ -394,6 +407,7 @@ def supported_precision_types(self):
return [
ConfigPrecisionType.FP32,
ConfigPrecisionType.STATIC_QUANT,
ConfigPrecisionType.DYNAMIC_QUANT,
]


Expand Down
1 change: 1 addition & 0 deletions backends/xnnpack/quantizer/xnnpack_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ class XNNPACKQuantizer(Quantizer):

DYNAMIC_OPS = [
"linear",
"conv",
]

def __init__(self) -> None:
Expand Down
19 changes: 18 additions & 1 deletion backends/xnnpack/quantizer/xnnpack_quantizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import torch
import torch.nn.functional as F
from executorch.backends.xnnpack.utils.utils import is_depthwise_conv
from torch._subclasses import FakeTensor
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
from torch.ao.quantization.pt2e.export_utils import _WrapperModule
Expand All @@ -29,7 +30,6 @@
)
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions


__all__ = [
"OperatorConfig",
"OperatorPatternType",
Expand Down Expand Up @@ -323,6 +323,23 @@ def _do_annotate_conv(
assert isinstance(weight, Node)
input_qspec_map[weight] = get_weight_qspec(quantization_config)

# Only annotate dynamically quantized conv if it's 2D and not depthwise
if (
quantization_config
and quantization_config.input_activation
and quantization_config.input_activation.is_dynamic
):
weight_val = weight.meta.get("val", None)
weight_shape = getattr(weight_val, "shape", None)

# Skip if not a 4D weight tensor (i.e. not conv2d)
if weight_shape is not None and len(weight_shape) != 4:
continue

# Skip if depthwise (default to groups=1 since it's not an arg)
if is_depthwise_conv(weight_shape, 1, is_conv_transpose):
continue

# adding weight node to the partition as well
partition = [conv_node, conv_node.args[1]]

Expand Down
7 changes: 1 addition & 6 deletions backends/xnnpack/runtime/XNNCompiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -512,11 +512,6 @@ Error defineTensor(
buffer_ptr == nullptr,
Internal,
"Dynamically quantized tensor should not have constant data but found non-nullptr");
// TODO(T179441835): Dynamic Quantization with num_nonbatch_dims > 1
ET_CHECK_OR_RETURN_ERROR(
qparams->num_nonbatch_dims() == 1,
Internal,
"Dynamically Quantized Tensors currently only support per token quantization");
status = xnn_define_dynamically_quantized_tensor_value(
/*subgraph=*/subgraph_ptr,
/*datatype=*/getDataType(tensor_value->datatype()),
Expand Down Expand Up @@ -1172,7 +1167,7 @@ Error defineStaticTransposeNode(
ET_CHECK_OR_RETURN_ERROR(
status == xnn_status_success,
Internal,
"Failed to create sigmoid node %i with code: %s",
"Failed to create static transpose node %i with code: %s",
node->debug_handle(),
xnn_status_to_string(status));

Expand Down
97 changes: 96 additions & 1 deletion backends/xnnpack/test/ops/test_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
except:
has_quantized_ops = False

from executorch.backends.xnnpack.partition.config.xnnpack_config import (
ConfigPrecisionType,
)
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
)
Expand All @@ -26,7 +30,7 @@
)
from executorch.backends.xnnpack.test.test_xnnpack_utils import randomize_bn
from executorch.backends.xnnpack.test.tester import Quantize, Tester

from executorch.backends.xnnpack.test.tester.tester import ToEdgeTransformAndLower
from executorch.exir.dialects._ops import ops as exir_ops


Expand Down Expand Up @@ -169,6 +173,43 @@ def get_inputs(self):
return (torch.randn(2, 2, 4, 4),)


class Conv2dDQSeq(torch.nn.Module):
def __init__(self):
super().__init__()
self.first = torch.nn.Conv2d(
in_channels=3, out_channels=8, kernel_size=3, padding=1
)
self.second = torch.nn.Conv2d(
in_channels=8, out_channels=10, kernel_size=3, padding=1
)

def forward(self, x):
y = self.first(x)
return self.second(y)

def get_inputs(self):
return (torch.randn(1, 3, 8, 8),)


class Conv2dDQParallel(torch.nn.Module):
def __init__(self):
super().__init__()
self.first = torch.nn.Conv2d(
in_channels=3, out_channels=8, kernel_size=3, padding=1
)
self.second = torch.nn.Conv2d(
in_channels=3, out_channels=8, kernel_size=3, padding=1
)

def forward(self, x):
first = self.first(x)
second = self.second(x)
return first, second

def get_inputs(self):
return (torch.randn(1, 3, 8, 8),)


class TestConv2d(unittest.TestCase):
def setUp(self):
torch._dynamo.reset()
Expand Down Expand Up @@ -223,6 +264,37 @@ def _test(
.run_method_and_compare_outputs(qtol=1)
)

def _test_dq(
self,
m: torch.nn.Module,
conv_count=1,
dynamic_shapes=None,
):
quant_config = get_symmetric_quantization_config(
is_per_channel=True,
is_dynamic=True,
)

DynamicallyQuantizedPartitioner = XnnpackPartitioner(
config_precisions=ConfigPrecisionType.DYNAMIC_QUANT,
per_op_mode=True,
)

tester = Tester(m, m.get_inputs(), dynamic_shapes=dynamic_shapes)
tester.quantize(Quantize(quantization_config=quant_config))
tester.export()
tester.check(["torch.ops.quantized_decomposed.choose_qparams"])
tester.to_edge_transform_and_lower(
ToEdgeTransformAndLower([DynamicallyQuantizedPartitioner])
)
tester.check_count(
{"torch.ops.higher_order.executorch_call_delegate": conv_count}
)
tester.check_not(["executorch_exir_dialects_edge__ops_aten_conv2d_default"])
tester.to_executorch()
tester.serialize()
tester.run_method_and_compare_outputs(qtol=1)

def test_fp16_conv2d(self) -> None:
for transpose in (True, False):
for has_bias in (True, False):
Expand Down Expand Up @@ -699,3 +771,26 @@ def forward(self, x):
.serialize()
.run_method_and_compare_outputs(qtol=1)
)

def test_dq_conv2d(self) -> None:
model = Conv2d(
in_channels=3,
out_channels=10,
kernel_size=(3, 3),
stride=(1, 1),
padding=(0, 0),
batches=1,
width=8,
height=8,
)
self._test_dq(model)

def test_dq_conv2d_seq(self) -> None:
model = Conv2dDQSeq()
conv_count = sum(1 for m in model.modules() if type(m) is torch.nn.Conv2d)
self._test_dq(model, conv_count)

def test_dq_conv2d_parallel(self) -> None:
model = Conv2dDQParallel()
conv_count = sum(1 for m in model.modules() if type(m) is torch.nn.Conv2d)
self._test_dq(model, conv_count)
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@
from executorch.backends.xnnpack._passes.channels_last_tagged_reshape_pass import (
ChannelsLastTaggedReshapePass,
)
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
)
from executorch.backends.xnnpack.test.test_xnnpack_utils_classes import (
OpSequencesAddConv2d,
)
from executorch.backends.xnnpack.test.tester import RunPasses, Tester
from executorch.backends.xnnpack.test.tester import Quantize, RunPasses, Tester


class TestChannelsLastTaggedReshapePass(unittest.TestCase):
Expand All @@ -35,6 +38,10 @@ def setUp(self):
dequant_name = "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default"
conv_name = "executorch_exir_dialects_edge__ops_aten_convolution_default"
relu_name = "executorch_exir_dialects_edge__ops_aten_relu_default"
choose_qparams_name = (
"executorch_exir_dialects_edge__ops_quantized_decomposed_choose_qparams_tensor"
)
dynamic_quant_name = "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_tensor"

def test_fp32_channels_last_tagged_reshape_pass(self):
for module, num_reshape in self.modules.items():
Expand Down Expand Up @@ -179,3 +186,37 @@ def test_fp32_channels_last_tagged_reshape_pass_conv_bn_hardtanh_mean_seq(self):
)
.run_method_and_compare_outputs()
)

class Conv2dDynamicQuant(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 10, 3)

def forward(self, x):
return self.conv(x)

def test_dq_conv2d_channels_last_tagged_reshape_pass(self) -> None:
(
Tester(self.Conv2dDynamicQuant().eval(), (torch.randn(1, 3, 8, 8),))
.quantize(
Quantize(
quantization_config=get_symmetric_quantization_config(
is_dynamic=True
)
)
)
.export()
.to_edge()
.run_passes(self.PassStage)
.check(
[
self.to_copy_name,
self.choose_qparams_name,
self.dynamic_quant_name,
self.dequant_name,
self.conv_name,
self.to_copy_name,
]
)
.run_method_and_compare_outputs()
)
Loading
Loading