Skip to content

Commit 8512f86

Browse files
Sebastian-Larssonkirklandsign
authored andcommitted
Arm backend: Convert assert to throw TypeError in op_add (#9897)
Asserts are converted to proper raises to ensure graph integrity. Signed-off-by: Sebastian Larsson <[email protected]>
1 parent b85eefe commit 8512f86

File tree

1 file changed

+21
-5
lines changed

1 file changed

+21
-5
lines changed

backends/arm/operators/op_add.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,18 @@ def define_node(
4141
) -> None:
4242
# Specification (0.80) states that input and output types
4343
# should all be the same
44-
assert inputs[0].dtype == inputs[1].dtype == output.dtype
44+
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
45+
raise TypeError(
46+
f"All IO needs to have the same data type, got input 1: "
47+
f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: "
48+
f"{output.dtype}"
49+
)
4550
# Handle int8 (quantized) and int32
46-
assert inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]
51+
supported_dtypes = [ts.DType.INT8, ts.DType.INT32]
52+
if inputs[0].dtype not in supported_dtypes:
53+
raise TypeError(
54+
f'IO data type needs to be {supported_dtypes}, got "{inputs[0].dtype}"'
55+
)
4756

4857
dim_order = (
4958
inputs[0].dim_order
@@ -105,15 +114,22 @@ def define_node(
105114
) -> None:
106115
# Specification (0.80) states that input and output types
107116
# should all be the same
108-
assert inputs[0].dtype == inputs[1].dtype == output.dtype
117+
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
118+
raise TypeError(
119+
f"All IO needs to have the same data type, got input 1: "
120+
f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: "
121+
f"{output.dtype}"
122+
)
109123

110124
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
111125
# Call the inherited define_node for handling integers
112126
super().define_node(node, tosa_graph, inputs, output)
113127
else:
114128
# FP32 Add lowering
115-
assert inputs[0].dtype == ts.DType.FP32
116-
assert output.dtype == ts.DType.FP32
129+
if inputs[0].dtype != ts.DType.FP32:
130+
raise TypeError(
131+
f"Expected IO data type to be FP32, got {inputs[0].dtype}"
132+
)
117133

118134
input1, input2 = tutils.reshape_for_broadcast(tosa_graph, inputs)
119135

0 commit comments

Comments
 (0)