Skip to content

Commit f785386

Browse files
authored
Arm backend: Update operator support for TOSA-1.0+INT+u55 (#10849)
### Summary Make is_U55_subset part of the base TOSA specification class as the subset is not tied to a specific specification. Using the is_U55_subset attribute the supported checks for U55 subset are updated to catch TOSA 1.0 as well. ### Test plan Tested on internal and external CI. Signed-off-by: Per Åstrand <[email protected]>
1 parent a21022c commit f785386

File tree

9 files changed

+32
-40
lines changed

9 files changed

+32
-40
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
UnsqueezeScalarPlaceholdersPass,
6060
)
6161

62-
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
62+
from executorch.backends.arm.tosa_specification import TosaSpecification
6363
from executorch.backends.transforms.decompose_sdpa import (
6464
DecomposeScaledDotProductAttention,
6565
)
@@ -92,7 +92,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
9292
self.add_pass(ConvertMinMaxPass())
9393
self.add_pass(ConvertAnyDefaultDimDimsPass())
9494
self.add_pass(MatchWhereSelfDtypePass())
95-
if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
95+
if self.tosa_spec.is_U55_subset:
9696
self.add_pass(CastToInt32Pass())
9797

9898
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
@@ -210,7 +210,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
210210
self.add_pass(DecomposeSqrtPass())
211211
self.add_pass(DecomposeSiluPass())
212212

213-
if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
213+
if self.tosa_spec.is_U55_subset:
214214
# Numerically stable softmax uses amax which is not supported on Ethos-U55
215215
self.add_pass(DecomposeSoftmaxUnstablePass())
216216
else:

backends/arm/operator_support/convolution_support.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,8 @@
1111
register_tosa_support_check,
1212
SupportedTOSAOperatorCheck,
1313
)
14-
from executorch.backends.arm.tosa_specification import (
15-
Tosa_0_80,
16-
Tosa_1_00,
17-
TosaSpecification,
18-
)
14+
from executorch.backends.arm.tosa_specification import TosaSpecification
15+
1916
from executorch.exir.dialects._ops import ops as exir_ops
2017

2118

@@ -46,13 +43,10 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
4643
return False
4744

4845
# Hardware specific constraints
49-
if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset):
50-
# TODO remove this once TOSA 1.0 support for u55 is added.
51-
if isinstance(tosa_spec, Tosa_1_00) and "u55" in tosa_spec.extensions:
52-
return False
53-
return True
54-
else:
46+
if tosa_spec.is_U55_subset:
5547
return self._is_node_supported_u55(node)
48+
else:
49+
return True
5650

5751
def _is_node_supported_u55(self, node: fx.Node):
5852
"""Hardware constraints for Ethos-U-55 case, Vela 4.2.0 (25.02 release)"""

backends/arm/operator_support/pool_2d_support.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
register_tosa_support_check,
1212
SupportedTOSAOperatorCheck,
1313
)
14-
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
14+
from executorch.backends.arm.tosa_specification import TosaSpecification
1515
from executorch.exir.dialects._ops import ops as exir_ops
1616

1717

@@ -46,7 +46,7 @@ class AvgPool2dSupported(SupportedTOSAOperatorCheck):
4646
]
4747

4848
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
49-
if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset):
49+
if not tosa_spec.is_U55_subset:
5050
return True
5151

5252
# U55 case, Vela 4.2.0 (25.02 release)
@@ -104,7 +104,7 @@ class MaxPool2dSupported(SupportedTOSAOperatorCheck):
104104
]
105105

106106
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
107-
if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset):
107+
if not tosa_spec.is_U55_subset:
108108
return True
109109

110110
# U55 case, Vela 4.2.0 (25.02 release)

backends/arm/operator_support/reduce_sum_support.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
register_tosa_support_check,
1111
SupportedTOSAOperatorCheck,
1212
)
13-
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
13+
from executorch.backends.arm.tosa_specification import TosaSpecification
1414
from executorch.exir.dialects._ops import ops as exir_ops
1515

1616

@@ -26,7 +26,7 @@ class SumSupported(SupportedTOSAOperatorCheck):
2626
]
2727

2828
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
29-
if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset):
29+
if not tosa_spec.is_U55_subset:
3030
return True
3131

3232
# U55 case, Vela 4.2.0 (25.02 release)

backends/arm/operator_support/right_shift_support.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
register_tosa_support_check,
1414
SupportedTOSAOperatorCheck,
1515
)
16-
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
16+
from executorch.backends.arm.tosa_specification import TosaSpecification
1717
from executorch.exir.dialects._ops import ops as exir_ops
1818

1919
logger = logging.getLogger(__name__)
@@ -36,6 +36,6 @@ class RightShiftSupported(SupportedTOSAOperatorCheck):
3636
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
3737

3838
# TODO MLETORCH-525 Remove warning
39-
if isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset:
39+
if tosa_spec.is_U55_subset:
4040
logging.warning(f"{node.target} may introduce one-off errors.")
4141
return True

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,7 @@
2424
EthosU55NotSupported,
2525
EthosU55TransposeCheck,
2626
)
27-
from executorch.backends.arm.tosa_specification import (
28-
Tosa_0_80,
29-
Tosa_1_00,
30-
TosaSpecification,
31-
)
27+
from executorch.backends.arm.tosa_specification import TosaSpecification
3228
from executorch.exir import ExportedProgram
3329
from executorch.exir.backend.utils import WhyNoPartitionReporter
3430
from executorch.exir.dialects._ops import ops as exir_ops
@@ -129,9 +125,7 @@ def tosa_support_factory(
129125
if not tosa_spec.support_float():
130126
negative_checks.append(NeedsDecompositionCheck(reporter))
131127
negative_checks.append(CheckProperQuantization(reporter))
132-
if (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset) or (
133-
isinstance(tosa_spec, Tosa_1_00) and "u55" in tosa_spec.extensions
134-
):
128+
if tosa_spec.is_U55_subset:
135129
negative_checks.append(EthosU55NotSupported(reporter))
136130
negative_checks.append(EthosU55DtypeSupport(reporter))
137131
negative_checks.append(EthosU55TransposeCheck(reporter))

backends/arm/operators/op_rshift_tensor.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
validate_num_inputs,
1818
)
1919
from executorch.backends.arm.tosa_mapping import TosaArg
20-
from executorch.backends.arm.tosa_specification import Tosa_0_80, Tosa_1_00
2120

2221

2322
@register_node_visitor
@@ -39,7 +38,7 @@ def define_node(
3938

4039
attr = ts.TosaSerializerAttribute()
4140
round = False
42-
if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
41+
if self.tosa_spec.is_U55_subset:
4342
# U55 only supports INT32 and round == True
4443
# TODO MLETORCH-525 Emulate round == False with different decomposition
4544
round = True
@@ -72,7 +71,7 @@ def define_node(
7271

7372
attr = ts.TosaSerializerAttribute()
7473
round = False
75-
if isinstance(self.tosa_spec, Tosa_1_00) and "u55" in self.tosa_spec.extensions:
74+
if self.tosa_spec.is_U55_subset:
7675
# U55 only supports INT32 and round == True
7776
# TODO MLETORCH-525 Emulate round == False with different decomposition
7877
round = True

backends/arm/test/tester/test_pipeline.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,9 @@ def __init__(
293293
)
294294
quant_stage = (
295295
Quantize(
296-
TOSAQuantizer(compile_spec).set_io(get_symmetric_quantization_config()),
296+
TOSAQuantizer(tosa_profiles[tosa_version]).set_io(
297+
get_symmetric_quantization_config()
298+
),
297299
get_symmetric_quantization_config(),
298300
)
299301
if symmetric_io_quantization

backends/arm/tosa_specification.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class TosaSpecification:
3636
"""
3737

3838
version: Version
39+
is_U55_subset: bool
3940

4041
def support_integer(self) -> bool:
4142
"""
@@ -49,9 +50,13 @@ def support_float(self) -> bool:
4950
"""
5051
raise NotImplementedError
5152

52-
def __init__(self, version: Version):
53+
def __init__(self, version: Version, extras: List[str]):
5354
self.version = version
5455

56+
self.is_U55_subset = "u55" in extras
57+
if self.is_U55_subset:
58+
extras.remove("u55")
59+
5560
@staticmethod
5661
def create_from_string(repr: str) -> "TosaSpecification":
5762
"""
@@ -85,11 +90,10 @@ def create_from_string(repr: str) -> "TosaSpecification":
8590
class Tosa_0_80(TosaSpecification):
8691
profile: str
8792
level_8k: bool
88-
is_U55_subset: bool
8993
available_profiles = ["BI", "MI"] # MT is not defined
9094

9195
def __init__(self, version: Version, extras: List[str]):
92-
super().__init__(version)
96+
super().__init__(version, extras)
9397
assert version >= Version("0.80") and version < Version("0.90")
9498

9599
# Check that we only have one profile in the extensions list
@@ -105,9 +109,6 @@ def __init__(self, version: Version, extras: List[str]):
105109
self.level_8k = "8k" in extras
106110
if self.level_8k:
107111
extras.remove("8k")
108-
self.is_U55_subset = "u55" in extras
109-
if self.is_U55_subset:
110-
extras.remove("u55")
111112

112113
if len(extras) > 0:
113114
raise ValueError(f"Unhandled extras found: {extras}")
@@ -147,7 +148,7 @@ class Tosa_1_00(TosaSpecification):
147148
}
148149

149150
def __init__(self, version: Version, extras: List[str]):
150-
super().__init__(version)
151+
super().__init__(version, extras)
151152

152153
# Check that we have at least one profile in the extensions list
153154
if [e in Tosa_1_00.available_profiles for e in extras].count(True) == 0:
@@ -194,6 +195,8 @@ def __repr__(self):
194195
extensions = self._get_extensions_string()
195196
if self.level_8k:
196197
extensions += "+8k"
198+
if self.is_U55_subset:
199+
extensions += "+u55"
197200
return f"TOSA-{self.version}{self._get_profiles_string()}{extensions}"
198201

199202
def __hash__(self) -> int:

0 commit comments

Comments
 (0)