Skip to content

Commit a1dd017

Browse files
authored
Refactor elementwise_util: create variants with out_dtypes in template argument list (#9387)
1 parent b71f03c commit a1dd017

File tree

1 file changed

+100
-13
lines changed

1 file changed

+100
-13
lines changed

kernels/portable/cpu/util/elementwise_util.h

Lines changed: 100 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,8 @@ inline int64_t scalar_to<int64_t>(const Scalar& s) {
5151
}
5252

5353
namespace internal {
54-
template <
55-
typename CTYPE_COMPUTE,
56-
const char* op_name,
57-
typename Op,
58-
typename... Args>
59-
inline void apply_elementwise_fn(
54+
template <typename CTYPE_COMPUTE, typename Op, typename... Args>
55+
inline bool validate_elementwise_fn_inputs(
6056
const Op& compute_fun,
6157
KernelRuntimeContext& ctx,
6258
const Tensor& out,
@@ -65,7 +61,6 @@ inline void apply_elementwise_fn(
6561
static_assert(
6662
(std::is_same_v<Args, std::pair<const Tensor*, SupportedTensorDtypes>> &&
6763
...));
68-
constexpr auto kNumInputs = sizeof...(inputs);
6964
constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMPUTE>::value;
7065
const auto check_input_dtype = [](auto input, auto compute_type) {
7166
return internal::check_tensor_dtype(
@@ -75,7 +70,30 @@ inline void apply_elementwise_fn(
7570
ctx,
7671
(check_input_dtype(inputs, compute_type) && ...) &&
7772
internal::check_tensor_dtype(out, out_dtypes, compute_type),
78-
InvalidArgument, );
73+
InvalidArgument,
74+
false);
75+
76+
return true;
77+
}
78+
79+
template <
80+
typename CTYPE_COMPUTE,
81+
const char* op_name,
82+
typename Op,
83+
typename... Args>
84+
inline void apply_elementwise_fn(
85+
const Op& compute_fun,
86+
KernelRuntimeContext& ctx,
87+
const Tensor& out,
88+
SupportedTensorDtypes out_dtypes,
89+
Args... inputs) {
90+
const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMPUTE>(
91+
compute_fun, ctx, out, out_dtypes, inputs...);
92+
if (!inputs_valid) {
93+
return;
94+
}
95+
96+
constexpr auto kNumInputs = sizeof...(inputs);
7997

8098
struct InputInfo {
8199
load_to_compute_fn<CTYPE_COMPUTE> load_to_compute;
@@ -120,6 +138,7 @@ inline void apply_elementwise_fn(
120138
});
121139
}
122140

141+
/// DEPRECATED: prefer the variant with out_dtypes in the template argument.
123142
template <typename CTYPE_COMPUTE, const char* op_name, typename Op>
124143
inline void apply_unitensor_elementwise_fn(
125144
const Op& compute_fun,
@@ -132,19 +151,83 @@ inline void apply_unitensor_elementwise_fn(
132151
compute_fun, ctx, out, out_dtypes, std::make_pair(&a, a_dtypes));
133152
}
134153

154+
template <
155+
typename CTYPE_COMPUTE,
156+
const char* op_name,
157+
SupportedTensorDtypes out_dtypes,
158+
typename Op>
159+
inline void apply_unitensor_elementwise_fn(
160+
const Op& compute_fun,
161+
KernelRuntimeContext& ctx,
162+
const Tensor& a,
163+
SupportedTensorDtypes a_dtypes,
164+
const Tensor& out) {
165+
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
166+
compute_fun, ctx, out, out_dtypes, std::make_pair(&a, a_dtypes));
167+
}
168+
169+
/**
170+
* DEPRECATED: prefer the variant with out_dtypes in the template argument list.
171+
*/
172+
template <typename CTYPE_COMPUTE, const char* op_name, typename Op>
173+
inline void apply_bitensor_elementwise_fn(
174+
const Op& compute_fun,
175+
KernelRuntimeContext& ctx,
176+
const Tensor& a,
177+
SupportedTensorDtypes a_dtypes,
178+
const Tensor& b,
179+
SupportedTensorDtypes b_dtypes,
180+
const Tensor& out,
181+
SupportedTensorDtypes out_dtypes) {
182+
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
183+
compute_fun,
184+
ctx,
185+
out,
186+
out_dtypes,
187+
std::make_pair(&a, a_dtypes),
188+
std::make_pair(&b, b_dtypes));
189+
}
190+
135191
/**
136192
* Useful for bi-tensor elementwise operators. For each element of the inputs,
137193
* perform a computation and write to the corresponding element of the output.
138194
* Tensor broadcasting is applied wherever it is required.
139195
*/
140-
template <typename CTYPE_COMPUTE, const char* op_name, typename Op>
196+
template <
197+
typename CTYPE_COMPUTE,
198+
const char* op_name,
199+
SupportedTensorDtypes out_dtypes,
200+
typename Op>
141201
inline void apply_bitensor_elementwise_fn(
142202
const Op& compute_fun,
143203
KernelRuntimeContext& ctx,
144204
const Tensor& a,
145205
SupportedTensorDtypes a_dtypes,
146206
const Tensor& b,
147207
SupportedTensorDtypes b_dtypes,
208+
const Tensor& out) {
209+
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
210+
compute_fun,
211+
ctx,
212+
out,
213+
out_dtypes,
214+
std::make_pair(&a, a_dtypes),
215+
std::make_pair(&b, b_dtypes));
216+
}
217+
218+
/**
219+
* DEPRECATED: prefer the variant with out_dtypes in the template argument list.
220+
*/
221+
template <typename CTYPE_COMPUTE, const char* op_name, typename Op>
222+
inline void apply_tritensor_elementwise_fn(
223+
const Op& compute_fun,
224+
KernelRuntimeContext& ctx,
225+
const Tensor& a,
226+
SupportedTensorDtypes a_dtypes,
227+
const Tensor& b,
228+
SupportedTensorDtypes b_dtypes,
229+
const Tensor& c,
230+
SupportedTensorDtypes c_dtypes,
148231
const Tensor& out,
149232
SupportedTensorDtypes out_dtypes) {
150233
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
@@ -153,7 +236,8 @@ inline void apply_bitensor_elementwise_fn(
153236
out,
154237
out_dtypes,
155238
std::make_pair(&a, a_dtypes),
156-
std::make_pair(&b, b_dtypes));
239+
std::make_pair(&b, b_dtypes),
240+
std::make_pair(&c, c_dtypes));
157241
}
158242

159243
/**
@@ -176,7 +260,11 @@ inline void apply_bitensor_elementwise_fn(
176260
* static constexpr const char op_name[] = "my_op";
177261
* apply_ternary_elementwise_fn<CTYPE_COMPUTE, op_name>.
178262
*/
179-
template <typename CTYPE_COMPUTE, const char* op_name, typename Op>
263+
template <
264+
typename CTYPE_COMPUTE,
265+
const char* op_name,
266+
SupportedTensorDtypes out_dtypes,
267+
typename Op>
180268
inline void apply_tritensor_elementwise_fn(
181269
const Op& compute_fun,
182270
KernelRuntimeContext& ctx,
@@ -186,8 +274,7 @@ inline void apply_tritensor_elementwise_fn(
186274
SupportedTensorDtypes b_dtypes,
187275
const Tensor& c,
188276
SupportedTensorDtypes c_dtypes,
189-
const Tensor& out,
190-
SupportedTensorDtypes out_dtypes) {
277+
const Tensor& out) {
191278
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
192279
compute_fun,
193280
ctx,

0 commit comments

Comments
 (0)