Skip to content

Commit 771588a

Browse files
Arm backend: Support scalar_tensor op (#9711)
- Handled by the ComputeConstantOpsAOT pass. - Add tests Signed-off-by: Erik Lundell <[email protected]> Co-authored-by: Måns Nilsson <[email protected]>
1 parent 03c5cf7 commit 771588a

File tree

4 files changed

+95
-3
lines changed

4 files changed

+95
-3
lines changed

backends/arm/_passes/fuse_constant_ops_pass.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def f(node_name_pre_computed):
161161
exir_ops.edge.aten.arange.start_step,
162162
exir_ops.edge.aten.eye.default,
163163
exir_ops.edge.aten.linspace.default,
164+
torch.ops.aten.scalar_tensor.default,
164165
]
165166

166167
def __init__(self, exported_program: ExportedProgram) -> None:

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ def is_node_supported(
205205
exir_ops.edge.aten.amin.default,
206206
exir_ops.edge.aten.eye.default,
207207
exir_ops.edge.aten.linspace.default,
208+
torch.ops.aten.scalar_tensor.default,
208209
]
209210

210211
return supported

backends/arm/test/models/test_conformer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ class TestConformer(unittest.TestCase):
3535
"executorch_exir_dialects_edge__ops_aten_where_self": 4,
3636
"torch.ops.aten._assert_scalar.default": 10,
3737
"torch.ops.aten._local_scalar_dense.default": 1,
38-
"torch.ops.aten.scalar_tensor.default": 2,
3938
"torch.ops.higher_order.executorch_call_delegate": 6,
4039
}
4140

@@ -92,7 +91,7 @@ def test_conformer_tosa_BI(self):
9291
)
9392
)
9493

95-
@conftest.expectedFailureOnFVP # TODO(MLETORCH-635)
94+
@unittest.expectedFailure # TODO(MLETORCH-635)
9695
def test_conformer_u55_BI(self):
9796
tester = (
9897
ArmTester(
@@ -114,7 +113,7 @@ def test_conformer_u55_BI(self):
114113
inputs=get_test_inputs(self.dim, self.lengths, self.num_examples),
115114
)
116115

117-
@conftest.expectedFailureOnFVP # TODO(MLETORCH-635)
116+
@unittest.expectedFailure # TODO(MLETORCH-635)
118117
def test_conformer_u85_BI(self):
119118
tester = (
120119
ArmTester(
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.backends.arm.test import common
8+
9+
from executorch.backends.arm.test.tester.test_pipeline import (
10+
EthosU55PipelineBI,
11+
EthosU85PipelineBI,
12+
TosaPipelineBI,
13+
TosaPipelineMI,
14+
)
15+
16+
float_test_data_suite = {
17+
"scalar_tensor_float_1": (3.7, torch.float32, torch.rand((1, 2, 3, 4))),
18+
"scalar_tensor_float_2": (66, torch.float32, torch.rand((1, 2, 3))),
19+
}
20+
21+
int_test_data_suite = {
22+
"scalar_tensor_int32": (
23+
33,
24+
torch.int32,
25+
torch.randint(0, 10, (1, 2), dtype=torch.int32),
26+
),
27+
"scalar_tensor_int8": (
28+
8,
29+
torch.int8,
30+
torch.rand(1, 2, 3),
31+
),
32+
"scalar_tensor_int16": (
33+
16 * 16 * 16,
34+
torch.int16,
35+
torch.rand((1,)).unsqueeze(0), # Rank 0 inputs not supported
36+
),
37+
}
38+
39+
40+
class ScalarTensor(torch.nn.Module):
41+
aten_op = "torch.ops.aten.scalar_tensor.default"
42+
43+
def __init__(self, scalar, dtype=torch.float32):
44+
super().__init__()
45+
self.scalar = scalar
46+
self.dtype = dtype
47+
48+
def forward(self, x: torch.Tensor):
49+
return torch.scalar_tensor(self.scalar, dtype=self.dtype) + x
50+
51+
52+
@common.parametrize("test_data", int_test_data_suite | float_test_data_suite)
53+
def test_scalar_tensor_tosa_MI(test_data): # Note TOSA MI supports all types
54+
scalar, dtype, data = test_data
55+
TosaPipelineMI(ScalarTensor(scalar, dtype), tuple(data), ScalarTensor.aten_op).run()
56+
57+
58+
@common.parametrize("test_data", int_test_data_suite | float_test_data_suite)
59+
def test_scalar_tensor_tosa_BI(test_data):
60+
scalar, dtype, data = test_data
61+
pipeline: TosaPipelineBI = TosaPipelineBI(
62+
ScalarTensor(scalar, dtype), tuple(data), ScalarTensor.aten_op
63+
)
64+
pipeline.pop_stage("check.quant_nodes")
65+
pipeline.run()
66+
67+
68+
@common.parametrize("test_data", float_test_data_suite)
69+
@common.XfailIfNoCorstone300
70+
def test_scalar_tensor_tosa_u55(test_data):
71+
scalar, dtype, data = test_data
72+
EthosU55PipelineBI(
73+
ScalarTensor(scalar, dtype),
74+
tuple(data),
75+
ScalarTensor.aten_op,
76+
symmetric_io_quantization=True,
77+
run_on_fvp=True,
78+
).run()
79+
80+
81+
@common.parametrize("test_data", float_test_data_suite)
82+
@common.XfailIfNoCorstone320
83+
def test_scalar_tensor_tosa_u85(test_data):
84+
scalar, dtype, data = test_data
85+
EthosU85PipelineBI(
86+
ScalarTensor(scalar, dtype),
87+
tuple(data),
88+
ScalarTensor.aten_op,
89+
symmetric_io_quantization=True,
90+
run_on_fvp=True,
91+
).run()

0 commit comments

Comments
 (0)