Skip to content

Arm backend: Refactor TosaArg to use TosaSpecification #10655

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions backends/arm/_passes/convert_split_to_slice.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import torch.fx
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.backends.arm.tosa_mapping import extract_tensor_meta
from executorch.backends.arm._passes.arm_pass_utils import (
create_node,
get_first_fake_tensor,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult

Expand All @@ -34,7 +35,7 @@ def call(self, graph_module: torch.fx.GraphModule):
split_node = node
input_node = split_node.all_input_nodes[0]
output_nodes = split_node.users.copy()
_, shape, _ = extract_tensor_meta(input_node.meta)
shape = get_first_fake_tensor(input_node).shape
rank = len(shape)
split_lengths = split_node.args[1]
dim = split_node.args[2] if len(split_node.args) > 2 else 0
Expand Down
5 changes: 2 additions & 3 deletions backends/arm/operator_support/slice_copy_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
SupportedTOSAOperatorCheck,
)
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.backends.arm.tosa_utils import getNodeArgs
from executorch.exir.dialects._ops import ops as exir_ops

logger = logging.getLogger(__name__)
Expand All @@ -33,8 +32,8 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification) ->
if tosa_spec not in self.tosa_specs:
return False

inputs = getNodeArgs(node)
if len(inputs) == 5 and (step := inputs[4].number) != 1:
args = node.args
if len(args) == 5 and (step := args[4]) != 1:
logging.warning(f"{node.target} with step size of {step} not supported.")
return False
return True
8 changes: 4 additions & 4 deletions backends/arm/operators/op_rescale.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg
from executorch.backends.arm.tosa_mapping import TosaArg
from torch.fx import Node


Expand All @@ -32,15 +32,15 @@ def define_node(
output: TosaArg,
) -> None:

input_dtype = inputs[0].dtype
input_dtype = node.all_input_nodes[0].meta["val"].dtype
output_dtype = cast(torch.dtype, node.args[1])
scale = cast(float, node.args[2])
input_zp = cast(int, node.args[3])
output_zp = cast(int, node.args[4])

if input_dtype != map_dtype(torch.int8) and input_zp != 0:
if input_dtype != torch.int8 and input_zp != 0:
raise ValueError(
f"If input dtype is not int8, input_zp must be 0. Got input_dtype{ts.DTypeNames[input_dtype]}, {input_zp=}"
f"If input dtype is not int8, input_zp must be 0. Got input_dtype{input_dtype=}, {input_zp=}"
)
if output_dtype != torch.int8 and output_zp != 0:
raise ValueError(
Expand Down
20 changes: 12 additions & 8 deletions backends/arm/process_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ def process_call_function(
tosa_spec: TosaSpecification,
):
# Unpack arguments and convert
inputs = getNodeArgs(node)
inputs = getNodeArgs(node, tosa_spec)

# Convert output (this node itself)
try:
output = TosaArg(node)
output = TosaArg(node, tosa_spec)
except ValueError as e:
raise ValueError(
f"Failed processing call_function: {node.name}. "
Expand Down Expand Up @@ -78,7 +78,7 @@ def process_inputs(
f"Expected dim_order: {tuple(range(meta.dim()))}, but got: {meta.dim_order()} for node {node.name}"
)
try:
tosa_arg = TosaArg(node)
tosa_arg = TosaArg(node, tosa_spec)
except ValueError as e:
raise ValueError(
f"Failed processing input placeholder: {node.name}. "
Expand Down Expand Up @@ -112,7 +112,7 @@ def process_inputs_to_parameters(
):
"""Serialize bias and non-quantized weights"""
try:
tosa_arg = TosaArg(node)
tosa_arg = TosaArg(node, tosa_spec)
except ValueError as e:
raise ValueError(
f"Failed processing parameter placeholder: {node.name}. "
Expand All @@ -137,10 +137,11 @@ def process_inputs_to_buffers(
node: torch.fx.Node,
tosa_graph: Any,
edge_program: ExportedProgram,
tosa_spec: TosaSpecification,
):
"""Serialize quantized weights"""
try:
tosa_arg = TosaArg(node)
tosa_arg = TosaArg(node, tosa_spec)
except ValueError as e:
raise ValueError(
f"Failed processing buffer placeholder: {node.name}. "
Expand All @@ -165,9 +166,10 @@ def process_inputs_to_lifted_tensor_constants(
node: torch.fx.Node,
tosa_graph: Any,
edge_program: ExportedProgram,
tosa_spec: TosaSpecification,
):
try:
tosa_arg = TosaArg(node)
tosa_arg = TosaArg(node, tosa_spec)
except ValueError as e:
raise ValueError(
f"Failed processing lifted tensor constant placeholder: {node.name}. "
Expand Down Expand Up @@ -196,9 +198,11 @@ def process_placeholder(
elif is_param(edge_program, node):
process_inputs_to_parameters(node, tosa_graph, edge_program, tosa_spec)
elif is_buffer(edge_program, node):
process_inputs_to_buffers(node, tosa_graph, edge_program)
process_inputs_to_buffers(node, tosa_graph, edge_program, tosa_spec)
elif is_lifted_tensor_constant(edge_program, node):
process_inputs_to_lifted_tensor_constants(node, tosa_graph, edge_program)
process_inputs_to_lifted_tensor_constants(
node, tosa_graph, edge_program, tosa_spec
)
elif node.name in edge_program.graph_signature.inputs_to_lifted_custom_objs:
raise NotImplementedError(
"Placeholder is of type 'lifted custom object' which is not supported."
Expand Down
12 changes: 9 additions & 3 deletions backends/arm/test/tester/arm_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
)
from executorch.backends.arm.tosa_mapping import extract_tensor_meta
from executorch.backends.arm.tosa_partitioner import TOSAPartitioner
from executorch.backends.arm.tosa_specification import TosaSpecification

from executorch.backends.xnnpack.test.tester import Tester
from executorch.devtools.backend_debug import get_delegation_info
Expand Down Expand Up @@ -564,7 +565,10 @@ def dump_dtype_distribution(
)

graph = self.get_graph(self.cur)
dtype_dist_placeholders, dtype_dirst_tensors = _get_dtype_distribution(graph)
tosa_spec = get_tosa_spec(self.compile_spec)
dtype_dist_placeholders, dtype_dirst_tensors = _get_dtype_distribution(
graph, tosa_spec
)
all_dtypes = set(dtype_dist_placeholders.keys()) | set(
dtype_dirst_tensors.keys()
)
Expand Down Expand Up @@ -659,7 +663,9 @@ def _compare_outputs(
raise e


def _get_dtype_distribution(graph: Graph) -> tuple[dict, dict]:
def _get_dtype_distribution(
graph: Graph, tosa_spec: TosaSpecification
) -> tuple[dict, dict]:
"""Counts the occurences of placeholder and call_function dtypes in a graph.
The result is a tuple of Counters (placeholder_distribution, call_function_distribution)
"""
Expand All @@ -670,7 +676,7 @@ def _get_dtype_distribution(graph: Graph) -> tuple[dict, dict]:
placeholder_dtypes.append(str(node.meta["val"].dtype))
if node.op == "call_function":
if "val" in node.meta:
dtype, _, _ = extract_tensor_meta(node.meta)
dtype, _, _ = extract_tensor_meta(node.meta, tosa_spec)
call_function_dtypes.append(ts.DTypeNames[dtype])
return Counter(placeholder_dtypes), Counter(call_function_dtypes)

Expand Down
79 changes: 53 additions & 26 deletions backends/arm/tosa_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
# the standardised TOSA representation.
#

from typing import Any, Sequence
from typing import Any, Optional, Sequence

import torch

import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

from executorch.backends.arm.tosa_specification import (
Tosa_0_80,
Tosa_1_00,
TosaSpecification,
)

UNSUPPORTED_DTYPES = (
torch.float64,
Expand All @@ -30,33 +32,39 @@
torch.long,
)

DTYPE_MAP = {
torch.float32: ts.DType.FP32,
torch.float: ts.DType.FP32,
torch.float16: ts.DType.FP16,
torch.half: ts.DType.FP16,
torch.bfloat16: ts.DType.BF16,
torch.int8: ts.DType.INT8,
torch.int16: ts.DType.INT16,
torch.short: ts.DType.INT16,
torch.int32: ts.DType.INT32,
torch.int: ts.DType.INT32,
torch.bool: ts.DType.BOOL,
}


def map_dtype(data_type: torch.dtype) -> ts.DType:

def map_dtype(data_type: torch.dtype, tosa_spec: TosaSpecification) -> Any:
if data_type in UNSUPPORTED_DTYPES:
raise ValueError(f"Unsupported type: {data_type}")
if data_type not in DTYPE_MAP:
if isinstance(tosa_spec, Tosa_0_80):
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
elif isinstance(tosa_spec, Tosa_1_00):
import serializer.tosa_serializer as ts # type: ignore
else:
raise RuntimeError(f"Unsupported tosa_spec: {tosa_spec}")

dtype_map = {
torch.float32: ts.DType.FP32,
torch.float: ts.DType.FP32,
torch.float16: ts.DType.FP16,
torch.half: ts.DType.FP16,
torch.bfloat16: ts.DType.BF16,
torch.int8: ts.DType.INT8,
torch.int16: ts.DType.INT16,
torch.short: ts.DType.INT16,
torch.int32: ts.DType.INT32,
torch.int: ts.DType.INT32,
torch.bool: ts.DType.BOOL,
}
if data_type not in dtype_map:
raise ValueError(f"Unknown type: {data_type}")
return DTYPE_MAP[data_type]
return dtype_map[data_type]


# Returns the shape and type of a node
# TODO: other types, can be
# SymInt, FakeTensor, a List[Union[FakeTensor, SymInt]], or None
def extract_tensor_meta(meta):
def extract_tensor_meta(meta, tosa_spec: TosaSpecification):
assert meta.get("val") is not None
val = meta["val"]
if type(val) is tuple:
Expand All @@ -67,7 +75,7 @@ def extract_tensor_meta(meta):
raise ValueError(
f"Expected first value in node.meta['val'] to be FakeTensor, got {val.__class__}"
)
dtype = map_dtype(val.dtype)
dtype = map_dtype(val.dtype, tosa_spec)
shape = tuple(val.size())

if meta.get("tosa_dim_order") is not None:
Expand All @@ -81,17 +89,28 @@ def extract_tensor_meta(meta):
class TosaArg:
def __process_node(self, argument: torch.fx.Node):
self.name: str = argument.name
self.dtype, self.shape, self.dim_order = extract_tensor_meta(argument.meta)
self.dtype, self.shape, self.dim_order = extract_tensor_meta(
argument.meta, self.tosa_spec
)

def __process_list(self, argument):
self.special: list = list(argument)

def __process_number(self, argument: float | int):
self.number: float | int = argument

def __init__(self, argument: Any) -> None:
def __init__(
self, argument: Any, tosa_spec: Optional[TosaSpecification] = None
) -> None:
if argument is None:
return
if tosa_spec is None:
raise ValueError("tosa_spec is None")
elif not isinstance(tosa_spec, TosaSpecification):
raise ValueError(
f"Expected tosa_spec to be a TosaSpecification, but got {tosa_spec}"
)
self.tosa_spec = tosa_spec

if isinstance(argument, torch.fx.Node):
self.__process_node(argument)
Expand All @@ -116,6 +135,12 @@ def __repr__(self):
if self.name is not None:
attrs.append(f"name={self.name!r}")
if self.dtype is not None:
if isinstance(self.tosa_spec, Tosa_0_80):
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
elif isinstance(self.tosa_spec, Tosa_1_00):
import serializer.tosa_serializer as ts # type: ignore
else:
raise RuntimeError(f"Unsupported tosa_spec: {self.tosa_spec}")
attrs.append(f"dtype={ts.DTypeNames[self.dtype]}")
if self.shape is not None:
attrs.append(f"shape={self.shape!r}")
Expand All @@ -125,4 +150,6 @@ def __repr__(self):
attrs.append(f"special={self.special!r}")
if hasattr(self, "number") and self.number is not None:
attrs.append(f"number={self.number!r}")
if hasattr(self, "tosa_spec") and self.tosa_spec is not None:
attrs.append(f"tosa_spec={self.tosa_spec!r}")
return f"{self.__class__.__name__}({', '.join(attrs)})"
19 changes: 10 additions & 9 deletions backends/arm/tosa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
from typing import Any, Optional, Tuple

import torch

import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.tosa_mapping import TosaArg

from executorch.backends.arm.tosa_specification import TosaSpecification

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.print_program import inspect_node
from torch.fx import Node
Expand Down Expand Up @@ -93,9 +94,9 @@ def dbg_fail(
dbg_node(node, graph_module)


def getNodeArgs(node: Node) -> list[TosaArg]:
def getNodeArgs(node: Node, tosa_spec: TosaSpecification) -> list[TosaArg]:
try:
return [TosaArg(arg) for arg in node.args]
return [TosaArg(arg, tosa_spec) for arg in node.args]
except ValueError as e:
raise ValueError(f"Failed processing args to op:\n{node}") from e

Expand Down Expand Up @@ -153,14 +154,14 @@ def get_new_shape(l_rank_in, h_rank_in):
return reshaped, input2


def is_consumer_node_depthwise_conv2d(node):
def is_consumer_node_depthwise_conv2d(node: Node):
consumer_node = list(node.users)[0]
if consumer_node.target == exir_ops.edge.aten.convolution.default:
inputs = getNodeArgs(consumer_node)
group = inputs[-1]
in_channels = inputs[0].shape[1]
out_channels = inputs[1].shape[0]
if (in_channels == group.number) and (out_channels % in_channels) == 0:
consumer_node_inputs = consumer_node.all_input_nodes
groups = consumer_node.args[-1]
in_channels = consumer_node_inputs[0].meta["val"].shape[1]
out_channels = consumer_node_inputs[1].meta["val"].shape[0]
if (in_channels == groups) and (out_channels % in_channels) == 0:
return True

return False
Expand Down
Loading