@@ -41,9 +41,18 @@ def define_node(
41
41
) -> None :
42
42
# Specification (0.80) states that input and output types
43
43
# should all be the same
44
- assert inputs [0 ].dtype == inputs [1 ].dtype == output .dtype
44
+ if inputs [0 ].dtype != inputs [1 ].dtype or inputs [0 ].dtype != output .dtype :
45
+ raise TypeError (
46
+ f"All IO needs to have the same data type, got input 1: "
47
+ f"{ inputs [0 ].dtype } , input 2: { inputs [1 ].dtype } and output: "
48
+ f"{ output .dtype } "
49
+ )
45
50
# Handle int8 (quantized) and int32
46
- assert inputs [0 ].dtype in [ts .DType .INT8 , ts .DType .INT32 ]
51
+ supported_dtypes = [ts .DType .INT8 , ts .DType .INT32 ]
52
+ if inputs [0 ].dtype not in supported_dtypes :
53
+ raise TypeError (
54
+ f'IO data type needs to be { supported_dtypes } , got "{ inputs [0 ].dtype } "'
55
+ )
47
56
48
57
dim_order = (
49
58
inputs [0 ].dim_order
@@ -105,15 +114,22 @@ def define_node(
105
114
) -> None :
106
115
# Specification (0.80) states that input and output types
107
116
# should all be the same
108
- assert inputs [0 ].dtype == inputs [1 ].dtype == output .dtype
117
+ if inputs [0 ].dtype != inputs [1 ].dtype or inputs [0 ].dtype != output .dtype :
118
+ raise TypeError (
119
+ f"All IO needs to have the same data type, got input 1: "
120
+ f"{ inputs [0 ].dtype } , input 2: { inputs [1 ].dtype } and output: "
121
+ f"{ output .dtype } "
122
+ )
109
123
110
124
if inputs [0 ].dtype in [ts .DType .INT8 , ts .DType .INT32 ]:
111
125
# Call the inherited define_node for handling integers
112
126
super ().define_node (node , tosa_graph , inputs , output )
113
127
else :
114
128
# FP32 Add lowering
115
- assert inputs [0 ].dtype == ts .DType .FP32
116
- assert output .dtype == ts .DType .FP32
129
+ if inputs [0 ].dtype != ts .DType .FP32 :
130
+ raise TypeError (
131
+ f"Expected IO data type to be FP32, got { inputs [0 ].dtype } "
132
+ )
117
133
118
134
input1 , input2 = tutils .reshape_for_broadcast (tosa_graph , inputs )
119
135
0 commit comments