@@ -60,8 +60,43 @@ using op_call_result =
6060
6161template <
6262 typename CTYPE_COMMON,
63+ typename CTYPE_OUT,
6364 typename Op,
64- typename ... Args>
65+ typename ... Args>
66+ inline void dtype_specialized_elementwise_fn_impl (
67+ const Op& compute_fun,
68+ KernelRuntimeContext& ctx,
69+ const Tensor& out,
70+ Args... inputs) {
71+ constexpr auto kNumInputs = sizeof ...(inputs);
72+ ET_DCHECK (((inputs.first ->element_size () == sizeof (CTYPE_COMMON)) && ...));
73+
74+ std::array<const CTYPE_COMMON*, kNumInputs > inputs_data_ptrs = {
75+ inputs.first ->template const_data_ptr <CTYPE_COMMON>()...};
76+
77+ CTYPE_OUT* const data_out = out.mutable_data_ptr <CTYPE_OUT>();
78+
79+ ::executorch::extension::parallel_for (
80+ 0 ,
81+ out.numel(),
82+ ::executorch::extension::internal::GRAIN_SIZE,
83+ [&](const auto begin, const auto end) {
84+ const auto range =
85+ BroadcastIndexesRange<kNumInputs >(out, (*inputs.first )...);
86+ auto begin_it = range.begin ();
87+ begin_it += begin;
88+ for (; (*begin_it)[0 ] < end; ++begin_it) {
89+ const auto & indexes = *begin_it;
90+ std::array<CTYPE_COMMON, kNumInputs > loaded_inputs;
91+ for (const auto idx : c10::irange (kNumInputs )) {
92+ loaded_inputs[idx] = inputs_data_ptrs[idx][indexes[idx + 1 ]];
93+ }
94+ data_out[indexes[0 ]] = std::apply (compute_fun, loaded_inputs);
95+ }
96+ });
97+ }
98+
99+ template <typename CTYPE_COMMON, typename Op, typename ... Args>
65100inline bool validate_elementwise_fn_inputs (
66101 const Op& compute_fun,
67102 KernelRuntimeContext& ctx,
@@ -80,7 +115,8 @@ inline bool validate_elementwise_fn_inputs(
80115 ctx,
81116 (check_input_dtype (inputs, compute_type) && ...) &&
82117 internal::check_tensor_dtype (out, out_dtypes, compute_type),
83- InvalidArgument, false );
118+ InvalidArgument,
119+ false );
84120
85121 return true ;
86122}
@@ -90,22 +126,12 @@ template <
90126 const char * op_name,
91127 typename Op,
92128 typename ... Args>
93- inline void apply_elementwise_fn (
129+ inline void apply_elementwise_fn_generic_impl (
94130 const Op& compute_fun,
95131 KernelRuntimeContext& ctx,
96132 const Tensor& out,
97133 SupportedTensorDtypes out_dtypes,
98134 Args... inputs) {
99- const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMMON>(
100- compute_fun,
101- ctx,
102- out,
103- out_dtypes,
104- inputs...);
105- if (!inputs_valid) {
106- return ;
107- }
108-
109135 constexpr auto kNumInputs = sizeof ...(inputs);
110136
111137 struct InputInfo {
@@ -157,6 +183,65 @@ inline void apply_elementwise_fn(
157183 }
158184 });
159185}
186+
187+ template <
188+ typename CTYPE_COMMON,
189+ const char * op_name,
190+ typename Op,
191+ typename ... Args>
192+ inline void apply_elementwise_fn_runtime_out_dtypes (
193+ const Op& compute_fun,
194+ KernelRuntimeContext& ctx,
195+ const Tensor& out,
196+ SupportedTensorDtypes out_dtypes,
197+ Args... inputs) {
198+ const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMMON>(
199+ compute_fun, ctx, out, out_dtypes, inputs...);
200+ if (!inputs_valid) {
201+ return ;
202+ }
203+
204+ apply_elementwise_fn_generic_impl<CTYPE_COMMON, op_name>(
205+ compute_fun, ctx, out, out_dtypes, inputs...);
206+ }
207+
208+ template <
209+ typename CTYPE_COMMON,
210+ const char * op_name,
211+ SupportedTensorDtypes out_dtypes,
212+ typename Op,
213+ typename ... Args>
214+ inline void apply_elementwise_fn (
215+ const Op& compute_fun,
216+ KernelRuntimeContext& ctx,
217+ const Tensor& out,
218+ Args... inputs) {
219+ const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMMON>(
220+ compute_fun, ctx, out, out_dtypes, inputs...);
221+ if (!inputs_valid) {
222+ return ;
223+ }
224+
225+ constexpr auto kNumInputs = sizeof ...(inputs);
226+
227+ constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;
228+ const bool all_inputs_compute_dtype =
229+ ((inputs.first ->scalar_type () == compute_type) && ...);
230+
231+ constexpr ScalarType out_specialized_scalar_type =
232+ specialized_output_scalar_type<CTYPE_COMMON>(out_dtypes);
233+ if (all_inputs_compute_dtype &&
234+ out.scalar_type () == out_specialized_scalar_type) {
235+ using CTYPE_OUT =
236+ typename ScalarTypeToCppType<out_specialized_scalar_type>::type;
237+ dtype_specialized_elementwise_fn_impl<CTYPE_COMMON, CTYPE_OUT>(
238+ compute_fun, ctx, out, inputs...);
239+ return ;
240+ }
241+
242+ apply_elementwise_fn_generic_impl<CTYPE_COMMON, op_name>(
243+ compute_fun, ctx, out, out_dtypes, inputs...);
244+ }
160245} // namespace internal
161246
162247// / DEPRECATED: prefer the variant with out_dtypes in the template argument.
@@ -168,18 +253,22 @@ inline void apply_unitensor_elementwise_fn(
168253 SupportedTensorDtypes a_dtypes,
169254 const Tensor& out,
170255 SupportedTensorDtypes out_dtypes) {
171- internal::apply_elementwise_fn <CTYPE_COMMON, op_name>(
256+ internal::apply_elementwise_fn_runtime_out_dtypes <CTYPE_COMMON, op_name>(
172257 compute_fun, ctx, out, out_dtypes, std::make_pair (&a, a_dtypes));
173258}
174259
175- template <typename CTYPE_COMMON, const char * op_name, SupportedTensorDtypes out_dtypes, typename Op>
260+ template <
261+ typename CTYPE_COMMON,
262+ const char * op_name,
263+ SupportedTensorDtypes out_dtypes,
264+ typename Op>
176265inline void apply_unitensor_elementwise_fn (
177266 const Op& compute_fun,
178267 KernelRuntimeContext& ctx,
179268 const Tensor& a,
180269 SupportedTensorDtypes a_dtypes,
181270 const Tensor& out) {
182- internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
271+ internal::apply_elementwise_fn<CTYPE_COMMON, op_name, out_dtypes >(
183272 compute_fun, ctx, out, out_dtypes, std::make_pair (&a, a_dtypes));
184273}
185274
@@ -196,7 +285,7 @@ inline void apply_bitensor_elementwise_fn(
196285 SupportedTensorDtypes b_dtypes,
197286 const Tensor& out,
198287 SupportedTensorDtypes out_dtypes) {
199- internal::apply_elementwise_fn <CTYPE_COMMON, op_name>(
288+ internal::apply_elementwise_fn_runtime_out_dtypes <CTYPE_COMMON, op_name>(
200289 compute_fun,
201290 ctx,
202291 out,
@@ -210,7 +299,11 @@ inline void apply_bitensor_elementwise_fn(
210299 * perform a computation and write to the corresponding element of the output.
211300 * Tensor broadcasting is applied wherever it is required.
212301 */
213- template <typename CTYPE_COMMON, const char * op_name, SupportedTensorDtypes out_dtypes, typename Op>
302+ template <
303+ typename CTYPE_COMMON,
304+ const char * op_name,
305+ SupportedTensorDtypes out_dtypes,
306+ typename Op>
214307inline void apply_bitensor_elementwise_fn (
215308 const Op& compute_fun,
216309 KernelRuntimeContext& ctx,
@@ -219,11 +312,10 @@ inline void apply_bitensor_elementwise_fn(
219312 const Tensor& b,
220313 SupportedTensorDtypes b_dtypes,
221314 const Tensor& out) {
222- internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
315+ internal::apply_elementwise_fn<CTYPE_COMMON, op_name, out_dtypes >(
223316 compute_fun,
224317 ctx,
225318 out,
226- out_dtypes,
227319 std::make_pair (&a, a_dtypes),
228320 std::make_pair (&b, b_dtypes));
229321}
@@ -243,7 +335,7 @@ inline void apply_tritensor_elementwise_fn(
243335 SupportedTensorDtypes c_dtypes,
244336 const Tensor& out,
245337 SupportedTensorDtypes out_dtypes) {
246- internal::apply_elementwise_fn <CTYPE_COMMON, op_name>(
338+ internal::apply_elementwise_fn_runtime_out_dtypes <CTYPE_COMMON, op_name>(
247339 compute_fun,
248340 ctx,
249341 out,
@@ -273,7 +365,11 @@ inline void apply_tritensor_elementwise_fn(
273365 * static constexpr const char op_name[] = "my_op";
274366 * apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>.
275367 */
276- template <typename CTYPE_COMMON, const char * op_name, SupportedTensorDtypes out_dtypes, typename Op>
368+ template <
369+ typename CTYPE_COMMON,
370+ const char * op_name,
371+ SupportedTensorDtypes out_dtypes,
372+ typename Op>
277373inline void apply_tritensor_elementwise_fn (
278374 const Op& compute_fun,
279375 KernelRuntimeContext& ctx,
@@ -284,11 +380,10 @@ inline void apply_tritensor_elementwise_fn(
284380 const Tensor& c,
285381 SupportedTensorDtypes c_dtypes,
286382 const Tensor& out) {
287- internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
383+ internal::apply_elementwise_fn<CTYPE_COMMON, op_name, out_dtypes >(
288384 compute_fun,
289385 ctx,
290386 out,
291- out_dtypes,
292387 std::make_pair (&a, a_dtypes),
293388 std::make_pair (&b, b_dtypes),
294389 std::make_pair (&c, c_dtypes));
0 commit comments