Skip to content

Commit 5362bf2

Browse files
committed
Add vectorization in elementwise_util (not working yet)
this works with op_mul, which is vectorized-friendly, but doesn't work when we roll out to pattern.h because those ops will not work with Vectorized yet. See TODO in elementwise_util.h ghstack-source-id: a12238d ghstack-comment-id: 2738665976 Pull Request resolved: #9432
1 parent 279cc8c commit 5362bf2

File tree

5 files changed

+102
-5
lines changed

5 files changed

+102
-5
lines changed

.lintrunner.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,8 @@ exclude_patterns = [
264264
'examples/**',
265265
'exir/verification/bindings.cpp',
266266
'extension/**',
267+
# Uses properly-gated (ET_USE_PYTORCH_HEADERS) ATen include.
268+
'kernels/portable/cpu/util/elementwise_util.h',
267269
'kernels/optimized/**',
268270
'runtime/core/exec_aten/**',
269271
# Want to be able to keep c10 in sync with PyTorch core.

kernels/portable/cpu/op_mul.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,7 @@ Tensor& mul_out(
5656
CTYPE_COMPUTE,
5757
op_name,
5858
utils::SupportedTensorDtypes::REALHBBF16>(
59-
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
60-
return val_a * val_b;
61-
},
59+
[](const auto val_a, const auto val_b) { return val_a * val_b; },
6260
ctx,
6361
a,
6462
utils::SupportedTensorDtypes::REALHBBF16,

kernels/portable/cpu/util/elementwise_util.h

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
#include <executorch/runtime/kernel/kernel_runtime_context.h>
1616
#include <executorch/runtime/kernel/thread_parallel_interface.h>
1717

18+
#ifdef ET_USE_PYTORCH_HEADERS
19+
#include <ATen/cpu/vec/vec.h>
20+
#endif // ET_USE_PYTORCH_HEADERS
21+
1822
#include <array>
1923
#include <utility>
2024

@@ -58,6 +62,19 @@ template <typename CTYPE_COMMON, typename Op, typename... Args>
5862
using op_call_result =
5963
std::invoke_result_t<Op, ignore_first_yield_second<Args, CTYPE_COMMON>...>;
6064

65+
#ifdef ET_USE_PYTORCH_HEADERS
66+
// Can I call a function of type Op with sizeof...(Args) arguments of type
67+
// at::vec::Vectorized<CTYPE_COMMON>?
68+
//
69+
// See [NOTE: Generic lambdas] below for requirements on Op.
70+
template <typename CTYPE_COMMON, typename Op, typename... Args>
71+
constexpr bool can_use_vectorized() {
72+
return std::is_invocable_v<
73+
Op,
74+
ignore_first_yield_second<Args, at::vec::Vectorized<CTYPE_COMMON>>...>;
75+
}
76+
#endif // ET_USE_PYTORCH_HEADERS
77+
6178
template <
6279
typename CTYPE_COMMON,
6380
typename CTYPE_OUT,
@@ -68,14 +85,72 @@ inline void dtype_specialized_elementwise_fn_impl(
6885
KernelRuntimeContext& ctx,
6986
const Tensor& out,
7087
Args... inputs) {
88+
static_assert(
89+
(std::is_same_v<Args, std::pair<const Tensor*, SupportedTensorDtypes>> &&
90+
...));
7191
constexpr auto kNumInputs = sizeof...(inputs);
72-
ET_DCHECK(((inputs.first->element_size() == sizeof(CTYPE_COMMON)) && ...));
92+
// All inputs must be of type CTYPE_COMMON.
93+
ET_DCHECK(
94+
((inputs.first->scalar_type() ==
95+
CppTypeToScalarType<CTYPE_COMMON>::value) &&
96+
...));
7397

7498
std::array<const CTYPE_COMMON*, kNumInputs> inputs_data_ptrs = {
7599
inputs.first->template const_data_ptr<CTYPE_COMMON>()...};
76100

77101
CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>();
78102

103+
#ifdef ET_USE_PYTORCH_HEADERS
104+
if constexpr (can_use_vectorized<CTYPE_COMMON, Op, Args...>()) {
105+
const bool any_is_broadcasted =
106+
!(torch::executor::internal::sizes_match_ignoring_leading_1s(
107+
inputs.first->sizes(), out.sizes()) &&
108+
...);
109+
if (!any_is_broadcasted) {
110+
using Vec = at::vec::Vectorized<CTYPE_COMMON>;
111+
::executorch::extension::parallel_for(
112+
0,
113+
out.numel(),
114+
::executorch::extension::internal::GRAIN_SIZE,
115+
[&](const auto begin, const auto end) {
116+
const auto vectorized_begin =
117+
begin + (Vec::size() - begin % Vec::size()) % Vec::size();
118+
const auto vectorized_end = end - (end % Vec::size());
119+
// Scalar prologue.
120+
for (const auto idx : c10::irange(begin, vectorized_begin)) {
121+
std::array<CTYPE_COMMON, kNumInputs> loaded_inputs;
122+
for (const auto input_idx : c10::irange(kNumInputs)) {
123+
loaded_inputs[input_idx] = inputs_data_ptrs[input_idx][idx];
124+
}
125+
data_out[idx] = std::apply(compute_fun, loaded_inputs);
126+
}
127+
128+
// Main vectorized loop.
129+
for (auto idx = vectorized_begin; idx < vectorized_end;
130+
idx += Vec::size()) {
131+
std::array<Vec, kNumInputs> loaded_vec_inputs;
132+
for (const auto input_idx : c10::irange(kNumInputs)) {
133+
loaded_vec_inputs[input_idx] =
134+
Vec::loadu(&inputs_data_ptrs[input_idx][idx]);
135+
}
136+
auto result_vec = std::apply(compute_fun, loaded_vec_inputs);
137+
result_vec.store(&data_out[idx]);
138+
}
139+
140+
// Scalar epilogue.
141+
for (const auto idx : c10::irange(vectorized_end, end)) {
142+
std::array<CTYPE_COMMON, kNumInputs> loaded_inputs;
143+
for (const auto input_idx : c10::irange(kNumInputs)) {
144+
loaded_inputs[input_idx] = inputs_data_ptrs[input_idx][idx];
145+
}
146+
data_out[idx] = std::apply(compute_fun, loaded_inputs);
147+
}
148+
});
149+
return;
150+
}
151+
}
152+
#endif
153+
79154
::executorch::extension::parallel_for(
80155
0,
81156
out.numel(),
@@ -255,6 +330,19 @@ inline void apply_unitensor_elementwise_fn(
255330
compute_fun, ctx, out, out_dtypes, std::make_pair(&a, a_dtypes));
256331
}
257332

333+
/**
334+
* Useful for unary elementwise operators. For each element of the
335+
* input, call Op and write to the corresponding element of the
336+
* output. Tensor broadcasting is applied wherever it is required.
337+
*
338+
* [NOTE: Generic lambdas]: If Op is a *generic* lambda (i.e., one with `auto`
339+
* parameters; normal lambdas are fine), it must fulfill one of the
340+
* following conditions. Either:
341+
* 1) It must in fact compile when passed at::vec::Vectorized<CTYPE_COMMON>, or
342+
* 2) It must be actively SFINAE-friendly, as per the C++17 examples in
343+
* https://stackoverflow.com/questions/76525790/detecting-if-a-generic-lambda-with-certain-arguments-is-invocable
344+
* .
345+
*/
258346
template <
259347
typename CTYPE_COMMON,
260348
const char* op_name,
@@ -296,6 +384,8 @@ inline void apply_bitensor_elementwise_fn(
296384
* Useful for bi-tensor elementwise operators. For each element of the inputs,
297385
* perform a computation and write to the corresponding element of the output.
298386
* Tensor broadcasting is applied wherever it is required.
387+
* See [NOTE: Generic lambdas] if you want to pass a generic lambda for
388+
* compute_fun.
299389
*/
300390
template <
301391
typename CTYPE_COMMON,
@@ -362,6 +452,9 @@ inline void apply_tritensor_elementwise_fn(
362452
*
363453
* static constexpr const char op_name[] = "my_op";
364454
* apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>.
455+
*
456+
* See [NOTE: Generic lambdas] if you want to pass a generic lambda for
457+
* compute_fun.
365458
*/
366459
template <
367460
typename CTYPE_COMMON,

kernels/portable/cpu/util/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def define_common_targets():
110110
":broadcast_indexes_range",
111111
":broadcast_util",
112112
":dtype_util",
113+
"//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch",
113114
"//executorch/runtime/kernel:kernel_runtime_context",
114115
"//executorch/runtime/kernel:thread_parallel_interface",
115116
],

runtime/core/portable_type/c10/c10/targets.bzl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,10 @@ def define_common_targets():
4949
runtime.cxx_library(
5050
name = "aten_headers_for_executorch",
5151
srcs = [],
52-
visibility = ["//executorch/kernels/optimized/..."],
52+
visibility = [
53+
"//executorch/kernels/optimized/...",
54+
"//executorch/kernels/portable/cpu/util/...",
55+
],
5356
exported_deps = select({
5457
"DEFAULT": [],
5558
"ovr_config//cpu:arm64": [

0 commit comments

Comments
 (0)