diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py index fb1f985edb9..9c884d7ab93 100644 --- a/backends/qualcomm/_passes/__init__.py +++ b/backends/qualcomm/_passes/__init__.py @@ -4,51 +4,51 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from .annotate_decomposed import AnnotateDecomposed from .annotate_quant_attrs import AnnotateQuantAttrs -from .constant_i64_to_i32 import ConstantI64toI32 +from .annotate_stack import AnnotateStack +from .annotate_unbind import AnnotateUnbind from .convert_bmm_to_matmul import ConvertBmmToMatmul from .convert_conv1d_to_conv2d import ConvertConv1dToConv2d -from .convert_to_linear import ConvertToLinear from .decompose_any import DecomposeAny from .decompose_einsum import DecomposeEinsum from .decompose_expm1 import DecomposeExpM1 from .decompose_linalg_vector_norm import DecomposeLinalgVectorNorm from .decompose_silu import DecomposeSilu from .expand_broadcast_tensor_shape import ExpandBroadcastTensorShape +from .fixed_linear_keep_dim import FixedLinearKeepDim from .fold_qdq import FoldQDQ from .fuse_consecutive_transpose import FuseConsecutiveTranspose +from .i64_to_i32 import I64toI32 from .insert_io_qdq import InsertIOQDQ from .insert_requantize import InsertRequantize from .layout_transform import LayoutTransform from .lift_constant_scalar_operands import LiftConstantScalarOperands from .recompose_pixel_unshuffle import RecomposePixelUnshuffle -from .recompose_prelu import RecomposePReLU from .recompose_rms_norm import RecomposeRmsNorm from .reduce_dynamic_range import ReduceDynamicRange from .remove_redundancy import RemoveRedundancy from .replace_arange_args import ReplaceArangeArgs from .replace_index_put_input import ReplaceIndexPutInput from .replace_inf_values import ReplaceInfValues -from .tensor_i64_to_i32 import TensorI64toI32 +from .tag_quant_io import TagQuantIO __all__ = [ - AnnotateDecomposed, AnnotateQuantAttrs, - ConstantI64toI32, + AnnotateStack, + AnnotateUnbind, ConvertBmmToMatmul, ConvertConv1dToConv2d, - RecomposePReLU, - ConvertToLinear, DecomposeAny, DecomposeEinsum, DecomposeExpM1, DecomposeLinalgVectorNorm, DecomposeSilu, ExpandBroadcastTensorShape, + FixedLinearKeepDim, FoldQDQ, FuseConsecutiveTranspose, + I64toI32, InsertIOQDQ, InsertRequantize, LayoutTransform, @@ -60,5 +60,5 @@ ReplaceArangeArgs, ReplaceIndexPutInput, ReplaceInfValues, - TensorI64toI32, + TagQuantIO, ] diff --git a/backends/qualcomm/_passes/annotate_decomposed.py b/backends/qualcomm/_passes/annotate_stack.py similarity index 63% rename from backends/qualcomm/_passes/annotate_decomposed.py rename to backends/qualcomm/_passes/annotate_stack.py index 918b705e5e9..c42804af2f2 100644 --- a/backends/qualcomm/_passes/annotate_decomposed.py +++ b/backends/qualcomm/_passes/annotate_stack.py @@ -8,31 +8,21 @@ from executorch.exir.pass_base import ExportPass, PassResult from torch.fx.passes.utils.source_matcher_utils import get_source_partitions -from .utils import dq_ops, get_quant_attrs, q_ops +from .utils import get_quant_attrs, q_ops -class AnnotateDecomposed(ExportPass): +class AnnotateStack(ExportPass): """ Add "quant_attrs" to graph nodes' meta from the QDQ information generated after quantization process. """ - decomp_ops = [torch.ops.aten.stack.default, torch.ops.aten.unbind.int] + decomp_ops = [torch.ops.aten.unbind.int] def __init__(self, edge_program: torch.export.ExportedProgram): - super(AnnotateDecomposed, self).__init__() + super(AnnotateStack, self).__init__() self.edge_program = edge_program - def _annotate_unbind(self, graph_module: torch.fx.GraphModule): - partitions = get_source_partitions(graph_module.graph, [torch.unbind, "unbind"]) - for _, src_partitions in partitions.items(): - for src_partition in src_partitions: - if src_partition.input_nodes[0].target in dq_ops: - q_node = src_partition.input_nodes[0].args[0] - quant_attrs = get_quant_attrs(self.edge_program, q_node) - for n in src_partition.nodes: - n.meta[QCOM_QUANT_ATTRS] = quant_attrs.copy() - def _annotate_stack(self, graph_module: torch.fx.GraphModule): partitions = get_source_partitions(graph_module.graph, [torch.stack, "stack"]) for _, src_partitions in partitions.items(): @@ -46,7 +36,6 @@ def _annotate_stack(self, graph_module: torch.fx.GraphModule): n.meta[QCOM_QUANT_ATTRS] = quant_attrs.copy() def call(self, graph_module: torch.fx.GraphModule): - self._annotate_unbind(graph_module) self._annotate_stack(graph_module) graph_module.recompile() return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/annotate_unbind.py b/backends/qualcomm/_passes/annotate_unbind.py new file mode 100644 index 00000000000..0efa1638bc4 --- /dev/null +++ b/backends/qualcomm/_passes/annotate_unbind.py @@ -0,0 +1,39 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import torch +from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS +from executorch.exir.pass_base import ExportPass, PassResult +from torch.fx.passes.utils.source_matcher_utils import get_source_partitions + +from .utils import dq_ops, get_quant_attrs + + +class AnnotateUnbind(ExportPass): + """ + Add "quant_attrs" to graph nodes' meta from the QDQ information + generated after quantization process. + """ + + decomp_ops = [torch.ops.aten.unbind.int] + + def __init__(self, edge_program: torch.export.ExportedProgram): + super(AnnotateUnbind, self).__init__() + self.edge_program = edge_program + + def _annotate_unbind(self, graph_module: torch.fx.GraphModule): + partitions = get_source_partitions(graph_module.graph, [torch.unbind, "unbind"]) + for _, src_partitions in partitions.items(): + for src_partition in src_partitions: + if src_partition.input_nodes[0].target in dq_ops: + q_node = src_partition.input_nodes[0].args[0] + quant_attrs = get_quant_attrs(self.edge_program, q_node) + for n in src_partition.nodes: + n.meta[QCOM_QUANT_ATTRS] = quant_attrs.copy() + + def call(self, graph_module: torch.fx.GraphModule): + self._annotate_unbind(graph_module) + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/build_quant_io.py b/backends/qualcomm/_passes/build_quant_io.py index b627b9b3052..b34d50c4e24 100644 --- a/backends/qualcomm/_passes/build_quant_io.py +++ b/backends/qualcomm/_passes/build_quant_io.py @@ -27,7 +27,7 @@ def _make_spec(self, x): return None def _build(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: - # forcely update delegate node's meta['spec'] to get correct output + # Forcedly update delegate node's meta['spec'] to get correct output # tensor size in runtime call_delegate = [ node @@ -35,17 +35,9 @@ def _build(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: if node.op == "call_function" and node.name == "executorch_call_delegate" ] assert len(call_delegate) == 1 - spec = [] for n in graph_module.graph.nodes: if QCOM_QUANTIZED_IO in n.meta: n.meta["val"] = n.meta["val"].to(dtype=n.meta[QCOM_QUANTIZED_IO]) - if n.op == "call_function" and "getitem" in n.name: - fake_tensor = n.meta["val"] - if QCOM_QUANTIZED_IO in n.meta: - fake_tensor = fake_tensor.to(dtype=n.meta[QCOM_QUANTIZED_IO]) - spec.append(self._make_spec(fake_tensor)) - - call_delegate[0].meta["spec"] = tuple(spec) def call(self, graph_module: torch.fx.GraphModule): self._build(graph_module) diff --git a/backends/qualcomm/_passes/constant_i64_to_i32.py b/backends/qualcomm/_passes/constant_i64_to_i32.py deleted file mode 100644 index 9b5178b386e..00000000000 --- a/backends/qualcomm/_passes/constant_i64_to_i32.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright (c) Qualcomm Innovation Center, Inc. -# All rights reserved -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -from typing import FrozenSet - -import torch -from executorch.backends.qualcomm.builders.utils import get_parameter, is_constant -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass, PassResult -from torch._subclasses.fake_tensor import FakeTensor - - -class ConstantI64toI32(ExportPass): - """ - Cast unsupported int64 datatype into int32. - This will only be applied on constant nodes such as weights. - """ - - def __init__( - self, - edge_program: torch.export.ExportedProgram, - skip_node: FrozenSet[str] = frozenset(), - ): - super(ConstantI64toI32, self).__init__() - self.edge_program = edge_program - self.skip_node = skip_node - # pyre-ignore[4] - self.copy_op = exir_ops.edge.aten._to_copy.default - - def _update_meta(self, node: torch.fx.node) -> None: - meta_val = node.meta["val"] - if isinstance(meta_val, tuple): - node.meta["val"] = ( - ( - fake_tensor.to(torch.int32) - if fake_tensor.dtype == torch.int64 - else fake_tensor - ) - for fake_tensor in meta_val - ) - else: - if meta_val.dtype == torch.int64: - node.meta["val"] = meta_val.to(torch.float) - - # pyre-ignore[2] - def _is_tensor_of_dtype(self, node_val, dtype: torch.dtype) -> bool: - return isinstance(node_val, FakeTensor) and node_val.dtype == dtype - - def _cast_to_int32(self, graph_module: torch.fx.GraphModule): - for n in graph_module.graph.nodes: - if n.target in self.skip_node: - continue - if is_constant(n, self.edge_program): - param = get_parameter(n, self.edge_program) - if param.dtype == torch.int64: - # QNN does not support int64 - self._update_meta(n) - elif n.op == "placeholder": - node_val = n.meta["val"] - if self._is_tensor_of_dtype(node_val, torch.int64): - with graph_module.graph.inserting_after(n): - args = (n,) - to_dst_node = graph_module.graph.create_node( - "call_function", - self.copy_op, - args, - {"dtype": torch.int32}, - ) - to_dst_node.meta["val"] = node_val.to(torch.int32) - - # Replace usage of the src dtype result with the dst dtype result. - n.replace_all_uses_with(to_dst_node) - to_dst_node.args = (n,) - - def call(self, graph_module: torch.fx.GraphModule): - self._cast_to_int32(graph_module) - graph_module.recompile() - graph_module = super().call(graph_module).graph_module - return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/convert_to_linear.py b/backends/qualcomm/_passes/convert_to_linear.py deleted file mode 100644 index 48883571a0c..00000000000 --- a/backends/qualcomm/_passes/convert_to_linear.py +++ /dev/null @@ -1,231 +0,0 @@ -# Copyright (c) Qualcomm Innovation Center, Inc. -# All rights reserved -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -from collections import Counter -from typing import Callable, List - -import torch -from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS -from executorch.backends.transforms.addmm_mm_to_linear import ( - apply_addmm_mm_to_linear_transform, -) -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.dialects.edge._ops import EdgeOpOverload as edge_op -from executorch.exir.pass_base import ExportPass, PassResult -from executorch.exir.passes import dead_code_elimination_pass - -from torch.fx.passes.utils.source_matcher_utils import ( - get_source_partitions, - SourcePartition, -) - -from .utils import dq_ops, get_quant_attrs, q_ops - - -class ConvertToLinear(ExportPass): - """ - Handle missing quantization tag for addmm op after decomposing - """ - - view_copy = exir_ops.edge.aten.view_copy.default - permute_copy = exir_ops.edge.aten.permute_copy.default - expand_copy = exir_ops.edge.aten.expand_copy.default - linear = exir_ops.edge.aten.linear.default - add = exir_ops.edge.aten.add.Tensor - addmm = exir_ops.edge.aten.addmm.default - bmm = exir_ops.edge.aten.bmm.default - mm = exir_ops.edge.aten.mm.default - - addmm_patterns = [ - {view_copy: 1, permute_copy: 1, addmm: 1}, - {view_copy: 2, permute_copy: 1, addmm: 1}, - {permute_copy: 1, addmm: 1}, - ] - - bmm_patterns = [ - {view_copy: 3, permute_copy: 1, expand_copy: 2, add: 1, bmm: 1}, - {view_copy: 3, permute_copy: 1, expand_copy: 2, bmm: 1}, - ] - - mm_patterns = [ - {view_copy: 2, permute_copy: 1, mm: 1}, - {permute_copy: 1, mm: 1}, - ] - - def __init__(self): - super(ConvertToLinear, self).__init__() - - def _get_original_input( - self, inputs: List[torch.fx.Node], cur_node: torch.fx.Node - ) -> torch.fx.Node: - while cur_node not in inputs and cur_node.args: - cur_node = cur_node.args[0] - return cur_node - - def _convert_to_linear( - self, - gm: torch.fx.GraphModule, - src_partition: SourcePartition, - extract_ops_fn: Callable, - ): - inputs = src_partition.input_nodes - # output_nodes contains output node and input buffer such as argX_X - outputs = [ - node - for node in src_partition.output_nodes - if node.target != torch.ops.aten.sym_size.int and node.op != "placeholder" - ] - assert ( - len(outputs) == 1 - ), f"Unexpected number of outputs for a torch.nn.Linear module, expecting 1 but got {outputs}" - output = outputs[0] - - ops = extract_ops_fn(src_partition.nodes) - input_node, weight_node, fn_node = ops[:3] - bias_node = None if len(ops) == 3 else ops[3] - - # qnn htp does not support keepdim, the view_copy(reshape) should exist for now - if self._get_original_input(inputs, input_node).target in dq_ops: - input_node.meta[QCOM_QUANT_ATTRS] = get_quant_attrs( - gm, self._get_original_input(inputs, input_node).args[0] - ) - args = [input_node, weight_node] - if bias_node: - args.append(bias_node) - - # We need a view copy node after linear op - with gm.graph.inserting_before(output): - linear_node = gm.graph.create_node( - "call_function", self.linear, tuple(args) - ) - linear_node.meta = fn_node.meta - if list(output.users)[0].target in q_ops: - linear_node.meta[QCOM_QUANT_ATTRS] = get_quant_attrs( - gm, list(output.users)[0] - ) - for user in fn_node.users.copy(): - user.replace_input_with(fn_node, linear_node) - - # Since QNN has no keep dims for linear op, we will need to add squeeze and unsqueeze around linear node - # TODO: Find a more general conditional statement. - linear_output = linear_node.meta["val"] - if linear_output.dim() >= 3: - with gm.graph.inserting_after(input_node): - input_users = list(input_node.users.keys()) - input_tensor = input_node.meta["val"] - squeeze_dim = (-1, input_tensor.shape[-1]) - squeeze_node = gm.graph.create_node( - "call_function", - self.view_copy, - ( - input_node, - squeeze_dim, - ), - ) - # meta needs to be copied elementwisely for fake-tensor - # to be updated correctly and not affect meta of input_node - for k, v in input_node.meta.items(): - squeeze_node.meta[k] = v - squeeze_node.meta["val"] = input_tensor.reshape(squeeze_dim) - for user in input_users: - if user == linear_node: - user.replace_input_with(input_node, squeeze_node) - - with gm.graph.inserting_after(linear_node): - output_users = list(linear_node.users.keys()) - unsqueeze_dim = linear_output.shape - unsqueeze_node = gm.graph.create_node( - "call_function", - self.view_copy, - ( - linear_node, - unsqueeze_dim, - ), - ) - # meta needs to be copied elementwisely for fake-tensor - # to be updated correctly and not affect meta of unsqueeze_node - for k, v in linear_node.meta.items(): - unsqueeze_node.meta[k] = v - # update linear node's shape - linear_node.meta["val"] = linear_output.reshape( - (squeeze_node.meta["val"].shape[0], linear_output.shape[-1]) - ) - for user in output_users: - user.replace_input_with(linear_node, unsqueeze_node) - - def _extract_mm_ops(self, partitioned_nodes: List[edge_op]) -> List[torch.fx.Node]: - mm_node = [n for n in partitioned_nodes if n.target == self.mm][0] - # weight -> permute -> input of mm - weight_node = mm_node.args[1].args[0] - input_node = mm_node.args[0] - return [input_node, weight_node, mm_node] - - def _extract_addmm_ops( - self, partitioned_nodes: List[edge_op] - ) -> List[torch.fx.Node]: - addmm_node = [n for n in partitioned_nodes if n.target == self.addmm][0] - # weight -> permute -> input of addmm - weight_node = addmm_node.args[2].args[0] - input_node = addmm_node.args[1] - bias_node = addmm_node.args[0] - return [input_node, weight_node, addmm_node, bias_node] - - def _extract_bmm_ops(self, partitioned_nodes: List[edge_op]) -> List[torch.fx.Node]: - bmm_node = [n for n in partitioned_nodes if n.target == self.bmm][0] - add_node = [n for n in partitioned_nodes if n.target == self.add] - - # weight -> expand_copy -> view_copy -> input of bmm - weight_node = bmm_node.args[1].args[0].args[0].args[0] - # input -> expand_copy -> view_copy -> input of bmm - input_node = bmm_node.args[0].args[0].args[0] - - ret = [input_node, weight_node, bmm_node] - if add_node: - bias_node = add_node[0].args[1] - ret = [input_node, weight_node, add_node[0], bias_node] - else: - ret = [input_node, weight_node, bmm_node] - - return ret - - def _convert(self, graph_module: torch.fx.GraphModule): - partitions = get_source_partitions( - graph_module.graph, [torch.nn.Linear, torch.ops.aten.linear.default] - ) - for _, src_partitions in partitions.items(): - for src_partition in src_partitions: - op_cnt = Counter( - [ - n.target - for n in src_partition.nodes - if isinstance(n.target, edge_op) - ] - ) - if self.linear in op_cnt: - continue - elif op_cnt in self.addmm_patterns: - self._convert_to_linear( - graph_module, src_partition, self._extract_addmm_ops - ) - elif op_cnt in self.mm_patterns: - self._convert_to_linear( - graph_module, src_partition, self._extract_mm_ops - ) - elif op_cnt in self.bmm_patterns: - self._convert_to_linear( - graph_module, src_partition, self._extract_bmm_ops - ) - else: - raise AssertionError( - "Found a new pattern needed be converted to linear op" - ) - - def call(self, graph_module: torch.fx.GraphModule): - self._convert(graph_module) - # We could not use get_source_partitions because it is the same source for MultiheadAttention - apply_addmm_mm_to_linear_transform(graph_module.graph) - dead_code_elimination_pass(graph_module) - graph_module.recompile() - return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/decompose_expm1.py b/backends/qualcomm/_passes/decompose_expm1.py index 8fe6ebdec5b..41499e8a4b0 100644 --- a/backends/qualcomm/_passes/decompose_expm1.py +++ b/backends/qualcomm/_passes/decompose_expm1.py @@ -15,7 +15,7 @@ class DecomposeExpM1(ExportPass): Decompose for expm1 to exponential and minus 1. """ - def __init__(self, quantization_capture=False) -> None: + def __init__(self) -> None: super().__init__() def call(self, graph_module: torch.fx.GraphModule) -> PassResult: diff --git a/backends/qualcomm/_passes/decompose_linalg_vector_norm.py b/backends/qualcomm/_passes/decompose_linalg_vector_norm.py index 4a54c2aa50c..7d70f5c9342 100644 --- a/backends/qualcomm/_passes/decompose_linalg_vector_norm.py +++ b/backends/qualcomm/_passes/decompose_linalg_vector_norm.py @@ -32,9 +32,9 @@ class DecomposeLinalgVectorNorm(ExportPass): Decompose for math equivalent op. """ - def __init__(self, aten_dialect_capture=False) -> None: + def __init__(self, quantization_capture=False) -> None: super().__init__() - self.aten_dialect_capture = aten_dialect_capture + self.quantization_capture = quantization_capture def call(self, graph_module: torch.fx.GraphModule) -> PassResult: graph = graph_module.graph @@ -44,7 +44,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: dim = node.args[2] if len(node.args) > 2 else None keepdim = node.args[3] if len(node.args) > 3 else False model = LinalgVectorNorm(ord, dim, keepdim) - if self.aten_dialect_capture: + if self.quantization_capture: decomposed_module = torch.export.export( model, (node.args[0].meta["val"],), strict=True ).module() diff --git a/backends/qualcomm/_passes/fixed_linear_keep_dim.py b/backends/qualcomm/_passes/fixed_linear_keep_dim.py new file mode 100644 index 00000000000..4f625b96f0e --- /dev/null +++ b/backends/qualcomm/_passes/fixed_linear_keep_dim.py @@ -0,0 +1,87 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import torch + +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.passes import dead_code_elimination_pass + +from torch.fx.passes.utils.source_matcher_utils import get_source_partitions + + +class FixedLinearKeepDim(ExportPass): + """ + Add squeeze and unsqueeze around linear node since QNN has no keep dims for linear op. + """ + + view_copy = exir_ops.edge.aten.view_copy.default + linear = exir_ops.edge.aten.linear.default + + def __init__(self): + super(FixedLinearKeepDim, self).__init__() + + def _fixed_keep_dim(self, graph_module: torch.fx.GraphModule): + partitions = get_source_partitions( + graph_module.graph, [torch.nn.Linear, torch.ops.aten.linear.default] + ) + for _, src_partitions in partitions.items(): + for src_partition in src_partitions: + linear_node = [ + n for n in src_partition.nodes if n.target == self.linear + ][0] + input_node = linear_node.args[0] + # Since QNN has no keep dims for linear op, we will need to add squeeze and unsqueeze around linear node + # TODO: Find a more general conditional statement. + linear_output = linear_node.meta["val"] + if linear_output.dim() >= 3: + with graph_module.graph.inserting_after(input_node): + input_users = list(input_node.users.keys()) + input_tensor = input_node.meta["val"] + squeeze_dim = (-1, input_tensor.shape[-1]) + squeeze_node = graph_module.graph.create_node( + "call_function", + self.view_copy, + ( + input_node, + squeeze_dim, + ), + ) + # meta needs to be copied elementwisely for fake-tensor + # to be updated correctly and not affect meta of input_node + for k, v in input_node.meta.items(): + squeeze_node.meta[k] = v + squeeze_node.meta["val"] = input_tensor.reshape(squeeze_dim) + for user in input_users: + if user == linear_node: + user.replace_input_with(input_node, squeeze_node) + + with graph_module.graph.inserting_after(linear_node): + output_users = list(linear_node.users.keys()) + unsqueeze_dim = linear_output.shape + unsqueeze_node = graph_module.graph.create_node( + "call_function", + self.view_copy, + ( + linear_node, + unsqueeze_dim, + ), + ) + # meta needs to be copied elementwisely for fake-tensor + # to be updated correctly and not affect meta of unsqueeze_node + for k, v in linear_node.meta.items(): + unsqueeze_node.meta[k] = v + # update linear node's shape + linear_node.meta["val"] = linear_output.reshape( + (squeeze_node.meta["val"].shape[0], linear_output.shape[-1]) + ) + for user in output_users: + user.replace_input_with(linear_node, unsqueeze_node) + + def call(self, graph_module: torch.fx.GraphModule): + self._fixed_keep_dim(graph_module) + dead_code_elimination_pass(graph_module) + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/tensor_i64_to_i32.py b/backends/qualcomm/_passes/i64_to_i32.py similarity index 50% rename from backends/qualcomm/_passes/tensor_i64_to_i32.py rename to backends/qualcomm/_passes/i64_to_i32.py index baddd747f99..f13b035552c 100644 --- a/backends/qualcomm/_passes/tensor_i64_to_i32.py +++ b/backends/qualcomm/_passes/i64_to_i32.py @@ -5,41 +5,57 @@ # LICENSE file in the root directory of this source tree. import logging +from typing import FrozenSet import torch -from executorch.backends.qualcomm.builders.utils import is_graph_output +from executorch.backends.qualcomm.builders.utils import ( + get_parameter, + is_constant, + is_graph_output, +) from executorch.backends.qualcomm.utils.constants import QCOM_ORIG_DTYPE -from executorch.exir import ExirExportedProgram from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult -from executorch.exir.program._program import _get_updated_graph_signature from torch._subclasses.fake_tensor import FakeTensor -class TensorI64toI32(ExportPass): +class I64toI32(ExportPass): """ Insert a cast node to cast dtype from int64 to int32. - This will only be applied on fake tensors. + This will be applied on operator and constant nodes such as weights. """ - cast_ops = { - torch.ops.aten.argmin.default, - torch.ops.aten.arange.start_step, - torch.ops.aten.full.default, - torch.ops.aten.scalar_tensor.default, + I64_OPS = { + exir_ops.edge.aten.argmin.default, + exir_ops.edge.aten.arange.start_step, + exir_ops.edge.aten.full.default, + exir_ops.edge.aten.scalar_tensor.default, } + copy_op = exir_ops.edge.aten._to_copy.default - def __init__(self, edge_program): - super(TensorI64toI32, self).__init__() + def __init__( + self, + edge_program, + skip_node: FrozenSet[str] = frozenset(), + ): + super(I64toI32, self).__init__() self.edge_program = edge_program + self.skip_node = skip_node # pyre-ignore[2] def _is_tensor_of_dtype(self, node_val, dtype: torch.dtype) -> bool: return isinstance(node_val, FakeTensor) and node_val.dtype == dtype - def _cast_to_int32(self, core_ep: ExirExportedProgram): - copy_op = torch.ops.aten._to_copy.default - for n in core_ep.exported_program.graph.nodes: + def call_operator(self, op, args, kwargs, meta): + if op in self.I64_OPS and self._is_tensor_of_dtype(meta["val"], torch.int64): + res = super().call_operator(op, args, kwargs, meta) + return super().call_operator( + self.copy_op, (res,), {"dtype": torch.int32}, meta + ) + return super().call_operator(op, args, kwargs, meta) + + def _record_original_output_dtype(self, graph_module: torch.fx.GraphModule): + for n in graph_module.graph.nodes: # Keep track of original output dtype so we ensure the dtype of the graph is consistent with nn.Module if is_graph_output(n): if isinstance(n.meta["val"], (tuple, list)): @@ -47,42 +63,8 @@ def _cast_to_int32(self, core_ep: ExirExportedProgram): n.meta[QCOM_ORIG_DTYPE] = dtype_list else: n.meta[QCOM_ORIG_DTYPE] = n.meta["val"].dtype - continue - if n.target in self.cast_ops: - node_val = n.meta["val"] - if self._is_tensor_of_dtype(node_val, torch.int64): - with core_ep.exported_program.graph.inserting_after(n): - users = list(n.users.keys()) - args = (n,) - cast_node = core_ep.exported_program.graph.create_node( - "call_function", - copy_op, - args, - {"dtype": torch.int32}, - ) - cast_node.meta["val"] = node_val.to(torch.int32) - cast_node.args = args - - for user in users: - # _assert_tensor_metadata is used to check dtype, which will cause lowering to fail since we are changing int64 to int32 - # We also skip if the next op is already a cast op, which prevents redundant casting. - if user.target not in { - torch.ops.aten._assert_tensor_metadata.default, - torch.ops.aten._to_copy.default, - }: - user.replace_input_with(n, cast_node) - - core_ep.exported_program._graph_signature = _get_updated_graph_signature( - core_ep.exported_program._graph_signature, - core_ep.exported_program.graph_module, - ) - core_ep.exported_program._validate() - def _preserve_output_dtype( - self, exported_program: torch.export.exported_program.ExportedProgram - ): - graph_module = exported_program.graph_module - copy_op = exir_ops.edge.aten._to_copy.default + def _preserve_output_dtype(self, graph_module: torch.fx.GraphModule): for n in graph_module.graph.nodes: if is_graph_output(n) and QCOM_ORIG_DTYPE in n.meta: if isinstance(n.meta["val"], (tuple, list)): @@ -107,7 +89,7 @@ def _preserve_output_dtype( ] cast_node = graph_module.graph.create_node( "call_function", - copy_op, + self.copy_op, args, {"dtype": orig_dtype}, ) @@ -116,22 +98,55 @@ def _preserve_output_dtype( for user in output_users: user.replace_input_with(n, cast_node) - def call(self, graph_module: torch.fx.GraphModule): - # Stage 1: _cast_to_int32 - # We add to_copy after the desired operations during this stage because the data type only propagates before to_edge. - # If we don't add to_copy here but do it after to_edge, the next operation after to_copy() will still expect int64 as its output. - # Stage 2: _preserve_output_dtype - # We will tag the output dtype during stage 1, and we will ensure that if user expects int64 as output, - # we need to convert the output back to int64 if it is casted from int64->int32 during stage 1. - if isinstance(self.edge_program, ExirExportedProgram): - self._cast_to_int32(self.edge_program) - self.edge_program.exported_program.graph_module.recompile() - elif isinstance( - self.edge_program, torch.export.exported_program.ExportedProgram - ): - self._preserve_output_dtype(self.edge_program) - else: - raise AssertionError( - "Should be ExirExportedProgram at stage 1 and torch.export.exported_program.ExportedProgram at stage 2" + def _update_meta(self, node: torch.fx.node) -> None: + meta_val = node.meta["val"] + if isinstance(meta_val, tuple): + node.meta["val"] = ( + ( + fake_tensor.to(torch.int32) + if fake_tensor.dtype == torch.int64 + else fake_tensor + ) + for fake_tensor in meta_val ) + else: + if meta_val.dtype == torch.int64: + # TODO This trick seems to use in mobilebert. + # It would be better to convert to torch.int32 + node.meta["val"] = meta_val.to(torch.float) + + def _cast_constant_to_int32(self, graph_module: torch.fx.GraphModule): + for n in graph_module.graph.nodes: + if n.target in self.skip_node: + continue + if is_constant(n, self.edge_program): + param = get_parameter(n, self.edge_program) + if param.dtype == torch.int64: + # QNN does not support int64 + self._update_meta(n) + elif n.op == "placeholder": + node_val = n.meta["val"] + if self._is_tensor_of_dtype(node_val, torch.int64): + with graph_module.graph.inserting_after(n): + args = (n,) + to_dst_node = graph_module.graph.create_node( + "call_function", + self.copy_op, + args, + {"dtype": torch.int32}, + ) + to_dst_node.meta["val"] = node_val.to(torch.int32) + + # Replace usage of the src dtype result with the dst dtype result. + n.replace_all_uses_with(to_dst_node) + to_dst_node.args = (n,) + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + # Record original output dtype to ensure that if user expects int64 as output, + # convert the output back to int64 if it is casted from int64->int32. + self._record_original_output_dtype(graph_module) + self._cast_constant_to_int32(graph_module) + graph_module = super().call(graph_module).graph_module + self._preserve_output_dtype(graph_module) + graph_module.recompile() return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/insert_requantize.py b/backends/qualcomm/_passes/insert_requantize.py index 83b729f3c44..372e7eb4d76 100644 --- a/backends/qualcomm/_passes/insert_requantize.py +++ b/backends/qualcomm/_passes/insert_requantize.py @@ -36,10 +36,8 @@ class InsertRequantize(ExportPass): def __init__( self, - edge_program: torch.export.ExportedProgram, ): super(InsertRequantize, self).__init__() - self.edge_program = edge_program def _make_hashable(self, value): if isinstance(value, dict): diff --git a/backends/qualcomm/_passes/layout_transform.py b/backends/qualcomm/_passes/layout_transform.py index 64fdcb2bb88..17960a6029b 100644 --- a/backends/qualcomm/_passes/layout_transform.py +++ b/backends/qualcomm/_passes/layout_transform.py @@ -87,6 +87,7 @@ class LayoutTransform(ExportPass): exir_ops.edge.aten._softmax.default, # TODO: Need to find a new solution to do "axis_order" to transform axis. exir_ops.edge.aten.sigmoid.default, exir_ops.edge.aten.split_with_sizes.default, + exir_ops.edge.aten.split_with_sizes_copy.default, exir_ops.edge.aten.sqrt.default, exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.sum.dim_IntList, diff --git a/backends/qualcomm/_passes/lift_constant_scalar_operands.py b/backends/qualcomm/_passes/lift_constant_scalar_operands.py index cef28988520..93abfe621bc 100644 --- a/backends/qualcomm/_passes/lift_constant_scalar_operands.py +++ b/backends/qualcomm/_passes/lift_constant_scalar_operands.py @@ -47,6 +47,7 @@ class TensorOpInfo: aten.pow.Tensor_Scalar: TensorOpInfo(aten.pow.Tensor_Tensor, False, False), # The scalar number arg[1] is missing when using default. Result in a corner case to deal aten.leaky_relu.default: TensorOpInfo(aten.prelu.default, True, False), + aten.leaky_relu_.default: TensorOpInfo(aten.prelu.default, True, False), aten.where.ScalarOther: TensorOpInfo(aten.where.self, False, True), aten.where.Scalar: TensorOpInfo(aten.where.self, False, True), } @@ -57,7 +58,9 @@ class TensorOpInfo: class LiftConstantScalarOperands(ExportPass): """ - Lift constant scalar so that we can use observer of quantizer + Lift constant scalar so that we can use observer of quantizer. + For floating point model, lift constant scalar to avoid + creating temporary tensors for scalar node in the operation builder """ def __init__(self): diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py new file mode 100644 index 00000000000..ab2c86102df --- /dev/null +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -0,0 +1,204 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import inspect +from collections import OrderedDict +from typing import Dict + +from executorch.backends.qualcomm._passes import ( + AnnotateQuantAttrs, + AnnotateStack, + AnnotateUnbind, + ConvertBmmToMatmul, + ConvertConv1dToConv2d, + DecomposeAny, + DecomposeEinsum, + DecomposeExpM1, + DecomposeLinalgVectorNorm, + DecomposeSilu, + ExpandBroadcastTensorShape, + FixedLinearKeepDim, + FoldQDQ, + FuseConsecutiveTranspose, + I64toI32, + InsertIOQDQ, + InsertRequantize, + LayoutTransform, + LiftConstantScalarOperands, + RecomposePixelUnshuffle, + RecomposeRmsNorm, + ReduceDynamicRange, + RemoveRedundancy, + ReplaceArangeArgs, + ReplaceIndexPutInput, + ReplaceInfValues, + TagQuantIO, +) +from executorch.backends.qualcomm._passes.utils import ( + get_passes_dependency_for_capture_program, +) +from executorch.backends.qualcomm.utils.constants import ( + QCOM_PASS_ACTIVATE_KEY, + QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY, +) +from executorch.backends.transforms.decompose_sdpa import ( + DecomposeScaledDotProductAttention, +) +from executorch.exir import ExportedProgram +from executorch.exir.pass_manager import PassManager +from executorch.exir.program._program import ( + _get_updated_graph_signature, + lift_constant_tensor_pass, +) +from torch.fx import GraphModule +from torch.fx.passes.infra.pass_manager import this_before_that_pass_constraint + + +def get_capture_program_passes(): + """ + Defines and returns the default ordered passes for the capture program. + This function creates an OrderedDict containing a series of default passes. + + Returns: + OrderedDict: An ordered dictionary containing all default passes along with their activation status and initialization parameters. + """ + + # The second value in each tuple in `default_passes_and_setting` indicates whether the corresponding pass is activated by default. + # If a pass is activated, it will be executed by default. + default_passes_and_setting = [ + (AnnotateQuantAttrs, True), + (AnnotateStack, False), + (AnnotateUnbind, True), + (ConvertBmmToMatmul, True), + (ConvertConv1dToConv2d, True), + (DecomposeAny, True), + (ExpandBroadcastTensorShape, False), + (FixedLinearKeepDim, True), + (FoldQDQ, True), + (I64toI32, True), + (LayoutTransform, True), + (RecomposePixelUnshuffle, True), + (RecomposeRmsNorm, False), + (RemoveRedundancy, True), + (ReplaceIndexPutInput, True), + (TagQuantIO, False), + ] + + passes = OrderedDict() + for p, act in default_passes_and_setting: + init_signature = inspect.signature(p.__init__) + + args_kwargs_defaults = { + k: v.default if v.default is not inspect.Parameter.empty else None + for k, v in init_signature.parameters.items() + if k != "self" + } + + passes[p] = { + QCOM_PASS_ACTIVATE_KEY: act, + QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY: args_kwargs_defaults, + } + + return passes + + +class QnnPassManager(PassManager): + + def __init__(self) -> None: + super().__init__() + + def _transform(self, graph_module: GraphModule): + return self(graph_module).graph_module + + # TODO: Move these passes into qnn_partitioner and qnn_preprocess to + # prevent users from needing to call custom APIs like capture_program + def get_to_edge_transform_passes( + self, + exported_program: ExportedProgram, + passes_job: OrderedDict = None, + dep_table: Dict = None, + ): + # TODO: remove this workaround when target could be correctly detected + from executorch.backends.qualcomm._passes import utils + from executorch.exir.dialects._ops import ops as exir_ops + + utils.q_ops.add(exir_ops.edge.pt2e_quant.quantize_affine.default) + utils.dq_ops.add(exir_ops.edge.pt2e_quant.dequantize_affine.default) + + passes_job = ( + passes_job if passes_job is not None else get_capture_program_passes() + ) + dep_table = ( + dep_table + if dep_table is not None + else get_passes_dependency_for_capture_program() + ) + for that, these in dep_table.items(): + for this in these: + self.add_constraint(this_before_that_pass_constraint(this, that)) + for p in passes_job: + self.add_pass(p) + self.solve_constraints() + + sorted_passes = self.passes + self.passes = [] + for p in sorted_passes: + if not passes_job[p][QCOM_PASS_ACTIVATE_KEY]: + continue + + kwargs = passes_job[p][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY] + if "edge_program" in kwargs: + kwargs["edge_program"] = exported_program + self.add_pass(p(**kwargs)) + return self.passes + + def transform_for_to_edge_pipeline( + self, + exported_program: ExportedProgram, + passes_job: OrderedDict = None, + dep_table: Dict = None, + ): + transform_passes = self.get_to_edge_transform_passes( + exported_program, passes_job=passes_job, dep_table=dep_table + ) + for p in transform_passes: + p(exported_program.graph_module) + exported_program._graph_signature = _get_updated_graph_signature( + exported_program.graph_signature, + exported_program.graph_module, + ) + exported_program._validate() + + return exported_program + + def transform_for_export_pipeline(self, exported_program: ExportedProgram): + self.add_pass(DecomposeScaledDotProductAttention()) + self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True)) + self.add_pass(DecomposeExpM1()) + self.add_pass(LiftConstantScalarOperands()) + self._transform(exported_program.graph_module) + ep = lift_constant_tensor_pass(exported_program) + return ep + + def transform_for_preprocess_pipeline(self, exported_program: ExportedProgram): + self.add_pass(InsertRequantize()) + self.add_pass(InsertIOQDQ(exported_program)) + self.add_pass(LayoutTransform(exported_program, insert_permute=True)) + self.add_pass(FuseConsecutiveTranspose()) + return self._transform(exported_program.graph_module) + + def transform_for_annotation_pipeline(self, graph_module: GraphModule): + self.add_pass(ReduceDynamicRange()) + self.add_pass(RecomposePixelUnshuffle(quantization_capture=True)) + self.add_pass(ReplaceArangeArgs()) + self.add_pass(DecomposeScaledDotProductAttention()) + self.add_pass(DecomposeSilu()) + self.add_pass(DecomposeEinsum()) + self.add_pass(DecomposeExpM1()) + self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True)) + self.add_pass(ReplaceInfValues()) + self.add_pass(LiftConstantScalarOperands()) + return self._transform(graph_module) diff --git a/backends/qualcomm/_passes/recompose_prelu.py b/backends/qualcomm/_passes/recompose_prelu.py deleted file mode 100644 index 082b9c83b27..00000000000 --- a/backends/qualcomm/_passes/recompose_prelu.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (c) Qualcomm Innovation Center, Inc. -# All rights reserved -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -from typing import List - -import torch -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass, PassResult -from torch.fx.passes.utils.source_matcher_utils import get_source_partitions - - -class RecomposePReLU(ExportPass): - """ - Merge decomposed operators from prelu back to one super node. - """ - - def __init__(self, edge_program: torch.export.ExportedProgram): - super(RecomposePReLU, self).__init__() - self.edge_program = edge_program - - def _get_coeff_node(self, nodes: List[torch.fx.Node]): - for node in nodes: - if node.target == exir_ops.edge.aten.view_copy.default: - return node.args[0] - - def _get_input_node(self, nodes: List[torch.fx.Node], coeff_node): - return [n for n in nodes if n != coeff_node][0] - - def call(self, graph_module: torch.fx.GraphModule): - graph = graph_module.graph - partitions = get_source_partitions(graph, [torch.nn.PReLU, torch.nn.LeakyReLU]) - for _, src_partitions in partitions.items(): - for src_partition in src_partitions: - # somehow op might not be decomposed, skip it - if len(src_partition.nodes) == 1: - continue - - coeff_node = self._get_coeff_node(src_partition.nodes) - input_node = self._get_input_node(src_partition.input_nodes, coeff_node) - output_node = src_partition.output_nodes[0] - - with graph.inserting_before(output_node): - prelu_op = exir_ops.edge.aten.prelu.default - prelu_node = graph.create_node( - "call_function", prelu_op, (input_node, coeff_node) - ) - users = output_node.users.copy() - for user in users: - user.replace_input_with(output_node, prelu_node) - # copy metadata - prelu_node.meta = output_node.meta - - graph.eliminate_dead_code() - graph_module.recompile() - return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/recompose_rms_norm.py b/backends/qualcomm/_passes/recompose_rms_norm.py index 77feecf9c1f..a5db826ab28 100644 --- a/backends/qualcomm/_passes/recompose_rms_norm.py +++ b/backends/qualcomm/_passes/recompose_rms_norm.py @@ -15,6 +15,8 @@ class RecomposeRmsNorm(ExportPass): """ Merge decomposed operators back to one super node. + TODO: After replacing export_to_edge with to_edge_transform_and_lowering + in examples/models/llama/export_llama_lib.py, this pass can be removed """ def __init__(self, edge_program: torch.export.ExportedProgram): diff --git a/backends/qualcomm/_passes/replace_index_put_input.py b/backends/qualcomm/_passes/replace_index_put_input.py index 1eb210cf67e..dcdf2bb3a7f 100644 --- a/backends/qualcomm/_passes/replace_index_put_input.py +++ b/backends/qualcomm/_passes/replace_index_put_input.py @@ -22,9 +22,8 @@ class ReplaceIndexPutInput(ExportPass): exir_ops.edge.quantized_decomposed.dequantize_per_channel.default: exir_ops.edge.quantized_decomposed.quantize_per_channel.default, } - def __init__(self, edge_program: torch.export.ExportedProgram): + def __init__(self): super(ReplaceIndexPutInput, self).__init__() - self.edge_program = edge_program def call(self, graph_module: torch.fx.GraphModule): graph = graph_module.graph diff --git a/backends/qualcomm/_passes/tag_quant_io.py b/backends/qualcomm/_passes/tag_quant_io.py new file mode 100644 index 00000000000..bb23d18b0ad --- /dev/null +++ b/backends/qualcomm/_passes/tag_quant_io.py @@ -0,0 +1,45 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Callable + +import torch +from executorch.backends.qualcomm.utils.constants import ( + QCOM_QUANT_ATTRS, + QCOM_QUANT_ATTRS_MAP, + QCOM_QUANTIZED_IO, +) +from executorch.exir.pass_base import ExportPass, PassResult + + +class TagQuantIO(ExportPass): + """ + Tag the IO nodes that handle quantized tensors to avoid inserting Q/DQ operations in qnn_preprocess. + """ + + def __init__(self, get_quant_io_dtype_fn: Callable = None): + super(TagQuantIO, self).__init__() + self.get_quant_io_dtype_fn = get_quant_io_dtype_fn + + def _tag_quant_io(self, gm: torch.fx.GraphModule): + for node in gm.graph.nodes: + if dtype := self.get_quant_io_dtype_fn(node): + node.meta[QCOM_QUANTIZED_IO] = dtype + + def _record_output_quant_attrs_map(self, gm: torch.fx.GraphModule): + for node in gm.graph.nodes: + if node.op == "output": + node.meta.setdefault(QCOM_QUANT_ATTRS_MAP, {}) + for arg in node.args[0]: + if QCOM_QUANT_ATTRS in arg.meta: + node.meta[QCOM_QUANT_ATTRS_MAP][arg] = arg.meta[ + QCOM_QUANT_ATTRS + ] + + def call(self, graph_module: torch.fx.GraphModule): + self._tag_quant_io(graph_module) + self._record_output_quant_attrs_map(graph_module) + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/utils.py b/backends/qualcomm/_passes/utils.py index 0c838e9a676..d538fe0d34f 100755 --- a/backends/qualcomm/_passes/utils.py +++ b/backends/qualcomm/_passes/utils.py @@ -63,62 +63,61 @@ def get_quant_attrs( def get_passes_dependency_for_capture_program(): """ - This function records the dependencies for passes used in the capture_program. + This function records the dependencies for passes used in the to_edge_transform_and_lower_to_qnn. It returns a dictionary where the keys are pass classes and the values are lists of dependencies required by each pass. This helps in managing and organizing the sequence - of passes needed for the capture_program to function correctly. + of passes needed for the to_edge_transform_and_lower_to_qnn to function correctly. Returns: dict: A dictionary mapping each pass to its corresponding list of dependencies. """ from executorch.backends.qualcomm._passes import ( - AnnotateDecomposed, AnnotateQuantAttrs, - ConstantI64toI32, + AnnotateStack, + AnnotateUnbind, ConvertBmmToMatmul, ConvertConv1dToConv2d, - ConvertToLinear, DecomposeAny, DecomposeLinalgVectorNorm, ExpandBroadcastTensorShape, + FixedLinearKeepDim, FoldQDQ, + I64toI32, LayoutTransform, RecomposePixelUnshuffle, - RecomposePReLU, RecomposeRmsNorm, RemoveRedundancy, ReplaceIndexPutInput, - TensorI64toI32, + TagQuantIO, ) return { - AnnotateDecomposed: [RemoveRedundancy], AnnotateQuantAttrs: [ RecomposePixelUnshuffle, - RecomposeRmsNorm, - ConvertToLinear, - RecomposePReLU, ConvertBmmToMatmul, + RemoveRedundancy, ], - ConstantI64toI32: [RemoveRedundancy], - ConvertBmmToMatmul: [ConvertToLinear], + AnnotateStack: [RemoveRedundancy], + AnnotateUnbind: [RemoveRedundancy], + ConvertBmmToMatmul: [RecomposePixelUnshuffle], ConvertConv1dToConv2d: [FoldQDQ], - ConvertToLinear: [RecomposePixelUnshuffle], DecomposeAny: [RemoveRedundancy], DecomposeLinalgVectorNorm: [RemoveRedundancy], ExpandBroadcastTensorShape: [RemoveRedundancy], - FoldQDQ: [AnnotateQuantAttrs, AnnotateDecomposed], + FixedLinearKeepDim: [FoldQDQ], + FoldQDQ: [AnnotateQuantAttrs, AnnotateStack, AnnotateUnbind], + I64toI32: [RemoveRedundancy], LayoutTransform: [ AnnotateQuantAttrs, ConvertConv1dToConv2d, ExpandBroadcastTensorShape, + FixedLinearKeepDim, ], RecomposePixelUnshuffle: [RemoveRedundancy], - RecomposePReLU: [RemoveRedundancy], RecomposeRmsNorm: [RemoveRedundancy], ReplaceIndexPutInput: [LayoutTransform], - TensorI64toI32: [RemoveRedundancy], + TagQuantIO: [ReplaceIndexPutInput], } diff --git a/backends/qualcomm/builders/op_split_with_sizes.py b/backends/qualcomm/builders/op_split_with_sizes.py index 629110b3084..138f6ed60ec 100644 --- a/backends/qualcomm/builders/op_split_with_sizes.py +++ b/backends/qualcomm/builders/op_split_with_sizes.py @@ -17,7 +17,7 @@ @register_node_visitor class SplitWithSizes(NodeVisitor): - target = ["aten.split_with_sizes.default"] + target = ["aten.split_with_sizes.default", "aten.split_with_sizes_copy.default"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/partition/qnn_partitioner.py b/backends/qualcomm/partition/qnn_partitioner.py index 93b1d50f5fe..7b5e72d461d 100644 --- a/backends/qualcomm/partition/qnn_partitioner.py +++ b/backends/qualcomm/partition/qnn_partitioner.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import copy from collections import defaultdict -from typing import Any, Dict, List +from typing import Any, Callable, Dict, List, Optional, Tuple import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -24,6 +24,7 @@ PartitionResult, ) from executorch.exir.backend.utils import tag_constant_data +from torch.export.exported_program import ExportedProgram from torch.fx.passes.infra.partitioner import Partition from torch.fx.passes.operator_support import OperatorSupportBase @@ -33,7 +34,7 @@ not_supported_operator, to_be_implemented_operator, ) -from .utils import generate_qnn_executorch_option +from .utils import generate_qnn_executorch_option, get_skip_decomp_table class QnnOperatorSupport(OperatorSupportBase): @@ -174,3 +175,11 @@ def partition(self, edge_program: torch.export.ExportedProgram) -> PartitionResu return PartitionResult( tagged_exported_program=edge_program, partition_tags=self.partition_tags ) + + # override + def ops_to_not_decompose( + self, ep: ExportedProgram + ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: + do_not_decompose = get_skip_decomp_table() + + return do_not_decompose, None diff --git a/backends/qualcomm/partition/utils.py b/backends/qualcomm/partition/utils.py index 88b922d4e1f..1e2b17b2a69 100644 --- a/backends/qualcomm/partition/utils.py +++ b/backends/qualcomm/partition/utils.py @@ -6,6 +6,8 @@ from typing import List +import torch + from executorch.backends.qualcomm.utils.constants import QCOM_QNN_COMPILE_SPEC from executorch.exir.backend.compile_spec_schema import CompileSpec @@ -20,3 +22,26 @@ def generate_qnn_executorch_option( else: raise ValueError(f"unknown compiler spec key value: {compiler_spec.key}") return qnn_compile_spec_buffer + + +def get_skip_decomp_table() -> List[torch._ops.OperatorBase]: + do_not_decompose = [ + torch.ops.aten.adaptive_avg_pool2d.default, + torch.ops.aten.elu.default, + torch.ops.aten.hardsigmoid.default, + torch.ops.aten.hardswish.default, + torch.ops.aten.instance_norm.default, + torch.ops.aten.leaky_relu.default, + torch.ops.aten.linear.default, + torch.ops.aten.pixel_shuffle.default, + torch.ops.aten.pixel_unshuffle.default, + torch.ops.aten.prelu.default, + torch.ops.aten.rms_norm.default, + torch.ops.aten._safe_softmax.default, + torch.ops.aten.stack.default, + # This request is ignored because it is in a blocklist. Refer to exir/program/_program.py + # torch.ops.aten.unbind.int, + torch.ops.pt2e_quant.quantize_affine.default, + torch.ops.pt2e_quant.dequantize_affine.default, + ] + return do_not_decompose diff --git a/backends/qualcomm/qnn_preprocess.py b/backends/qualcomm/qnn_preprocess.py index 8afb4851814..4a11bf050a2 100644 --- a/backends/qualcomm/qnn_preprocess.py +++ b/backends/qualcomm/qnn_preprocess.py @@ -11,12 +11,7 @@ import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch # noqa: F401 -from executorch.backends.qualcomm._passes import ( - FuseConsecutiveTranspose, - InsertIOQDQ, - InsertRequantize, - LayoutTransform, -) +from executorch.backends.qualcomm._passes.qnn_pass_manager import QnnPassManager from executorch.backends.qualcomm.builders.node_visitor import get_node_visitors from executorch.backends.qualcomm.builders.qnn_constants import OpContextLoader from executorch.backends.qualcomm.partition.utils import generate_qnn_executorch_option @@ -25,7 +20,6 @@ CompileSpec, PreprocessResult, ) -from executorch.exir.passes import PassManager from torch.export.exported_program import ExportedProgram DEFAULT_DEBUG_HANDLE = 65535 @@ -46,17 +40,8 @@ def preprocess( qnn_manager.Init() # QNN Delegate Specific Passes - qnn_compiler_passes = PassManager( - passes=[ - InsertRequantize(edge_program), - InsertIOQDQ(edge_program), - LayoutTransform(edge_program, insert_permute=True), - FuseConsecutiveTranspose(), - ] - ) - - pass_result = qnn_compiler_passes(edge_program.graph_module) - assert pass_result is not None + graph_module = QnnPassManager().transform_for_preprocess_pipeline(edge_program) + assert graph_module is not None enable_tensor_dump = qnn_manager.IsTensorDump() nodes_to_wrappers = defaultdict(dict) @@ -64,7 +49,7 @@ def preprocess( edge_program, enable_tensor_dump=enable_tensor_dump ) py_op_wrapper_list = [] - for node in pass_result.graph_module.graph.nodes: + for node in graph_module.graph.nodes: if node.op == "call_function": logger.info(f"Visiting: {node}, {node.target.__name__}") if node.target.__name__ in node_visitors: diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index 028ffb69f1d..3620841aff9 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -8,20 +8,7 @@ from typing import Callable, Dict, Optional, Sequence, Set, Tuple import torch -from executorch.backends.qualcomm._passes import ( - DecomposeEinsum, - DecomposeExpM1, - DecomposeLinalgVectorNorm, - DecomposeSilu, - LiftConstantScalarOperands, - RecomposePixelUnshuffle, - ReduceDynamicRange, - ReplaceArangeArgs, - ReplaceInfValues, -) -from executorch.backends.transforms.decompose_sdpa import ( - DecomposeScaledDotProductAttention, -) +from executorch.backends.qualcomm._passes.qnn_pass_manager import QnnPassManager from torch._ops import OpOverload from torch.ao.quantization.quantizer import Quantizer @@ -273,17 +260,7 @@ def set_per_channel_linear_quant(self, enable: bool) -> None: self._update_per_channel_weight_quant_ops(linear_ops, enable) def transform_for_annotation(self, model: GraphModule) -> GraphModule: - model = ReduceDynamicRange()(model).graph_module - model = RecomposePixelUnshuffle(quantization_capture=True)(model).graph_module - model = ReplaceArangeArgs()(model).graph_module - model = DecomposeScaledDotProductAttention()(model).graph_module - model = DecomposeSilu()(model).graph_module - model = DecomposeEinsum()(model).graph_module - model = DecomposeExpM1()(model).graph_module - model = DecomposeLinalgVectorNorm(aten_dialect_capture=True)(model).graph_module - model = ReplaceInfValues()(model).graph_module - model = LiftConstantScalarOperands()(model).graph_module - return model + return QnnPassManager().transform_for_annotation_pipeline(model) def validate(self, model: GraphModule) -> None: pass diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index c3c439261d2..0857a597d88 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -100,7 +100,7 @@ def forward(self, x): class ArgminViewSqueezeConv2D(torch.nn.Module): def __init__(self): - # This model is mainly to test the PASS TensorI64toI32 + # This model is mainly to test the PASS I64toI32 super().__init__() self.conv = torch.nn.Conv2d( in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1 @@ -235,9 +235,7 @@ class CompositeDelegateModule(torch.nn.Module): def __init__( self, compiler_specs, - partitioner_type, - capture_method, - lowered_method, + to_edge_transform_and_lower_method, quantize_method=None, ) -> None: super().__init__() @@ -255,15 +253,15 @@ def __init__( ] self.lowered_modules = [] for module, sample_input in zip(self.modules, self.sample_inputs): - partitioner = partitioner_type(compiler_specs) if quantize_method: module = quantize_method(module, sample_input) - edge_prog = capture_method(module, sample_input) - edge_prog.exported_program = lowered_method( - edge_prog.exported_program, partitioner + edge_prog = to_edge_transform_and_lower_method( + module, sample_input, compiler_specs ) self.lowered_modules.append( - edge_prog.exported_program.graph_module._modules.get("lowered_module_0") + edge_prog.exported_program().graph_module._modules.get( + "lowered_module_0" + ) ) def forward(self, x, y): @@ -873,9 +871,9 @@ def forward(self, x): class LeakyReLUCustom(torch.nn.Module): - def __init__(self, coeff): + def __init__(self, coeff, inplace=False): super().__init__() - self.leaky_relu = torch.nn.LeakyReLU(coeff) + self.leaky_relu = torch.nn.LeakyReLU(coeff, inplace=inplace) def forward(self, x): return self.leaky_relu(x) diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 05e368f372e..795459a9f77 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -13,18 +13,28 @@ from pathlib import Path import torch +from executorch.backends.qualcomm._passes.qnn_pass_manager import ( + get_capture_program_passes, + QnnPassManager, +) + +from executorch.backends.qualcomm._passes.utils import ( + get_passes_dependency_for_capture_program, +) + from executorch.backends.qualcomm.tests.utils import ( generate_context_binary, - QnnPartitioner, QuantDtype, TestQNN, - to_backend, validate_context_binary, validate_qcir, ) from executorch.backends.qualcomm.utils.constants import ( QCOM_ANNOTATION, QCOM_MODULE, + QCOM_PASS_ACTIVATE_KEY, + QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY, + QCOM_QUANT_ATTRS_MAP, QCOM_QUANT_DTYPE, QCOM_SAMPLE_INPUTS, ) @@ -37,6 +47,7 @@ generate_qnn_executorch_compiler_spec, PyQnnManagerAdaptor, skip_annotation, + to_edge_transform_and_lower_to_qnn, update_spill_fill_size, ) @@ -57,12 +68,7 @@ from collections import defaultdict from typing import List -from executorch.backends.qualcomm._passes import ( - FuseConsecutiveTranspose, - InsertIOQDQ, - InsertRequantize, - LayoutTransform, -) +from executorch.backends.qualcomm._passes import FoldQDQ, TagQuantIO from executorch.backends.qualcomm.builders.node_visitor import get_node_visitors from executorch.backends.qualcomm.debugger.utils import DrawGraph from executorch.examples.models.deeplab_v3 import DeepLabV3ResNet101Model @@ -78,7 +84,6 @@ from executorch.examples.models.wav2letter import Wav2LetterModel from executorch.exir import to_edge from executorch.exir.backend.backend_api import disable_validation -from executorch.exir.passes import PassManager class TestQNNFloatingPointOperator(TestQNN): @@ -555,7 +560,11 @@ def test_qnn_backend_leaky_relu(self): QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)], }, { - QCOM_MODULE: [LeakyReLUCustom(0.05)], # noqa: F405 + QCOM_MODULE: [LeakyReLUCustom(0.05, inplace=False)], # noqa: F405 + QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)], + }, + { + QCOM_MODULE: [LeakyReLUCustom(0.05, inplace=True)], # noqa: F405 QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)], }, ] @@ -1642,6 +1651,10 @@ def test_qnn_backend_leaky_relu(self): QCOM_MODULE: [LeakyReLUCustom(0.05)], # noqa: F405 QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)], }, + { + QCOM_MODULE: [LeakyReLUCustom(0.05, inplace=True)], # noqa: F405 + QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)], + }, ] index = 0 @@ -2254,8 +2267,6 @@ def test_qnn_backend_spill_fill_buffer_size(self): # TODO: Fix self.assertNotEqual(0, max_sf_size) module = LargeTensorLinear() # noqa: F405 sample_input = (torch.randn(1, 256, 512),) - edge_prog = capture_program(module, sample_input) - backend_options = generate_htp_compiler_spec( use_fp16=True, use_multi_contexts=True, @@ -2264,17 +2275,16 @@ def test_qnn_backend_spill_fill_buffer_size(self): soc_model=self.chipset_table[TestQNN.model], backend_options=backend_options, ) - partitioner = QnnPartitioner(compiler_specs) - edge_prog.exported_program = to_backend(edge_prog.exported_program, partitioner) - max_sf_size = update_spill_fill_size(edge_prog.exported_program) + edge_prog = to_edge_transform_and_lower_to_qnn( + module, sample_input, compiler_specs + ) + + max_sf_size = update_spill_fill_size(edge_prog.exported_program()) self.assertNotEqual(0, max_sf_size) def test_qnn_backend_multi_contexts(self): module = SimpleModel() # noqa: F405 sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) - edge_prog = capture_program(module, sample_input) - self.split_graph(edge_prog.exported_program.graph_module, 4) - backend_options = generate_htp_compiler_spec( use_fp16=True, use_dlbc=True, @@ -2284,9 +2294,20 @@ def test_qnn_backend_multi_contexts(self): soc_model=self.chipset_table[TestQNN.model], backend_options=backend_options, ) - partitioner = QnnPartitioner(compiler_specs) - edge_prog.exported_program = to_backend(edge_prog.exported_program, partitioner) - update_spill_fill_size(edge_prog.exported_program) + pass_jobs = get_capture_program_passes() + split_graph_pass, setting = self.split_graph(4) + pass_jobs[split_graph_pass] = setting + dep_table = get_passes_dependency_for_capture_program() + dep_table[split_graph_pass] = [FoldQDQ] + edge_prog = to_edge_transform_and_lower_to_qnn( + module, + sample_input, + compiler_specs, + dep_table=dep_table, + passes_job=pass_jobs, + ) + + update_spill_fill_size(edge_prog.exported_program()) exec_prog = edge_prog.to_executorch() self.verify_output(module, sample_input, exec_prog) @@ -2302,9 +2323,7 @@ def test_qnn_backend_multi_contexts_composite(self): ) module = CompositeDelegateModule( # noqa: F405 compiler_specs=compiler_specs, - partitioner_type=QnnPartitioner, - capture_method=capture_program, - lowered_method=to_backend, + to_edge_transform_and_lower_method=to_edge_transform_and_lower_to_qnn, ) sample_input = module.get_random_input() edge_prog = to_edge( @@ -2323,10 +2342,6 @@ def test_qnn_backend_multi_graphs(self): modules = [seq_conv, seq_conv.second] sample_inputs = [(torch.randn([1, 1, 3, 3]),), (torch.randn([1, 3, 3, 3]),)] graph_names = ["seq_conv", "single_conv"] - edge_progs = [ - capture_program(module, sample_input) - for module, sample_input in zip(modules, sample_inputs) - ] backend_options = generate_htp_compiler_spec( use_fp16=True, ) @@ -2340,15 +2355,17 @@ def test_qnn_backend_multi_graphs(self): ) for graph_name in graph_names ] - exported_programs = [ - to_backend(edge_prog.exported_program, QnnPartitioner(compiler_specs[i])) - for i, edge_prog in enumerate(edge_progs) + edge_progs = [ + to_edge_transform_and_lower_to_qnn(module, sample_input, compiler_spec) + for module, sample_input, compiler_spec in zip( + modules, sample_inputs, compiler_specs + ) ] prog_mgr, _ = generate_multi_graph_program( compiler_specs=compiler_specs[0], processed_bytes=[ - prog.graph_module.lowered_module_0.processed_bytes - for prog in exported_programs + edge_prog.exported_program().graph_module.lowered_module_0.processed_bytes + for edge_prog in edge_progs ], ) for index, module in enumerate(modules): @@ -2428,8 +2445,6 @@ def test_qnn_backend_context_direct(self): ) def test_qnn_backend_context_extraction(self): - from executorch.exir import EdgeCompileConfig, EdgeProgramManager - module = SimpleModel() # noqa: F405 sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) backend_options = generate_htp_compiler_spec(use_fp16=True) @@ -2444,12 +2459,9 @@ def test_qnn_backend_context_extraction(self): validators = [validate_context_binary, validate_qcir] for compiler_spec, validate in zip(compiler_specs, validators): - edge_prog_mgr = EdgeProgramManager( - edge_programs={ - "forward": capture_program(module, sample_input).exported_program - }, - compile_config=EdgeCompileConfig(_use_edge_ops=False), - ).to_backend(QnnPartitioner(compiler_spec)) + edge_prog_mgr = to_edge_transform_and_lower_to_qnn( + module, sample_input, compiler_spec + ) lowered_module = edge_prog_mgr.exported_program().graph_module._modules[ "lowered_module_0" ] @@ -2461,8 +2473,6 @@ def test_qnn_backend_context_extraction(self): validate(binary) def test_qnn_backend_dump_context_from_pte(self): - from executorch.exir import EdgeCompileConfig, EdgeProgramManager - module = SimpleModel() # noqa: F405 sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) backend_options = generate_htp_compiler_spec(use_fp16=True) @@ -2477,18 +2487,9 @@ def test_qnn_backend_dump_context_from_pte(self): validators = [validate_context_binary, validate_qcir] for compiler_spec, validate in zip(compiler_specs, validators): - edge_prog_mgr = ( - EdgeProgramManager( - edge_programs={ - "forward": capture_program( - module, sample_input - ).exported_program - }, - compile_config=EdgeCompileConfig(_use_edge_ops=False), - ) - .to_backend(QnnPartitioner(compiler_spec)) - .to_executorch() - ) + edge_prog_mgr = to_edge_transform_and_lower_to_qnn( + module, sample_input, compiler_spec + ).to_executorch() with tempfile.TemporaryDirectory() as tmp_dir: pte_path = f"{tmp_dir}/model.pte" @@ -2599,25 +2600,15 @@ def test_qnn_backend_draw_graph(self): """ module = DrawGraphModel() # noqa: F405 sample_input = (torch.randn(1, 32, 28, 28),) + # TODO: Figure out how to get the original graph module with the to_edge_transform_and_lowering_to_qnn API delegated_program = capture_program(module, sample_input) """ This piece of code simulates the behavior of the final preprocessing step to obtain the op wrapper list. In practice, users need to set a breakpoint in the preprocessing step and use the DrawGraph tool to visualize the graph. """ - qnn_compiler_passes = PassManager( - passes=[ - InsertRequantize(delegated_program.exported_program), - InsertIOQDQ(delegated_program.exported_program), - LayoutTransform( - delegated_program.exported_program, insert_permute=True - ), - FuseConsecutiveTranspose(), - ] - ) - - pass_result = qnn_compiler_passes( - delegated_program.exported_program.graph_module + graph_module = QnnPassManager().transform_for_preprocess_pipeline( + delegated_program.exported_program ) nodes_to_wrappers = defaultdict(dict) node_visitors = get_node_visitors( @@ -2625,7 +2616,7 @@ def test_qnn_backend_draw_graph(self): ) py_op_wrapper_list = [] - for node in pass_result.graph_module.graph.nodes: + for node in graph_module.graph.nodes: if node.op == "call_function": if node.target.__name__ in node_visitors: py_op_wrapper = node_visitors[node.target.__name__].define_node( @@ -2689,12 +2680,7 @@ def test_qnn_backend_dynamic_shape(self): QCOM_DTYPE, QCOM_QUANT_ATTRS, ) - from executorch.backends.qualcomm.utils.utils import tag_quant_io - from executorch.exir.capture._config import ( - EdgeCompileConfig, - ExecutorchBackendConfig, - ) - from executorch.exir.program import EdgeProgramManager + from executorch.exir.capture._config import ExecutorchBackendConfig module = Add() # noqa: F405 last_dim = torch.export.Dim("last_dim", min=1, max=8) @@ -2714,31 +2700,32 @@ def test_qnn_backend_dynamic_shape(self): ) # only few ops with 16bit are supported with dynamic shape now # strip unsupported quantize / dequantize ops generated in preprocess - prog = capture_program(module, sample_input, dynamic_shapes=dynamic_shapes) - tag_quant_io( - prog.exported_program.graph_module, - lambda n: ( - torch.uint16 - if any(name in n.name for name in ["x", "y", "add"]) - else None - ), + pass_jobs = get_capture_program_passes() + pass_jobs[TagQuantIO][QCOM_PASS_ACTIVATE_KEY] = True + pass_jobs[TagQuantIO][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY][ + "get_quant_io_dtype_fn" + ] = lambda n: ( + torch.uint16 if any(name in n.name for name in ["x", "y", "add"]) else None ) + edge_prog_mgr = to_edge_transform_and_lower_to_qnn( + module, + sample_input, + self.compiler_specs, + dynamic_shapes=dynamic_shapes, + passes_job=pass_jobs, + ) + # collect encodings for ios input_encodings, output_encodings = [], [] - for n in prog.exported_program.graph.nodes: + for n in edge_prog_mgr.exported_program().graph.nodes: if n.op == "placeholder": input_encodings.append(n.meta[QCOM_QUANT_ATTRS]) input_encodings[-1][QCOM_DTYPE] = torch.uint16 elif n.op == "output": - for arg in n.args[0]: - output_encodings.append(arg.meta[QCOM_QUANT_ATTRS]) - output_encodings[-1][QCOM_DTYPE] = torch.uint16 + output_encodings = n.meta[QCOM_QUANT_ATTRS_MAP].values() + for output_encoding in output_encodings: + output_encoding[QCOM_DTYPE] = torch.uint16 - edge_prog_mgr = EdgeProgramManager( - edge_programs={"forward": prog.exported_program}, - compile_config=EdgeCompileConfig(_check_ir_validity=False), - ) - edge_prog_mgr = edge_prog_mgr.to_backend(QnnPartitioner(self.compiler_specs)) exec_prog = edge_prog_mgr.to_executorch( ExecutorchBackendConfig(passes=[BuildQuantIo()]) ) @@ -2771,7 +2758,7 @@ def test_qnn_backend_skip_node_id_quantizer(self): module = SimpleModel() # noqa: F405 sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) - # define partitioner + # define compile specs backend_options = generate_htp_compiler_spec( use_fp16=False, ) @@ -2779,7 +2766,6 @@ def test_qnn_backend_skip_node_id_quantizer(self): soc_model=self.chipset_table[TestQNN.model], backend_options=backend_options, ) - partitioner = QnnPartitioner(compiler_specs) # define quantizer quantizer = make_quantizer() @@ -2788,15 +2774,15 @@ def calibrator(gm): gm(*sample_input) # get partially lowererd graph module - graph_module, exported_progs = skip_annotation( + graph_module, edge_prog_mgrs = skip_annotation( nn_module=module, quantizer=quantizer, - partitioner=partitioner, + compiler_specs=compiler_specs, sample_input=sample_input, calibration_cb=calibrator, fp_node_id_set={"conv2d"}, ) - self.assertEqual(len(exported_progs), 1) + self.assertEqual(len(edge_prog_mgrs), 1) # lower all graph again, the skipped operators will be left in CPU exec_prog = to_edge( torch.export.export(graph_module, sample_input, strict=True), @@ -2818,7 +2804,7 @@ def test_qnn_backend_skip_node_op_quantizer(self): module = SimpleModel() # noqa: F405 sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) - # define partitioner + # define compile specs backend_options = generate_htp_compiler_spec( use_fp16=False, ) @@ -2826,7 +2812,6 @@ def test_qnn_backend_skip_node_op_quantizer(self): soc_model=self.chipset_table[TestQNN.model], backend_options=backend_options, ) - partitioner = QnnPartitioner(compiler_specs) # define quantizer quantizer = make_quantizer() @@ -2835,15 +2820,15 @@ def calibrator(gm): gm(*sample_input) # get partially lowererd graph module - graph_module, exported_progs = skip_annotation( + graph_module, edge_prog_mgrs = skip_annotation( nn_module=module, quantizer=quantizer, - partitioner=partitioner, + compiler_specs=compiler_specs, sample_input=sample_input, calibration_cb=calibrator, fp_node_op_set={torch.ops.aten.add.Tensor}, ) - self.assertEqual(len(exported_progs), 2) + self.assertEqual(len(edge_prog_mgrs), 2) # lower all graph again, the skipped operators will be left in CPU exec_prog = exec_prog = to_edge( torch.export.export(graph_module, sample_input, strict=True), @@ -2856,8 +2841,6 @@ def test_qnn_backend_spill_fill_buffer_size(self): module = LargeTensorLinear() # noqa: F405 sample_input = (torch.randn(1, 256, 512),) module = self.get_qdq_module(module, sample_input) - edge_prog = capture_program(module, sample_input) - backend_options = generate_htp_compiler_spec( use_fp16=False, use_multi_contexts=True, @@ -2866,16 +2849,18 @@ def test_qnn_backend_spill_fill_buffer_size(self): soc_model=self.chipset_table[TestQNN.model], backend_options=backend_options, ) - partitioner = QnnPartitioner(compiler_specs) - edge_prog.exported_program = to_backend(edge_prog.exported_program, partitioner) - max_sf_size = update_spill_fill_size(edge_prog.exported_program) + edge_prog = to_edge_transform_and_lower_to_qnn( + module, sample_input, compiler_specs + ) + + max_sf_size = update_spill_fill_size(edge_prog.exported_program()) self.assertNotEqual(0, max_sf_size) def test_qnn_backend_graph_level_mixed_precision(self): module = SimpleModel() # noqa: F405 sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) - # define partitioner + # define compile spec backend_options = generate_htp_compiler_spec( use_fp16=False, ) @@ -2883,7 +2868,6 @@ def test_qnn_backend_graph_level_mixed_precision(self): soc_model=self.chipset_table[TestQNN.model], backend_options=backend_options, ) - partitioner = QnnPartitioner(compiler_specs) # define quantizer quantizer = make_quantizer() @@ -2892,16 +2876,16 @@ def calibrator(gm): gm(*sample_input) # get partially lowererd graph module - graph_module, exported_progs = skip_annotation( + graph_module, edge_prog_mgrs = skip_annotation( nn_module=module, quantizer=quantizer, - partitioner=partitioner, + compiler_specs=compiler_specs, sample_input=sample_input, calibration_cb=calibrator, fp_node_id_set={"add", "mean"}, fallback_to_cpu=False, ) - self.assertEqual(len(exported_progs), 5) + self.assertEqual(len(edge_prog_mgrs), 5) # lower all graph again, the skipped operators will be delegated with fp16 exec_prog = to_edge( torch.export.export(graph_module, sample_input, strict=True), @@ -2912,9 +2896,6 @@ def test_qnn_backend_multi_contexts(self): module = SimpleModel() # noqa: F405 sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) module = self.get_qdq_module(module, sample_input) - edge_prog = capture_program(module, sample_input) - self.split_graph(edge_prog.exported_program.graph_module, 4) - backend_options = generate_htp_compiler_spec( use_fp16=False, use_dlbc=True, @@ -2924,9 +2905,20 @@ def test_qnn_backend_multi_contexts(self): soc_model=self.chipset_table[TestQNN.model], backend_options=backend_options, ) - partitioner = QnnPartitioner(compiler_specs) - edge_prog.exported_program = to_backend(edge_prog.exported_program, partitioner) - update_spill_fill_size(edge_prog.exported_program) + pass_jobs = get_capture_program_passes() + split_graph_pass, setting = self.split_graph(4) + pass_jobs[split_graph_pass] = setting + dep_table = get_passes_dependency_for_capture_program() + dep_table[split_graph_pass] = [FoldQDQ] + edge_prog = to_edge_transform_and_lower_to_qnn( + module, + sample_input, + compiler_specs, + dep_table=dep_table, + passes_job=pass_jobs, + ) + + update_spill_fill_size(edge_prog.exported_program()) exec_prog = edge_prog.to_executorch() self.verify_output(module, sample_input, exec_prog) @@ -2942,9 +2934,7 @@ def test_qnn_backend_multi_contexts_composite(self): ) module = CompositeDelegateModule( # noqa: F405 compiler_specs=compiler_specs, - partitioner_type=QnnPartitioner, - capture_method=capture_program, - lowered_method=to_backend, + to_edge_transform_and_lower_method=to_edge_transform_and_lower_to_qnn, quantize_method=self.get_qdq_module, ) sample_input = module.get_random_input() @@ -2964,10 +2954,6 @@ def test_qnn_backend_multi_graphs(self): modules = [seq_conv, seq_conv.second] sample_inputs = [(torch.randn([1, 1, 3, 3]),), (torch.randn([1, 3, 3, 3]),)] graph_names = ["seq_conv", "single_conv"] - edge_progs = [ - capture_program(self.get_qdq_module(module, sample_input), sample_input) - for module, sample_input in zip(modules, sample_inputs) - ] backend_options = generate_htp_compiler_spec( use_fp16=False, ) @@ -2981,15 +2967,19 @@ def test_qnn_backend_multi_graphs(self): ) for graph_name in graph_names ] - exported_programs = [ - to_backend(edge_prog.exported_program, QnnPartitioner(compiler_specs[i])) - for i, edge_prog in enumerate(edge_progs) + edge_progs = [ + to_edge_transform_and_lower_to_qnn( + self.get_qdq_module(module, sample_input), sample_input, compiler_spec + ) + for module, sample_input, compiler_spec in zip( + modules, sample_inputs, compiler_specs + ) ] prog_mgr, _ = generate_multi_graph_program( compiler_specs=compiler_specs[0], processed_bytes=[ - prog.graph_module.lowered_module_0.processed_bytes - for prog in exported_programs + edge_prog.exported_program().graph_module.lowered_module_0.processed_bytes + for edge_prog in edge_progs ], ) for index, module in enumerate(modules): @@ -3072,8 +3062,6 @@ def test_qnn_backend_context_direct(self): ) def test_qnn_backend_context_extraction(self): - from executorch.exir import EdgeCompileConfig, EdgeProgramManager - module = SimpleModel() # noqa: F405 sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) module = self.get_qdq_module(module, sample_input) @@ -3089,12 +3077,9 @@ def test_qnn_backend_context_extraction(self): validators = [validate_context_binary, validate_qcir] for compiler_spec, validate in zip(compiler_specs, validators): - edge_prog_mgr = EdgeProgramManager( - edge_programs={ - "forward": capture_program(module, sample_input).exported_program - }, - compile_config=EdgeCompileConfig(_use_edge_ops=False), - ).to_backend(QnnPartitioner(compiler_spec)) + edge_prog_mgr = to_edge_transform_and_lower_to_qnn( + module, sample_input, compiler_spec + ) lowered_module = edge_prog_mgr.exported_program().graph_module._modules[ "lowered_module_0" ] @@ -3106,8 +3091,6 @@ def test_qnn_backend_context_extraction(self): validate(binary) def test_qnn_backend_dump_context_from_pte(self): - from executorch.exir import EdgeCompileConfig, EdgeProgramManager - module = SimpleModel() # noqa: F405 sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) module = self.get_qdq_module(module, sample_input) @@ -3123,18 +3106,9 @@ def test_qnn_backend_dump_context_from_pte(self): validators = [validate_context_binary, validate_qcir] for compiler_spec, validate in zip(compiler_specs, validators): - edge_prog_mgr = ( - EdgeProgramManager( - edge_programs={ - "forward": capture_program( - module, sample_input - ).exported_program - }, - compile_config=EdgeCompileConfig(_use_edge_ops=False), - ) - .to_backend(QnnPartitioner(compiler_spec)) - .to_executorch() - ) + edge_prog_mgr = to_edge_transform_and_lower_to_qnn( + module, sample_input, compiler_spec + ).to_executorch() with tempfile.TemporaryDirectory() as tmp_dir: pte_path = f"{tmp_dir}/model.pte" @@ -3167,9 +3141,9 @@ def test_qnn_backend_draw_graph(self):
name: quantized_decomposed_quantize_per_tensor_default_8_0 |
name: quantized_decomposed_quantize_per_tensor_default_0 |
data_type: Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_8 |
tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE |
dims: [1, 32, 28, 28] |
dims: [32] |
quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET |