-
Notifications
You must be signed in to change notification settings - Fork 566
Add vectorization in elementwise_util #9432
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/swolchok/439/head
Are you sure you want to change the base?
Changes from all commits
31a49e0
9fcd885
29d6de9
79b908c
fd62a07
854c991
def7ed4
40c1b1b
7c78357
7ba269a
edd45fb
b9c545f
3091007
4a00cac
21b81bf
4c4add0
8782a90
75f8970
2d19e75
b61a8a2
91161bd
4add706
5348a92
001d72c
e49080d
44ee51a
f659627
f1c5429
b34f04f
f934bc0
3a74f25
bbc7ba8
151bf4a
9a93839
bb16a55
2242f1e
0822028
f1b97dc
7f57a19
5d95c06
42623bb
284bc17
29c2cfd
4553283
39610ad
b3120fa
350bcd8
37e5b7d
ff2c358
9c2340f
545777f
7086659
e13de0e
943ab82
f22d039
45ce46d
754dba4
d5dfe2f
3f1b775
e55ac4a
34eb5d4
ea9dc6f
7d7859e
b98829d
3140910
afad88e
946f2e0
242995d
7c23fec
7f2bbdb
960315e
9e42e93
96d258e
e6f66ab
a756254
de9d52f
ef74fe1
b2e23ae
7dc5cee
20f3046
3aa266d
3c88a56
153735d
cac4293
85451ea
77a4fc6
21ae5da
a61c9b8
0beabbb
a7876b5
3d59208
6541b28
a6d2402
5a0da3f
c0ad8ac
84170e4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,9 +12,14 @@ | |
#include <executorch/kernels/portable/cpu/util/broadcast_indexes_range.h> | ||
#include <executorch/kernels/portable/cpu/util/broadcast_util.h> | ||
#include <executorch/kernels/portable/cpu/util/dtype_util.h> | ||
#include <executorch/kernels/portable/cpu/util/vectorized_math.h> // Make vectorization support easy for clients. | ||
#include <executorch/runtime/kernel/kernel_runtime_context.h> | ||
#include <executorch/runtime/kernel/thread_parallel_interface.h> | ||
|
||
#ifdef ET_USE_PYTORCH_HEADERS | ||
#include <ATen/cpu/vec/vec.h> | ||
#endif // ET_USE_PYTORCH_HEADERS | ||
|
||
#include <array> | ||
#include <utility> | ||
|
||
|
@@ -51,6 +56,38 @@ inline int64_t scalar_to<int64_t>(const Scalar& s) { | |
} | ||
|
||
namespace internal { | ||
template <typename Ignore, typename T> | ||
using ignore_first_yield_second = T; | ||
|
||
#ifdef ET_USE_PYTORCH_HEADERS | ||
// Can I call a function of type Op with sizeof...(Args) arguments of type | ||
// at::vec::Vectorized<CTYPE_COMPUTE>? | ||
// | ||
// See [NOTE: Generic lambdas] below for requirements on Op. | ||
template <typename CTYPE_COMPUTE, typename Op, typename... Args> | ||
constexpr bool can_use_vectorized() { | ||
using Vec = at::vec::Vectorized<CTYPE_COMPUTE>; | ||
// NOTE: if we start building optimized kernels on platforms that | ||
// ATen Vectorized doesn't support well, we will want to add a way | ||
// to check that Vectorized actually does something on our target | ||
// platform. For now, I see no concrete need for that. | ||
if constexpr (std::is_invocable_v< | ||
Op, | ||
ignore_first_yield_second<Args, Vec>...>) { | ||
// For bool, we will get a false positive if we rely on only the | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you mean bool return type? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. bool ctype_compute |
||
// is_invocable_v check above because at::vec::Vectorized is | ||
// implicitly convertible to a pointer, which makes it implicitly | ||
// convertible to bool (which was 15 minutes of fun to debug). Also | ||
// just seems like good hygiene to make sure we get the Vectorized | ||
// we're expecting. | ||
return std::is_same_v< | ||
std::invoke_result_t<Op, ignore_first_yield_second<Args, Vec>...>, | ||
Vec>; | ||
} | ||
return false; | ||
} | ||
#endif // ET_USE_PYTORCH_HEADERS | ||
|
||
template < | ||
typename CTYPE_COMPUTE, | ||
typename CTYPE_OUT, | ||
|
@@ -61,8 +98,90 @@ inline void dtype_specialized_elementwise_fn_impl( | |
KernelRuntimeContext& ctx, | ||
const Tensor& out, | ||
Args... inputs) { | ||
static_assert( | ||
(std::is_same_v<Args, std::pair<const Tensor*, SupportedTensorDtypes>> && | ||
...)); | ||
constexpr auto kNumInputs = sizeof...(inputs); | ||
ET_DCHECK(((inputs.first->element_size() == sizeof(CTYPE_COMPUTE)) && ...)); | ||
// All inputs must be of type CTYPE_COMPUTE. | ||
ET_DCHECK( | ||
((inputs.first->scalar_type() == | ||
CppTypeToScalarType<CTYPE_COMPUTE>::value) && | ||
...)); | ||
Comment on lines
+107
to
+109
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Man, this is a good reminder of all the template meta programming magic |
||
|
||
#ifdef ET_USE_PYTORCH_HEADERS | ||
if constexpr (can_use_vectorized<CTYPE_COMPUTE, Op, Args...>()) { | ||
const bool any_is_broadcasted = | ||
!(torch::executor::internal::sizes_match_ignoring_leading_1s( | ||
inputs.first->sizes(), out.sizes()) && | ||
...); | ||
if (!any_is_broadcasted) { | ||
using Vec = at::vec::Vectorized<CTYPE_COMPUTE>; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So the one point of contention for me is that why do we need vectorized_math.h which largely is doing trampoline to underlying vectorized methods. Mainly you dont even need to use However, place where I can potentially see that being useful is for dtype that Vectorized doesnt support but for float I am not sure. So maybe if you can clarify that it would help. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
without it, you can't take the same lambda you already wrote for scalars and reuse it for Vectorized (the change isn't zero because you have to point at executorch::math, but crucially it doesn't require separate code) |
||
::executorch::extension::parallel_for( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think doing this blindly for each op is a bit risky in that, no all multithreading is always better. some ops benefit from smaller grain size while others with larger There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIRC xnnpack blindly parallelizes absolutely everything; we're doing strictly better here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah I am not comparing with xnnpack. In fact one bit part of the reason why we ended up leveraging optimized op lib for some of the llama stuff for exactly that reason. That it blindly parallelized everything and that actually hurt perf |
||
0, | ||
out.numel(), | ||
::executorch::extension::internal::GRAIN_SIZE, | ||
[&](const auto begin, const auto end) { | ||
std::array<const CTYPE_COMPUTE*, kNumInputs> inputs_data_ptrs = { | ||
inputs.first->template const_data_ptr<CTYPE_COMPUTE>()...}; | ||
|
||
CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>(); | ||
|
||
const auto vectorized_begin = | ||
begin + (Vec::size() - begin % Vec::size()) % Vec::size(); | ||
Comment on lines
+129
to
+130
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Feels like something that has chances of bug. Hope we test this enough. I would doubt if our test cases will exercise this code path. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. although I do see you treat scalar left overs of both head and tails separately There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Good point; adding tests for each affected op with lengths sufficient to hit the vectorized path. |
||
const auto vectorized_end = end - (end % Vec::size()); | ||
// Scalar prologue. | ||
for (const auto idx : c10::irange(begin, vectorized_begin)) { | ||
// In debug mode, always use Vectorized so that even | ||
// small-sized tests will test whether using Vectorized broke our | ||
// lambda. | ||
#ifndef NDEBUG | ||
std::array<Vec, kNumInputs> loaded_inputs; | ||
#else // NDEBUG | ||
std::array<CTYPE_COMPUTE, kNumInputs> loaded_inputs; | ||
#endif // NDEBUG | ||
for (const auto input_idx : c10::irange(kNumInputs)) { | ||
loaded_inputs[input_idx] = inputs_data_ptrs[input_idx][idx]; | ||
} | ||
#ifndef NDEBUG | ||
std::apply(compute_fun, loaded_inputs).store(&data_out[idx], 1); | ||
#else // NDEBUG | ||
data_out[idx] = std::apply(compute_fun, loaded_inputs); | ||
#endif // NDEBUG | ||
} | ||
|
||
// Main vectorized loop. | ||
for (auto idx = vectorized_begin; idx < vectorized_end; | ||
idx += Vec::size()) { | ||
std::array<Vec, kNumInputs> loaded_vec_inputs; | ||
for (const auto input_idx : c10::irange(kNumInputs)) { | ||
loaded_vec_inputs[input_idx] = | ||
Vec::loadu(&inputs_data_ptrs[input_idx][idx]); | ||
} | ||
auto result_vec = std::apply(compute_fun, loaded_vec_inputs); | ||
result_vec.store(&data_out[idx]); | ||
} | ||
|
||
// Scalar epilogue. | ||
for (const auto idx : c10::irange(vectorized_end, end)) { | ||
#ifndef NDEBUG | ||
std::array<Vec, kNumInputs> loaded_inputs; | ||
#else // NDEBUG | ||
std::array<CTYPE_COMPUTE, kNumInputs> loaded_inputs; | ||
#endif // NDEBUG | ||
for (const auto input_idx : c10::irange(kNumInputs)) { | ||
loaded_inputs[input_idx] = inputs_data_ptrs[input_idx][idx]; | ||
} | ||
#ifndef NDEBUG | ||
std::apply(compute_fun, loaded_inputs).store(&data_out[idx], 1); | ||
#else // NDEBUG | ||
data_out[idx] = std::apply(compute_fun, loaded_inputs); | ||
#endif // NDEBUG | ||
} | ||
}); | ||
return; | ||
} | ||
} | ||
#endif // ET_USE_PYTORCH_HEADERS | ||
|
||
::executorch::extension::parallel_for( | ||
0, | ||
|
@@ -240,6 +359,19 @@ inline void apply_unitensor_elementwise_fn( | |
compute_fun, ctx, out, out_dtypes, std::make_pair(&a, a_dtypes)); | ||
} | ||
|
||
/** | ||
* Useful for unary elementwise operators. For each element of the | ||
* input, call Op and write to the corresponding element of the | ||
* output. Tensor broadcasting is applied wherever it is required. | ||
* | ||
* [NOTE: Generic lambdas]: If Op is a *generic* lambda (i.e., one with `auto` | ||
* parameters; normal lambdas are fine), it must fulfill one of the | ||
* following conditions. Either: | ||
* 1) It must in fact compile when passed at::vec::Vectorized<CTYPE_COMPUTE>, or | ||
* 2) It must be actively SFINAE-friendly, as per the C++17 examples in | ||
* https://stackoverflow.com/questions/76525790/detecting-if-a-generic-lambda-with-certain-arguments-is-invocable | ||
* . | ||
*/ | ||
template < | ||
typename CTYPE_COMPUTE, | ||
const char* op_name, | ||
|
@@ -281,6 +413,8 @@ inline void apply_bitensor_elementwise_fn( | |
* Useful for bi-tensor elementwise operators. For each element of the inputs, | ||
* perform a computation and write to the corresponding element of the output. | ||
* Tensor broadcasting is applied wherever it is required. | ||
* See [NOTE: Generic lambdas] if you want to pass a generic lambda for | ||
* compute_fun. | ||
*/ | ||
template < | ||
typename CTYPE_COMPUTE, | ||
|
@@ -347,6 +481,9 @@ inline void apply_tritensor_elementwise_fn( | |
* | ||
* static constexpr const char op_name[] = "my_op"; | ||
* apply_ternary_elementwise_fn<CTYPE_COMPUTE, op_name>. | ||
* | ||
* See [NOTE: Generic lambdas] if you want to pass a generic lambda for | ||
* compute_fun. | ||
*/ | ||
template < | ||
typename CTYPE_COMPUTE, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like these names