From 37322421570e421553cfaee1b8046b8a9e8f1202 Mon Sep 17 00:00:00 2001 From: shewu-quic Date: Wed, 12 Mar 2025 18:19:29 +0800 Subject: [PATCH] Qualcomm AI Engine Direct - Add QNN support for to_edge_transform_and_lower summary: - Support `to_edge_transform_and_lower` - Replace capture_program with new API `to_edge_transform_and_lower_to_qnn` - Replace capture_program with to_edge_transform_and_lower_to_qnn for unit_test - Replace capture_program with to_edge_transform_and_lower_to_qnn for examples - Replace capture_program with to_edge_transform_and_lower_to_qnn for llama - Add QnnPassManager to manage all passes in different stage - Deprecated _transform in export_llama_lib with qnn_pass_manager - Add transform_for_export_pipeline for LiftConstantScalarOperands to avoid creating temporary tensors in the operation builder. However, this pass will create a get_attr node, which should be converted into a lifted tensor constant by the lift_constant_tensor_pass. If placed in the to_edge_transform_passes, it will be executed after the lift_constant_tensor_pass, causing the operation builder to fail to correctly retrieve the parameter by the get_parameter for get_attr node. - Refactor the passes - Fix the output dtype doesn't match in runtime after build quant io - Combine constant_i64_to_i32 and tensor_i64_to_i32 into i64_to_i32 - Replace convert_to_linear pass with fixed_linear_keep_dim pass - Since QNN has no keep dims for linear op, we will need to add squeeze and unsqueeze around linear node - Add TagQuantIO pass to tag io nodes to avoid inserting q/dq in qnn_preprocess - Add prelu, leaky_relu, linear, rms_norm into decompose_table - Remove recompose_prelu.py - Remove unused variable in insert_requantize.py, and replace_index_put_input.py - Support aten.split_with_sizes_copy.default - Support leaky_relu with inplace=True --- backends/qualcomm/_passes/__init__.py | 20 +- ...notate_decomposed.py => annotate_stack.py} | 19 +- backends/qualcomm/_passes/annotate_unbind.py | 39 +++ backends/qualcomm/_passes/build_quant_io.py | 10 +- .../qualcomm/_passes/constant_i64_to_i32.py | 81 ----- .../qualcomm/_passes/convert_to_linear.py | 231 ------------- backends/qualcomm/_passes/decompose_expm1.py | 2 +- .../_passes/decompose_linalg_vector_norm.py | 6 +- .../qualcomm/_passes/fixed_linear_keep_dim.py | 87 +++++ .../{tensor_i64_to_i32.py => i64_to_i32.py} | 151 +++++---- .../qualcomm/_passes/insert_requantize.py | 2 - backends/qualcomm/_passes/layout_transform.py | 1 + .../_passes/lift_constant_scalar_operands.py | 5 +- backends/qualcomm/_passes/qnn_pass_manager.py | 204 ++++++++++++ backends/qualcomm/_passes/recompose_prelu.py | 57 ---- .../qualcomm/_passes/recompose_rms_norm.py | 2 + .../_passes/replace_index_put_input.py | 3 +- backends/qualcomm/_passes/tag_quant_io.py | 45 +++ backends/qualcomm/_passes/utils.py | 33 +- .../qualcomm/builders/op_split_with_sizes.py | 2 +- .../qualcomm/partition/qnn_partitioner.py | 13 +- backends/qualcomm/partition/utils.py | 25 ++ backends/qualcomm/qnn_preprocess.py | 23 +- backends/qualcomm/quantizer/quantizer.py | 27 +- backends/qualcomm/tests/models.py | 20 +- backends/qualcomm/tests/test_qnn_delegate.py | 306 ++++++++---------- backends/qualcomm/tests/utils.py | 76 ++--- backends/qualcomm/utils/constants.py | 1 + backends/qualcomm/utils/utils.py | 302 +++++++---------- examples/models/llama/export_llama_lib.py | 60 ++-- .../executor_runner/qnn_executor_runner.cpp | 32 +- examples/qualcomm/oss_scripts/fastvit.py | 9 +- examples/qualcomm/oss_scripts/llama/llama.py | 269 ++++++++------- examples/qualcomm/scripts/export_example.py | 31 +- examples/qualcomm/utils.py | 86 ++--- extension/llm/custom_ops/model_sharding.py | 14 +- 36 files changed, 1063 insertions(+), 1231 deletions(-) rename backends/qualcomm/_passes/{annotate_decomposed.py => annotate_stack.py} (63%) create mode 100644 backends/qualcomm/_passes/annotate_unbind.py delete mode 100644 backends/qualcomm/_passes/constant_i64_to_i32.py delete mode 100644 backends/qualcomm/_passes/convert_to_linear.py create mode 100644 backends/qualcomm/_passes/fixed_linear_keep_dim.py rename backends/qualcomm/_passes/{tensor_i64_to_i32.py => i64_to_i32.py} (50%) create mode 100644 backends/qualcomm/_passes/qnn_pass_manager.py delete mode 100644 backends/qualcomm/_passes/recompose_prelu.py create mode 100644 backends/qualcomm/_passes/tag_quant_io.py diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py index fb1f985edb9..9c884d7ab93 100644 --- a/backends/qualcomm/_passes/__init__.py +++ b/backends/qualcomm/_passes/__init__.py @@ -4,51 +4,51 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from .annotate_decomposed import AnnotateDecomposed from .annotate_quant_attrs import AnnotateQuantAttrs -from .constant_i64_to_i32 import ConstantI64toI32 +from .annotate_stack import AnnotateStack +from .annotate_unbind import AnnotateUnbind from .convert_bmm_to_matmul import ConvertBmmToMatmul from .convert_conv1d_to_conv2d import ConvertConv1dToConv2d -from .convert_to_linear import ConvertToLinear from .decompose_any import DecomposeAny from .decompose_einsum import DecomposeEinsum from .decompose_expm1 import DecomposeExpM1 from .decompose_linalg_vector_norm import DecomposeLinalgVectorNorm from .decompose_silu import DecomposeSilu from .expand_broadcast_tensor_shape import ExpandBroadcastTensorShape +from .fixed_linear_keep_dim import FixedLinearKeepDim from .fold_qdq import FoldQDQ from .fuse_consecutive_transpose import FuseConsecutiveTranspose +from .i64_to_i32 import I64toI32 from .insert_io_qdq import InsertIOQDQ from .insert_requantize import InsertRequantize from .layout_transform import LayoutTransform from .lift_constant_scalar_operands import LiftConstantScalarOperands from .recompose_pixel_unshuffle import RecomposePixelUnshuffle -from .recompose_prelu import RecomposePReLU from .recompose_rms_norm import RecomposeRmsNorm from .reduce_dynamic_range import ReduceDynamicRange from .remove_redundancy import RemoveRedundancy from .replace_arange_args import ReplaceArangeArgs from .replace_index_put_input import ReplaceIndexPutInput from .replace_inf_values import ReplaceInfValues -from .tensor_i64_to_i32 import TensorI64toI32 +from .tag_quant_io import TagQuantIO __all__ = [ - AnnotateDecomposed, AnnotateQuantAttrs, - ConstantI64toI32, + AnnotateStack, + AnnotateUnbind, ConvertBmmToMatmul, ConvertConv1dToConv2d, - RecomposePReLU, - ConvertToLinear, DecomposeAny, DecomposeEinsum, DecomposeExpM1, DecomposeLinalgVectorNorm, DecomposeSilu, ExpandBroadcastTensorShape, + FixedLinearKeepDim, FoldQDQ, FuseConsecutiveTranspose, + I64toI32, InsertIOQDQ, InsertRequantize, LayoutTransform, @@ -60,5 +60,5 @@ ReplaceArangeArgs, ReplaceIndexPutInput, ReplaceInfValues, - TensorI64toI32, + TagQuantIO, ] diff --git a/backends/qualcomm/_passes/annotate_decomposed.py b/backends/qualcomm/_passes/annotate_stack.py similarity index 63% rename from backends/qualcomm/_passes/annotate_decomposed.py rename to backends/qualcomm/_passes/annotate_stack.py index 918b705e5e9..c42804af2f2 100644 --- a/backends/qualcomm/_passes/annotate_decomposed.py +++ b/backends/qualcomm/_passes/annotate_stack.py @@ -8,31 +8,21 @@ from executorch.exir.pass_base import ExportPass, PassResult from torch.fx.passes.utils.source_matcher_utils import get_source_partitions -from .utils import dq_ops, get_quant_attrs, q_ops +from .utils import get_quant_attrs, q_ops -class AnnotateDecomposed(ExportPass): +class AnnotateStack(ExportPass): """ Add "quant_attrs" to graph nodes' meta from the QDQ information generated after quantization process. """ - decomp_ops = [torch.ops.aten.stack.default, torch.ops.aten.unbind.int] + decomp_ops = [torch.ops.aten.unbind.int] def __init__(self, edge_program: torch.export.ExportedProgram): - super(AnnotateDecomposed, self).__init__() + super(AnnotateStack, self).__init__() self.edge_program = edge_program - def _annotate_unbind(self, graph_module: torch.fx.GraphModule): - partitions = get_source_partitions(graph_module.graph, [torch.unbind, "unbind"]) - for _, src_partitions in partitions.items(): - for src_partition in src_partitions: - if src_partition.input_nodes[0].target in dq_ops: - q_node = src_partition.input_nodes[0].args[0] - quant_attrs = get_quant_attrs(self.edge_program, q_node) - for n in src_partition.nodes: - n.meta[QCOM_QUANT_ATTRS] = quant_attrs.copy() - def _annotate_stack(self, graph_module: torch.fx.GraphModule): partitions = get_source_partitions(graph_module.graph, [torch.stack, "stack"]) for _, src_partitions in partitions.items(): @@ -46,7 +36,6 @@ def _annotate_stack(self, graph_module: torch.fx.GraphModule): n.meta[QCOM_QUANT_ATTRS] = quant_attrs.copy() def call(self, graph_module: torch.fx.GraphModule): - self._annotate_unbind(graph_module) self._annotate_stack(graph_module) graph_module.recompile() return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/annotate_unbind.py b/backends/qualcomm/_passes/annotate_unbind.py new file mode 100644 index 00000000000..0efa1638bc4 --- /dev/null +++ b/backends/qualcomm/_passes/annotate_unbind.py @@ -0,0 +1,39 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import torch +from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS +from executorch.exir.pass_base import ExportPass, PassResult +from torch.fx.passes.utils.source_matcher_utils import get_source_partitions + +from .utils import dq_ops, get_quant_attrs + + +class AnnotateUnbind(ExportPass): + """ + Add "quant_attrs" to graph nodes' meta from the QDQ information + generated after quantization process. + """ + + decomp_ops = [torch.ops.aten.unbind.int] + + def __init__(self, edge_program: torch.export.ExportedProgram): + super(AnnotateUnbind, self).__init__() + self.edge_program = edge_program + + def _annotate_unbind(self, graph_module: torch.fx.GraphModule): + partitions = get_source_partitions(graph_module.graph, [torch.unbind, "unbind"]) + for _, src_partitions in partitions.items(): + for src_partition in src_partitions: + if src_partition.input_nodes[0].target in dq_ops: + q_node = src_partition.input_nodes[0].args[0] + quant_attrs = get_quant_attrs(self.edge_program, q_node) + for n in src_partition.nodes: + n.meta[QCOM_QUANT_ATTRS] = quant_attrs.copy() + + def call(self, graph_module: torch.fx.GraphModule): + self._annotate_unbind(graph_module) + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/build_quant_io.py b/backends/qualcomm/_passes/build_quant_io.py index b627b9b3052..b34d50c4e24 100644 --- a/backends/qualcomm/_passes/build_quant_io.py +++ b/backends/qualcomm/_passes/build_quant_io.py @@ -27,7 +27,7 @@ def _make_spec(self, x): return None def _build(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: - # forcely update delegate node's meta['spec'] to get correct output + # Forcedly update delegate node's meta['spec'] to get correct output # tensor size in runtime call_delegate = [ node @@ -35,17 +35,9 @@ def _build(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: if node.op == "call_function" and node.name == "executorch_call_delegate" ] assert len(call_delegate) == 1 - spec = [] for n in graph_module.graph.nodes: if QCOM_QUANTIZED_IO in n.meta: n.meta["val"] = n.meta["val"].to(dtype=n.meta[QCOM_QUANTIZED_IO]) - if n.op == "call_function" and "getitem" in n.name: - fake_tensor = n.meta["val"] - if QCOM_QUANTIZED_IO in n.meta: - fake_tensor = fake_tensor.to(dtype=n.meta[QCOM_QUANTIZED_IO]) - spec.append(self._make_spec(fake_tensor)) - - call_delegate[0].meta["spec"] = tuple(spec) def call(self, graph_module: torch.fx.GraphModule): self._build(graph_module) diff --git a/backends/qualcomm/_passes/constant_i64_to_i32.py b/backends/qualcomm/_passes/constant_i64_to_i32.py deleted file mode 100644 index 9b5178b386e..00000000000 --- a/backends/qualcomm/_passes/constant_i64_to_i32.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright (c) Qualcomm Innovation Center, Inc. -# All rights reserved -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -from typing import FrozenSet - -import torch -from executorch.backends.qualcomm.builders.utils import get_parameter, is_constant -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass, PassResult -from torch._subclasses.fake_tensor import FakeTensor - - -class ConstantI64toI32(ExportPass): - """ - Cast unsupported int64 datatype into int32. - This will only be applied on constant nodes such as weights. - """ - - def __init__( - self, - edge_program: torch.export.ExportedProgram, - skip_node: FrozenSet[str] = frozenset(), - ): - super(ConstantI64toI32, self).__init__() - self.edge_program = edge_program - self.skip_node = skip_node - # pyre-ignore[4] - self.copy_op = exir_ops.edge.aten._to_copy.default - - def _update_meta(self, node: torch.fx.node) -> None: - meta_val = node.meta["val"] - if isinstance(meta_val, tuple): - node.meta["val"] = ( - ( - fake_tensor.to(torch.int32) - if fake_tensor.dtype == torch.int64 - else fake_tensor - ) - for fake_tensor in meta_val - ) - else: - if meta_val.dtype == torch.int64: - node.meta["val"] = meta_val.to(torch.float) - - # pyre-ignore[2] - def _is_tensor_of_dtype(self, node_val, dtype: torch.dtype) -> bool: - return isinstance(node_val, FakeTensor) and node_val.dtype == dtype - - def _cast_to_int32(self, graph_module: torch.fx.GraphModule): - for n in graph_module.graph.nodes: - if n.target in self.skip_node: - continue - if is_constant(n, self.edge_program): - param = get_parameter(n, self.edge_program) - if param.dtype == torch.int64: - # QNN does not support int64 - self._update_meta(n) - elif n.op == "placeholder": - node_val = n.meta["val"] - if self._is_tensor_of_dtype(node_val, torch.int64): - with graph_module.graph.inserting_after(n): - args = (n,) - to_dst_node = graph_module.graph.create_node( - "call_function", - self.copy_op, - args, - {"dtype": torch.int32}, - ) - to_dst_node.meta["val"] = node_val.to(torch.int32) - - # Replace usage of the src dtype result with the dst dtype result. - n.replace_all_uses_with(to_dst_node) - to_dst_node.args = (n,) - - def call(self, graph_module: torch.fx.GraphModule): - self._cast_to_int32(graph_module) - graph_module.recompile() - graph_module = super().call(graph_module).graph_module - return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/convert_to_linear.py b/backends/qualcomm/_passes/convert_to_linear.py deleted file mode 100644 index 48883571a0c..00000000000 --- a/backends/qualcomm/_passes/convert_to_linear.py +++ /dev/null @@ -1,231 +0,0 @@ -# Copyright (c) Qualcomm Innovation Center, Inc. -# All rights reserved -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -from collections import Counter -from typing import Callable, List - -import torch -from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS -from executorch.backends.transforms.addmm_mm_to_linear import ( - apply_addmm_mm_to_linear_transform, -) -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.dialects.edge._ops import EdgeOpOverload as edge_op -from executorch.exir.pass_base import ExportPass, PassResult -from executorch.exir.passes import dead_code_elimination_pass - -from torch.fx.passes.utils.source_matcher_utils import ( - get_source_partitions, - SourcePartition, -) - -from .utils import dq_ops, get_quant_attrs, q_ops - - -class ConvertToLinear(ExportPass): - """ - Handle missing quantization tag for addmm op after decomposing - """ - - view_copy = exir_ops.edge.aten.view_copy.default - permute_copy = exir_ops.edge.aten.permute_copy.default - expand_copy = exir_ops.edge.aten.expand_copy.default - linear = exir_ops.edge.aten.linear.default - add = exir_ops.edge.aten.add.Tensor - addmm = exir_ops.edge.aten.addmm.default - bmm = exir_ops.edge.aten.bmm.default - mm = exir_ops.edge.aten.mm.default - - addmm_patterns = [ - {view_copy: 1, permute_copy: 1, addmm: 1}, - {view_copy: 2, permute_copy: 1, addmm: 1}, - {permute_copy: 1, addmm: 1}, - ] - - bmm_patterns = [ - {view_copy: 3, permute_copy: 1, expand_copy: 2, add: 1, bmm: 1}, - {view_copy: 3, permute_copy: 1, expand_copy: 2, bmm: 1}, - ] - - mm_patterns = [ - {view_copy: 2, permute_copy: 1, mm: 1}, - {permute_copy: 1, mm: 1}, - ] - - def __init__(self): - super(ConvertToLinear, self).__init__() - - def _get_original_input( - self, inputs: List[torch.fx.Node], cur_node: torch.fx.Node - ) -> torch.fx.Node: - while cur_node not in inputs and cur_node.args: - cur_node = cur_node.args[0] - return cur_node - - def _convert_to_linear( - self, - gm: torch.fx.GraphModule, - src_partition: SourcePartition, - extract_ops_fn: Callable, - ): - inputs = src_partition.input_nodes - # output_nodes contains output node and input buffer such as argX_X - outputs = [ - node - for node in src_partition.output_nodes - if node.target != torch.ops.aten.sym_size.int and node.op != "placeholder" - ] - assert ( - len(outputs) == 1 - ), f"Unexpected number of outputs for a torch.nn.Linear module, expecting 1 but got {outputs}" - output = outputs[0] - - ops = extract_ops_fn(src_partition.nodes) - input_node, weight_node, fn_node = ops[:3] - bias_node = None if len(ops) == 3 else ops[3] - - # qnn htp does not support keepdim, the view_copy(reshape) should exist for now - if self._get_original_input(inputs, input_node).target in dq_ops: - input_node.meta[QCOM_QUANT_ATTRS] = get_quant_attrs( - gm, self._get_original_input(inputs, input_node).args[0] - ) - args = [input_node, weight_node] - if bias_node: - args.append(bias_node) - - # We need a view copy node after linear op - with gm.graph.inserting_before(output): - linear_node = gm.graph.create_node( - "call_function", self.linear, tuple(args) - ) - linear_node.meta = fn_node.meta - if list(output.users)[0].target in q_ops: - linear_node.meta[QCOM_QUANT_ATTRS] = get_quant_attrs( - gm, list(output.users)[0] - ) - for user in fn_node.users.copy(): - user.replace_input_with(fn_node, linear_node) - - # Since QNN has no keep dims for linear op, we will need to add squeeze and unsqueeze around linear node - # TODO: Find a more general conditional statement. - linear_output = linear_node.meta["val"] - if linear_output.dim() >= 3: - with gm.graph.inserting_after(input_node): - input_users = list(input_node.users.keys()) - input_tensor = input_node.meta["val"] - squeeze_dim = (-1, input_tensor.shape[-1]) - squeeze_node = gm.graph.create_node( - "call_function", - self.view_copy, - ( - input_node, - squeeze_dim, - ), - ) - # meta needs to be copied elementwisely for fake-tensor - # to be updated correctly and not affect meta of input_node - for k, v in input_node.meta.items(): - squeeze_node.meta[k] = v - squeeze_node.meta["val"] = input_tensor.reshape(squeeze_dim) - for user in input_users: - if user == linear_node: - user.replace_input_with(input_node, squeeze_node) - - with gm.graph.inserting_after(linear_node): - output_users = list(linear_node.users.keys()) - unsqueeze_dim = linear_output.shape - unsqueeze_node = gm.graph.create_node( - "call_function", - self.view_copy, - ( - linear_node, - unsqueeze_dim, - ), - ) - # meta needs to be copied elementwisely for fake-tensor - # to be updated correctly and not affect meta of unsqueeze_node - for k, v in linear_node.meta.items(): - unsqueeze_node.meta[k] = v - # update linear node's shape - linear_node.meta["val"] = linear_output.reshape( - (squeeze_node.meta["val"].shape[0], linear_output.shape[-1]) - ) - for user in output_users: - user.replace_input_with(linear_node, unsqueeze_node) - - def _extract_mm_ops(self, partitioned_nodes: List[edge_op]) -> List[torch.fx.Node]: - mm_node = [n for n in partitioned_nodes if n.target == self.mm][0] - # weight -> permute -> input of mm - weight_node = mm_node.args[1].args[0] - input_node = mm_node.args[0] - return [input_node, weight_node, mm_node] - - def _extract_addmm_ops( - self, partitioned_nodes: List[edge_op] - ) -> List[torch.fx.Node]: - addmm_node = [n for n in partitioned_nodes if n.target == self.addmm][0] - # weight -> permute -> input of addmm - weight_node = addmm_node.args[2].args[0] - input_node = addmm_node.args[1] - bias_node = addmm_node.args[0] - return [input_node, weight_node, addmm_node, bias_node] - - def _extract_bmm_ops(self, partitioned_nodes: List[edge_op]) -> List[torch.fx.Node]: - bmm_node = [n for n in partitioned_nodes if n.target == self.bmm][0] - add_node = [n for n in partitioned_nodes if n.target == self.add] - - # weight -> expand_copy -> view_copy -> input of bmm - weight_node = bmm_node.args[1].args[0].args[0].args[0] - # input -> expand_copy -> view_copy -> input of bmm - input_node = bmm_node.args[0].args[0].args[0] - - ret = [input_node, weight_node, bmm_node] - if add_node: - bias_node = add_node[0].args[1] - ret = [input_node, weight_node, add_node[0], bias_node] - else: - ret = [input_node, weight_node, bmm_node] - - return ret - - def _convert(self, graph_module: torch.fx.GraphModule): - partitions = get_source_partitions( - graph_module.graph, [torch.nn.Linear, torch.ops.aten.linear.default] - ) - for _, src_partitions in partitions.items(): - for src_partition in src_partitions: - op_cnt = Counter( - [ - n.target - for n in src_partition.nodes - if isinstance(n.target, edge_op) - ] - ) - if self.linear in op_cnt: - continue - elif op_cnt in self.addmm_patterns: - self._convert_to_linear( - graph_module, src_partition, self._extract_addmm_ops - ) - elif op_cnt in self.mm_patterns: - self._convert_to_linear( - graph_module, src_partition, self._extract_mm_ops - ) - elif op_cnt in self.bmm_patterns: - self._convert_to_linear( - graph_module, src_partition, self._extract_bmm_ops - ) - else: - raise AssertionError( - "Found a new pattern needed be converted to linear op" - ) - - def call(self, graph_module: torch.fx.GraphModule): - self._convert(graph_module) - # We could not use get_source_partitions because it is the same source for MultiheadAttention - apply_addmm_mm_to_linear_transform(graph_module.graph) - dead_code_elimination_pass(graph_module) - graph_module.recompile() - return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/decompose_expm1.py b/backends/qualcomm/_passes/decompose_expm1.py index 8fe6ebdec5b..41499e8a4b0 100644 --- a/backends/qualcomm/_passes/decompose_expm1.py +++ b/backends/qualcomm/_passes/decompose_expm1.py @@ -15,7 +15,7 @@ class DecomposeExpM1(ExportPass): Decompose for expm1 to exponential and minus 1. """ - def __init__(self, quantization_capture=False) -> None: + def __init__(self) -> None: super().__init__() def call(self, graph_module: torch.fx.GraphModule) -> PassResult: diff --git a/backends/qualcomm/_passes/decompose_linalg_vector_norm.py b/backends/qualcomm/_passes/decompose_linalg_vector_norm.py index 4a54c2aa50c..7d70f5c9342 100644 --- a/backends/qualcomm/_passes/decompose_linalg_vector_norm.py +++ b/backends/qualcomm/_passes/decompose_linalg_vector_norm.py @@ -32,9 +32,9 @@ class DecomposeLinalgVectorNorm(ExportPass): Decompose for math equivalent op. """ - def __init__(self, aten_dialect_capture=False) -> None: + def __init__(self, quantization_capture=False) -> None: super().__init__() - self.aten_dialect_capture = aten_dialect_capture + self.quantization_capture = quantization_capture def call(self, graph_module: torch.fx.GraphModule) -> PassResult: graph = graph_module.graph @@ -44,7 +44,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: dim = node.args[2] if len(node.args) > 2 else None keepdim = node.args[3] if len(node.args) > 3 else False model = LinalgVectorNorm(ord, dim, keepdim) - if self.aten_dialect_capture: + if self.quantization_capture: decomposed_module = torch.export.export( model, (node.args[0].meta["val"],), strict=True ).module() diff --git a/backends/qualcomm/_passes/fixed_linear_keep_dim.py b/backends/qualcomm/_passes/fixed_linear_keep_dim.py new file mode 100644 index 00000000000..4f625b96f0e --- /dev/null +++ b/backends/qualcomm/_passes/fixed_linear_keep_dim.py @@ -0,0 +1,87 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import torch + +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.passes import dead_code_elimination_pass + +from torch.fx.passes.utils.source_matcher_utils import get_source_partitions + + +class FixedLinearKeepDim(ExportPass): + """ + Add squeeze and unsqueeze around linear node since QNN has no keep dims for linear op. + """ + + view_copy = exir_ops.edge.aten.view_copy.default + linear = exir_ops.edge.aten.linear.default + + def __init__(self): + super(FixedLinearKeepDim, self).__init__() + + def _fixed_keep_dim(self, graph_module: torch.fx.GraphModule): + partitions = get_source_partitions( + graph_module.graph, [torch.nn.Linear, torch.ops.aten.linear.default] + ) + for _, src_partitions in partitions.items(): + for src_partition in src_partitions: + linear_node = [ + n for n in src_partition.nodes if n.target == self.linear + ][0] + input_node = linear_node.args[0] + # Since QNN has no keep dims for linear op, we will need to add squeeze and unsqueeze around linear node + # TODO: Find a more general conditional statement. + linear_output = linear_node.meta["val"] + if linear_output.dim() >= 3: + with graph_module.graph.inserting_after(input_node): + input_users = list(input_node.users.keys()) + input_tensor = input_node.meta["val"] + squeeze_dim = (-1, input_tensor.shape[-1]) + squeeze_node = graph_module.graph.create_node( + "call_function", + self.view_copy, + ( + input_node, + squeeze_dim, + ), + ) + # meta needs to be copied elementwisely for fake-tensor + # to be updated correctly and not affect meta of input_node + for k, v in input_node.meta.items(): + squeeze_node.meta[k] = v + squeeze_node.meta["val"] = input_tensor.reshape(squeeze_dim) + for user in input_users: + if user == linear_node: + user.replace_input_with(input_node, squeeze_node) + + with graph_module.graph.inserting_after(linear_node): + output_users = list(linear_node.users.keys()) + unsqueeze_dim = linear_output.shape + unsqueeze_node = graph_module.graph.create_node( + "call_function", + self.view_copy, + ( + linear_node, + unsqueeze_dim, + ), + ) + # meta needs to be copied elementwisely for fake-tensor + # to be updated correctly and not affect meta of unsqueeze_node + for k, v in linear_node.meta.items(): + unsqueeze_node.meta[k] = v + # update linear node's shape + linear_node.meta["val"] = linear_output.reshape( + (squeeze_node.meta["val"].shape[0], linear_output.shape[-1]) + ) + for user in output_users: + user.replace_input_with(linear_node, unsqueeze_node) + + def call(self, graph_module: torch.fx.GraphModule): + self._fixed_keep_dim(graph_module) + dead_code_elimination_pass(graph_module) + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/tensor_i64_to_i32.py b/backends/qualcomm/_passes/i64_to_i32.py similarity index 50% rename from backends/qualcomm/_passes/tensor_i64_to_i32.py rename to backends/qualcomm/_passes/i64_to_i32.py index baddd747f99..f13b035552c 100644 --- a/backends/qualcomm/_passes/tensor_i64_to_i32.py +++ b/backends/qualcomm/_passes/i64_to_i32.py @@ -5,41 +5,57 @@ # LICENSE file in the root directory of this source tree. import logging +from typing import FrozenSet import torch -from executorch.backends.qualcomm.builders.utils import is_graph_output +from executorch.backends.qualcomm.builders.utils import ( + get_parameter, + is_constant, + is_graph_output, +) from executorch.backends.qualcomm.utils.constants import QCOM_ORIG_DTYPE -from executorch.exir import ExirExportedProgram from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult -from executorch.exir.program._program import _get_updated_graph_signature from torch._subclasses.fake_tensor import FakeTensor -class TensorI64toI32(ExportPass): +class I64toI32(ExportPass): """ Insert a cast node to cast dtype from int64 to int32. - This will only be applied on fake tensors. + This will be applied on operator and constant nodes such as weights. """ - cast_ops = { - torch.ops.aten.argmin.default, - torch.ops.aten.arange.start_step, - torch.ops.aten.full.default, - torch.ops.aten.scalar_tensor.default, + I64_OPS = { + exir_ops.edge.aten.argmin.default, + exir_ops.edge.aten.arange.start_step, + exir_ops.edge.aten.full.default, + exir_ops.edge.aten.scalar_tensor.default, } + copy_op = exir_ops.edge.aten._to_copy.default - def __init__(self, edge_program): - super(TensorI64toI32, self).__init__() + def __init__( + self, + edge_program, + skip_node: FrozenSet[str] = frozenset(), + ): + super(I64toI32, self).__init__() self.edge_program = edge_program + self.skip_node = skip_node # pyre-ignore[2] def _is_tensor_of_dtype(self, node_val, dtype: torch.dtype) -> bool: return isinstance(node_val, FakeTensor) and node_val.dtype == dtype - def _cast_to_int32(self, core_ep: ExirExportedProgram): - copy_op = torch.ops.aten._to_copy.default - for n in core_ep.exported_program.graph.nodes: + def call_operator(self, op, args, kwargs, meta): + if op in self.I64_OPS and self._is_tensor_of_dtype(meta["val"], torch.int64): + res = super().call_operator(op, args, kwargs, meta) + return super().call_operator( + self.copy_op, (res,), {"dtype": torch.int32}, meta + ) + return super().call_operator(op, args, kwargs, meta) + + def _record_original_output_dtype(self, graph_module: torch.fx.GraphModule): + for n in graph_module.graph.nodes: # Keep track of original output dtype so we ensure the dtype of the graph is consistent with nn.Module if is_graph_output(n): if isinstance(n.meta["val"], (tuple, list)): @@ -47,42 +63,8 @@ def _cast_to_int32(self, core_ep: ExirExportedProgram): n.meta[QCOM_ORIG_DTYPE] = dtype_list else: n.meta[QCOM_ORIG_DTYPE] = n.meta["val"].dtype - continue - if n.target in self.cast_ops: - node_val = n.meta["val"] - if self._is_tensor_of_dtype(node_val, torch.int64): - with core_ep.exported_program.graph.inserting_after(n): - users = list(n.users.keys()) - args = (n,) - cast_node = core_ep.exported_program.graph.create_node( - "call_function", - copy_op, - args, - {"dtype": torch.int32}, - ) - cast_node.meta["val"] = node_val.to(torch.int32) - cast_node.args = args - - for user in users: - # _assert_tensor_metadata is used to check dtype, which will cause lowering to fail since we are changing int64 to int32 - # We also skip if the next op is already a cast op, which prevents redundant casting. - if user.target not in { - torch.ops.aten._assert_tensor_metadata.default, - torch.ops.aten._to_copy.default, - }: - user.replace_input_with(n, cast_node) - - core_ep.exported_program._graph_signature = _get_updated_graph_signature( - core_ep.exported_program._graph_signature, - core_ep.exported_program.graph_module, - ) - core_ep.exported_program._validate() - def _preserve_output_dtype( - self, exported_program: torch.export.exported_program.ExportedProgram - ): - graph_module = exported_program.graph_module - copy_op = exir_ops.edge.aten._to_copy.default + def _preserve_output_dtype(self, graph_module: torch.fx.GraphModule): for n in graph_module.graph.nodes: if is_graph_output(n) and QCOM_ORIG_DTYPE in n.meta: if isinstance(n.meta["val"], (tuple, list)): @@ -107,7 +89,7 @@ def _preserve_output_dtype( ] cast_node = graph_module.graph.create_node( "call_function", - copy_op, + self.copy_op, args, {"dtype": orig_dtype}, ) @@ -116,22 +98,55 @@ def _preserve_output_dtype( for user in output_users: user.replace_input_with(n, cast_node) - def call(self, graph_module: torch.fx.GraphModule): - # Stage 1: _cast_to_int32 - # We add to_copy after the desired operations during this stage because the data type only propagates before to_edge. - # If we don't add to_copy here but do it after to_edge, the next operation after to_copy() will still expect int64 as its output. - # Stage 2: _preserve_output_dtype - # We will tag the output dtype during stage 1, and we will ensure that if user expects int64 as output, - # we need to convert the output back to int64 if it is casted from int64->int32 during stage 1. - if isinstance(self.edge_program, ExirExportedProgram): - self._cast_to_int32(self.edge_program) - self.edge_program.exported_program.graph_module.recompile() - elif isinstance( - self.edge_program, torch.export.exported_program.ExportedProgram - ): - self._preserve_output_dtype(self.edge_program) - else: - raise AssertionError( - "Should be ExirExportedProgram at stage 1 and torch.export.exported_program.ExportedProgram at stage 2" + def _update_meta(self, node: torch.fx.node) -> None: + meta_val = node.meta["val"] + if isinstance(meta_val, tuple): + node.meta["val"] = ( + ( + fake_tensor.to(torch.int32) + if fake_tensor.dtype == torch.int64 + else fake_tensor + ) + for fake_tensor in meta_val ) + else: + if meta_val.dtype == torch.int64: + # TODO This trick seems to use in mobilebert. + # It would be better to convert to torch.int32 + node.meta["val"] = meta_val.to(torch.float) + + def _cast_constant_to_int32(self, graph_module: torch.fx.GraphModule): + for n in graph_module.graph.nodes: + if n.target in self.skip_node: + continue + if is_constant(n, self.edge_program): + param = get_parameter(n, self.edge_program) + if param.dtype == torch.int64: + # QNN does not support int64 + self._update_meta(n) + elif n.op == "placeholder": + node_val = n.meta["val"] + if self._is_tensor_of_dtype(node_val, torch.int64): + with graph_module.graph.inserting_after(n): + args = (n,) + to_dst_node = graph_module.graph.create_node( + "call_function", + self.copy_op, + args, + {"dtype": torch.int32}, + ) + to_dst_node.meta["val"] = node_val.to(torch.int32) + + # Replace usage of the src dtype result with the dst dtype result. + n.replace_all_uses_with(to_dst_node) + to_dst_node.args = (n,) + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + # Record original output dtype to ensure that if user expects int64 as output, + # convert the output back to int64 if it is casted from int64->int32. + self._record_original_output_dtype(graph_module) + self._cast_constant_to_int32(graph_module) + graph_module = super().call(graph_module).graph_module + self._preserve_output_dtype(graph_module) + graph_module.recompile() return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/insert_requantize.py b/backends/qualcomm/_passes/insert_requantize.py index 83b729f3c44..372e7eb4d76 100644 --- a/backends/qualcomm/_passes/insert_requantize.py +++ b/backends/qualcomm/_passes/insert_requantize.py @@ -36,10 +36,8 @@ class InsertRequantize(ExportPass): def __init__( self, - edge_program: torch.export.ExportedProgram, ): super(InsertRequantize, self).__init__() - self.edge_program = edge_program def _make_hashable(self, value): if isinstance(value, dict): diff --git a/backends/qualcomm/_passes/layout_transform.py b/backends/qualcomm/_passes/layout_transform.py index 64fdcb2bb88..17960a6029b 100644 --- a/backends/qualcomm/_passes/layout_transform.py +++ b/backends/qualcomm/_passes/layout_transform.py @@ -87,6 +87,7 @@ class LayoutTransform(ExportPass): exir_ops.edge.aten._softmax.default, # TODO: Need to find a new solution to do "axis_order" to transform axis. exir_ops.edge.aten.sigmoid.default, exir_ops.edge.aten.split_with_sizes.default, + exir_ops.edge.aten.split_with_sizes_copy.default, exir_ops.edge.aten.sqrt.default, exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.sum.dim_IntList, diff --git a/backends/qualcomm/_passes/lift_constant_scalar_operands.py b/backends/qualcomm/_passes/lift_constant_scalar_operands.py index cef28988520..93abfe621bc 100644 --- a/backends/qualcomm/_passes/lift_constant_scalar_operands.py +++ b/backends/qualcomm/_passes/lift_constant_scalar_operands.py @@ -47,6 +47,7 @@ class TensorOpInfo: aten.pow.Tensor_Scalar: TensorOpInfo(aten.pow.Tensor_Tensor, False, False), # The scalar number arg[1] is missing when using default. Result in a corner case to deal aten.leaky_relu.default: TensorOpInfo(aten.prelu.default, True, False), + aten.leaky_relu_.default: TensorOpInfo(aten.prelu.default, True, False), aten.where.ScalarOther: TensorOpInfo(aten.where.self, False, True), aten.where.Scalar: TensorOpInfo(aten.where.self, False, True), } @@ -57,7 +58,9 @@ class TensorOpInfo: class LiftConstantScalarOperands(ExportPass): """ - Lift constant scalar so that we can use observer of quantizer + Lift constant scalar so that we can use observer of quantizer. + For floating point model, lift constant scalar to avoid + creating temporary tensors for scalar node in the operation builder """ def __init__(self): diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py new file mode 100644 index 00000000000..ab2c86102df --- /dev/null +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -0,0 +1,204 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import inspect +from collections import OrderedDict +from typing import Dict + +from executorch.backends.qualcomm._passes import ( + AnnotateQuantAttrs, + AnnotateStack, + AnnotateUnbind, + ConvertBmmToMatmul, + ConvertConv1dToConv2d, + DecomposeAny, + DecomposeEinsum, + DecomposeExpM1, + DecomposeLinalgVectorNorm, + DecomposeSilu, + ExpandBroadcastTensorShape, + FixedLinearKeepDim, + FoldQDQ, + FuseConsecutiveTranspose, + I64toI32, + InsertIOQDQ, + InsertRequantize, + LayoutTransform, + LiftConstantScalarOperands, + RecomposePixelUnshuffle, + RecomposeRmsNorm, + ReduceDynamicRange, + RemoveRedundancy, + ReplaceArangeArgs, + ReplaceIndexPutInput, + ReplaceInfValues, + TagQuantIO, +) +from executorch.backends.qualcomm._passes.utils import ( + get_passes_dependency_for_capture_program, +) +from executorch.backends.qualcomm.utils.constants import ( + QCOM_PASS_ACTIVATE_KEY, + QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY, +) +from executorch.backends.transforms.decompose_sdpa import ( + DecomposeScaledDotProductAttention, +) +from executorch.exir import ExportedProgram +from executorch.exir.pass_manager import PassManager +from executorch.exir.program._program import ( + _get_updated_graph_signature, + lift_constant_tensor_pass, +) +from torch.fx import GraphModule +from torch.fx.passes.infra.pass_manager import this_before_that_pass_constraint + + +def get_capture_program_passes(): + """ + Defines and returns the default ordered passes for the capture program. + This function creates an OrderedDict containing a series of default passes. + + Returns: + OrderedDict: An ordered dictionary containing all default passes along with their activation status and initialization parameters. + """ + + # The second value in each tuple in `default_passes_and_setting` indicates whether the corresponding pass is activated by default. + # If a pass is activated, it will be executed by default. + default_passes_and_setting = [ + (AnnotateQuantAttrs, True), + (AnnotateStack, False), + (AnnotateUnbind, True), + (ConvertBmmToMatmul, True), + (ConvertConv1dToConv2d, True), + (DecomposeAny, True), + (ExpandBroadcastTensorShape, False), + (FixedLinearKeepDim, True), + (FoldQDQ, True), + (I64toI32, True), + (LayoutTransform, True), + (RecomposePixelUnshuffle, True), + (RecomposeRmsNorm, False), + (RemoveRedundancy, True), + (ReplaceIndexPutInput, True), + (TagQuantIO, False), + ] + + passes = OrderedDict() + for p, act in default_passes_and_setting: + init_signature = inspect.signature(p.__init__) + + args_kwargs_defaults = { + k: v.default if v.default is not inspect.Parameter.empty else None + for k, v in init_signature.parameters.items() + if k != "self" + } + + passes[p] = { + QCOM_PASS_ACTIVATE_KEY: act, + QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY: args_kwargs_defaults, + } + + return passes + + +class QnnPassManager(PassManager): + + def __init__(self) -> None: + super().__init__() + + def _transform(self, graph_module: GraphModule): + return self(graph_module).graph_module + + # TODO: Move these passes into qnn_partitioner and qnn_preprocess to + # prevent users from needing to call custom APIs like capture_program + def get_to_edge_transform_passes( + self, + exported_program: ExportedProgram, + passes_job: OrderedDict = None, + dep_table: Dict = None, + ): + # TODO: remove this workaround when target could be correctly detected + from executorch.backends.qualcomm._passes import utils + from executorch.exir.dialects._ops import ops as exir_ops + + utils.q_ops.add(exir_ops.edge.pt2e_quant.quantize_affine.default) + utils.dq_ops.add(exir_ops.edge.pt2e_quant.dequantize_affine.default) + + passes_job = ( + passes_job if passes_job is not None else get_capture_program_passes() + ) + dep_table = ( + dep_table + if dep_table is not None + else get_passes_dependency_for_capture_program() + ) + for that, these in dep_table.items(): + for this in these: + self.add_constraint(this_before_that_pass_constraint(this, that)) + for p in passes_job: + self.add_pass(p) + self.solve_constraints() + + sorted_passes = self.passes + self.passes = [] + for p in sorted_passes: + if not passes_job[p][QCOM_PASS_ACTIVATE_KEY]: + continue + + kwargs = passes_job[p][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY] + if "edge_program" in kwargs: + kwargs["edge_program"] = exported_program + self.add_pass(p(**kwargs)) + return self.passes + + def transform_for_to_edge_pipeline( + self, + exported_program: ExportedProgram, + passes_job: OrderedDict = None, + dep_table: Dict = None, + ): + transform_passes = self.get_to_edge_transform_passes( + exported_program, passes_job=passes_job, dep_table=dep_table + ) + for p in transform_passes: + p(exported_program.graph_module) + exported_program._graph_signature = _get_updated_graph_signature( + exported_program.graph_signature, + exported_program.graph_module, + ) + exported_program._validate() + + return exported_program + + def transform_for_export_pipeline(self, exported_program: ExportedProgram): + self.add_pass(DecomposeScaledDotProductAttention()) + self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True)) + self.add_pass(DecomposeExpM1()) + self.add_pass(LiftConstantScalarOperands()) + self._transform(exported_program.graph_module) + ep = lift_constant_tensor_pass(exported_program) + return ep + + def transform_for_preprocess_pipeline(self, exported_program: ExportedProgram): + self.add_pass(InsertRequantize()) + self.add_pass(InsertIOQDQ(exported_program)) + self.add_pass(LayoutTransform(exported_program, insert_permute=True)) + self.add_pass(FuseConsecutiveTranspose()) + return self._transform(exported_program.graph_module) + + def transform_for_annotation_pipeline(self, graph_module: GraphModule): + self.add_pass(ReduceDynamicRange()) + self.add_pass(RecomposePixelUnshuffle(quantization_capture=True)) + self.add_pass(ReplaceArangeArgs()) + self.add_pass(DecomposeScaledDotProductAttention()) + self.add_pass(DecomposeSilu()) + self.add_pass(DecomposeEinsum()) + self.add_pass(DecomposeExpM1()) + self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True)) + self.add_pass(ReplaceInfValues()) + self.add_pass(LiftConstantScalarOperands()) + return self._transform(graph_module) diff --git a/backends/qualcomm/_passes/recompose_prelu.py b/backends/qualcomm/_passes/recompose_prelu.py deleted file mode 100644 index 082b9c83b27..00000000000 --- a/backends/qualcomm/_passes/recompose_prelu.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (c) Qualcomm Innovation Center, Inc. -# All rights reserved -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -from typing import List - -import torch -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass, PassResult -from torch.fx.passes.utils.source_matcher_utils import get_source_partitions - - -class RecomposePReLU(ExportPass): - """ - Merge decomposed operators from prelu back to one super node. - """ - - def __init__(self, edge_program: torch.export.ExportedProgram): - super(RecomposePReLU, self).__init__() - self.edge_program = edge_program - - def _get_coeff_node(self, nodes: List[torch.fx.Node]): - for node in nodes: - if node.target == exir_ops.edge.aten.view_copy.default: - return node.args[0] - - def _get_input_node(self, nodes: List[torch.fx.Node], coeff_node): - return [n for n in nodes if n != coeff_node][0] - - def call(self, graph_module: torch.fx.GraphModule): - graph = graph_module.graph - partitions = get_source_partitions(graph, [torch.nn.PReLU, torch.nn.LeakyReLU]) - for _, src_partitions in partitions.items(): - for src_partition in src_partitions: - # somehow op might not be decomposed, skip it - if len(src_partition.nodes) == 1: - continue - - coeff_node = self._get_coeff_node(src_partition.nodes) - input_node = self._get_input_node(src_partition.input_nodes, coeff_node) - output_node = src_partition.output_nodes[0] - - with graph.inserting_before(output_node): - prelu_op = exir_ops.edge.aten.prelu.default - prelu_node = graph.create_node( - "call_function", prelu_op, (input_node, coeff_node) - ) - users = output_node.users.copy() - for user in users: - user.replace_input_with(output_node, prelu_node) - # copy metadata - prelu_node.meta = output_node.meta - - graph.eliminate_dead_code() - graph_module.recompile() - return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/recompose_rms_norm.py b/backends/qualcomm/_passes/recompose_rms_norm.py index 77feecf9c1f..a5db826ab28 100644 --- a/backends/qualcomm/_passes/recompose_rms_norm.py +++ b/backends/qualcomm/_passes/recompose_rms_norm.py @@ -15,6 +15,8 @@ class RecomposeRmsNorm(ExportPass): """ Merge decomposed operators back to one super node. + TODO: After replacing export_to_edge with to_edge_transform_and_lowering + in examples/models/llama/export_llama_lib.py, this pass can be removed """ def __init__(self, edge_program: torch.export.ExportedProgram): diff --git a/backends/qualcomm/_passes/replace_index_put_input.py b/backends/qualcomm/_passes/replace_index_put_input.py index 1eb210cf67e..dcdf2bb3a7f 100644 --- a/backends/qualcomm/_passes/replace_index_put_input.py +++ b/backends/qualcomm/_passes/replace_index_put_input.py @@ -22,9 +22,8 @@ class ReplaceIndexPutInput(ExportPass): exir_ops.edge.quantized_decomposed.dequantize_per_channel.default: exir_ops.edge.quantized_decomposed.quantize_per_channel.default, } - def __init__(self, edge_program: torch.export.ExportedProgram): + def __init__(self): super(ReplaceIndexPutInput, self).__init__() - self.edge_program = edge_program def call(self, graph_module: torch.fx.GraphModule): graph = graph_module.graph diff --git a/backends/qualcomm/_passes/tag_quant_io.py b/backends/qualcomm/_passes/tag_quant_io.py new file mode 100644 index 00000000000..bb23d18b0ad --- /dev/null +++ b/backends/qualcomm/_passes/tag_quant_io.py @@ -0,0 +1,45 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Callable + +import torch +from executorch.backends.qualcomm.utils.constants import ( + QCOM_QUANT_ATTRS, + QCOM_QUANT_ATTRS_MAP, + QCOM_QUANTIZED_IO, +) +from executorch.exir.pass_base import ExportPass, PassResult + + +class TagQuantIO(ExportPass): + """ + Tag the IO nodes that handle quantized tensors to avoid inserting Q/DQ operations in qnn_preprocess. + """ + + def __init__(self, get_quant_io_dtype_fn: Callable = None): + super(TagQuantIO, self).__init__() + self.get_quant_io_dtype_fn = get_quant_io_dtype_fn + + def _tag_quant_io(self, gm: torch.fx.GraphModule): + for node in gm.graph.nodes: + if dtype := self.get_quant_io_dtype_fn(node): + node.meta[QCOM_QUANTIZED_IO] = dtype + + def _record_output_quant_attrs_map(self, gm: torch.fx.GraphModule): + for node in gm.graph.nodes: + if node.op == "output": + node.meta.setdefault(QCOM_QUANT_ATTRS_MAP, {}) + for arg in node.args[0]: + if QCOM_QUANT_ATTRS in arg.meta: + node.meta[QCOM_QUANT_ATTRS_MAP][arg] = arg.meta[ + QCOM_QUANT_ATTRS + ] + + def call(self, graph_module: torch.fx.GraphModule): + self._tag_quant_io(graph_module) + self._record_output_quant_attrs_map(graph_module) + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/utils.py b/backends/qualcomm/_passes/utils.py index 0c838e9a676..d538fe0d34f 100755 --- a/backends/qualcomm/_passes/utils.py +++ b/backends/qualcomm/_passes/utils.py @@ -63,62 +63,61 @@ def get_quant_attrs( def get_passes_dependency_for_capture_program(): """ - This function records the dependencies for passes used in the capture_program. + This function records the dependencies for passes used in the to_edge_transform_and_lower_to_qnn. It returns a dictionary where the keys are pass classes and the values are lists of dependencies required by each pass. This helps in managing and organizing the sequence - of passes needed for the capture_program to function correctly. + of passes needed for the to_edge_transform_and_lower_to_qnn to function correctly. Returns: dict: A dictionary mapping each pass to its corresponding list of dependencies. """ from executorch.backends.qualcomm._passes import ( - AnnotateDecomposed, AnnotateQuantAttrs, - ConstantI64toI32, + AnnotateStack, + AnnotateUnbind, ConvertBmmToMatmul, ConvertConv1dToConv2d, - ConvertToLinear, DecomposeAny, DecomposeLinalgVectorNorm, ExpandBroadcastTensorShape, + FixedLinearKeepDim, FoldQDQ, + I64toI32, LayoutTransform, RecomposePixelUnshuffle, - RecomposePReLU, RecomposeRmsNorm, RemoveRedundancy, ReplaceIndexPutInput, - TensorI64toI32, + TagQuantIO, ) return { - AnnotateDecomposed: [RemoveRedundancy], AnnotateQuantAttrs: [ RecomposePixelUnshuffle, - RecomposeRmsNorm, - ConvertToLinear, - RecomposePReLU, ConvertBmmToMatmul, + RemoveRedundancy, ], - ConstantI64toI32: [RemoveRedundancy], - ConvertBmmToMatmul: [ConvertToLinear], + AnnotateStack: [RemoveRedundancy], + AnnotateUnbind: [RemoveRedundancy], + ConvertBmmToMatmul: [RecomposePixelUnshuffle], ConvertConv1dToConv2d: [FoldQDQ], - ConvertToLinear: [RecomposePixelUnshuffle], DecomposeAny: [RemoveRedundancy], DecomposeLinalgVectorNorm: [RemoveRedundancy], ExpandBroadcastTensorShape: [RemoveRedundancy], - FoldQDQ: [AnnotateQuantAttrs, AnnotateDecomposed], + FixedLinearKeepDim: [FoldQDQ], + FoldQDQ: [AnnotateQuantAttrs, AnnotateStack, AnnotateUnbind], + I64toI32: [RemoveRedundancy], LayoutTransform: [ AnnotateQuantAttrs, ConvertConv1dToConv2d, ExpandBroadcastTensorShape, + FixedLinearKeepDim, ], RecomposePixelUnshuffle: [RemoveRedundancy], - RecomposePReLU: [RemoveRedundancy], RecomposeRmsNorm: [RemoveRedundancy], ReplaceIndexPutInput: [LayoutTransform], - TensorI64toI32: [RemoveRedundancy], + TagQuantIO: [ReplaceIndexPutInput], } diff --git a/backends/qualcomm/builders/op_split_with_sizes.py b/backends/qualcomm/builders/op_split_with_sizes.py index 629110b3084..138f6ed60ec 100644 --- a/backends/qualcomm/builders/op_split_with_sizes.py +++ b/backends/qualcomm/builders/op_split_with_sizes.py @@ -17,7 +17,7 @@ @register_node_visitor class SplitWithSizes(NodeVisitor): - target = ["aten.split_with_sizes.default"] + target = ["aten.split_with_sizes.default", "aten.split_with_sizes_copy.default"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/partition/qnn_partitioner.py b/backends/qualcomm/partition/qnn_partitioner.py index 93b1d50f5fe..7b5e72d461d 100644 --- a/backends/qualcomm/partition/qnn_partitioner.py +++ b/backends/qualcomm/partition/qnn_partitioner.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import copy from collections import defaultdict -from typing import Any, Dict, List +from typing import Any, Callable, Dict, List, Optional, Tuple import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch @@ -24,6 +24,7 @@ PartitionResult, ) from executorch.exir.backend.utils import tag_constant_data +from torch.export.exported_program import ExportedProgram from torch.fx.passes.infra.partitioner import Partition from torch.fx.passes.operator_support import OperatorSupportBase @@ -33,7 +34,7 @@ not_supported_operator, to_be_implemented_operator, ) -from .utils import generate_qnn_executorch_option +from .utils import generate_qnn_executorch_option, get_skip_decomp_table class QnnOperatorSupport(OperatorSupportBase): @@ -174,3 +175,11 @@ def partition(self, edge_program: torch.export.ExportedProgram) -> PartitionResu return PartitionResult( tagged_exported_program=edge_program, partition_tags=self.partition_tags ) + + # override + def ops_to_not_decompose( + self, ep: ExportedProgram + ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: + do_not_decompose = get_skip_decomp_table() + + return do_not_decompose, None diff --git a/backends/qualcomm/partition/utils.py b/backends/qualcomm/partition/utils.py index 88b922d4e1f..1e2b17b2a69 100644 --- a/backends/qualcomm/partition/utils.py +++ b/backends/qualcomm/partition/utils.py @@ -6,6 +6,8 @@ from typing import List +import torch + from executorch.backends.qualcomm.utils.constants import QCOM_QNN_COMPILE_SPEC from executorch.exir.backend.compile_spec_schema import CompileSpec @@ -20,3 +22,26 @@ def generate_qnn_executorch_option( else: raise ValueError(f"unknown compiler spec key value: {compiler_spec.key}") return qnn_compile_spec_buffer + + +def get_skip_decomp_table() -> List[torch._ops.OperatorBase]: + do_not_decompose = [ + torch.ops.aten.adaptive_avg_pool2d.default, + torch.ops.aten.elu.default, + torch.ops.aten.hardsigmoid.default, + torch.ops.aten.hardswish.default, + torch.ops.aten.instance_norm.default, + torch.ops.aten.leaky_relu.default, + torch.ops.aten.linear.default, + torch.ops.aten.pixel_shuffle.default, + torch.ops.aten.pixel_unshuffle.default, + torch.ops.aten.prelu.default, + torch.ops.aten.rms_norm.default, + torch.ops.aten._safe_softmax.default, + torch.ops.aten.stack.default, + # This request is ignored because it is in a blocklist. Refer to exir/program/_program.py + # torch.ops.aten.unbind.int, + torch.ops.pt2e_quant.quantize_affine.default, + torch.ops.pt2e_quant.dequantize_affine.default, + ] + return do_not_decompose diff --git a/backends/qualcomm/qnn_preprocess.py b/backends/qualcomm/qnn_preprocess.py index 8afb4851814..4a11bf050a2 100644 --- a/backends/qualcomm/qnn_preprocess.py +++ b/backends/qualcomm/qnn_preprocess.py @@ -11,12 +11,7 @@ import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager import torch # noqa: F401 -from executorch.backends.qualcomm._passes import ( - FuseConsecutiveTranspose, - InsertIOQDQ, - InsertRequantize, - LayoutTransform, -) +from executorch.backends.qualcomm._passes.qnn_pass_manager import QnnPassManager from executorch.backends.qualcomm.builders.node_visitor import get_node_visitors from executorch.backends.qualcomm.builders.qnn_constants import OpContextLoader from executorch.backends.qualcomm.partition.utils import generate_qnn_executorch_option @@ -25,7 +20,6 @@ CompileSpec, PreprocessResult, ) -from executorch.exir.passes import PassManager from torch.export.exported_program import ExportedProgram DEFAULT_DEBUG_HANDLE = 65535 @@ -46,17 +40,8 @@ def preprocess( qnn_manager.Init() # QNN Delegate Specific Passes - qnn_compiler_passes = PassManager( - passes=[ - InsertRequantize(edge_program), - InsertIOQDQ(edge_program), - LayoutTransform(edge_program, insert_permute=True), - FuseConsecutiveTranspose(), - ] - ) - - pass_result = qnn_compiler_passes(edge_program.graph_module) - assert pass_result is not None + graph_module = QnnPassManager().transform_for_preprocess_pipeline(edge_program) + assert graph_module is not None enable_tensor_dump = qnn_manager.IsTensorDump() nodes_to_wrappers = defaultdict(dict) @@ -64,7 +49,7 @@ def preprocess( edge_program, enable_tensor_dump=enable_tensor_dump ) py_op_wrapper_list = [] - for node in pass_result.graph_module.graph.nodes: + for node in graph_module.graph.nodes: if node.op == "call_function": logger.info(f"Visiting: {node}, {node.target.__name__}") if node.target.__name__ in node_visitors: diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index 028ffb69f1d..3620841aff9 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -8,20 +8,7 @@ from typing import Callable, Dict, Optional, Sequence, Set, Tuple import torch -from executorch.backends.qualcomm._passes import ( - DecomposeEinsum, - DecomposeExpM1, - DecomposeLinalgVectorNorm, - DecomposeSilu, - LiftConstantScalarOperands, - RecomposePixelUnshuffle, - ReduceDynamicRange, - ReplaceArangeArgs, - ReplaceInfValues, -) -from executorch.backends.transforms.decompose_sdpa import ( - DecomposeScaledDotProductAttention, -) +from executorch.backends.qualcomm._passes.qnn_pass_manager import QnnPassManager from torch._ops import OpOverload from torch.ao.quantization.quantizer import Quantizer @@ -273,17 +260,7 @@ def set_per_channel_linear_quant(self, enable: bool) -> None: self._update_per_channel_weight_quant_ops(linear_ops, enable) def transform_for_annotation(self, model: GraphModule) -> GraphModule: - model = ReduceDynamicRange()(model).graph_module - model = RecomposePixelUnshuffle(quantization_capture=True)(model).graph_module - model = ReplaceArangeArgs()(model).graph_module - model = DecomposeScaledDotProductAttention()(model).graph_module - model = DecomposeSilu()(model).graph_module - model = DecomposeEinsum()(model).graph_module - model = DecomposeExpM1()(model).graph_module - model = DecomposeLinalgVectorNorm(aten_dialect_capture=True)(model).graph_module - model = ReplaceInfValues()(model).graph_module - model = LiftConstantScalarOperands()(model).graph_module - return model + return QnnPassManager().transform_for_annotation_pipeline(model) def validate(self, model: GraphModule) -> None: pass diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index c3c439261d2..0857a597d88 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -100,7 +100,7 @@ def forward(self, x): class ArgminViewSqueezeConv2D(torch.nn.Module): def __init__(self): - # This model is mainly to test the PASS TensorI64toI32 + # This model is mainly to test the PASS I64toI32 super().__init__() self.conv = torch.nn.Conv2d( in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1 @@ -235,9 +235,7 @@ class CompositeDelegateModule(torch.nn.Module): def __init__( self, compiler_specs, - partitioner_type, - capture_method, - lowered_method, + to_edge_transform_and_lower_method, quantize_method=None, ) -> None: super().__init__() @@ -255,15 +253,15 @@ def __init__( ] self.lowered_modules = [] for module, sample_input in zip(self.modules, self.sample_inputs): - partitioner = partitioner_type(compiler_specs) if quantize_method: module = quantize_method(module, sample_input) - edge_prog = capture_method(module, sample_input) - edge_prog.exported_program = lowered_method( - edge_prog.exported_program, partitioner + edge_prog = to_edge_transform_and_lower_method( + module, sample_input, compiler_specs ) self.lowered_modules.append( - edge_prog.exported_program.graph_module._modules.get("lowered_module_0") + edge_prog.exported_program().graph_module._modules.get( + "lowered_module_0" + ) ) def forward(self, x, y): @@ -873,9 +871,9 @@ def forward(self, x): class LeakyReLUCustom(torch.nn.Module): - def __init__(self, coeff): + def __init__(self, coeff, inplace=False): super().__init__() - self.leaky_relu = torch.nn.LeakyReLU(coeff) + self.leaky_relu = torch.nn.LeakyReLU(coeff, inplace=inplace) def forward(self, x): return self.leaky_relu(x) diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 05e368f372e..795459a9f77 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -13,18 +13,28 @@ from pathlib import Path import torch +from executorch.backends.qualcomm._passes.qnn_pass_manager import ( + get_capture_program_passes, + QnnPassManager, +) + +from executorch.backends.qualcomm._passes.utils import ( + get_passes_dependency_for_capture_program, +) + from executorch.backends.qualcomm.tests.utils import ( generate_context_binary, - QnnPartitioner, QuantDtype, TestQNN, - to_backend, validate_context_binary, validate_qcir, ) from executorch.backends.qualcomm.utils.constants import ( QCOM_ANNOTATION, QCOM_MODULE, + QCOM_PASS_ACTIVATE_KEY, + QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY, + QCOM_QUANT_ATTRS_MAP, QCOM_QUANT_DTYPE, QCOM_SAMPLE_INPUTS, ) @@ -37,6 +47,7 @@ generate_qnn_executorch_compiler_spec, PyQnnManagerAdaptor, skip_annotation, + to_edge_transform_and_lower_to_qnn, update_spill_fill_size, ) @@ -57,12 +68,7 @@ from collections import defaultdict from typing import List -from executorch.backends.qualcomm._passes import ( - FuseConsecutiveTranspose, - InsertIOQDQ, - InsertRequantize, - LayoutTransform, -) +from executorch.backends.qualcomm._passes import FoldQDQ, TagQuantIO from executorch.backends.qualcomm.builders.node_visitor import get_node_visitors from executorch.backends.qualcomm.debugger.utils import DrawGraph from executorch.examples.models.deeplab_v3 import DeepLabV3ResNet101Model @@ -78,7 +84,6 @@ from executorch.examples.models.wav2letter import Wav2LetterModel from executorch.exir import to_edge from executorch.exir.backend.backend_api import disable_validation -from executorch.exir.passes import PassManager class TestQNNFloatingPointOperator(TestQNN): @@ -555,7 +560,11 @@ def test_qnn_backend_leaky_relu(self): QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)], }, { - QCOM_MODULE: [LeakyReLUCustom(0.05)], # noqa: F405 + QCOM_MODULE: [LeakyReLUCustom(0.05, inplace=False)], # noqa: F405 + QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)], + }, + { + QCOM_MODULE: [LeakyReLUCustom(0.05, inplace=True)], # noqa: F405 QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)], }, ] @@ -1642,6 +1651,10 @@ def test_qnn_backend_leaky_relu(self): QCOM_MODULE: [LeakyReLUCustom(0.05)], # noqa: F405 QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)], }, + { + QCOM_MODULE: [LeakyReLUCustom(0.05, inplace=True)], # noqa: F405 + QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)], + }, ] index = 0 @@ -2254,8 +2267,6 @@ def test_qnn_backend_spill_fill_buffer_size(self): # TODO: Fix self.assertNotEqual(0, max_sf_size) module = LargeTensorLinear() # noqa: F405 sample_input = (torch.randn(1, 256, 512),) - edge_prog = capture_program(module, sample_input) - backend_options = generate_htp_compiler_spec( use_fp16=True, use_multi_contexts=True, @@ -2264,17 +2275,16 @@ def test_qnn_backend_spill_fill_buffer_size(self): soc_model=self.chipset_table[TestQNN.model], backend_options=backend_options, ) - partitioner = QnnPartitioner(compiler_specs) - edge_prog.exported_program = to_backend(edge_prog.exported_program, partitioner) - max_sf_size = update_spill_fill_size(edge_prog.exported_program) + edge_prog = to_edge_transform_and_lower_to_qnn( + module, sample_input, compiler_specs + ) + + max_sf_size = update_spill_fill_size(edge_prog.exported_program()) self.assertNotEqual(0, max_sf_size) def test_qnn_backend_multi_contexts(self): module = SimpleModel() # noqa: F405 sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) - edge_prog = capture_program(module, sample_input) - self.split_graph(edge_prog.exported_program.graph_module, 4) - backend_options = generate_htp_compiler_spec( use_fp16=True, use_dlbc=True, @@ -2284,9 +2294,20 @@ def test_qnn_backend_multi_contexts(self): soc_model=self.chipset_table[TestQNN.model], backend_options=backend_options, ) - partitioner = QnnPartitioner(compiler_specs) - edge_prog.exported_program = to_backend(edge_prog.exported_program, partitioner) - update_spill_fill_size(edge_prog.exported_program) + pass_jobs = get_capture_program_passes() + split_graph_pass, setting = self.split_graph(4) + pass_jobs[split_graph_pass] = setting + dep_table = get_passes_dependency_for_capture_program() + dep_table[split_graph_pass] = [FoldQDQ] + edge_prog = to_edge_transform_and_lower_to_qnn( + module, + sample_input, + compiler_specs, + dep_table=dep_table, + passes_job=pass_jobs, + ) + + update_spill_fill_size(edge_prog.exported_program()) exec_prog = edge_prog.to_executorch() self.verify_output(module, sample_input, exec_prog) @@ -2302,9 +2323,7 @@ def test_qnn_backend_multi_contexts_composite(self): ) module = CompositeDelegateModule( # noqa: F405 compiler_specs=compiler_specs, - partitioner_type=QnnPartitioner, - capture_method=capture_program, - lowered_method=to_backend, + to_edge_transform_and_lower_method=to_edge_transform_and_lower_to_qnn, ) sample_input = module.get_random_input() edge_prog = to_edge( @@ -2323,10 +2342,6 @@ def test_qnn_backend_multi_graphs(self): modules = [seq_conv, seq_conv.second] sample_inputs = [(torch.randn([1, 1, 3, 3]),), (torch.randn([1, 3, 3, 3]),)] graph_names = ["seq_conv", "single_conv"] - edge_progs = [ - capture_program(module, sample_input) - for module, sample_input in zip(modules, sample_inputs) - ] backend_options = generate_htp_compiler_spec( use_fp16=True, ) @@ -2340,15 +2355,17 @@ def test_qnn_backend_multi_graphs(self): ) for graph_name in graph_names ] - exported_programs = [ - to_backend(edge_prog.exported_program, QnnPartitioner(compiler_specs[i])) - for i, edge_prog in enumerate(edge_progs) + edge_progs = [ + to_edge_transform_and_lower_to_qnn(module, sample_input, compiler_spec) + for module, sample_input, compiler_spec in zip( + modules, sample_inputs, compiler_specs + ) ] prog_mgr, _ = generate_multi_graph_program( compiler_specs=compiler_specs[0], processed_bytes=[ - prog.graph_module.lowered_module_0.processed_bytes - for prog in exported_programs + edge_prog.exported_program().graph_module.lowered_module_0.processed_bytes + for edge_prog in edge_progs ], ) for index, module in enumerate(modules): @@ -2428,8 +2445,6 @@ def test_qnn_backend_context_direct(self): ) def test_qnn_backend_context_extraction(self): - from executorch.exir import EdgeCompileConfig, EdgeProgramManager - module = SimpleModel() # noqa: F405 sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) backend_options = generate_htp_compiler_spec(use_fp16=True) @@ -2444,12 +2459,9 @@ def test_qnn_backend_context_extraction(self): validators = [validate_context_binary, validate_qcir] for compiler_spec, validate in zip(compiler_specs, validators): - edge_prog_mgr = EdgeProgramManager( - edge_programs={ - "forward": capture_program(module, sample_input).exported_program - }, - compile_config=EdgeCompileConfig(_use_edge_ops=False), - ).to_backend(QnnPartitioner(compiler_spec)) + edge_prog_mgr = to_edge_transform_and_lower_to_qnn( + module, sample_input, compiler_spec + ) lowered_module = edge_prog_mgr.exported_program().graph_module._modules[ "lowered_module_0" ] @@ -2461,8 +2473,6 @@ def test_qnn_backend_context_extraction(self): validate(binary) def test_qnn_backend_dump_context_from_pte(self): - from executorch.exir import EdgeCompileConfig, EdgeProgramManager - module = SimpleModel() # noqa: F405 sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) backend_options = generate_htp_compiler_spec(use_fp16=True) @@ -2477,18 +2487,9 @@ def test_qnn_backend_dump_context_from_pte(self): validators = [validate_context_binary, validate_qcir] for compiler_spec, validate in zip(compiler_specs, validators): - edge_prog_mgr = ( - EdgeProgramManager( - edge_programs={ - "forward": capture_program( - module, sample_input - ).exported_program - }, - compile_config=EdgeCompileConfig(_use_edge_ops=False), - ) - .to_backend(QnnPartitioner(compiler_spec)) - .to_executorch() - ) + edge_prog_mgr = to_edge_transform_and_lower_to_qnn( + module, sample_input, compiler_spec + ).to_executorch() with tempfile.TemporaryDirectory() as tmp_dir: pte_path = f"{tmp_dir}/model.pte" @@ -2599,25 +2600,15 @@ def test_qnn_backend_draw_graph(self): """ module = DrawGraphModel() # noqa: F405 sample_input = (torch.randn(1, 32, 28, 28),) + # TODO: Figure out how to get the original graph module with the to_edge_transform_and_lowering_to_qnn API delegated_program = capture_program(module, sample_input) """ This piece of code simulates the behavior of the final preprocessing step to obtain the op wrapper list. In practice, users need to set a breakpoint in the preprocessing step and use the DrawGraph tool to visualize the graph. """ - qnn_compiler_passes = PassManager( - passes=[ - InsertRequantize(delegated_program.exported_program), - InsertIOQDQ(delegated_program.exported_program), - LayoutTransform( - delegated_program.exported_program, insert_permute=True - ), - FuseConsecutiveTranspose(), - ] - ) - - pass_result = qnn_compiler_passes( - delegated_program.exported_program.graph_module + graph_module = QnnPassManager().transform_for_preprocess_pipeline( + delegated_program.exported_program ) nodes_to_wrappers = defaultdict(dict) node_visitors = get_node_visitors( @@ -2625,7 +2616,7 @@ def test_qnn_backend_draw_graph(self): ) py_op_wrapper_list = [] - for node in pass_result.graph_module.graph.nodes: + for node in graph_module.graph.nodes: if node.op == "call_function": if node.target.__name__ in node_visitors: py_op_wrapper = node_visitors[node.target.__name__].define_node( @@ -2689,12 +2680,7 @@ def test_qnn_backend_dynamic_shape(self): QCOM_DTYPE, QCOM_QUANT_ATTRS, ) - from executorch.backends.qualcomm.utils.utils import tag_quant_io - from executorch.exir.capture._config import ( - EdgeCompileConfig, - ExecutorchBackendConfig, - ) - from executorch.exir.program import EdgeProgramManager + from executorch.exir.capture._config import ExecutorchBackendConfig module = Add() # noqa: F405 last_dim = torch.export.Dim("last_dim", min=1, max=8) @@ -2714,31 +2700,32 @@ def test_qnn_backend_dynamic_shape(self): ) # only few ops with 16bit are supported with dynamic shape now # strip unsupported quantize / dequantize ops generated in preprocess - prog = capture_program(module, sample_input, dynamic_shapes=dynamic_shapes) - tag_quant_io( - prog.exported_program.graph_module, - lambda n: ( - torch.uint16 - if any(name in n.name for name in ["x", "y", "add"]) - else None - ), + pass_jobs = get_capture_program_passes() + pass_jobs[TagQuantIO][QCOM_PASS_ACTIVATE_KEY] = True + pass_jobs[TagQuantIO][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY][ + "get_quant_io_dtype_fn" + ] = lambda n: ( + torch.uint16 if any(name in n.name for name in ["x", "y", "add"]) else None ) + edge_prog_mgr = to_edge_transform_and_lower_to_qnn( + module, + sample_input, + self.compiler_specs, + dynamic_shapes=dynamic_shapes, + passes_job=pass_jobs, + ) + # collect encodings for ios input_encodings, output_encodings = [], [] - for n in prog.exported_program.graph.nodes: + for n in edge_prog_mgr.exported_program().graph.nodes: if n.op == "placeholder": input_encodings.append(n.meta[QCOM_QUANT_ATTRS]) input_encodings[-1][QCOM_DTYPE] = torch.uint16 elif n.op == "output": - for arg in n.args[0]: - output_encodings.append(arg.meta[QCOM_QUANT_ATTRS]) - output_encodings[-1][QCOM_DTYPE] = torch.uint16 + output_encodings = n.meta[QCOM_QUANT_ATTRS_MAP].values() + for output_encoding in output_encodings: + output_encoding[QCOM_DTYPE] = torch.uint16 - edge_prog_mgr = EdgeProgramManager( - edge_programs={"forward": prog.exported_program}, - compile_config=EdgeCompileConfig(_check_ir_validity=False), - ) - edge_prog_mgr = edge_prog_mgr.to_backend(QnnPartitioner(self.compiler_specs)) exec_prog = edge_prog_mgr.to_executorch( ExecutorchBackendConfig(passes=[BuildQuantIo()]) ) @@ -2771,7 +2758,7 @@ def test_qnn_backend_skip_node_id_quantizer(self): module = SimpleModel() # noqa: F405 sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) - # define partitioner + # define compile specs backend_options = generate_htp_compiler_spec( use_fp16=False, ) @@ -2779,7 +2766,6 @@ def test_qnn_backend_skip_node_id_quantizer(self): soc_model=self.chipset_table[TestQNN.model], backend_options=backend_options, ) - partitioner = QnnPartitioner(compiler_specs) # define quantizer quantizer = make_quantizer() @@ -2788,15 +2774,15 @@ def calibrator(gm): gm(*sample_input) # get partially lowererd graph module - graph_module, exported_progs = skip_annotation( + graph_module, edge_prog_mgrs = skip_annotation( nn_module=module, quantizer=quantizer, - partitioner=partitioner, + compiler_specs=compiler_specs, sample_input=sample_input, calibration_cb=calibrator, fp_node_id_set={"conv2d"}, ) - self.assertEqual(len(exported_progs), 1) + self.assertEqual(len(edge_prog_mgrs), 1) # lower all graph again, the skipped operators will be left in CPU exec_prog = to_edge( torch.export.export(graph_module, sample_input, strict=True), @@ -2818,7 +2804,7 @@ def test_qnn_backend_skip_node_op_quantizer(self): module = SimpleModel() # noqa: F405 sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) - # define partitioner + # define compile specs backend_options = generate_htp_compiler_spec( use_fp16=False, ) @@ -2826,7 +2812,6 @@ def test_qnn_backend_skip_node_op_quantizer(self): soc_model=self.chipset_table[TestQNN.model], backend_options=backend_options, ) - partitioner = QnnPartitioner(compiler_specs) # define quantizer quantizer = make_quantizer() @@ -2835,15 +2820,15 @@ def calibrator(gm): gm(*sample_input) # get partially lowererd graph module - graph_module, exported_progs = skip_annotation( + graph_module, edge_prog_mgrs = skip_annotation( nn_module=module, quantizer=quantizer, - partitioner=partitioner, + compiler_specs=compiler_specs, sample_input=sample_input, calibration_cb=calibrator, fp_node_op_set={torch.ops.aten.add.Tensor}, ) - self.assertEqual(len(exported_progs), 2) + self.assertEqual(len(edge_prog_mgrs), 2) # lower all graph again, the skipped operators will be left in CPU exec_prog = exec_prog = to_edge( torch.export.export(graph_module, sample_input, strict=True), @@ -2856,8 +2841,6 @@ def test_qnn_backend_spill_fill_buffer_size(self): module = LargeTensorLinear() # noqa: F405 sample_input = (torch.randn(1, 256, 512),) module = self.get_qdq_module(module, sample_input) - edge_prog = capture_program(module, sample_input) - backend_options = generate_htp_compiler_spec( use_fp16=False, use_multi_contexts=True, @@ -2866,16 +2849,18 @@ def test_qnn_backend_spill_fill_buffer_size(self): soc_model=self.chipset_table[TestQNN.model], backend_options=backend_options, ) - partitioner = QnnPartitioner(compiler_specs) - edge_prog.exported_program = to_backend(edge_prog.exported_program, partitioner) - max_sf_size = update_spill_fill_size(edge_prog.exported_program) + edge_prog = to_edge_transform_and_lower_to_qnn( + module, sample_input, compiler_specs + ) + + max_sf_size = update_spill_fill_size(edge_prog.exported_program()) self.assertNotEqual(0, max_sf_size) def test_qnn_backend_graph_level_mixed_precision(self): module = SimpleModel() # noqa: F405 sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) - # define partitioner + # define compile spec backend_options = generate_htp_compiler_spec( use_fp16=False, ) @@ -2883,7 +2868,6 @@ def test_qnn_backend_graph_level_mixed_precision(self): soc_model=self.chipset_table[TestQNN.model], backend_options=backend_options, ) - partitioner = QnnPartitioner(compiler_specs) # define quantizer quantizer = make_quantizer() @@ -2892,16 +2876,16 @@ def calibrator(gm): gm(*sample_input) # get partially lowererd graph module - graph_module, exported_progs = skip_annotation( + graph_module, edge_prog_mgrs = skip_annotation( nn_module=module, quantizer=quantizer, - partitioner=partitioner, + compiler_specs=compiler_specs, sample_input=sample_input, calibration_cb=calibrator, fp_node_id_set={"add", "mean"}, fallback_to_cpu=False, ) - self.assertEqual(len(exported_progs), 5) + self.assertEqual(len(edge_prog_mgrs), 5) # lower all graph again, the skipped operators will be delegated with fp16 exec_prog = to_edge( torch.export.export(graph_module, sample_input, strict=True), @@ -2912,9 +2896,6 @@ def test_qnn_backend_multi_contexts(self): module = SimpleModel() # noqa: F405 sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) module = self.get_qdq_module(module, sample_input) - edge_prog = capture_program(module, sample_input) - self.split_graph(edge_prog.exported_program.graph_module, 4) - backend_options = generate_htp_compiler_spec( use_fp16=False, use_dlbc=True, @@ -2924,9 +2905,20 @@ def test_qnn_backend_multi_contexts(self): soc_model=self.chipset_table[TestQNN.model], backend_options=backend_options, ) - partitioner = QnnPartitioner(compiler_specs) - edge_prog.exported_program = to_backend(edge_prog.exported_program, partitioner) - update_spill_fill_size(edge_prog.exported_program) + pass_jobs = get_capture_program_passes() + split_graph_pass, setting = self.split_graph(4) + pass_jobs[split_graph_pass] = setting + dep_table = get_passes_dependency_for_capture_program() + dep_table[split_graph_pass] = [FoldQDQ] + edge_prog = to_edge_transform_and_lower_to_qnn( + module, + sample_input, + compiler_specs, + dep_table=dep_table, + passes_job=pass_jobs, + ) + + update_spill_fill_size(edge_prog.exported_program()) exec_prog = edge_prog.to_executorch() self.verify_output(module, sample_input, exec_prog) @@ -2942,9 +2934,7 @@ def test_qnn_backend_multi_contexts_composite(self): ) module = CompositeDelegateModule( # noqa: F405 compiler_specs=compiler_specs, - partitioner_type=QnnPartitioner, - capture_method=capture_program, - lowered_method=to_backend, + to_edge_transform_and_lower_method=to_edge_transform_and_lower_to_qnn, quantize_method=self.get_qdq_module, ) sample_input = module.get_random_input() @@ -2964,10 +2954,6 @@ def test_qnn_backend_multi_graphs(self): modules = [seq_conv, seq_conv.second] sample_inputs = [(torch.randn([1, 1, 3, 3]),), (torch.randn([1, 3, 3, 3]),)] graph_names = ["seq_conv", "single_conv"] - edge_progs = [ - capture_program(self.get_qdq_module(module, sample_input), sample_input) - for module, sample_input in zip(modules, sample_inputs) - ] backend_options = generate_htp_compiler_spec( use_fp16=False, ) @@ -2981,15 +2967,19 @@ def test_qnn_backend_multi_graphs(self): ) for graph_name in graph_names ] - exported_programs = [ - to_backend(edge_prog.exported_program, QnnPartitioner(compiler_specs[i])) - for i, edge_prog in enumerate(edge_progs) + edge_progs = [ + to_edge_transform_and_lower_to_qnn( + self.get_qdq_module(module, sample_input), sample_input, compiler_spec + ) + for module, sample_input, compiler_spec in zip( + modules, sample_inputs, compiler_specs + ) ] prog_mgr, _ = generate_multi_graph_program( compiler_specs=compiler_specs[0], processed_bytes=[ - prog.graph_module.lowered_module_0.processed_bytes - for prog in exported_programs + edge_prog.exported_program().graph_module.lowered_module_0.processed_bytes + for edge_prog in edge_progs ], ) for index, module in enumerate(modules): @@ -3072,8 +3062,6 @@ def test_qnn_backend_context_direct(self): ) def test_qnn_backend_context_extraction(self): - from executorch.exir import EdgeCompileConfig, EdgeProgramManager - module = SimpleModel() # noqa: F405 sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) module = self.get_qdq_module(module, sample_input) @@ -3089,12 +3077,9 @@ def test_qnn_backend_context_extraction(self): validators = [validate_context_binary, validate_qcir] for compiler_spec, validate in zip(compiler_specs, validators): - edge_prog_mgr = EdgeProgramManager( - edge_programs={ - "forward": capture_program(module, sample_input).exported_program - }, - compile_config=EdgeCompileConfig(_use_edge_ops=False), - ).to_backend(QnnPartitioner(compiler_spec)) + edge_prog_mgr = to_edge_transform_and_lower_to_qnn( + module, sample_input, compiler_spec + ) lowered_module = edge_prog_mgr.exported_program().graph_module._modules[ "lowered_module_0" ] @@ -3106,8 +3091,6 @@ def test_qnn_backend_context_extraction(self): validate(binary) def test_qnn_backend_dump_context_from_pte(self): - from executorch.exir import EdgeCompileConfig, EdgeProgramManager - module = SimpleModel() # noqa: F405 sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) module = self.get_qdq_module(module, sample_input) @@ -3123,18 +3106,9 @@ def test_qnn_backend_dump_context_from_pte(self): validators = [validate_context_binary, validate_qcir] for compiler_spec, validate in zip(compiler_specs, validators): - edge_prog_mgr = ( - EdgeProgramManager( - edge_programs={ - "forward": capture_program( - module, sample_input - ).exported_program - }, - compile_config=EdgeCompileConfig(_use_edge_ops=False), - ) - .to_backend(QnnPartitioner(compiler_spec)) - .to_executorch() - ) + edge_prog_mgr = to_edge_transform_and_lower_to_qnn( + module, sample_input, compiler_spec + ).to_executorch() with tempfile.TemporaryDirectory() as tmp_dir: pte_path = f"{tmp_dir}/model.pte" @@ -3167,9 +3141,9 @@ def test_qnn_backend_draw_graph(self): dims: [1, 28, 28, 32] quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_SCALE_OFFSET > color=black fillcolor=transparent shape=box style=rounded] - quantized_decomposed_quantize_per_tensor_default_8_0 [label=< + quantized_decomposed_quantize_per_tensor_default_0 [label=< - + @@ -3247,12 +3221,12 @@ def test_qnn_backend_draw_graph(self):
name: quantized_decomposed_quantize_per_tensor_default_8_0
name: quantized_decomposed_quantize_per_tensor_default_0
data_type: Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_8
tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
dims: [1, 32, 28, 28]
dims: [32]
quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET
> color=black fillcolor=transparent shape=box style=rounded] - quantized_decomposed_quantize_per_tensor_default_8_0 -> aten_convolution_default_0 - input_0_x_0 -> quantized_decomposed_quantize_per_tensor_default_8_0 + quantized_decomposed_quantize_per_tensor_default_0 -> aten_convolution_default_0 + input_0_x_0 -> quantized_decomposed_quantize_per_tensor_default_0 b__frozen_param0_0 -> aten_convolution_default_0 b__frozen_param1_0 -> aten_convolution_default_0 aten_convolution_default_0 -> aten_relu_default_0 - quantized_decomposed_quantize_per_tensor_default_8_0 -> aten_convolution_default_1_0 + quantized_decomposed_quantize_per_tensor_default_0 -> aten_convolution_default_1_0 b__frozen_param2_0 -> aten_convolution_default_1_0 b__frozen_param3_0 -> aten_convolution_default_1_0 aten_convolution_default_1_0 -> aten_relu_default_1_0 @@ -3264,25 +3238,15 @@ def test_qnn_backend_draw_graph(self): module = DrawGraphModel() # noqa: F405 sample_input = (torch.randn(1, 32, 28, 28),) module = self.get_qdq_module(module, sample_input) + # TODO: Figure out how to get the original graph module with to_edge_transform_and_lowering_to_qnn delegated_program = capture_program(module, sample_input) """ This piece of code simulates the behavior of the final preprocessing step to obtain the op wrapper list. In practice, users need to set a breakpoint in the preprocessing step and use the DrawGraph tool to visualize the graph. """ - qnn_compiler_passes = PassManager( - passes=[ - InsertRequantize(delegated_program.exported_program), - InsertIOQDQ(delegated_program.exported_program), - LayoutTransform( - delegated_program.exported_program, insert_permute=True - ), - FuseConsecutiveTranspose(), - ] - ) - - pass_result = qnn_compiler_passes( - delegated_program.exported_program.graph_module + graph_module = QnnPassManager().transform_for_preprocess_pipeline( + delegated_program.exported_program ) nodes_to_wrappers = defaultdict(dict) node_visitors = get_node_visitors( @@ -3290,7 +3254,7 @@ def test_qnn_backend_draw_graph(self): ) py_op_wrapper_list = [] - for node in pass_result.graph_module.graph.nodes: + for node in graph_module.graph.nodes: if node.op == "call_function": if node.target.__name__ in node_visitors: py_op_wrapper = node_visitors[node.target.__name__].define_node( diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index 769f24ba0d8..932a1e2d6fb 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -15,18 +15,19 @@ import torch from executorch import exir -from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner from executorch.backends.qualcomm.qnn_preprocess import QnnBackend from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset from executorch.backends.qualcomm.utils.constants import ( QCOM_DTYPE, + QCOM_PASS_ACTIVATE_KEY, + QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY, QCOM_SCALE, QCOM_ZERO_POINT, ) from executorch.backends.qualcomm.utils.utils import ( - capture_program, get_soc_to_chipset_map, + to_edge_transform_and_lower_to_qnn, ) from executorch.devtools import generate_etrecord, Inspector from executorch.examples.qualcomm.utils import ( @@ -36,7 +37,6 @@ SimpleADB, ) -from executorch.exir.backend.backend_api import to_backend from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -47,6 +47,7 @@ prepare_pt2e, prepare_qat_pt2e, ) +from torch.fx.passes.infra.pass_base import PassResult def generate_context_binary( @@ -359,19 +360,6 @@ def validate_intermediate_tensor(): cmd.append(f"--{name}_path") cmd.append(f"{tmp_dir}/{name}.txt") - dtype_info = { - "input_type_size": input_encodings, - "output_type_size": output_encodings, - } - for name, encodings in dtype_info.items(): - with open(f"{tmp_dir}/{name}.txt", "w") as f: - for e in encodings: - f.write( - f"{torch.tensor([], dtype=e[QCOM_DTYPE]).element_size()}\n" - ) - cmd.append(f"--{name}_path") - cmd.append(f"{tmp_dir}/{name}.txt") - env = dict(os.environ) env["LD_LIBRARY_PATH"] = f"{qnn_sdk}/lib/{target}/:{build_folder}/lib" proc = subprocess.run( @@ -423,16 +411,6 @@ def validate_intermediate_tensor(): if check_io_shape else None ), - expected_input_dtype=( - (encoding[QCOM_DTYPE] for encoding in input_encodings) - if check_io_shape - else None - ), - expected_output_dtype=( - (encoding[QCOM_DTYPE] for encoding in output_encodings) - if check_io_shape - else None - ), ) adb.push(inputs=[processed_inputs], input_list=input_list) adb.execute(method_index=method_index) @@ -461,19 +439,18 @@ def lower_module_and_test_output( skip_node_op_set: set = None, dynamic_shapes: Dict = None, ): - qnn_partitioner = QnnPartitioner( - self.compiler_specs, skip_node_id_set, skip_node_op_set - ) - delegated_program = capture_program( - module, sample_inputs, dynamic_shapes=dynamic_shapes + delegated_program = to_edge_transform_and_lower_to_qnn( + module, + sample_inputs, + self.compiler_specs, + dynamic_shapes=dynamic_shapes, + skip_node_id_set=skip_node_id_set, + skip_node_op_set=skip_node_op_set, ) # this is needed for the ETRecord as lowering modifies the graph in-place edge_copy = copy.deepcopy(delegated_program) - delegated_program.exported_program = to_backend( - delegated_program.exported_program, qnn_partitioner - ) exec_prog = delegated_program.to_executorch( exir.ExecutorchBackendConfig( # For shared buffer, user must pass the memory address @@ -489,12 +466,12 @@ def lower_module_and_test_output( # Assert the backend name is qnn self.assertEqual( - len(exec_prog.program.execution_plan[0].delegates), + len(exec_prog.executorch_program.execution_plan[0].delegates), expected_partitions, ) for i in range(expected_partitions): self.assertEqual( - exec_prog.program.execution_plan[0].delegates[i].id, + exec_prog.executorch_program.execution_plan[0].delegates[i].id, QnnBackend.__name__, ) @@ -599,24 +576,33 @@ def get_converted_sgd_trained_module( optimizer.step() return torch.ao.quantization.quantize_pt2e.convert_pt2e(prepared) - def split_graph(self, graph_module: torch.fx.GraphModule, division: int): + def split_graph(self, division: int): class SplitGraph(ExportPass): """ Split graph based on number of nodes. """ - def __init__(self, shares): + def __init__(self, division): super().__init__() - self.shares = shares + self.division = division def _insert_clone( self, graph_module: torch.fx.GraphModule ) -> torch.fx.GraphModule: + # Count the total of nodes in the graph + num_graph_nodes = 0 + for node in graph_module.graph.nodes: + num_graph_nodes += 1 if node.op == "call_function" else 0 + + # Compute how many nodes in one share + shares = num_graph_nodes // self.division + + # Insert clone op to split model based on the shares num_graph_nodes = 0 for node in graph_module.graph.nodes: num_graph_nodes += 1 if node.op == "call_function" else 0 - if num_graph_nodes % self.shares != 0 or node.op != "call_function": + if num_graph_nodes % shares != 0 or node.op != "call_function": continue with graph_module.graph.inserting_after(node): @@ -635,9 +621,9 @@ def _insert_clone( def call(self, graph_module: torch.fx.GraphModule): self._insert_clone(graph_module) graph_module.recompile() + return PassResult(graph_module, True) - num_graph_nodes = 0 - for node in graph_module.graph.nodes: - num_graph_nodes += 1 if node.op == "call_function" else 0 - - SplitGraph(-(num_graph_nodes // -division))(graph_module) + return SplitGraph, { + QCOM_PASS_ACTIVATE_KEY: True, + QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY: {"division": division}, + } diff --git a/backends/qualcomm/utils/constants.py b/backends/qualcomm/utils/constants.py index 2e364c37119..ce917bf4115 100644 --- a/backends/qualcomm/utils/constants.py +++ b/backends/qualcomm/utils/constants.py @@ -25,6 +25,7 @@ QCOM_ORIG_DTYPE = "orig_dtype" QCOM_QUANTIZED_IO = "q_tensor_io" QCOM_QUANT_ATTRS = "quant_attrs" +QCOM_QUANT_ATTRS_MAP = "quant_attrs_map" QCOM_QUANT_MIN = "quant_min" QCOM_QUANT_MAX = "quant_max" QCOM_REQUANTIZE = "requantize" diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index 7033f30997a..f7b966ee8ea 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -3,43 +3,21 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import inspect import operator import re import time import warnings from collections import OrderedDict -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManagerAdaptor import executorch.exir as exir import torch -from executorch.backends.qualcomm._passes import ( - AnnotateDecomposed, - AnnotateQuantAttrs, - ConstantI64toI32, - ConvertBmmToMatmul, - ConvertConv1dToConv2d, - ConvertToLinear, - DecomposeAny, - DecomposeExpM1, - DecomposeLinalgVectorNorm, - ExpandBroadcastTensorShape, - FoldQDQ, - LayoutTransform, - LiftConstantScalarOperands, - RecomposePixelUnshuffle, - RecomposePReLU, - RecomposeRmsNorm, - RemoveRedundancy, - ReplaceIndexPutInput, -) -from executorch.backends.qualcomm._passes.tensor_i64_to_i32 import TensorI64toI32 -from executorch.backends.qualcomm._passes.utils import ( - get_passes_dependency_for_capture_program, -) + +from executorch.backends.qualcomm._passes import AnnotateStack +from executorch.backends.qualcomm._passes.qnn_pass_manager import QnnPassManager from executorch.backends.qualcomm.builders.node_visitor import ( QNN_QUANT_TYPE_MAP, @@ -48,6 +26,7 @@ from executorch.backends.qualcomm.builders.qnn_constants import OpContextLoader from executorch.backends.qualcomm.partition.qnn_partitioner import ( generate_qnn_executorch_option, + get_skip_decomp_table, QnnPartitioner, ) from executorch.backends.qualcomm.serialization.qc_schema import ( @@ -68,14 +47,9 @@ option_to_flatbuffer, ) from executorch.backends.qualcomm.utils.constants import ( - QCOM_PASS_ACTIVATE_KEY, - QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY, QCOM_QNN_COMPILE_SPEC, QCOM_QUANTIZED_IO, ) -from executorch.backends.transforms.decompose_sdpa import ( - DecomposeScaledDotProductAttention, -) from executorch.exir import ( EdgeCompileConfig, @@ -86,12 +60,13 @@ from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.capture import ExecutorchBackendConfig from executorch.exir.lowered_backend_module import LoweredBackendModule -from executorch.exir.passes import PassManager -from executorch.exir.program._program import _get_updated_graph_signature +from executorch.exir.program._program import ( + EdgeProgramManager, + to_edge_transform_and_lower, +) from torch._decomp import core_aten_decompositions, remove_decompositions from torch.export.exported_program import ExportedProgram from torch.fx import passes -from torch.fx.passes.infra.pass_manager import this_before_that_pass_constraint from torch.fx.passes.operator_support import OperatorSupportBase from torch.library import Library @@ -326,160 +301,114 @@ def canonicalize_program(obj): def get_decomp_table(passes_job) -> Dict[torch._ops.OperatorBase, Callable]: source_decompositions = core_aten_decompositions() # The below super ops are supported by QNN - skip_decompositions = [ - torch.ops.aten.adaptive_avg_pool2d.default, - torch.ops.aten.elu.default, - torch.ops.aten.instance_norm.default, - torch.ops.aten.pixel_shuffle.default, - torch.ops.aten.pixel_unshuffle.default, - torch.ops.aten.hardsigmoid.default, - torch.ops.aten.hardswish.default, - torch.ops.pt2e_quant.quantize_affine.default, - torch.ops.pt2e_quant.dequantize_affine.default, - torch.ops.aten._safe_softmax.default, - torch.ops.aten.stack.default, - torch.ops.aten.unbind.int, - ] + skip_decompositions = get_skip_decomp_table() # If we want to annotate the decomposed ops, then we should decompose the operation. - if passes_job and passes_job.get(AnnotateDecomposed, False): + if passes_job and passes_job.get(AnnotateStack, False): skip_decompositions = [ skip_decomp_op for skip_decomp_op in skip_decompositions - if skip_decomp_op not in AnnotateDecomposed.decomp_ops + if skip_decomp_op not in AnnotateStack.decomp_ops ] remove_decompositions(source_decompositions, skip_decompositions) return source_decompositions -def get_capture_program_passes(): +def to_edge_transform_and_lower_to_qnn( + module: Union[torch.nn.Module, torch.fx.GraphModule], + inputs: Tuple[torch.Tensor], + compiler_specs: List[CompileSpec], + constant_methods: Optional[Dict[str, Any]] = None, + dynamic_shapes: Optional[Dict] = None, + dep_table: Optional[Dict] = None, + passes_job: Optional[OrderedDict] = None, + skip_node_id_set: Optional[set] = None, + skip_node_op_set: Optional[set] = None, +) -> EdgeProgramManager: """ - Defines and returns the default ordered passes for the capture program. - This function creates an OrderedDict containing a series of default passes. + Transforms and lowers a given PyTorch module to QNN backend. + + Args: + module (Union[torch.nn.Module, torch.fx.GraphModule]): The PyTorch module or fx.GraphModule to be transformed. + inputs (Tuple[torch.Tensor]): The input tensors for the module. + compiler_specs (List[CompileSpec]): Compiler specs for Qualcomm AI Engine Direct. + constant_methods (Optional[Dict[str, Any]]): An optional dictionary of method name to the constant value + returned by that method in eager mode. Often used to store config information on + Edge models. + dynamic_shapes (Optional[Dict]): Information about dynamic shapes. + dep_table (Optional[Dict]): Dependency table for the transformation passes. + passes_job (Optional[OrderedDict]): Ordered dictionary of transformation passes. + skip_node_id_set (Optional[set]): Set of node IDs to skip during partitioning. + skip_node_op_set (Optional[set]): Set of node operations to skip during partitioning. Returns: - OrderedDict: An ordered dictionary containing all default passes along with their activation status and initialization parameters. + EdgeProgramManager: The manager for the edge program after transformation and lowering. """ - - # The second value in each tuple in `default_passes_and_setting` indicates whether the corresponding pass is activated by default. - # If a pass is activated, it will be executed by default. - default_passes_and_setting = [ - (AnnotateDecomposed, False), - (AnnotateQuantAttrs, True), - (ConstantI64toI32, True), - (ConvertBmmToMatmul, True), - (ConvertConv1dToConv2d, True), - (ConvertToLinear, True), - (DecomposeAny, True), - (DecomposeLinalgVectorNorm, True), - (ExpandBroadcastTensorShape, False), - (FoldQDQ, True), - (LayoutTransform, True), - (RecomposePReLU, True), - (RecomposePixelUnshuffle, True), - (RecomposeRmsNorm, True), - (RemoveRedundancy, True), - (ReplaceIndexPutInput, True), - (TensorI64toI32, True), - ] - - passes = OrderedDict() - for p, act in default_passes_and_setting: - init_signature = inspect.signature(p.__init__) - - args_kwargs_defaults = { - k: v.default if v.default is not inspect.Parameter.empty else None - for k, v in init_signature.parameters.items() - if k != "self" - } - - passes[p] = { - QCOM_PASS_ACTIVATE_KEY: act, - QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY: args_kwargs_defaults, - } - - return passes - - -def _topological_sort_passes(passes: OrderedDict): - dep_table = get_passes_dependency_for_capture_program() - pm = PassManager() - for p in passes: - pm.add_pass(p) - - for that, these in dep_table.items(): - for this in these: - pm.add_constraint(this_before_that_pass_constraint(this, that)) - - pm.solve_constraints() - sorted_passes = OrderedDict() - for p in pm.passes: - sorted_passes[p] = passes[p] - return sorted_passes - - -def _transform( - edge_program: ExportedProgram, passes_job: OrderedDict = None -) -> ExportedProgram: - # TODO: remove this workaround when target could be correclty detected - from executorch.backends.qualcomm._passes import utils - from executorch.exir.dialects._ops import ops as exir_ops - - utils.q_ops.add(exir_ops.edge.pt2e_quant.quantize_affine.default) - utils.dq_ops.add(exir_ops.edge.pt2e_quant.dequantize_affine.default) - - # currently ExirExportedProgram.transform does not accept - # changes of input number which was caused by FoldQDQ - # apply passes one by one here to avoid IR capture failure - graph_module = edge_program.graph_module - passes_job = passes_job if passes_job is not None else get_capture_program_passes() - passes_job = _topological_sort_passes(passes_job) - for p in passes_job: - if not passes_job[p][QCOM_PASS_ACTIVATE_KEY]: - continue - - kwargs = passes_job[p][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY] - if "edge_program" in kwargs: - kwargs["edge_program"] = edge_program - p(**kwargs)(graph_module) - - # Since QDQ nodes are stripped, update graph signature again to validate program - edge_program._graph_signature = _get_updated_graph_signature( - edge_program.graph_signature, - edge_program.graph_module, + ep = torch.export.export(module, inputs, dynamic_shapes=dynamic_shapes, strict=True) + # This transformation is primarily intended for the LiftConstantScalarOperands pass + # to avoid creating temporary tensors in the operation builder. + # However, this pass will create a get_attr node, which should be converted + # into a lifted tensor constant by the lift_constant_tensor_pass. + # If placed in the to_edge_transform_passes, it will be executed + # after the lift_constant_tensor_pass, causing the operation builder + # to fail to correctly retrieve the parameter by the get_parameter. + ep = QnnPassManager().transform_for_export_pipeline(ep) + transform_passes = QnnPassManager().get_to_edge_transform_passes( + ep, passes_job=passes_job, dep_table=dep_table ) - edge_program._validate() - return edge_program - - -# Modify the fx graph at very beginning for floating point model -# Aim to reduce registration of scalar at graph_module or program -def _preprocess_module(module: torch.nn.Module, inputs: Tuple[torch.Tensor]): - if isinstance(module, torch.fx.graph_module.GraphModule): - return module - module = torch.export.export(module, inputs, strict=True).module() - module = DecomposeScaledDotProductAttention()(module).graph_module - module = DecomposeLinalgVectorNorm(True)(module).graph_module - module = DecomposeExpM1()(module).graph_module - module = LiftConstantScalarOperands()(module).graph_module - return module + qnn_partitioner = QnnPartitioner( + compiler_specs, + skip_node_id_set=skip_node_id_set, + skip_node_op_set=skip_node_op_set, + ) + edge_program_manager = to_edge_transform_and_lower( + ep, + transform_passes=transform_passes, + partitioner=[qnn_partitioner], + constant_methods=constant_methods, + compile_config=qnn_edge_config(), + ) + return edge_program_manager def capture_program( - module: torch.nn.Module, + module: Union[torch.nn.Module, torch.fx.GraphModule], inputs: Tuple[torch.Tensor], passes_job: OrderedDict = None, dynamic_shapes: Dict = None, ) -> exir.ExirExportedProgram: - module = _preprocess_module(module, inputs) + """ + TODO: Deprecated capture_program with to_edge_transform_and_lower_to_qnn + + Captures and transforms a PyTorch module into an Exir exported program. + + Args: + module (Union[torch.nn.Module, torch.fx.GraphModule]): The PyTorch module or fx.GraphModule to be captured. + inputs (Tuple[torch.Tensor]): The input tensors for the module. + passes_job (OrderedDict, optional): Ordered dictionary of transformation passes. + dynamic_shapes (Dict, optional): Information about dynamic shapes. + + Returns: + exir.ExirExportedProgram: The transformed Exir exported program ready for lowering to QNN backend. + """ + warnings.warn( + "capture_program is deprecated. Use to_edge_transform_and_lower_to_qnn instead.", + DeprecationWarning, + stacklevel=1, + ) ep = torch.export.export(module, inputs, dynamic_shapes=dynamic_shapes, strict=True) - # TODO: Handle stack op. If we want to run annotate_decomposed pass for stack op, we need to make stack op decompose, which means we need to find a method to remove it from skip_decomp table + ep = QnnPassManager().transform_for_export_pipeline(ep) + # TODO: Handle stack op. If we want to run annotate_decomposed pass for stack op, + # we need to make stack op decompose, which means we need to find a method to + # remove it from skip_decomp table decomposed_ep = ep.run_decompositions(get_decomp_table(passes_job)) core_ep = ExirExportedProgram(decomposed_ep, False) - core_ep.transform(TensorI64toI32(edge_program=core_ep)) edge_ep = core_ep.to_edge(qnn_edge_config()) - _transform(edge_ep.exported_program, passes_job) + transform_passes = QnnPassManager().get_to_edge_transform_passes( + edge_ep.exported_program, passes_job=passes_job + ) + edge_ep.transform(*transform_passes) return edge_ep @@ -523,11 +452,9 @@ def _partition_graph_into_submodules(gm, subgm_tag, subgm_cb, ptn): return gm -def _canonicalize_graph_with_lowered_module(gm, subgm_tag, ptn): - from executorch.exir.backend.backend_api import to_backend - +def _canonicalize_graph_with_lowered_module(gm, subgm_tag, compiler_specs): # return lowered program for user to debug - exported_progs = [] + edge_prog_mgrs = [] # partition each submodule which went through convert_pt2e for node in gm.graph.nodes: if node.op == "call_module" and subgm_tag in node.name: @@ -536,14 +463,16 @@ def _canonicalize_graph_with_lowered_module(gm, subgm_tag, ptn): torch.ones(arg.meta["val"].shape, dtype=arg.meta["val"].dtype) for arg in node.args ] - # program meets QNN backend requirement - sub_prog = capture_program(gm.get_submodule(node.name), tuple(subgm_input)) # start lowering with given partitioner - exported_progs.append(to_backend(sub_prog.exported_program, ptn)) + edge_prog_mgrs.append( + to_edge_transform_and_lower_to_qnn( + gm.get_submodule(node.name), tuple(subgm_input), compiler_specs + ) + ) # replace submodule with lowered module gm.set_submodule( node.name, - exported_progs[-1].graph_module, + edge_prog_mgrs[-1].exported_program().graph_module, ) # if node has multiple outputs, getitems will be default generated if all(n.target != operator.getitem for n in node.users): @@ -559,13 +488,13 @@ def _canonicalize_graph_with_lowered_module(gm, subgm_tag, ptn): ) gm.recompile() - return gm, exported_progs + return gm, edge_prog_mgrs def skip_annotation( nn_module: torch.nn.Module, quantizer, - partitioner, + compiler_specs, sample_input: Tuple[torch.Tensor, ...], calibration_cb: Callable[[torch.fx.GraphModule], None], fp_node_id_set: set = None, @@ -655,7 +584,7 @@ def skip_annotation( Args: nn_module (torch.nn.Module): The module to be lowered. quantizer (QnnQuantizer): Instance of QnnQuantizer. - partitioner (QnnPartitioner): Instance of QnnPartitioner. + compiler_specs (List[CompileSpec]): Compiler specs for Qualcomm AI Engine Direct. sample_input ((torch.Tensor, ...)): Sample input tensors for graph exporting. calibration_cb (callable): Callback function for user-defined calibration. fp_node_id_set ({str, ...}): Set of operator names to be left in fp precision. @@ -707,37 +636,26 @@ def prepare_subgm(subgm, subgm_name): node.name, convert_pt2e(graph_module.get_submodule(node.name)) ) # canonicalize graph for lowering again - graph_module, exported_progs = _canonicalize_graph_with_lowered_module( + graph_module, edge_prog_mgrs = _canonicalize_graph_with_lowered_module( gm=graph_module, subgm_tag=subgm_tag, - ptn=partitioner, + compiler_specs=compiler_specs, ) if not fallback_to_cpu: try: - from executorch.exir.backend.partitioner import DelegationSpec - # change HTP compiler spec for hardware to enable fp16 - qnn_option = generate_qnn_executorch_option( - partitioner.compiler_specs_snapshot - ) + qnn_option = generate_qnn_executorch_option(compiler_specs) compile_option = flatbuffer_to_option(qnn_option) htp_options = compile_option.backend_options.htp_options htp_options.precision = QnnExecuTorchHtpPrecision.kHtpFp16 - partitioner.delegation_spec = DelegationSpec( - "QnnBackend", - [ - CompileSpec( - QCOM_QNN_COMPILE_SPEC, option_to_flatbuffer(compile_option) - ) - ], - ) + compiler_specs[0].value = option_to_flatbuffer(compile_option) except: print( "Failed to change HTP compiler spec with 'use_fp16' as True," " skipped operators will fallback to cpu," ) - return graph_module, exported_progs + return graph_module, edge_prog_mgrs # try lowering skipped operator into fp16 capability_partitioner = CapabilityBasedPartitioner( @@ -752,14 +670,14 @@ def prepare_subgm(subgm, subgm_name): subgm_cb=lambda subgm, _: subgm, ptn=capability_partitioner, ) - graph_module, exported_progs_fp = _canonicalize_graph_with_lowered_module( + graph_module, edge_prog_mgrs_fp = _canonicalize_graph_with_lowered_module( gm=graph_module, subgm_tag=subgm_tag, - ptn=partitioner, + compiler_specs=compiler_specs, ) - exported_progs.extend(exported_progs_fp) + edge_prog_mgrs.extend(edge_prog_mgrs_fp) - return graph_module, exported_progs + return graph_module, edge_prog_mgrs def from_context_binary( # noqa: C901 diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index f1c5c3a73f1..86e30f0ac62 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -825,33 +825,32 @@ def _to_edge_and_lower_llama( # noqa: C901 args.use_kv_cache, args.pt2e_quantize, args.num_sharding, args.soc_model ) ) - # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils` - from executorch.backends.qualcomm._passes.annotate_decomposed import ( - AnnotateDecomposed, + # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm._passes` + from executorch.backends.qualcomm._passes import ( + AnnotateStack, + FoldQDQ, + RecomposeRmsNorm, + TagQuantIO, ) - from executorch.backends.qualcomm.utils.constants import QCOM_PASS_ACTIVATE_KEY - from executorch.backends.qualcomm.utils.utils import ( - _transform, + + # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm._passes.qnn_pass_manager` + from executorch.backends.qualcomm._passes.qnn_pass_manager import ( get_capture_program_passes, - tag_quant_io, + get_passes_dependency_for_capture_program, + QnnPassManager, ) - passes_job = get_capture_program_passes() - passes_job[AnnotateDecomposed][QCOM_PASS_ACTIVATE_KEY] = True - _transform(builder_exported_to_edge.edge_manager.exported_program(), passes_job) - - if args.num_sharding > 0: - model_sharding.split_graph( - builder_exported_to_edge.edge_manager.exported_program(), - builder_exported_to_edge.metadata["get_n_layers"], - shares=args.num_sharding, - ) - # pyre-ignore from executorch.backends.qualcomm.quantizer.custom_annotation import ( get_custom_quant_ios_dtype, ) + # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.constants` + from executorch.backends.qualcomm.utils.constants import ( + QCOM_PASS_ACTIVATE_KEY, + QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY, + ) + atten = builder_exported_to_edge.model.layers[0].attention if args.use_qnn_sha: cache_shape = torch.Size( @@ -866,9 +865,28 @@ def _to_edge_and_lower_llama( # noqa: C901 atten.head_dim, ) ) - tag_quant_io( - builder_exported_to_edge.edge_manager.exported_program().graph_module, - partial(get_custom_quant_ios_dtype, cache_shape), + + # TODO: Use to_edge_lower_and_transform for QNN + passes_job = get_capture_program_passes() + dep_table = get_passes_dependency_for_capture_program() + passes_job[AnnotateStack][QCOM_PASS_ACTIVATE_KEY] = True + passes_job[RecomposeRmsNorm][QCOM_PASS_ACTIVATE_KEY] = True + passes_job[TagQuantIO][QCOM_PASS_ACTIVATE_KEY] = True + passes_job[TagQuantIO][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY][ + "get_quant_io_dtype_fn" + ] = partial(get_custom_quant_ios_dtype, cache_shape) + if args.num_sharding > 0: + SplitGraph, setting = model_sharding.get_split_graph_pass( + builder_exported_to_edge.metadata["get_n_layers"], + shares=args.num_sharding, + ) + passes_job[SplitGraph] = setting + dep_table[SplitGraph] = [FoldQDQ] + dep_table[TagQuantIO] = [SplitGraph] + QnnPassManager().transform_for_to_edge_pipeline( + builder_exported_to_edge.edge_manager.exported_program(), + dep_table=dep_table, + passes_job=passes_job, ) logging.info("Lowering model using following partitioner(s): ") diff --git a/examples/qualcomm/executor_runner/qnn_executor_runner.cpp b/examples/qualcomm/executor_runner/qnn_executor_runner.cpp index f3b83700ce2..83478bd8e68 100644 --- a/examples/qualcomm/executor_runner/qnn_executor_runner.cpp +++ b/examples/qualcomm/executor_runner/qnn_executor_runner.cpp @@ -78,16 +78,6 @@ DEFINE_string( "", "Path to file with output shapes specified (used in dynamic shape scenario)."); -DEFINE_string( - input_type_size_path, - "", - "Path to file with input dtype sizes specified."); - -DEFINE_string( - output_type_size_path, - "", - "Path to file with output dtype sizes specified."); - DEFINE_int32( debug_buffer_size, 20000000, // 20MB @@ -368,23 +358,6 @@ int main(int argc, char** argv) { expected_output_shapes.emplace_back(std::move(shape)); } } - // currently only expected_output_type_sizes is used - // TODO: remove following when meta could be correctly propagated - std::vector expected_input_type_sizes, expected_output_type_sizes; - if (!FLAGS_input_type_size_path.empty() && - !FLAGS_output_type_size_path.empty()) { - std::ifstream input_type_size_list(FLAGS_input_type_size_path); - std::ifstream output_type_size_list(FLAGS_output_type_size_path); - std::string type_sizes_content; - while (std::getline(input_type_size_list, type_sizes_content)) { - expected_input_type_sizes.push_back( - std::stoi(split(type_sizes_content, ", ")[0])); - } - while (std::getline(output_type_size_list, type_sizes_content)) { - expected_output_type_sizes.push_back( - std::stoi(split(type_sizes_content, ", ")[0])); - } - } std::string file_path; int inference_index = 0; @@ -495,10 +468,7 @@ int main(int argc, char** argv) { nbytes = std::accumulate( expected_output_shapes[output_index].begin(), expected_output_shapes[output_index].end(), - !expected_output_type_sizes.empty() - ? expected_output_type_sizes[output_index] - : executorch::runtime::elementSize( - output_tensor.scalar_type()), + executorch::runtime::elementSize(output_tensor.scalar_type()), std::multiplies()); } auto output_file_name = FLAGS_output_folder_path + "/output_" + diff --git a/examples/qualcomm/oss_scripts/fastvit.py b/examples/qualcomm/oss_scripts/fastvit.py index f0d2f4c3f09..501ea522acd 100644 --- a/examples/qualcomm/oss_scripts/fastvit.py +++ b/examples/qualcomm/oss_scripts/fastvit.py @@ -13,6 +13,10 @@ from executorch.backends.qualcomm._passes.expand_broadcast_tensor_shape import ( ExpandBroadcastTensorShape, ) + +from executorch.backends.qualcomm._passes.qnn_pass_manager import ( + get_capture_program_passes, +) from executorch.backends.qualcomm.quantizer.annotators import ( QuantizationConfig, QuantizationSpec, @@ -27,10 +31,7 @@ from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.utils.constants import QCOM_PASS_ACTIVATE_KEY -from executorch.backends.qualcomm.utils.utils import ( - convert_linear_to_conv2d, - get_capture_program_passes, -) +from executorch.backends.qualcomm.utils.utils import convert_linear_to_conv2d from executorch.examples.qualcomm.utils import ( build_executorch_binary, get_imagenet_dataset, diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index a999270c15b..c1c3a48fd72 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -15,14 +15,20 @@ import subprocess import sys import time -from collections import OrderedDict from functools import partial from multiprocessing.connection import Client import torch -from executorch.backends.qualcomm._passes.constant_i64_to_i32 import ConstantI64toI32 +from executorch.backends.qualcomm._passes import FoldQDQ, TagQuantIO +from executorch.backends.qualcomm._passes.i64_to_i32 import I64toI32 +from executorch.backends.qualcomm._passes.qnn_pass_manager import ( + get_capture_program_passes, +) +from executorch.backends.qualcomm._passes.utils import ( + get_passes_dependency_for_capture_program, +) -from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner +from executorch.backends.qualcomm.builders.utils import is_graph_output from executorch.backends.qualcomm.quantizer.custom_annotation import ( annotate_linear_16a8w_in_affine_layer, @@ -38,18 +44,18 @@ option_to_flatbuffer, ) from executorch.backends.qualcomm.utils.constants import ( + QCOM_PASS_ACTIVATE_KEY, QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY, - QCOM_QUANTIZED_IO, + QCOM_QUANT_ATTRS_MAP, ) from executorch.backends.qualcomm.utils.utils import ( - capture_program, convert_linear_to_conv2d, generate_composite_llama_program, generate_htp_compiler_spec, generate_multi_graph_program, generate_qnn_executorch_compiler_spec, - get_capture_program_passes, get_soc_to_chipset_map, + to_edge_transform_and_lower_to_qnn, update_spill_fill_size, ) @@ -68,8 +74,6 @@ setup_common_args_and_variables, SimpleADB, ) -from executorch.exir import EdgeCompileConfig, EdgeProgramManager -from executorch.exir.backend.backend_api import to_backend from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass @@ -288,6 +292,9 @@ class SingleLlama: def __init__(self, llama_model, pte_filename) -> None: super().__init__() self.llama_model = llama_model + self.passes_job = get_capture_program_passes() + self.dep_table = get_passes_dependency_for_capture_program() + self.quant_attrs = None self.quant_dtype = None self.llama_meta = self.llama_model.get_metadata() self.has_quant_io = False @@ -301,8 +308,16 @@ def __init__(self, llama_model, pte_filename) -> None: tokens, atten_mask = self.get_example_inputs(use_kv_cache=False) self.inputs = (tokens, atten_mask) self.llama_graph_module = llama_model + self.io_shape = { + # logit output + ( + self.llama_meta["get_max_batch_size"], + self.llama_meta["get_ar_len"], + self.llama_meta["get_vocab_size"], + ), + } - def _tag_ios(self, gm: torch.fx.GraphModule, fixed_point_type): + def _tag_ios(self, node, fixed_point_type): if not self.has_quant_io: return @@ -315,14 +330,6 @@ def _tag_ios(self, gm: torch.fx.GraphModule, fixed_point_type): (self.llama_meta["get_head_dim"], self.llama_meta["get_ar_len"]), (self.llama_meta["get_ar_len"], self.llama_meta["get_head_dim"]), } - io_shape = { - # logit output - ( - self.llama_meta["get_max_batch_size"], - self.llama_meta["get_ar_len"], - self.llama_meta["get_vocab_size"], - ), - } atten_mask_shape = { ( @@ -339,37 +346,35 @@ def _tag_ios(self, gm: torch.fx.GraphModule, fixed_point_type): freq_op = { exir_ops.edge.aten.select.int, } - - for n in gm.graph.nodes: - if n.op == "placeholder": - if ( - len(users := list(n.users)) == 1 - and users[0].meta["val"].size()[-2:] in kv_cache_shape - ): - n.meta[QCOM_QUANTIZED_IO] = fixed_point_type["kv_type"] - elif n.meta["val"].size() in io_shape: - n.meta[QCOM_QUANTIZED_IO] = fixed_point_type["io_type"] - elif n.meta["val"].size() in atten_mask_shape: - n.meta[QCOM_QUANTIZED_IO] = fixed_point_type["io_type"] - elif n.op == "output": - for a in n.args[0]: - if a.meta["val"].size()[-2:] in kv_cache_shape: - a.meta[QCOM_QUANTIZED_IO] = fixed_point_type["kv_type"] - elif a.meta["val"].size() in io_shape: - a.meta[QCOM_QUANTIZED_IO] = fixed_point_type["io_type"] - quant_attrs = a.meta["quant_attrs"] - - # Tag sharding io - if exir_ops.edge.llama.fallback.default in [ - u.target for u in list(n.users.keys()) - ] + [n.target]: - n.meta[QCOM_QUANTIZED_IO] = fixed_point_type["io_type"] - - # Tag select op as quantized tensors for freq_sin and freq_cos. It is caused by sharding - if n.target in freq_op and n.meta["val"].size() in freq_shape: - n.meta[QCOM_QUANTIZED_IO] = fixed_point_type["io_type"] - - return quant_attrs + quant_io_type = None + + if node.op == "placeholder": + if ( + len(users := list(node.users)) == 1 + and users[0].meta["val"].size()[-2:] in kv_cache_shape + ): + quant_io_type = fixed_point_type["kv_type"] + elif node.meta["val"].size() in self.io_shape: + quant_io_type = fixed_point_type["io_type"] + elif node.meta["val"].size() in atten_mask_shape: + quant_io_type = fixed_point_type["io_type"] + if is_graph_output(node): + if node.meta["val"].size()[-2:] in kv_cache_shape: + quant_io_type = fixed_point_type["kv_type"] + elif node.meta["val"].size() in self.io_shape: + quant_io_type = fixed_point_type["io_type"] + + # Tag sharding io + if exir_ops.edge.llama.fallback.default in [ + u.target for u in list(node.users.keys()) + ] + [node.target]: + quant_io_type = fixed_point_type["io_type"] + + # Tag select op as quantized tensors for freq_sin and freq_cos. It is caused by sharding + if node.target in freq_op and node.meta["val"].size() in freq_shape: + quant_io_type = fixed_point_type["io_type"] + + return quant_io_type def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()): self.quant_dtype = quant_dtype @@ -407,11 +412,9 @@ def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()): def lowering_modules( self, work_space, - fixed_point_type, use_fp16=False, soc_model=QcomChipset.SM8650, num_sharding=1, - passes_job=OrderedDict(), shared_buffer=False, verbose=False, ): @@ -437,32 +440,22 @@ def lowering_modules( shared_buffer=shared_buffer, ) skip_node_op_set = {"llama.fallback.default"} - partitioner = QnnPartitioner( - compiler_specs, skip_node_op_set=skip_node_op_set - ) - edge_prog = capture_program( + edge_prog_mgr = to_edge_transform_and_lower_to_qnn( self.llama_graph_module, self.inputs, - passes_job, + compiler_specs, + constant_methods=self.llama_meta, + dep_table=self.dep_table, + passes_job=self.passes_job, + skip_node_op_set=skip_node_op_set, ) - if num_sharding > 1: - model_sharding.split_graph( - edge_prog.exported_program, - self.llama_meta["get_n_layers"], - shares=num_sharding, - ) + for n in edge_prog_mgr.exported_program().graph.nodes: + if n.op == "output": + for node, output_encoding in n.meta[QCOM_QUANT_ATTRS_MAP].items(): + if node.meta["val"].size() in self.io_shape: + self.quant_attrs = output_encoding - self.quant_attrs = self._tag_ios( - edge_prog.exported_program.graph_module, - fixed_point_type=fixed_point_type, - ) - edge_prog_mgr = EdgeProgramManager( - edge_programs={"forward": edge_prog.exported_program}, - constant_methods=self.llama_meta, - compile_config=EdgeCompileConfig(_check_ir_validity=False), - ) - edge_prog_mgr = edge_prog_mgr.to_backend(partitioner) if num_sharding > 1: update_spill_fill_size(edge_prog_mgr.exported_program()) @@ -595,7 +588,6 @@ def permute(w, heads): assert args.tokenizer_model is not None, "Need tokenizer model for calibration" - passes_job = get_capture_program_passes() if args.dtype_override is not None: dtype_override = DType[args.dtype_override] for i in range(len(llama_instance_list)): @@ -608,13 +600,14 @@ def permute(w, heads): llama_instance_list[i] = get_quant_embedding_transform(args)( llama_instance_list[i] ) - passes_job[ConstantI64toI32][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY][ - "skip_node" - ] = {"tokens"} llama_instance_list[i] = convert_linear_to_conv2d(llama_instance_list[i]) llama_instance_list[i] = SingleLlama( llama_instance_list[i].eval(), pte_filename ) + if args.embedding_quantize: + llama_instance_list[i].passes_job[I64toI32][ + QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY + ]["skip_node"] = {"tokens"} if args.ptq: start_quantize_ts = time.time() @@ -623,44 +616,54 @@ def permute(w, heads): custom_annotations = custom_annotations + ( annotate_linear_16a8w_in_affine_layer, ) - if args.ptq != None: - kv_quant_attrs = {} - for i, llama_instance in enumerate(llama_instance_list): - llama_instance.quantize( - quant_dtype=quant_dtype, - args=args, - tokenizer=tokenizer, - custom_annotations=custom_annotations, + kv_quant_attrs = {} + for i, llama_instance in enumerate(llama_instance_list): + llama_instance.quantize( + quant_dtype=quant_dtype, + args=args, + tokenizer=tokenizer, + custom_annotations=custom_annotations, + ) + # If hybrid mode, we store kv output quant_attrs and apply to prefill output quant_attrs later + if i == 0 and args.model_mode == "hybrid": + output_indices = 0 + for node in llama_instance.llama_graph_module.graph.nodes: + if node.op == "output": + for output in node.args[0]: + kv_quant_attrs[output_indices] = output.args[1:] + output_indices += 1 + break + custom_annotations = custom_annotations + ( + partial( + annotate_prefill_kv_output, + kv_quant_attrs=kv_quant_attrs, + ), ) - # If hybrid mode, we store kv output quant_attrs and apply to prefill output quant_attrs later - if i == 0 and args.model_mode == "hybrid": - output_indices = 0 - for node in llama_instance.llama_graph_module.graph.nodes: - if node.op == "output": - for output in node.args[0]: - kv_quant_attrs[output_indices] = output.args[1:] - output_indices += 1 - break - custom_annotations = custom_annotations + ( - partial( - annotate_prefill_kv_output, - kv_quant_attrs=kv_quant_attrs, - ), - ) + llama_instance.passes_job[TagQuantIO][QCOM_PASS_ACTIVATE_KEY] = True + llama_instance.passes_job[TagQuantIO][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY][ + "get_quant_io_dtype_fn" + ] = partial(llama_instance._tag_ios, fixed_point_type=fixed_point_type) end_quantize_ts = time.time() logging.info(f"Time for quantizing: {end_quantize_ts - start_quantize_ts}") start_lowering_ts = time.time() quant_attrs = None + if args.num_sharding > 1: + for llama_instance in llama_instance_list: + SplitGraph, setting = model_sharding.get_split_graph_pass( + llama_instance.llama_meta["get_n_layers"], + shares=args.num_sharding, + ) + llama_instance.passes_job[SplitGraph] = setting + llama_instance.dep_table[SplitGraph] = [FoldQDQ] + llama_instance.dep_table[TagQuantIO] = [SplitGraph] if args.model_mode in ["kv"]: llama_instance_list[0].lowering_modules( args.artifact, - fixed_point_type, use_fp16=use_fp16, soc_model=get_soc_to_chipset_map()[args.model], num_sharding=args.num_sharding, - passes_job=passes_job, shared_buffer=args.shared_buffer, ) quant_attrs = llama_instance_list[0].get_quant_attrs() @@ -668,30 +671,6 @@ def permute(w, heads): sample_inputs_list = [ llama_instace.inputs for llama_instace in llama_instance_list ] - edge_progs = [ - capture_program( - llama_instance.llama_graph_module, - sample_input, - passes_job=passes_job, - ) - for llama_instance, sample_input in zip( - llama_instance_list, sample_inputs_list - ) - ] - - if args.num_sharding > 1: - for i in range(len(llama_instance_list)): - model_sharding.split_graph( - edge_progs[i].exported_program, - llama_instance_list[i].llama_meta["get_n_layers"], - shares=args.num_sharding, - ) - - for i in range(len(llama_instance_list)): - quant_attrs = llama_instance_list[i]._tag_ios( - edge_progs[i].exported_program.graph_module, - fixed_point_type, - ) backend_options = generate_htp_compiler_spec( use_fp16=use_fp16, use_multi_contexts=args.num_sharding > 1 ) @@ -708,15 +687,29 @@ def permute(w, heads): for graph_name in graph_names ] skip_node_op_set = {"llama.fallback.default"} - exported_programs = [ - to_backend( - edge_prog.exported_program, - QnnPartitioner(compiler_specs[i], skip_node_op_set=skip_node_op_set), + edge_prog_mgrs = [ + to_edge_transform_and_lower_to_qnn( + llama_instance.llama_graph_module, + sample_input, + compile_spec, + dep_table=llama_instance.dep_table, + passes_job=llama_instance.passes_job, + skip_node_op_set=skip_node_op_set, + ) + for llama_instance, sample_input, compile_spec in zip( + llama_instance_list, sample_inputs_list, compiler_specs ) - for i, edge_prog in enumerate(edge_progs) ] + for n in edge_prog_mgrs[0].exported_program().graph.nodes: + if n.op == "output": + for node, output_encoding in n.meta[QCOM_QUANT_ATTRS_MAP].items(): + if node.meta["val"].size() in llama_instance_list[0].io_shape: + quant_attrs = output_encoding + if args.num_sharding > 1: - max_sf_size = update_spill_fill_size(exported_programs) + max_sf_size = update_spill_fill_size( + [edge_prog_mgr.exported_program() for edge_prog_mgr in edge_prog_mgrs] + ) qnn_executorch_options = flatbuffer_to_option(compiler_specs[0][0].value) qnn_executorch_options.backend_options.htp_options.max_sf_buf_size = ( max_sf_size @@ -724,8 +717,8 @@ def permute(w, heads): compiler_specs[0][0].value = option_to_flatbuffer(qnn_executorch_options) if args.verbose: - for exported_program in exported_programs: - print_delegation_info(exported_program.graph_module) + for edge_prog_mgr in edge_prog_mgrs: + print_delegation_info(edge_prog_mgr.exported_program().graph_module) executorch_config = ExecutorchBackendConfig( # For shared buffer, user must pass the memory address @@ -745,8 +738,8 @@ def permute(w, heads): call_delegate_node_name_dict = {name: [] for name in graph_names} outputs_dict = {name: [] for name in graph_names} input_nodes_dict = {name: [] for name in graph_names} - for prog, graph_name in zip(exported_programs, graph_names): - for node in prog.graph_module.graph.nodes: + for prog, graph_name in zip(edge_prog_mgrs, graph_names): + for node in prog.exported_program().graph_module.graph.nodes: if ( node.op == "call_function" and "executorch_call_delegate" in node.name @@ -777,13 +770,15 @@ def permute(w, heads): outputs_dict[graph_name].append((arg.args[0].name, arg.args[1])) for num in range(args.num_sharding - 1, -1, -1): processed_bytes = [] - for prog, graph_name in zip(exported_programs, graph_names): + for prog, graph_name in zip(edge_prog_mgrs, graph_names): processed_bytes.append( - getattr(prog.graph_module, f"lowered_module_{num}").processed_bytes + getattr( + prog.exported_program().graph_module, f"lowered_module_{num}" + ).processed_bytes ) call_delegate_node = [ list(node.users.keys())[0] - for node in prog.graph_module.graph.nodes + for node in prog.exported_program().graph_module.graph.nodes if node.op == "get_attr" and node.name == f"lowered_module_{num}" ] input_nodes_dict[graph_name] = [ diff --git a/examples/qualcomm/scripts/export_example.py b/examples/qualcomm/scripts/export_example.py index 23f1f59a7dd..acf8a9ab468 100644 --- a/examples/qualcomm/scripts/export_example.py +++ b/examples/qualcomm/scripts/export_example.py @@ -3,18 +3,16 @@ import copy import torch -from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset from executorch.backends.qualcomm.utils.utils import ( - capture_program, generate_htp_compiler_spec, generate_qnn_executorch_compiler_spec, + to_edge_transform_and_lower_to_qnn, ) from executorch.devtools import generate_etrecord from executorch.examples.models import MODEL_NAME_TO_MODEL from executorch.examples.models.model_factory import EagerModelFactory -from executorch.exir.backend.backend_api import to_backend, validation_disabled from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.extension.export_util.utils import save_pte_program @@ -68,27 +66,20 @@ def main() -> None: # Get the quantized model m = convert_pt2e(m) - # Capture program for edge IR - edge_program = capture_program(m, example_inputs) - - # this is needed for the ETRecord as lowering modifies the graph in-place - edge_copy = copy.deepcopy(edge_program) - - # Delegate to QNN backend + # Capture program for edge IR and delegate to QNN backend backend_options = generate_htp_compiler_spec( use_fp16=False, ) - qnn_partitioner = QnnPartitioner( - generate_qnn_executorch_compiler_spec( - soc_model=QcomChipset.SM8550, - backend_options=backend_options, - ) + compile_spec = generate_qnn_executorch_compiler_spec( + soc_model=QcomChipset.SM8550, + backend_options=backend_options, ) - with validation_disabled(): - delegated_program = edge_program - delegated_program.exported_program = to_backend( - edge_program.exported_program, qnn_partitioner - ) + delegated_program = to_edge_transform_and_lower_to_qnn( + m, example_inputs, compile_spec + ) + + # this is needed for the ETRecord as lowering modifies the graph in-place + edge_copy = copy.deepcopy(delegated_program) executorch_program = delegated_program.to_executorch( config=ExecutorchBackendConfig(extract_delegate_segments=False) diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index 29f5e96dcd0..2b2f32b037b 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -19,17 +19,14 @@ import numpy as np import torch -from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer, QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset from executorch.backends.qualcomm.utils.utils import ( - capture_program, generate_htp_compiler_spec, generate_qnn_executorch_compiler_spec, get_soc_to_arch_map, + to_edge_transform_and_lower_to_qnn, ) -from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge -from executorch.exir.backend.backend_api import to_backend from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass from torch.ao.quantization.observer import MovingAverageMinMaxObserver @@ -76,8 +73,6 @@ def __init__( runner="examples/qualcomm/executor_runner/qnn_executor_runner", expected_input_shape=None, expected_output_shape=None, - expected_input_dtype=None, - expected_output_dtype=None, ): self.qnn_sdk = qnn_sdk self.build_path = build_path @@ -97,8 +92,6 @@ def __init__( self.runner = runner self.expected_input_shape = expected_input_shape self.expected_output_shape = expected_output_shape - self.expected_input_dtype = expected_input_dtype - self.expected_output_dtype = expected_output_dtype self.extra_cmds = "" def _adb(self, cmd): @@ -162,18 +155,6 @@ def push(self, inputs=None, input_list=None, files=None): self._adb(["push", f"{tmp_dir}/{name}.txt", self.workspace]) self.extra_cmds += f" --{name}_path {name}.txt" - if self.expected_input_dtype and self.expected_output_dtype: - dtype_info = { - "input_type_size": self.expected_input_dtype, - "output_type_size": self.expected_output_dtype, - } - for name, dtypes in dtype_info.items(): - with open(f"{tmp_dir}/{name}.txt", "w") as f: - for dtype in dtypes: - f.write(f"{torch.tensor([], dtype=dtype).element_size()}\n") - self._adb(["push", f"{tmp_dir}/{name}.txt", self.workspace]) - self.extra_cmds += f" --{name}_path {name}.txt" - # custom files if files is not None: for file_name in files: @@ -326,6 +307,15 @@ def build_executorch_binary( Returns: None: The function writes the output to a specified .pte file. """ + backend_options = generate_htp_compiler_spec( + use_fp16=False if quant_dtype else True + ) + compile_spec = generate_qnn_executorch_compiler_spec( + soc_model=getattr(QcomChipset, soc_model), + backend_options=backend_options, + shared_buffer=shared_buffer, + dump_intermediate_outputs=dump_intermediate_outputs, + ) if quant_dtype is not None: captured_model = torch.export.export(model, inputs, strict=True).module() if qat_training_data: @@ -342,23 +332,25 @@ def build_executorch_binary( annotated_model = ptq_calibrate(captured_model, quantizer, dataset) quantized_model = convert_pt2e(annotated_model) - edge_prog = capture_program(quantized_model, inputs, passes_job) + edge_prog_mgr = to_edge_transform_and_lower_to_qnn( + quantized_model, + inputs, + compile_spec, + constant_methods=metadata, + passes_job=passes_job, + skip_node_id_set=skip_node_id_set, + skip_node_op_set=skip_node_op_set, + ) else: - edge_prog = capture_program(model, inputs, passes_job) - - backend_options = generate_htp_compiler_spec( - use_fp16=False if quant_dtype else True - ) - qnn_partitioner = QnnPartitioner( - generate_qnn_executorch_compiler_spec( - soc_model=getattr(QcomChipset, soc_model), - backend_options=backend_options, - shared_buffer=shared_buffer, - dump_intermediate_outputs=dump_intermediate_outputs, - ), - skip_node_id_set, - skip_node_op_set, - ) + edge_prog_mgr = to_edge_transform_and_lower_to_qnn( + model, + inputs, + compile_spec, + constant_methods=metadata, + passes_job=passes_job, + skip_node_id_set=skip_node_id_set, + skip_node_op_set=skip_node_op_set, + ) executorch_config = ExecutorchBackendConfig( # For shared buffer, user must pass the memory address @@ -370,24 +362,10 @@ def build_executorch_binary( alloc_graph_output=not shared_buffer, ), ) - - if metadata is None: - exported_program = to_backend(edge_prog.exported_program, qnn_partitioner) - exported_program.graph_module.graph.print_tabular() - exec_prog = to_edge(exported_program).to_executorch(config=executorch_config) - with open(f"{file_name}.pte", "wb") as file: - file.write(exec_prog.buffer) - else: - edge_prog_mgr = EdgeProgramManager( - edge_programs={"forward": edge_prog.exported_program}, - constant_methods=metadata, - compile_config=EdgeCompileConfig(_check_ir_validity=False), - ) - - edge_prog_mgr = edge_prog_mgr.to_backend(qnn_partitioner) - exec_prog_mgr = edge_prog_mgr.to_executorch(config=executorch_config) - with open(f"{file_name}.pte", "wb") as file: - file.write(exec_prog_mgr.buffer) + pte_name = f"{file_name}.pte" + exec_prog_mgr = edge_prog_mgr.to_executorch(config=executorch_config) + with open(pte_name, "wb") as file: + exec_prog_mgr.write_to_file(file) def make_output_dir(path: str): diff --git a/extension/llm/custom_ops/model_sharding.py b/extension/llm/custom_ops/model_sharding.py index 244c036c9b7..df87274a115 100644 --- a/extension/llm/custom_ops/model_sharding.py +++ b/extension/llm/custom_ops/model_sharding.py @@ -9,7 +9,11 @@ import torch -from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS +from executorch.backends.qualcomm.utils.constants import ( + QCOM_PASS_ACTIVATE_KEY, + QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY, + QCOM_QUANT_ATTRS, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult from torch.export.exported_program import ExportedProgram @@ -103,3 +107,11 @@ def split_graph(edge_program: ExportedProgram, num_layers: int, shares: int): graph_module = edge_program.graph_module shard_layers = list(range(0, num_layers, int(num_layers / shares))) return SplitGraph(shard_layers)(graph_module) + + +def get_split_graph_pass(num_layers: int, shares: int): + shard_layers = list(range(0, num_layers, int(num_layers / shares))) + return SplitGraph, { + QCOM_PASS_ACTIVATE_KEY: True, + QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY: {"shard_layers": shard_layers}, + }