Skip to content

Add vectorization in elementwise_util #9432

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 96 commits into
base: gh/swolchok/439/head
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
96 commits
Select commit Hold shift + click to select a range
31a49e0
Update
swolchok Mar 19, 2025
9fcd885
Update
swolchok Mar 19, 2025
29d6de9
Update
swolchok Mar 19, 2025
79b908c
Update
swolchok Mar 19, 2025
fd62a07
Update
swolchok Mar 19, 2025
854c991
Update
swolchok Mar 19, 2025
def7ed4
Update
swolchok Mar 19, 2025
40c1b1b
Update
swolchok Mar 19, 2025
7c78357
Update
swolchok Mar 19, 2025
7ba269a
Update
swolchok Mar 19, 2025
edd45fb
Update
swolchok Mar 19, 2025
b9c545f
Update
swolchok Mar 20, 2025
3091007
Update
swolchok Mar 20, 2025
4a00cac
Update
swolchok Mar 20, 2025
21b81bf
Update
swolchok Mar 20, 2025
4c4add0
Update
swolchok Mar 20, 2025
8782a90
Update
swolchok Mar 20, 2025
75f8970
Update
swolchok Mar 20, 2025
2d19e75
Update
swolchok Mar 20, 2025
b61a8a2
Update
swolchok Mar 25, 2025
91161bd
Update
swolchok Mar 25, 2025
4add706
Update
swolchok Mar 25, 2025
5348a92
Update
swolchok Mar 25, 2025
001d72c
Update
swolchok Mar 25, 2025
e49080d
Update
swolchok Mar 25, 2025
44ee51a
Update
swolchok Mar 25, 2025
f659627
Update
swolchok Mar 25, 2025
f1c5429
Update
swolchok Mar 25, 2025
b34f04f
Update
swolchok Mar 25, 2025
f934bc0
Update
swolchok Mar 25, 2025
3a74f25
Update
swolchok Mar 25, 2025
bbc7ba8
Update
swolchok Mar 25, 2025
151bf4a
Update
swolchok Mar 25, 2025
9a93839
Update
swolchok Mar 26, 2025
bb16a55
Update
swolchok Mar 26, 2025
2242f1e
Update
swolchok Mar 26, 2025
0822028
Update
swolchok Mar 26, 2025
f1b97dc
Update
swolchok Mar 26, 2025
7f57a19
Update
swolchok Mar 26, 2025
5d95c06
Update
swolchok Mar 26, 2025
42623bb
Update
swolchok Mar 26, 2025
284bc17
Update
swolchok Mar 26, 2025
29c2cfd
Update
swolchok Mar 26, 2025
4553283
Update
swolchok Mar 26, 2025
39610ad
Update
swolchok Mar 26, 2025
b3120fa
Update
swolchok Mar 26, 2025
350bcd8
Update
swolchok Mar 26, 2025
37e5b7d
Update
swolchok Mar 26, 2025
ff2c358
Update
swolchok Mar 26, 2025
9c2340f
Update
swolchok Mar 26, 2025
545777f
Update
swolchok Mar 26, 2025
7086659
Update
swolchok Mar 28, 2025
e13de0e
Update
swolchok Mar 28, 2025
943ab82
Update
swolchok Mar 28, 2025
f22d039
Update
swolchok Mar 28, 2025
45ce46d
Update
swolchok Mar 28, 2025
754dba4
Update
swolchok Mar 28, 2025
d5dfe2f
Update
swolchok Mar 28, 2025
3f1b775
Update
swolchok Mar 28, 2025
e55ac4a
Update
swolchok Mar 28, 2025
34eb5d4
Update
swolchok Mar 28, 2025
ea9dc6f
Update
swolchok Mar 28, 2025
7d7859e
Update
swolchok Mar 28, 2025
b98829d
Update
swolchok Mar 28, 2025
3140910
Update
swolchok Mar 28, 2025
afad88e
Update
swolchok Mar 28, 2025
946f2e0
Update
swolchok Mar 28, 2025
242995d
Update
swolchok Mar 28, 2025
7c23fec
Update
swolchok Mar 28, 2025
7f2bbdb
Update
swolchok Apr 2, 2025
960315e
Update
swolchok Apr 2, 2025
9e42e93
Update
swolchok Apr 2, 2025
96d258e
Update
swolchok Apr 2, 2025
e6f66ab
Update
swolchok Apr 2, 2025
a756254
Update
swolchok Apr 2, 2025
de9d52f
Update
swolchok Apr 2, 2025
ef74fe1
Update
swolchok Apr 2, 2025
b2e23ae
Update
swolchok Apr 2, 2025
7dc5cee
Update
swolchok Apr 2, 2025
20f3046
Update
swolchok Apr 2, 2025
3aa266d
Update
swolchok Apr 2, 2025
3c88a56
Update
swolchok Apr 2, 2025
153735d
Update
swolchok Apr 2, 2025
cac4293
Update
swolchok Apr 2, 2025
85451ea
Update
swolchok Apr 2, 2025
77a4fc6
Update
swolchok Apr 2, 2025
21ae5da
Update
swolchok Apr 2, 2025
a61c9b8
Update
swolchok Apr 2, 2025
0beabbb
Update
swolchok Apr 2, 2025
a7876b5
Update
swolchok May 27, 2025
3d59208
add tests that are long enough to hit the vectorized path
swolchok May 28, 2025
6541b28
actually verified test coverage
swolchok May 28, 2025
a6d2402
split out some commits
swolchok May 28, 2025
5a0da3f
fix visibility
swolchok May 28, 2025
c0ad8ac
Update
swolchok May 28, 2025
84170e4
rebase, fix lint in #9432
swolchok May 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,8 @@ exclude_patterns = [
'exir/verification/bindings.cpp',
'extension/**',
# Uses properly-gated (ET_USE_PYTORCH_HEADERS) ATen include.
'kernels/portable/cpu/util/elementwise_util.h',
'kernels/portable/cpu/util/math_util.h',
'kernels/portable/cpu/util/vectorized_math.h',
'kernels/optimized/**',
'runtime/core/exec_aten/**',
Expand Down
12 changes: 8 additions & 4 deletions kernels/portable/cpu/op_add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,18 @@ Tensor& add_scalar_out(
static constexpr const char op_name[] = "add.Scalar_out";

ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
auto val_alpha_times_b = val_alpha * val_b;
utils::apply_unitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::SAME_AS_COMMON>(
[b, alpha](const auto val_a) {
CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
return val_a + val_alpha * val_b;
[val_alpha_times_b](const auto val_a) {
// Cast here supports vectorization; either it does nothing
// or it casts from CTYPE_COMPUTE to
// Vectorized<CTYPE_COMPUTE>.
return val_a + decltype(val_a)(val_alpha_times_b);
},
ctx,
a,
Expand Down
2 changes: 1 addition & 1 deletion kernels/portable/cpu/op_atan2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ Tensor& atan2_out(
op_name,
utils::SupportedTensorDtypes::FLOATHBF16>(
[](const auto val_a, const auto val_b) {
return std::atan2(val_a, val_b);
return executorch::math::atan2(val_a, val_b);
},
ctx,
a,
Expand Down
5 changes: 2 additions & 3 deletions kernels/portable/cpu/op_clamp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,8 @@ Tensor& clamp_out(
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::SAME_AS_COMMON>(
[has_min, min_opt, has_max, max_opt](const CTYPE_COMPUTE val_in) {
// TODO: rewrite this to be vectorization-capable.
CTYPE_COMPUTE val_out = val_in;
[has_min, min_opt, has_max, max_opt](const auto val_in) {
auto val_out = val_in;
if (has_min) {
val_out = utils::max_override(
val_out, utils::scalar_to<CTYPE_COMPUTE>(min_opt.value()));
Expand Down
3 changes: 1 addition & 2 deletions kernels/portable/cpu/op_elu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ Tensor& elu_out(
CTYPE,
op_name,
utils::SupportedTensorDtypes::SAME_AS_COMMON>(
[negcoef, math_scale, math_input_scale](const auto x) {
// TODO: rewrite this to be vectorization-capable.
[negcoef, math_scale, math_input_scale](const CTYPE x) {
return MathT(x) <= MathT(0)
? std::expm1(MathT(x) * math_input_scale) * negcoef
: MathT(x) * math_scale;
Expand Down
8 changes: 3 additions & 5 deletions kernels/portable/cpu/op_fmod.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ Tensor& fmod_Tensor_out(
utils::SupportedTensorDtypes::REALHBF16>(
[&div_by_zero_error](
const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
// TODO: rewrite this to be vectorization-capable.
// TODO: rewrite this to be vectorization-capable?
CTYPE_COMPUTE value = 0;
if (is_integral_type<CTYPE_COMPUTE, /*includeBool=*/true>::value) {
if (val_b == 0) {
Expand Down Expand Up @@ -138,10 +138,8 @@ Tensor& fmod_Scalar_out(
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBF16>(
[val_b](const CTYPE_COMPUTE val_a) {
// TODO: rewrite this to be vectorization-capable.
CTYPE_COMPUTE value = std::fmod(val_a, val_b);
return value;
[val_b](const auto val_a) {
return executorch::math::fmod(val_a, (decltype(val_a))val_b);
},
ctx,
a,
Expand Down
2 changes: 1 addition & 1 deletion kernels/portable/cpu/op_maximum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ Tensor& maximum_out(
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBBF16>(
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
[](const auto val_a, const auto val_b) {
return utils::max_override(val_a, val_b);
},
ctx,
Expand Down
3 changes: 1 addition & 2 deletions kernels/portable/cpu/op_minimum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ Tensor& minimum_out(
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBBF16>(
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
// TODO: rewrite this to be vectorization-capable.
[](const auto val_a, const auto val_b) {
return utils::min_override(val_a, val_b);
},
ctx,
Expand Down
4 changes: 1 addition & 3 deletions kernels/portable/cpu/op_mul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,7 @@ Tensor& mul_out(
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBBF16>(
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
return val_a * val_b;
},
[](const auto val_a, const auto val_b) { return val_a * val_b; },
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
Expand Down
10 changes: 6 additions & 4 deletions kernels/portable/cpu/op_native_dropout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,11 @@ std::tuple<Tensor&, Tensor&> native_dropout_out(
}
ET_SWITCH_FLOATHBF16_TYPES(
input.scalar_type(), ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
[](const auto val, const auto mask_val) {
utils::apply_bitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::SAME_AS_COMMON>(
[](const CTYPE_COMPUTE val, const CTYPE_COMPUTE mask_val) {
if (!mask_val) {
return static_cast<decltype(val)>(0);
}
Expand All @@ -70,8 +73,7 @@ std::tuple<Tensor&, Tensor&> native_dropout_out(
mask,
// TODO: should really be just BOOL
utils::SupportedTensorDtypes::BOOL_OR_BYTE,
out,
utils::SupportedTensorDtypes::SAME_AS_COMMON);
out);
});
} else if (input.numel() > 0) {
std::memcpy(out.mutable_data_ptr(), input.data_ptr(), input.nbytes());
Expand Down
23 changes: 16 additions & 7 deletions kernels/portable/cpu/op_pow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,8 @@ Tensor& pow_Tensor_Tensor_out(
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBF16>(
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
// TODO: rewrite this to be vectorization-capable.
return std::pow(val_a, val_b);
[](const auto val_a, const auto val_b) {
return executorch::math::pow(val_a, val_b);
},
ctx,
a,
Expand Down Expand Up @@ -111,8 +110,13 @@ Tensor& pow_Tensor_Scalar_out(
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBF16>(
// TODO: rewrite this to be vectorization-capable.
[val_b](const CTYPE_COMPUTE val_a) { return std::pow(val_a, val_b); },
// Casting val_b here supports vectorization; it does
// nothing if we are not vectorizing (casts to
// CTYPE_COMPUTE) and casts to a vectorized type
// otherwise.
[val_b](const auto val_a) {
return executorch::math::pow(val_a, decltype(val_a)(val_b));
},
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
Expand Down Expand Up @@ -161,8 +165,13 @@ Tensor& pow_Scalar_out(
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBF16>(
// TODO: rewrite this to be vectorization-capable.
[val_a](const CTYPE_COMPUTE val_b) { return std::pow(val_a, val_b); },
// Casting val_a here supports vectorization; it does
// nothing if we are not vectorizing (casts to
// CTYPE_COMPUTE) and casts to a vectorized type
// otherwise.
[val_a](const auto val_b) {
return executorch::math::pow(decltype(val_b)(val_a), val_b);
},
ctx,
b,
utils::SupportedTensorDtypes::REALHBBF16,
Expand Down
7 changes: 3 additions & 4 deletions kernels/portable/cpu/op_sigmoid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,9 @@ Tensor& sigmoid_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::FLOATHBF16>(
[](const auto val_in) -> CTYPE_COMPUTE {
// TODO: rewrite this to be vectorization-capable
CTYPE_COMPUTE out_val = static_cast<CTYPE_COMPUTE>(1.0) /
(static_cast<CTYPE_COMPUTE>(1.0) + exp(-val_in));
[](const auto val_in) {
const auto one = static_cast<decltype(val_in)>(1.0);
auto out_val = one / (one + executorch::math::exp(-val_in));
return out_val;
},
ctx,
Expand Down
6 changes: 3 additions & 3 deletions kernels/portable/cpu/op_where.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ Tensor& where_out(
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::SAME_AS_COMMON>(
[](const auto val_a, const auto val_b, const auto val_c) {
return val_c ? val_a : val_b;
},
[](const CTYPE_COMPUTE val_a,
const CTYPE_COMPUTE val_b,
const CTYPE_COMPUTE val_c) { return val_c ? val_a : val_b; },
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
Expand Down
139 changes: 138 additions & 1 deletion kernels/portable/cpu/util/elementwise_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,14 @@
#include <executorch/kernels/portable/cpu/util/broadcast_indexes_range.h>
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
#include <executorch/kernels/portable/cpu/util/dtype_util.h>
#include <executorch/kernels/portable/cpu/util/vectorized_math.h> // Make vectorization support easy for clients.
#include <executorch/runtime/kernel/kernel_runtime_context.h>
#include <executorch/runtime/kernel/thread_parallel_interface.h>

#ifdef ET_USE_PYTORCH_HEADERS
#include <ATen/cpu/vec/vec.h>
#endif // ET_USE_PYTORCH_HEADERS

#include <array>
#include <utility>

Expand Down Expand Up @@ -51,6 +56,38 @@ inline int64_t scalar_to<int64_t>(const Scalar& s) {
}

namespace internal {
template <typename Ignore, typename T>
using ignore_first_yield_second = T;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like these names


#ifdef ET_USE_PYTORCH_HEADERS
// Can I call a function of type Op with sizeof...(Args) arguments of type
// at::vec::Vectorized<CTYPE_COMPUTE>?
//
// See [NOTE: Generic lambdas] below for requirements on Op.
template <typename CTYPE_COMPUTE, typename Op, typename... Args>
constexpr bool can_use_vectorized() {
using Vec = at::vec::Vectorized<CTYPE_COMPUTE>;
// NOTE: if we start building optimized kernels on platforms that
// ATen Vectorized doesn't support well, we will want to add a way
// to check that Vectorized actually does something on our target
// platform. For now, I see no concrete need for that.
if constexpr (std::is_invocable_v<
Op,
ignore_first_yield_second<Args, Vec>...>) {
// For bool, we will get a false positive if we rely on only the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean bool return type?

Copy link
Contributor Author

@swolchok swolchok Apr 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bool ctype_compute

// is_invocable_v check above because at::vec::Vectorized is
// implicitly convertible to a pointer, which makes it implicitly
// convertible to bool (which was 15 minutes of fun to debug). Also
// just seems like good hygiene to make sure we get the Vectorized
// we're expecting.
return std::is_same_v<
std::invoke_result_t<Op, ignore_first_yield_second<Args, Vec>...>,
Vec>;
}
return false;
}
#endif // ET_USE_PYTORCH_HEADERS

template <
typename CTYPE_COMPUTE,
typename CTYPE_OUT,
Expand All @@ -61,8 +98,90 @@ inline void dtype_specialized_elementwise_fn_impl(
KernelRuntimeContext& ctx,
const Tensor& out,
Args... inputs) {
static_assert(
(std::is_same_v<Args, std::pair<const Tensor*, SupportedTensorDtypes>> &&
...));
constexpr auto kNumInputs = sizeof...(inputs);
ET_DCHECK(((inputs.first->element_size() == sizeof(CTYPE_COMPUTE)) && ...));
// All inputs must be of type CTYPE_COMPUTE.
ET_DCHECK(
((inputs.first->scalar_type() ==
CppTypeToScalarType<CTYPE_COMPUTE>::value) &&
...));
Comment on lines +107 to +109
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Man, this is a good reminder of all the template meta programming magic


#ifdef ET_USE_PYTORCH_HEADERS
if constexpr (can_use_vectorized<CTYPE_COMPUTE, Op, Args...>()) {
const bool any_is_broadcasted =
!(torch::executor::internal::sizes_match_ignoring_leading_1s(
inputs.first->sizes(), out.sizes()) &&
...);
if (!any_is_broadcasted) {
using Vec = at::vec::Vectorized<CTYPE_COMPUTE>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the one point of contention for me is that why do we need vectorized_math.h which largely is doing trampoline to underlying vectorized methods. Mainly you dont even need to use can_use_vectorized, because on non accelerated platforms Vectorized falls back to scalar impl even if Vec::size() != `. Maybe you said that the generated code would be worse if forced Vectorized, but I am not sure why. Rest makes sense.

However, place where I can potentially see that being useful is for dtype that Vectorized doesnt support but for float I am not sure. So maybe if you can clarify that it would help.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need vectorized_math.h which largely is doing trampoline to underlying vectorized methods

without it, you can't take the same lambda you already wrote for scalars and reuse it for Vectorized (the change isn't zero because you have to point at executorch::math, but crucially it doesn't require separate code)

::executorch::extension::parallel_for(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think doing this blindly for each op is a bit risky in that, no all multithreading is always better. some ops benefit from smaller grain size while others with larger

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC xnnpack blindly parallelizes absolutely everything; we're doing strictly better here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I am not comparing with xnnpack. In fact one bit part of the reason why we ended up leveraging optimized op lib for some of the llama stuff for exactly that reason. That it blindly parallelized everything and that actually hurt perf

0,
out.numel(),
::executorch::extension::internal::GRAIN_SIZE,
[&](const auto begin, const auto end) {
std::array<const CTYPE_COMPUTE*, kNumInputs> inputs_data_ptrs = {
inputs.first->template const_data_ptr<CTYPE_COMPUTE>()...};

CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>();

const auto vectorized_begin =
begin + (Vec::size() - begin % Vec::size()) % Vec::size();
Comment on lines +129 to +130
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feels like something that has chances of bug. Hope we test this enough. I would doubt if our test cases will exercise this code path.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

although I do see you treat scalar left overs of both head and tails separately

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hope we test this enough.

Good point; adding tests for each affected op with lengths sufficient to hit the vectorized path.

const auto vectorized_end = end - (end % Vec::size());
// Scalar prologue.
for (const auto idx : c10::irange(begin, vectorized_begin)) {
// In debug mode, always use Vectorized so that even
// small-sized tests will test whether using Vectorized broke our
// lambda.
#ifndef NDEBUG
std::array<Vec, kNumInputs> loaded_inputs;
#else // NDEBUG
std::array<CTYPE_COMPUTE, kNumInputs> loaded_inputs;
#endif // NDEBUG
for (const auto input_idx : c10::irange(kNumInputs)) {
loaded_inputs[input_idx] = inputs_data_ptrs[input_idx][idx];
}
#ifndef NDEBUG
std::apply(compute_fun, loaded_inputs).store(&data_out[idx], 1);
#else // NDEBUG
data_out[idx] = std::apply(compute_fun, loaded_inputs);
#endif // NDEBUG
}

// Main vectorized loop.
for (auto idx = vectorized_begin; idx < vectorized_end;
idx += Vec::size()) {
std::array<Vec, kNumInputs> loaded_vec_inputs;
for (const auto input_idx : c10::irange(kNumInputs)) {
loaded_vec_inputs[input_idx] =
Vec::loadu(&inputs_data_ptrs[input_idx][idx]);
}
auto result_vec = std::apply(compute_fun, loaded_vec_inputs);
result_vec.store(&data_out[idx]);
}

// Scalar epilogue.
for (const auto idx : c10::irange(vectorized_end, end)) {
#ifndef NDEBUG
std::array<Vec, kNumInputs> loaded_inputs;
#else // NDEBUG
std::array<CTYPE_COMPUTE, kNumInputs> loaded_inputs;
#endif // NDEBUG
for (const auto input_idx : c10::irange(kNumInputs)) {
loaded_inputs[input_idx] = inputs_data_ptrs[input_idx][idx];
}
#ifndef NDEBUG
std::apply(compute_fun, loaded_inputs).store(&data_out[idx], 1);
#else // NDEBUG
data_out[idx] = std::apply(compute_fun, loaded_inputs);
#endif // NDEBUG
}
});
return;
}
}
#endif // ET_USE_PYTORCH_HEADERS

::executorch::extension::parallel_for(
0,
Expand Down Expand Up @@ -240,6 +359,19 @@ inline void apply_unitensor_elementwise_fn(
compute_fun, ctx, out, out_dtypes, std::make_pair(&a, a_dtypes));
}

/**
* Useful for unary elementwise operators. For each element of the
* input, call Op and write to the corresponding element of the
* output. Tensor broadcasting is applied wherever it is required.
*
* [NOTE: Generic lambdas]: If Op is a *generic* lambda (i.e., one with `auto`
* parameters; normal lambdas are fine), it must fulfill one of the
* following conditions. Either:
* 1) It must in fact compile when passed at::vec::Vectorized<CTYPE_COMPUTE>, or
* 2) It must be actively SFINAE-friendly, as per the C++17 examples in
* https://stackoverflow.com/questions/76525790/detecting-if-a-generic-lambda-with-certain-arguments-is-invocable
* .
*/
template <
typename CTYPE_COMPUTE,
const char* op_name,
Expand Down Expand Up @@ -281,6 +413,8 @@ inline void apply_bitensor_elementwise_fn(
* Useful for bi-tensor elementwise operators. For each element of the inputs,
* perform a computation and write to the corresponding element of the output.
* Tensor broadcasting is applied wherever it is required.
* See [NOTE: Generic lambdas] if you want to pass a generic lambda for
* compute_fun.
*/
template <
typename CTYPE_COMPUTE,
Expand Down Expand Up @@ -347,6 +481,9 @@ inline void apply_tritensor_elementwise_fn(
*
* static constexpr const char op_name[] = "my_op";
* apply_ternary_elementwise_fn<CTYPE_COMPUTE, op_name>.
*
* See [NOTE: Generic lambdas] if you want to pass a generic lambda for
* compute_fun.
*/
template <
typename CTYPE_COMPUTE,
Expand Down
Loading
Loading