@@ -51,12 +51,8 @@ inline int64_t scalar_to<int64_t>(const Scalar& s) {
51
51
}
52
52
53
53
namespace internal {
54
- template <
55
- typename CTYPE_COMPUTE,
56
- const char * op_name,
57
- typename Op,
58
- typename ... Args>
59
- inline void apply_elementwise_fn (
54
+ template <typename CTYPE_COMPUTE, typename Op, typename ... Args>
55
+ inline bool validate_elementwise_fn_inputs (
60
56
const Op& compute_fun,
61
57
KernelRuntimeContext& ctx,
62
58
const Tensor& out,
@@ -65,7 +61,6 @@ inline void apply_elementwise_fn(
65
61
static_assert (
66
62
(std::is_same_v<Args, std::pair<const Tensor*, SupportedTensorDtypes>> &&
67
63
...));
68
- constexpr auto kNumInputs = sizeof ...(inputs);
69
64
constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMPUTE>::value;
70
65
const auto check_input_dtype = [](auto input, auto compute_type) {
71
66
return internal::check_tensor_dtype (
@@ -75,7 +70,30 @@ inline void apply_elementwise_fn(
75
70
ctx,
76
71
(check_input_dtype (inputs, compute_type) && ...) &&
77
72
internal::check_tensor_dtype (out, out_dtypes, compute_type),
78
- InvalidArgument, );
73
+ InvalidArgument,
74
+ false );
75
+
76
+ return true ;
77
+ }
78
+
79
+ template <
80
+ typename CTYPE_COMPUTE,
81
+ const char * op_name,
82
+ typename Op,
83
+ typename ... Args>
84
+ inline void apply_elementwise_fn (
85
+ const Op& compute_fun,
86
+ KernelRuntimeContext& ctx,
87
+ const Tensor& out,
88
+ SupportedTensorDtypes out_dtypes,
89
+ Args... inputs) {
90
+ const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMPUTE>(
91
+ compute_fun, ctx, out, out_dtypes, inputs...);
92
+ if (!inputs_valid) {
93
+ return ;
94
+ }
95
+
96
+ constexpr auto kNumInputs = sizeof ...(inputs);
79
97
80
98
struct InputInfo {
81
99
load_to_compute_fn<CTYPE_COMPUTE> load_to_compute;
@@ -120,6 +138,7 @@ inline void apply_elementwise_fn(
120
138
});
121
139
}
122
140
141
+ // / DEPRECATED: prefer the variant with out_dtypes in the template argument.
123
142
template <typename CTYPE_COMPUTE, const char * op_name, typename Op>
124
143
inline void apply_unitensor_elementwise_fn (
125
144
const Op& compute_fun,
@@ -132,19 +151,83 @@ inline void apply_unitensor_elementwise_fn(
132
151
compute_fun, ctx, out, out_dtypes, std::make_pair (&a, a_dtypes));
133
152
}
134
153
154
+ template <
155
+ typename CTYPE_COMPUTE,
156
+ const char * op_name,
157
+ SupportedTensorDtypes out_dtypes,
158
+ typename Op>
159
+ inline void apply_unitensor_elementwise_fn (
160
+ const Op& compute_fun,
161
+ KernelRuntimeContext& ctx,
162
+ const Tensor& a,
163
+ SupportedTensorDtypes a_dtypes,
164
+ const Tensor& out) {
165
+ internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
166
+ compute_fun, ctx, out, out_dtypes, std::make_pair (&a, a_dtypes));
167
+ }
168
+
169
+ /* *
170
+ * DEPRECATED: prefer the variant with out_dtypes in the template argument list.
171
+ */
172
+ template <typename CTYPE_COMPUTE, const char * op_name, typename Op>
173
+ inline void apply_bitensor_elementwise_fn (
174
+ const Op& compute_fun,
175
+ KernelRuntimeContext& ctx,
176
+ const Tensor& a,
177
+ SupportedTensorDtypes a_dtypes,
178
+ const Tensor& b,
179
+ SupportedTensorDtypes b_dtypes,
180
+ const Tensor& out,
181
+ SupportedTensorDtypes out_dtypes) {
182
+ internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
183
+ compute_fun,
184
+ ctx,
185
+ out,
186
+ out_dtypes,
187
+ std::make_pair (&a, a_dtypes),
188
+ std::make_pair (&b, b_dtypes));
189
+ }
190
+
135
191
/* *
136
192
* Useful for bi-tensor elementwise operators. For each element of the inputs,
137
193
* perform a computation and write to the corresponding element of the output.
138
194
* Tensor broadcasting is applied wherever it is required.
139
195
*/
140
- template <typename CTYPE_COMPUTE, const char * op_name, typename Op>
196
+ template <
197
+ typename CTYPE_COMPUTE,
198
+ const char * op_name,
199
+ SupportedTensorDtypes out_dtypes,
200
+ typename Op>
141
201
inline void apply_bitensor_elementwise_fn (
142
202
const Op& compute_fun,
143
203
KernelRuntimeContext& ctx,
144
204
const Tensor& a,
145
205
SupportedTensorDtypes a_dtypes,
146
206
const Tensor& b,
147
207
SupportedTensorDtypes b_dtypes,
208
+ const Tensor& out) {
209
+ internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
210
+ compute_fun,
211
+ ctx,
212
+ out,
213
+ out_dtypes,
214
+ std::make_pair (&a, a_dtypes),
215
+ std::make_pair (&b, b_dtypes));
216
+ }
217
+
218
+ /* *
219
+ * DEPRECATED: prefer the variant with out_dtypes in the template argument list.
220
+ */
221
+ template <typename CTYPE_COMPUTE, const char * op_name, typename Op>
222
+ inline void apply_tritensor_elementwise_fn (
223
+ const Op& compute_fun,
224
+ KernelRuntimeContext& ctx,
225
+ const Tensor& a,
226
+ SupportedTensorDtypes a_dtypes,
227
+ const Tensor& b,
228
+ SupportedTensorDtypes b_dtypes,
229
+ const Tensor& c,
230
+ SupportedTensorDtypes c_dtypes,
148
231
const Tensor& out,
149
232
SupportedTensorDtypes out_dtypes) {
150
233
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
@@ -153,7 +236,8 @@ inline void apply_bitensor_elementwise_fn(
153
236
out,
154
237
out_dtypes,
155
238
std::make_pair (&a, a_dtypes),
156
- std::make_pair (&b, b_dtypes));
239
+ std::make_pair (&b, b_dtypes),
240
+ std::make_pair (&c, c_dtypes));
157
241
}
158
242
159
243
/* *
@@ -176,7 +260,11 @@ inline void apply_bitensor_elementwise_fn(
176
260
* static constexpr const char op_name[] = "my_op";
177
261
* apply_ternary_elementwise_fn<CTYPE_COMPUTE, op_name>.
178
262
*/
179
- template <typename CTYPE_COMPUTE, const char * op_name, typename Op>
263
+ template <
264
+ typename CTYPE_COMPUTE,
265
+ const char * op_name,
266
+ SupportedTensorDtypes out_dtypes,
267
+ typename Op>
180
268
inline void apply_tritensor_elementwise_fn (
181
269
const Op& compute_fun,
182
270
KernelRuntimeContext& ctx,
@@ -186,8 +274,7 @@ inline void apply_tritensor_elementwise_fn(
186
274
SupportedTensorDtypes b_dtypes,
187
275
const Tensor& c,
188
276
SupportedTensorDtypes c_dtypes,
189
- const Tensor& out,
190
- SupportedTensorDtypes out_dtypes) {
277
+ const Tensor& out) {
191
278
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
192
279
compute_fun,
193
280
ctx,
0 commit comments