diff --git a/backends/arm/_passes/convert_split_to_slice.py b/backends/arm/_passes/convert_split_to_slice.py index ed2dcd4008..06104811be 100644 --- a/backends/arm/_passes/convert_split_to_slice.py +++ b/backends/arm/_passes/convert_split_to_slice.py @@ -1,5 +1,4 @@ -# Copyright 2024 Arm Limited and/or its affiliates. -# All rights reserved. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -7,8 +6,10 @@ # pyre-unsafe import torch.fx -from executorch.backends.arm._passes.arm_pass_utils import create_node -from executorch.backends.arm.tosa_mapping import extract_tensor_meta +from executorch.backends.arm._passes.arm_pass_utils import ( + create_node, + get_first_fake_tensor, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -34,7 +35,7 @@ def call(self, graph_module: torch.fx.GraphModule): split_node = node input_node = split_node.all_input_nodes[0] output_nodes = split_node.users.copy() - _, shape, _ = extract_tensor_meta(input_node.meta) + shape = get_first_fake_tensor(input_node).shape rank = len(shape) split_lengths = split_node.args[1] dim = split_node.args[2] if len(split_node.args) > 2 else 0 diff --git a/backends/arm/operator_support/slice_copy_support.py b/backends/arm/operator_support/slice_copy_support.py index ea18c40814..3c0c69969c 100644 --- a/backends/arm/operator_support/slice_copy_support.py +++ b/backends/arm/operator_support/slice_copy_support.py @@ -12,7 +12,6 @@ SupportedTOSAOperatorCheck, ) from executorch.backends.arm.tosa_specification import TosaSpecification -from executorch.backends.arm.tosa_utils import getNodeArgs from executorch.exir.dialects._ops import ops as exir_ops logger = logging.getLogger(__name__) @@ -33,8 +32,8 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> if tosa_spec not in self.tosa_specs: return False - inputs = getNodeArgs(node) - if len(inputs) == 5 and (step := inputs[4].number) != 1: + args = node.args + if len(args) == 5 and (step := args[4]) != 1: logging.warning(f"{node.target} with step size of {step} not supported.") return False return True diff --git a/backends/arm/operators/op_rescale.py b/backends/arm/operators/op_rescale.py index 33bf6b8fb6..3c9abe1ba5 100644 --- a/backends/arm/operators/op_rescale.py +++ b/backends/arm/operators/op_rescale.py @@ -13,7 +13,7 @@ NodeVisitor, register_node_visitor, ) -from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg +from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_quant_utils import create_const_ops_for_rescale from executorch.backends.arm.tosa_specification import TosaSpecification @@ -35,15 +35,15 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - input_dtype = inputs[0].dtype + input_dtype = node.all_input_nodes[0].meta["val"].dtype output_dtype = cast(torch.dtype, node.args[1]) scale = cast(float, node.args[2]) input_zp = cast(int, node.args[3]) output_zp = cast(int, node.args[4]) - if input_dtype != map_dtype(torch.int8) and input_zp != 0: + if input_dtype != torch.int8 and input_zp != 0: raise ValueError( - f"If input dtype is not int8, input_zp must be 0. Got input_dtype{ts.DTypeNames[input_dtype]}, {input_zp=}" + f"If input dtype is not int8, input_zp must be 0. Got input_dtype{input_dtype=}, {input_zp=}" ) if output_dtype != torch.int8 and output_zp != 0: raise ValueError( @@ -91,15 +91,15 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore from tosa.RoundingMode import RoundingMode # type: ignore - input_dtype = inputs[0].dtype + input_dtype = node.all_input_nodes[0].meta["val"].dtype output_dtype = cast(torch.dtype, node.args[1]) scale = cast(float, node.args[2]) input_zp = cast(int, node.args[3]) output_zp = cast(int, node.args[4]) - if input_dtype != map_dtype(torch.int8) and input_zp != 0: + if input_dtype != torch.int8 and input_zp != 0: raise ValueError( - f"If input dtype is not int8, input_zp must be 0. Got input_dtype{ts.DTypeNames[input_dtype]}, {input_zp=}" + f"If input dtype is not int8, input_zp must be 0. Got input_dtype{input_dtype=}, {input_zp=}" ) if output_dtype != torch.int8 and output_zp != 0: raise ValueError( diff --git a/backends/arm/process_node.py b/backends/arm/process_node.py index 6692b75c89..0994079c4a 100644 --- a/backends/arm/process_node.py +++ b/backends/arm/process_node.py @@ -36,11 +36,11 @@ def process_call_function( tosa_spec: TosaSpecification, ): # Unpack arguments and convert - inputs = getNodeArgs(node) + inputs = getNodeArgs(node, tosa_spec) # Convert output (this node itself) try: - output = TosaArg(node) + output = TosaArg(node, tosa_spec) except ValueError as e: raise ValueError( f"Failed processing call_function: {node.name}. " @@ -78,7 +78,7 @@ def process_inputs( f"Expected dim_order: {tuple(range(meta.dim()))}, but got: {meta.dim_order()} for node {node.name}" ) try: - tosa_arg = TosaArg(node) + tosa_arg = TosaArg(node, tosa_spec) except ValueError as e: raise ValueError( f"Failed processing input placeholder: {node.name}. " @@ -112,7 +112,7 @@ def process_inputs_to_parameters( ): """Serialize bias and non-quantized weights""" try: - tosa_arg = TosaArg(node) + tosa_arg = TosaArg(node, tosa_spec) except ValueError as e: raise ValueError( f"Failed processing parameter placeholder: {node.name}. " @@ -137,10 +137,11 @@ def process_inputs_to_buffers( node: torch.fx.Node, tosa_graph: Any, edge_program: ExportedProgram, + tosa_spec: TosaSpecification, ): """Serialize quantized weights""" try: - tosa_arg = TosaArg(node) + tosa_arg = TosaArg(node, tosa_spec) except ValueError as e: raise ValueError( f"Failed processing buffer placeholder: {node.name}. " @@ -165,9 +166,10 @@ def process_inputs_to_lifted_tensor_constants( node: torch.fx.Node, tosa_graph: Any, edge_program: ExportedProgram, + tosa_spec: TosaSpecification, ): try: - tosa_arg = TosaArg(node) + tosa_arg = TosaArg(node, tosa_spec) except ValueError as e: raise ValueError( f"Failed processing lifted tensor constant placeholder: {node.name}. " @@ -196,9 +198,11 @@ def process_placeholder( elif is_param(edge_program, node): process_inputs_to_parameters(node, tosa_graph, edge_program, tosa_spec) elif is_buffer(edge_program, node): - process_inputs_to_buffers(node, tosa_graph, edge_program) + process_inputs_to_buffers(node, tosa_graph, edge_program, tosa_spec) elif is_lifted_tensor_constant(edge_program, node): - process_inputs_to_lifted_tensor_constants(node, tosa_graph, edge_program) + process_inputs_to_lifted_tensor_constants( + node, tosa_graph, edge_program, tosa_spec + ) elif node.name in edge_program.graph_signature.inputs_to_lifted_custom_objs: raise NotImplementedError( "Placeholder is of type 'lifted custom object' which is not supported." diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index d7434e7149..9e7e7450b7 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -48,6 +48,7 @@ ) from executorch.backends.arm.tosa_mapping import extract_tensor_meta from executorch.backends.arm.tosa_partitioner import TOSAPartitioner +from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.xnnpack.test.tester import Tester from executorch.devtools.backend_debug import get_delegation_info @@ -564,7 +565,10 @@ def dump_dtype_distribution( ) graph = self.get_graph(self.cur) - dtype_dist_placeholders, dtype_dirst_tensors = _get_dtype_distribution(graph) + tosa_spec = get_tosa_spec(self.compile_spec) + dtype_dist_placeholders, dtype_dirst_tensors = _get_dtype_distribution( + graph, tosa_spec + ) all_dtypes = set(dtype_dist_placeholders.keys()) | set( dtype_dirst_tensors.keys() ) @@ -659,7 +663,9 @@ def _compare_outputs( raise e -def _get_dtype_distribution(graph: Graph) -> tuple[dict, dict]: +def _get_dtype_distribution( + graph: Graph, tosa_spec: TosaSpecification +) -> tuple[dict, dict]: """Counts the occurences of placeholder and call_function dtypes in a graph. The result is a tuple of Counters (placeholder_distribution, call_function_distribution) """ @@ -670,7 +676,7 @@ def _get_dtype_distribution(graph: Graph) -> tuple[dict, dict]: placeholder_dtypes.append(str(node.meta["val"].dtype)) if node.op == "call_function": if "val" in node.meta: - dtype, _, _ = extract_tensor_meta(node.meta) + dtype, _, _ = extract_tensor_meta(node.meta, tosa_spec) call_function_dtypes.append(ts.DTypeNames[dtype]) return Counter(placeholder_dtypes), Counter(call_function_dtypes) diff --git a/backends/arm/tosa_mapping.py b/backends/arm/tosa_mapping.py index 26441cbfb0..18abe1a754 100644 --- a/backends/arm/tosa_mapping.py +++ b/backends/arm/tosa_mapping.py @@ -11,12 +11,14 @@ # the standardised TOSA representation. # -from typing import Any, Sequence +from typing import Any, Optional, Sequence import torch - -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - +from executorch.backends.arm.tosa_specification import ( + Tosa_0_80, + Tosa_1_00, + TosaSpecification, +) UNSUPPORTED_DTYPES = ( torch.float64, @@ -30,33 +32,39 @@ torch.long, ) -DTYPE_MAP = { - torch.float32: ts.DType.FP32, - torch.float: ts.DType.FP32, - torch.float16: ts.DType.FP16, - torch.half: ts.DType.FP16, - torch.bfloat16: ts.DType.BF16, - torch.int8: ts.DType.INT8, - torch.int16: ts.DType.INT16, - torch.short: ts.DType.INT16, - torch.int32: ts.DType.INT32, - torch.int: ts.DType.INT32, - torch.bool: ts.DType.BOOL, -} - - -def map_dtype(data_type: torch.dtype) -> ts.DType: + +def map_dtype(data_type: torch.dtype, tosa_spec: TosaSpecification) -> Any: if data_type in UNSUPPORTED_DTYPES: raise ValueError(f"Unsupported type: {data_type}") - if data_type not in DTYPE_MAP: + if isinstance(tosa_spec, Tosa_0_80): + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + elif isinstance(tosa_spec, Tosa_1_00): + import serializer.tosa_serializer as ts # type: ignore + else: + raise RuntimeError(f"Unsupported tosa_spec: {tosa_spec}") + + dtype_map = { + torch.float32: ts.DType.FP32, + torch.float: ts.DType.FP32, + torch.float16: ts.DType.FP16, + torch.half: ts.DType.FP16, + torch.bfloat16: ts.DType.BF16, + torch.int8: ts.DType.INT8, + torch.int16: ts.DType.INT16, + torch.short: ts.DType.INT16, + torch.int32: ts.DType.INT32, + torch.int: ts.DType.INT32, + torch.bool: ts.DType.BOOL, + } + if data_type not in dtype_map: raise ValueError(f"Unknown type: {data_type}") - return DTYPE_MAP[data_type] + return dtype_map[data_type] # Returns the shape and type of a node # TODO: other types, can be # SymInt, FakeTensor, a List[Union[FakeTensor, SymInt]], or None -def extract_tensor_meta(meta): +def extract_tensor_meta(meta, tosa_spec: TosaSpecification): assert meta.get("val") is not None val = meta["val"] if type(val) is tuple: @@ -67,7 +75,7 @@ def extract_tensor_meta(meta): raise ValueError( f"Expected first value in node.meta['val'] to be FakeTensor, got {val.__class__}" ) - dtype = map_dtype(val.dtype) + dtype = map_dtype(val.dtype, tosa_spec) shape = tuple(val.size()) if meta.get("tosa_dim_order") is not None: @@ -81,7 +89,9 @@ def extract_tensor_meta(meta): class TosaArg: def __process_node(self, argument: torch.fx.Node): self.name: str = argument.name - self.dtype, self.shape, self.dim_order = extract_tensor_meta(argument.meta) + self.dtype, self.shape, self.dim_order = extract_tensor_meta( + argument.meta, self.tosa_spec + ) def __process_list(self, argument): self.special: list = list(argument) @@ -89,9 +99,18 @@ def __process_list(self, argument): def __process_number(self, argument: float | int): self.number: float | int = argument - def __init__(self, argument: Any) -> None: + def __init__( + self, argument: Any, tosa_spec: Optional[TosaSpecification] = None + ) -> None: if argument is None: return + if tosa_spec is None: + raise ValueError("tosa_spec is None") + elif not isinstance(tosa_spec, TosaSpecification): + raise ValueError( + f"Expected tosa_spec to be a TosaSpecification, but got {tosa_spec}" + ) + self.tosa_spec = tosa_spec if isinstance(argument, torch.fx.Node): self.__process_node(argument) @@ -116,6 +135,12 @@ def __repr__(self): if self.name is not None: attrs.append(f"name={self.name!r}") if self.dtype is not None: + if isinstance(self.tosa_spec, Tosa_0_80): + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + elif isinstance(self.tosa_spec, Tosa_1_00): + import serializer.tosa_serializer as ts # type: ignore + else: + raise RuntimeError(f"Unsupported tosa_spec: {self.tosa_spec}") attrs.append(f"dtype={ts.DTypeNames[self.dtype]}") if self.shape is not None: attrs.append(f"shape={self.shape!r}") @@ -125,4 +150,6 @@ def __repr__(self): attrs.append(f"special={self.special!r}") if hasattr(self, "number") and self.number is not None: attrs.append(f"number={self.number!r}") + if hasattr(self, "tosa_spec") and self.tosa_spec is not None: + attrs.append(f"tosa_spec={self.tosa_spec!r}") return f"{self.__class__.__name__}({', '.join(attrs)})" diff --git a/backends/arm/tosa_utils.py b/backends/arm/tosa_utils.py index 4d0f33003b..c5546647b1 100644 --- a/backends/arm/tosa_utils.py +++ b/backends/arm/tosa_utils.py @@ -10,10 +10,11 @@ from typing import Any, Optional, Tuple import torch - import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_specification import TosaSpecification + from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.print_program import inspect_node from torch.fx import Node @@ -93,9 +94,9 @@ def dbg_fail( dbg_node(node, graph_module) -def getNodeArgs(node: Node) -> list[TosaArg]: +def getNodeArgs(node: Node, tosa_spec: TosaSpecification) -> list[TosaArg]: try: - return [TosaArg(arg) for arg in node.args] + return [TosaArg(arg, tosa_spec) for arg in node.args] except ValueError as e: raise ValueError(f"Failed processing args to op:\n{node}") from e @@ -153,14 +154,14 @@ def get_new_shape(l_rank_in, h_rank_in): return reshaped, input2 -def is_consumer_node_depthwise_conv2d(node): +def is_consumer_node_depthwise_conv2d(node: Node): consumer_node = list(node.users)[0] if consumer_node.target == exir_ops.edge.aten.convolution.default: - inputs = getNodeArgs(consumer_node) - group = inputs[-1] - in_channels = inputs[0].shape[1] - out_channels = inputs[1].shape[0] - if (in_channels == group.number) and (out_channels % in_channels) == 0: + consumer_node_inputs = consumer_node.all_input_nodes + groups = consumer_node.args[-1] + in_channels = consumer_node_inputs[0].meta["val"].shape[1] + out_channels = consumer_node_inputs[1].meta["val"].shape[0] + if (in_channels == groups) and (out_channels % in_channels) == 0: return True return False