diff --git a/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py b/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py index 89a44f303df..768df1f4f04 100644 --- a/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py +++ b/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py @@ -8,6 +8,7 @@ import torch from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass +from executorch.backends.xnnpack.utils.quant_utils import is_dynamic_qdq from executorch.backends.xnnpack.utils.utils import is_param_node from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import PassResult @@ -283,6 +284,14 @@ def input_to_nhwc( ] else: # Need to create NHWC node + # Check if input uses dynamic quantization + is_dynamic_input = is_dynamic_qdq(input_node) + + if is_dynamic_input: + # Trace back to original source node + while getattr(input_node, "args", None): + input_node = input_node.args[0] + with graph_module.graph.inserting_after(input_node): input_node_nhwc = self.create_call_function_node( graph_module=graph_module, @@ -290,7 +299,11 @@ def input_to_nhwc( args=(input_node,), memory_format=torch.channels_last, ) - self.mark_as_nhwc_node(input_node_nhwc) + + if is_dynamic_input: + # Replace downstream input_nodes with NHWC node + input_node.replace_all_uses_with(input_node_nhwc) + input_node_nhwc.args = (input_node,) self.insert_copy_and_assign_partner_nodes_quantization_sensitive( graph_module=graph_module, diff --git a/backends/xnnpack/operators/quant_params.py b/backends/xnnpack/operators/quant_params.py index e695b151560..fbee1d192cf 100644 --- a/backends/xnnpack/operators/quant_params.py +++ b/backends/xnnpack/operators/quant_params.py @@ -141,12 +141,27 @@ def quantize_tensor(self, tensor: torch.Tensor) -> torch.Tensor: tensor, self.scale, self.zp, self.qmin, self.qmax, self.dtype ) + # Temporary helper until non-batch dimensions can be inferred + # Detects if a node feeds into a conv op by checking all downstream users + @staticmethod + def _feeds_into_conv(node: torch.fx.Node) -> bool: + users_list = [node] + + while users_list: + current_user = users_list.pop() + if "convolution" in str(current_user.target): + return True + users_list.extend(current_user.users) + + return False + @classmethod def _from_dynamic_input_node(cls, quant_node: torch.fx.Node) -> QuantParams: q_input = quant_node.args[0] # fp32 input assert isinstance(q_input, torch.fx.Node) # TODO - materialize this from the quant_node scale count and val shape - num_nonbatch_dims = 1 + # Set non-batch dims to 3 if node feeds into conv (only 2D is supported), otherwise set to 1 for linear + num_nonbatch_dims = 3 if cls._feeds_into_conv(quant_node) else 1 return cls( per_channel=False, # True is not valid diff --git a/backends/xnnpack/partition/config/gemm_configs.py b/backends/xnnpack/partition/config/gemm_configs.py index 8712c2709ac..67bccbc52d1 100644 --- a/backends/xnnpack/partition/config/gemm_configs.py +++ b/backends/xnnpack/partition/config/gemm_configs.py @@ -9,6 +9,7 @@ from typing import cast, List, Optional, Tuple import torch +from executorch.backends.transforms import get_shape from executorch.backends.xnnpack.operators.quant_params import QuantParams from executorch.backends.xnnpack.partition.config.xnnpack_config import ( ConfigPrecisionType, @@ -27,6 +28,7 @@ ) from executorch.backends.xnnpack.utils.utils import ( get_input_node, + is_depthwise_conv, is_getitem, is_node, is_param_node, @@ -359,12 +361,23 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: return False # Only support 1D + 2D Conv kernel_node = get_input_node(node, 1) + kernel_shape = get_shape(kernel_node) weight_quant_params = QuantParams.from_weights(kernel_node, ep) - - is_transpose = node.args[6] groups = cast(int, node.args[8]) + is_transpose = node.args[6] + + # XNNPACK does not support dynamic quantization convs that are not 2D or are depthwise + if self._detect_precision(node) == ConfigPrecisionType.DYNAMIC_QUANT and ( + len(conv_stride) != 2 + or is_depthwise_conv(kernel_shape, groups, is_transpose) + ): + why( + node, + "XNNPACK only supports standard 2D convolutions for dynamic quantization", + ) + return False - # XNNPack does not support non-zero output padding in transposed + # XNNPACK does not support non-zero output padding in transposed # convolutions. if is_transpose and any( out_pad != 0 for out_pad in cast(List[int], node.args[7]) @@ -394,6 +407,7 @@ def supported_precision_types(self): return [ ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT, + ConfigPrecisionType.DYNAMIC_QUANT, ] diff --git a/backends/xnnpack/quantizer/xnnpack_quantizer.py b/backends/xnnpack/quantizer/xnnpack_quantizer.py index 0ddee53a41a..fdabd0383e6 100644 --- a/backends/xnnpack/quantizer/xnnpack_quantizer.py +++ b/backends/xnnpack/quantizer/xnnpack_quantizer.py @@ -265,6 +265,7 @@ class XNNPACKQuantizer(Quantizer): DYNAMIC_OPS = [ "linear", + "conv", ] def __init__(self) -> None: diff --git a/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py b/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py index ce459806c6e..4b961bef81d 100644 --- a/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py +++ b/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py @@ -6,6 +6,7 @@ import torch import torch.nn.functional as F +from executorch.backends.xnnpack.utils.utils import is_depthwise_conv from torch._subclasses import FakeTensor from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix from torch.ao.quantization.pt2e.export_utils import _WrapperModule @@ -29,7 +30,6 @@ ) from torch.fx.passes.utils.source_matcher_utils import get_source_partitions - __all__ = [ "OperatorConfig", "OperatorPatternType", @@ -323,6 +323,23 @@ def _do_annotate_conv( assert isinstance(weight, Node) input_qspec_map[weight] = get_weight_qspec(quantization_config) + # Only annotate dynamically quantized conv if it's 2D and not depthwise + if ( + quantization_config + and quantization_config.input_activation + and quantization_config.input_activation.is_dynamic + ): + weight_val = weight.meta.get("val", None) + weight_shape = getattr(weight_val, "shape", None) + + # Skip if not a 4D weight tensor (i.e. not conv2d) + if weight_shape is not None and len(weight_shape) != 4: + continue + + # Skip if depthwise (default to groups=1 since it's not an arg) + if is_depthwise_conv(weight_shape, 1, is_conv_transpose): + continue + # adding weight node to the partition as well partition = [conv_node, conv_node.args[1]] diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index c0204831c07..0b187d05df0 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -512,11 +512,6 @@ Error defineTensor( buffer_ptr == nullptr, Internal, "Dynamically quantized tensor should not have constant data but found non-nullptr"); - // TODO(T179441835): Dynamic Quantization with num_nonbatch_dims > 1 - ET_CHECK_OR_RETURN_ERROR( - qparams->num_nonbatch_dims() == 1, - Internal, - "Dynamically Quantized Tensors currently only support per token quantization"); status = xnn_define_dynamically_quantized_tensor_value( /*subgraph=*/subgraph_ptr, /*datatype=*/getDataType(tensor_value->datatype()), @@ -1172,7 +1167,7 @@ Error defineStaticTransposeNode( ET_CHECK_OR_RETURN_ERROR( status == xnn_status_success, Internal, - "Failed to create sigmoid node %i with code: %s", + "Failed to create static transpose node %i with code: %s", node->debug_handle(), xnn_status_to_string(status)); diff --git a/backends/xnnpack/test/ops/test_conv2d.py b/backends/xnnpack/test/ops/test_conv2d.py index 80b731bd18e..92bb03c907a 100644 --- a/backends/xnnpack/test/ops/test_conv2d.py +++ b/backends/xnnpack/test/ops/test_conv2d.py @@ -18,6 +18,10 @@ except: has_quantized_ops = False +from executorch.backends.xnnpack.partition.config.xnnpack_config import ( + ConfigPrecisionType, +) +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( get_symmetric_quantization_config, ) @@ -26,7 +30,7 @@ ) from executorch.backends.xnnpack.test.test_xnnpack_utils import randomize_bn from executorch.backends.xnnpack.test.tester import Quantize, Tester - +from executorch.backends.xnnpack.test.tester.tester import ToEdgeTransformAndLower from executorch.exir.dialects._ops import ops as exir_ops @@ -169,6 +173,43 @@ def get_inputs(self): return (torch.randn(2, 2, 4, 4),) +class Conv2dDQSeq(torch.nn.Module): + def __init__(self): + super().__init__() + self.first = torch.nn.Conv2d( + in_channels=3, out_channels=8, kernel_size=3, padding=1 + ) + self.second = torch.nn.Conv2d( + in_channels=8, out_channels=10, kernel_size=3, padding=1 + ) + + def forward(self, x): + y = self.first(x) + return self.second(y) + + def get_inputs(self): + return (torch.randn(1, 3, 8, 8),) + + +class Conv2dDQParallel(torch.nn.Module): + def __init__(self): + super().__init__() + self.first = torch.nn.Conv2d( + in_channels=3, out_channels=8, kernel_size=3, padding=1 + ) + self.second = torch.nn.Conv2d( + in_channels=3, out_channels=8, kernel_size=3, padding=1 + ) + + def forward(self, x): + first = self.first(x) + second = self.second(x) + return first, second + + def get_inputs(self): + return (torch.randn(1, 3, 8, 8),) + + class TestConv2d(unittest.TestCase): def setUp(self): torch._dynamo.reset() @@ -223,6 +264,37 @@ def _test( .run_method_and_compare_outputs(qtol=1) ) + def _test_dq( + self, + m: torch.nn.Module, + conv_count=1, + dynamic_shapes=None, + ): + quant_config = get_symmetric_quantization_config( + is_per_channel=True, + is_dynamic=True, + ) + + DynamicallyQuantizedPartitioner = XnnpackPartitioner( + config_precisions=ConfigPrecisionType.DYNAMIC_QUANT, + per_op_mode=True, + ) + + tester = Tester(m, m.get_inputs(), dynamic_shapes=dynamic_shapes) + tester.quantize(Quantize(quantization_config=quant_config)) + tester.export() + tester.check(["torch.ops.quantized_decomposed.choose_qparams"]) + tester.to_edge_transform_and_lower( + ToEdgeTransformAndLower([DynamicallyQuantizedPartitioner]) + ) + tester.check_count( + {"torch.ops.higher_order.executorch_call_delegate": conv_count} + ) + tester.check_not(["executorch_exir_dialects_edge__ops_aten_conv2d_default"]) + tester.to_executorch() + tester.serialize() + tester.run_method_and_compare_outputs(qtol=1) + def test_fp16_conv2d(self) -> None: for transpose in (True, False): for has_bias in (True, False): @@ -699,3 +771,26 @@ def forward(self, x): .serialize() .run_method_and_compare_outputs(qtol=1) ) + + def test_dq_conv2d(self) -> None: + model = Conv2d( + in_channels=3, + out_channels=10, + kernel_size=(3, 3), + stride=(1, 1), + padding=(0, 0), + batches=1, + width=8, + height=8, + ) + self._test_dq(model) + + def test_dq_conv2d_seq(self) -> None: + model = Conv2dDQSeq() + conv_count = sum(1 for m in model.modules() if type(m) is torch.nn.Conv2d) + self._test_dq(model, conv_count) + + def test_dq_conv2d_parallel(self) -> None: + model = Conv2dDQParallel() + conv_count = sum(1 for m in model.modules() if type(m) is torch.nn.Conv2d) + self._test_dq(model, conv_count) diff --git a/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py b/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py index 6d60f9d76b5..a00209f4ea6 100644 --- a/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py +++ b/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py @@ -10,10 +10,13 @@ from executorch.backends.xnnpack._passes.channels_last_tagged_reshape_pass import ( ChannelsLastTaggedReshapePass, ) +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, +) from executorch.backends.xnnpack.test.test_xnnpack_utils_classes import ( OpSequencesAddConv2d, ) -from executorch.backends.xnnpack.test.tester import RunPasses, Tester +from executorch.backends.xnnpack.test.tester import Quantize, RunPasses, Tester class TestChannelsLastTaggedReshapePass(unittest.TestCase): @@ -35,6 +38,10 @@ def setUp(self): dequant_name = "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default" conv_name = "executorch_exir_dialects_edge__ops_aten_convolution_default" relu_name = "executorch_exir_dialects_edge__ops_aten_relu_default" + choose_qparams_name = ( + "executorch_exir_dialects_edge__ops_quantized_decomposed_choose_qparams_tensor" + ) + dynamic_quant_name = "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_tensor" def test_fp32_channels_last_tagged_reshape_pass(self): for module, num_reshape in self.modules.items(): @@ -179,3 +186,37 @@ def test_fp32_channels_last_tagged_reshape_pass_conv_bn_hardtanh_mean_seq(self): ) .run_method_and_compare_outputs() ) + + class Conv2dDynamicQuant(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 10, 3) + + def forward(self, x): + return self.conv(x) + + def test_dq_conv2d_channels_last_tagged_reshape_pass(self) -> None: + ( + Tester(self.Conv2dDynamicQuant().eval(), (torch.randn(1, 3, 8, 8),)) + .quantize( + Quantize( + quantization_config=get_symmetric_quantization_config( + is_dynamic=True + ) + ) + ) + .export() + .to_edge() + .run_passes(self.PassStage) + .check( + [ + self.to_copy_name, + self.choose_qparams_name, + self.dynamic_quant_name, + self.dequant_name, + self.conv_name, + self.to_copy_name, + ] + ) + .run_method_and_compare_outputs() + ) diff --git a/backends/xnnpack/utils/utils.py b/backends/xnnpack/utils/utils.py index fab95618807..b23fd444117 100644 --- a/backends/xnnpack/utils/utils.py +++ b/backends/xnnpack/utils/utils.py @@ -158,3 +158,33 @@ def get_source_fn(node: torch.fx.Node) -> Optional[torch.fx.Node]: return None source_fn = source_fn_st[-1] return source_fn[1] + + +def is_depthwise_conv( + kernel_shape: Tuple[int, ...], groups: int = 1, is_transpose: bool = False +) -> bool: + """ + A convolution is depthwise if: + 1) groups = input_channels (i.e. group_input_channels = 1) + 2) output_channels is a positive integer multiple of input channels + + For standard convolutions: + weight shape = (out_channels, in_channels_per_group, height, width) + For transposed convolutions: + weight shape = (in_channels, out_channels_per_group, height, width) + + Returns True if the convolution is depthwise + """ + if len(kernel_shape) < 2 or groups < 1: + return False + + if is_transpose: + group_input_channels = int(kernel_shape[0] / groups) + group_output_channels = kernel_shape[1] + else: + group_input_channels = kernel_shape[1] + group_output_channels = int(kernel_shape[0] / groups) + + return ( + group_input_channels == 1 and group_output_channels % group_input_channels == 0 + )