Skip to content

Commit 998ad74

Browse files
Arm backend: Convert assert to throw ValueError in op_tanh
Asserts are converted to proper raises to ensure graph integrity. It should not be possible for tanh to have more than 1 input for a correctly formatted graph, but in the node visitor we cannot know for sure that the graph is formatted correctly. torch.tanh supports more data types than fp32, which is why it should be checked. Change-Id: Ibbe2f6964f85ee6c5883fdbe8973526ff6f224cd Signed-off-by: Sebastian Larsson <[email protected]>
1 parent 4717459 commit 998ad74

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

backends/arm/operators/op_tanh.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,5 +34,14 @@ def define_node(
3434
inputs: List[TosaArg],
3535
output: TosaArg,
3636
) -> None:
37-
assert inputs[0].dtype == output.dtype == ts.DType.FP32
37+
if len(node.all_input_nodes) != 1:
38+
raise ValueError(
39+
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"
40+
)
41+
if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32:
42+
raise ValueError(
43+
f"Input and output for {self.target} need to be FP32, got input_dtype: "
44+
f"{inputs[0].dtype} and output_dtype: {output.dtype}"
45+
)
46+
3847
tosa_graph.addOperator(TosaOp.Op().TANH, [inputs[0].name], [output.name])

0 commit comments

Comments
 (0)