Skip to content

Commit d421fdf

Browse files
committed
Add vectorization in elementwise_util (not working yet)
this works with op_mul, which is vectorized-friendly, but doesn't work when we roll out to pattern.h because those ops will not work with Vectorized yet. See TODO in elementwise_util.h ghstack-source-id: 4ed038f ghstack-comment-id: 2738665976 Pull Request resolved: #9432
1 parent d8ac866 commit d421fdf

27 files changed

+555
-48
lines changed

.lintrunner.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,10 @@ exclude_patterns = [
271271
'examples/**',
272272
'exir/verification/bindings.cpp',
273273
'extension/**',
274+
# Uses properly-gated (ET_USE_PYTORCH_HEADERS) ATen include.
275+
'kernels/portable/cpu/util/elementwise_util.h',
276+
'kernels/portable/cpu/util/math_util.h',
277+
'kernels/portable/cpu/util/vectorized_math.h',
274278
'kernels/optimized/**',
275279
'runtime/core/exec_aten/**',
276280
# Want to be able to keep c10 in sync with PyTorch core.

kernels/portable/cpu/op_atan2.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ Tensor& atan2_out(
6060
op_name,
6161
utils::SupportedTensorDtypes::FLOATHBF16>(
6262
[](const auto val_a, const auto val_b) {
63-
return std::atan2(val_a, val_b);
63+
return executorch::math::atan2(val_a, val_b);
6464
},
6565
ctx,
6666
a,

kernels/portable/cpu/op_elu.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@ Tensor& elu_out(
4848
CTYPE,
4949
op_name,
5050
utils::SupportedTensorDtypes::SAME_AS_COMMON>(
51-
[negcoef, math_scale, math_input_scale](const auto x) {
52-
// TODO: rewrite this to be vectorization-capable.
51+
[negcoef, math_scale, math_input_scale](const CTYPE x) {
5352
return MathT(x) <= MathT(0)
5453
? std::expm1(MathT(x) * math_input_scale) * negcoef
5554
: MathT(x) * math_scale;

kernels/portable/cpu/op_fmod.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ Tensor& fmod_Tensor_out(
6161
utils::SupportedTensorDtypes::REALHBF16>(
6262
[&div_by_zero_error](
6363
const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
64-
// TODO: rewrite this to be vectorization-capable.
64+
// TODO: rewrite this to be vectorization-capable?
6565
CTYPE_COMPUTE value = 0;
6666
if (is_integral_type<CTYPE_COMPUTE, /*includeBool=*/true>::value) {
6767
if (val_b == 0) {
@@ -138,10 +138,8 @@ Tensor& fmod_Scalar_out(
138138
CTYPE_COMPUTE,
139139
op_name,
140140
utils::SupportedTensorDtypes::REALHBF16>(
141-
[val_b](const CTYPE_COMPUTE val_a) {
142-
// TODO: rewrite this to be vectorization-capable.
143-
CTYPE_COMPUTE value = std::fmod(val_a, val_b);
144-
return value;
141+
[val_b](const auto val_a) {
142+
return executorch::math::fmod(val_a, (decltype(val_a))val_b);
145143
},
146144
ctx,
147145
a,

kernels/portable/cpu/op_maximum.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ Tensor& maximum_out(
4949
CTYPE_COMPUTE,
5050
op_name,
5151
utils::SupportedTensorDtypes::REALHBBF16>(
52-
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
52+
[](const auto val_a, const auto val_b) {
5353
return utils::max_override(val_a, val_b);
5454
},
5555
ctx,

kernels/portable/cpu/op_minimum.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@ Tensor& minimum_out(
4949
CTYPE_COMPUTE,
5050
op_name,
5151
utils::SupportedTensorDtypes::REALHBBF16>(
52-
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
53-
// TODO: rewrite this to be vectorization-capable.
52+
[](const auto val_a, const auto val_b) {
5453
return utils::min_override(val_a, val_b);
5554
},
5655
ctx,

kernels/portable/cpu/op_mul.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,7 @@ Tensor& mul_out(
7272
CTYPE_COMPUTE,
7373
op_name,
7474
utils::SupportedTensorDtypes::REALHBBF16>(
75-
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
76-
return val_a * val_b;
77-
},
75+
[](const auto val_a, const auto val_b) { return val_a * val_b; },
7876
ctx,
7977
a,
8078
utils::SupportedTensorDtypes::REALHBBF16,

kernels/portable/cpu/op_pow.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@ Tensor& pow_Tensor_Tensor_out(
5757
CTYPE_COMPUTE,
5858
op_name,
5959
utils::SupportedTensorDtypes::REALHBF16>(
60-
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
60+
[](const auto val_a, const auto val_b) {
6161
// TODO: rewrite this to be vectorization-capable.
62-
return std::pow(val_a, val_b);
62+
return executorch::math::pow(val_a, val_b);
6363
},
6464
ctx,
6565
a,

kernels/portable/cpu/op_sigmoid.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,9 @@ Tensor& sigmoid_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
4949
CTYPE_COMPUTE,
5050
op_name,
5151
utils::SupportedTensorDtypes::FLOATHBF16>(
52-
[](const auto val_in) -> CTYPE_COMPUTE {
53-
// TODO: rewrite this to be vectorization-capable
54-
CTYPE_COMPUTE out_val = static_cast<CTYPE_COMPUTE>(1.0) /
55-
(static_cast<CTYPE_COMPUTE>(1.0) + exp(-val_in));
52+
[](const auto val_in) {
53+
const auto one = static_cast<decltype(val_in)>(1.0);
54+
auto out_val = one / (one + executorch::math::exp(-val_in));
5655
return out_val;
5756
},
5857
ctx,

kernels/portable/cpu/op_where.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ Tensor& where_out(
4747
CTYPE_COMPUTE,
4848
op_name,
4949
utils::SupportedTensorDtypes::SAME_AS_COMMON>(
50-
[](const auto val_a, const auto val_b, const auto val_c) {
51-
return val_c ? val_a : val_b;
52-
},
50+
[](const CTYPE_COMPUTE val_a,
51+
const CTYPE_COMPUTE val_b,
52+
const CTYPE_COMPUTE val_c) { return val_c ? val_a : val_b; },
5353
ctx,
5454
a,
5555
utils::SupportedTensorDtypes::REALHBBF16,

kernels/portable/cpu/util/elementwise_util.h

Lines changed: 119 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,14 @@
1212
#include <executorch/kernels/portable/cpu/util/broadcast_indexes_range.h>
1313
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
1414
#include <executorch/kernels/portable/cpu/util/dtype_util.h>
15+
#include <executorch/kernels/portable/cpu/util/vectorized_math.h> // Make vectorization support easy for clients.
1516
#include <executorch/runtime/kernel/kernel_runtime_context.h>
1617
#include <executorch/runtime/kernel/thread_parallel_interface.h>
1718

19+
#ifdef ET_USE_PYTORCH_HEADERS
20+
#include <ATen/cpu/vec/vec.h>
21+
#endif // ET_USE_PYTORCH_HEADERS
22+
1823
#include <array>
1924
#include <utility>
2025

@@ -51,6 +56,38 @@ inline int64_t scalar_to<int64_t>(const Scalar& s) {
5156
}
5257

5358
namespace internal {
59+
template <typename Ignore, typename T>
60+
using ignore_first_yield_second = T;
61+
62+
#ifdef ET_USE_PYTORCH_HEADERS
63+
// Can I call a function of type Op with sizeof...(Args) arguments of type
64+
// at::vec::Vectorized<CTYPE_COMPUTE>?
65+
//
66+
// See [NOTE: Generic lambdas] below for requirements on Op.
67+
template <typename CTYPE_COMPUTE, typename Op, typename... Args>
68+
constexpr bool can_use_vectorized() {
69+
using Vec = at::vec::Vectorized<CTYPE_COMPUTE>;
70+
// NOTE: if we start building optimized kernels on platforms that
71+
// ATen Vectorized doesn't support well, we will want to add a way
72+
// to check that Vectorized actually does something on our target
73+
// platform. For now, I see no concrete need for that.
74+
if constexpr (std::is_invocable_v<
75+
Op,
76+
ignore_first_yield_second<Args, Vec>...>) {
77+
// For bool, we will get a false positive if we rely on only the
78+
// is_invocable_v check above because at::vec::Vectorized is
79+
// implicitly convertible to a pointer, which makes it implicitly
80+
// convertible to bool (which was 15 minutes of fun to debug). Also
81+
// just seems like good hygiene to make sure we get the Vectorized
82+
// we're expecting.
83+
return std::is_same_v<
84+
std::invoke_result_t<Op, ignore_first_yield_second<Args, Vec>...>,
85+
Vec>;
86+
}
87+
return false;
88+
}
89+
#endif // ET_USE_PYTORCH_HEADERS
90+
5491
template <
5592
typename CTYPE_COMPUTE,
5693
typename CTYPE_OUT,
@@ -61,8 +98,71 @@ inline void dtype_specialized_elementwise_fn_impl(
6198
KernelRuntimeContext& ctx,
6299
const Tensor& out,
63100
Args... inputs) {
101+
static_assert(
102+
(std::is_same_v<Args, std::pair<const Tensor*, SupportedTensorDtypes>> &&
103+
...));
64104
constexpr auto kNumInputs = sizeof...(inputs);
65-
ET_DCHECK(((inputs.first->element_size() == sizeof(CTYPE_COMPUTE)) && ...));
105+
// All inputs must be of type CTYPE_COMPUTE.
106+
ET_DCHECK(
107+
((inputs.first->scalar_type() ==
108+
CppTypeToScalarType<CTYPE_COMPUTE>::value) &&
109+
...));
110+
111+
#ifdef ET_USE_PYTORCH_HEADERS
112+
if constexpr (can_use_vectorized<CTYPE_COMPUTE, Op, Args...>()) {
113+
const bool any_is_broadcasted =
114+
!(torch::executor::internal::sizes_match_ignoring_leading_1s(
115+
inputs.first->sizes(), out.sizes()) &&
116+
...);
117+
if (!any_is_broadcasted) {
118+
using Vec = at::vec::Vectorized<CTYPE_COMPUTE>;
119+
::executorch::extension::parallel_for(
120+
0,
121+
out.numel(),
122+
::executorch::extension::internal::GRAIN_SIZE,
123+
[&](const auto begin, const auto end) {
124+
std::array<const CTYPE_COMPUTE*, kNumInputs> inputs_data_ptrs = {
125+
inputs.first->template const_data_ptr<CTYPE_COMPUTE>()...};
126+
127+
CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>();
128+
129+
const auto vectorized_begin =
130+
begin + (Vec::size() - begin % Vec::size()) % Vec::size();
131+
const auto vectorized_end = end - (end % Vec::size());
132+
// Scalar prologue.
133+
for (const auto idx : c10::irange(begin, vectorized_begin)) {
134+
std::array<CTYPE_COMPUTE, kNumInputs> loaded_inputs;
135+
for (const auto input_idx : c10::irange(kNumInputs)) {
136+
loaded_inputs[input_idx] = inputs_data_ptrs[input_idx][idx];
137+
}
138+
data_out[idx] = std::apply(compute_fun, loaded_inputs);
139+
}
140+
141+
// Main vectorized loop.
142+
for (auto idx = vectorized_begin; idx < vectorized_end;
143+
idx += Vec::size()) {
144+
std::array<Vec, kNumInputs> loaded_vec_inputs;
145+
for (const auto input_idx : c10::irange(kNumInputs)) {
146+
loaded_vec_inputs[input_idx] =
147+
Vec::loadu(&inputs_data_ptrs[input_idx][idx]);
148+
}
149+
auto result_vec = std::apply(compute_fun, loaded_vec_inputs);
150+
result_vec.store(&data_out[idx]);
151+
}
152+
153+
// Scalar epilogue.
154+
for (const auto idx : c10::irange(vectorized_end, end)) {
155+
std::array<CTYPE_COMPUTE, kNumInputs> loaded_inputs;
156+
for (const auto input_idx : c10::irange(kNumInputs)) {
157+
loaded_inputs[input_idx] = inputs_data_ptrs[input_idx][idx];
158+
}
159+
data_out[idx] = std::apply(compute_fun, loaded_inputs);
160+
}
161+
});
162+
return;
163+
}
164+
}
165+
#endif // ET_USE_PYTORCH_HEADERS
66166

67167
::executorch::extension::parallel_for(
68168
0,
@@ -240,6 +340,19 @@ inline void apply_unitensor_elementwise_fn(
240340
compute_fun, ctx, out, out_dtypes, std::make_pair(&a, a_dtypes));
241341
}
242342

343+
/**
344+
* Useful for unary elementwise operators. For each element of the
345+
* input, call Op and write to the corresponding element of the
346+
* output. Tensor broadcasting is applied wherever it is required.
347+
*
348+
* [NOTE: Generic lambdas]: If Op is a *generic* lambda (i.e., one with `auto`
349+
* parameters; normal lambdas are fine), it must fulfill one of the
350+
* following conditions. Either:
351+
* 1) It must in fact compile when passed at::vec::Vectorized<CTYPE_COMPUTE>, or
352+
* 2) It must be actively SFINAE-friendly, as per the C++17 examples in
353+
* https://stackoverflow.com/questions/76525790/detecting-if-a-generic-lambda-with-certain-arguments-is-invocable
354+
* .
355+
*/
243356
template <
244357
typename CTYPE_COMPUTE,
245358
const char* op_name,
@@ -281,6 +394,8 @@ inline void apply_bitensor_elementwise_fn(
281394
* Useful for bi-tensor elementwise operators. For each element of the inputs,
282395
* perform a computation and write to the corresponding element of the output.
283396
* Tensor broadcasting is applied wherever it is required.
397+
* See [NOTE: Generic lambdas] if you want to pass a generic lambda for
398+
* compute_fun.
284399
*/
285400
template <
286401
typename CTYPE_COMPUTE,
@@ -347,6 +462,9 @@ inline void apply_tritensor_elementwise_fn(
347462
*
348463
* static constexpr const char op_name[] = "my_op";
349464
* apply_ternary_elementwise_fn<CTYPE_COMPUTE, op_name>.
465+
*
466+
* See [NOTE: Generic lambdas] if you want to pass a generic lambda for
467+
* compute_fun.
350468
*/
351469
template <
352470
typename CTYPE_COMPUTE,

kernels/portable/cpu/util/math_util.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88

99
#pragma once
1010

11+
#ifdef ET_USE_PYTORCH_HEADERS
12+
#include <ATen/cpu/vec/vec.h>
13+
#endif
14+
1115
namespace torch {
1216
namespace executor {
1317
namespace native {
@@ -138,6 +142,21 @@ T max_override(T a, T b) {
138142
return b;
139143
}
140144

145+
#ifdef ET_USE_PYTORCH_HEADERS
146+
template <typename T>
147+
at::vec::Vectorized<T> min_override(
148+
at::vec::Vectorized<T> a,
149+
at::vec::Vectorized<T> b) {
150+
return at::vec::minimum(a, b);
151+
}
152+
153+
template <typename T>
154+
at::vec::Vectorized<T> max_override(
155+
at::vec::Vectorized<T> a,
156+
at::vec::Vectorized<T> b) {
157+
return at::vec::maximum(a, b);
158+
}
159+
#endif
141160
/**
142161
* There is a slight difference in how std::fmod works compared to how ATen
143162
* determines remainders:

kernels/portable/cpu/util/targets.bzl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def define_common_targets():
3333
"//executorch/kernels/portable/cpu/util:slice_util",
3434
"//executorch/kernels/portable/cpu/util:elementwise_util",
3535
"//executorch/kernels/portable/cpu/util:upsample_util",
36+
"//executorch/kernels/portable/cpu/util:vectorized_math",
3637
],
3738
visibility = ["//executorch/...", "@EXECUTORCH_CLIENTS"],
3839
)
@@ -110,6 +111,8 @@ def define_common_targets():
110111
":broadcast_indexes_range",
111112
":broadcast_util",
112113
":dtype_util",
114+
":vectorized_math",
115+
"//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch",
113116
"//executorch/runtime/kernel:kernel_runtime_context",
114117
"//executorch/extension/threadpool:threadpool",
115118
],
@@ -260,6 +263,9 @@ def define_common_targets():
260263
srcs = [],
261264
exported_headers = ["math_util.h"],
262265
visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/quantized/..."],
266+
exported_deps = [
267+
"//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch",
268+
],
263269
)
264270

265271
runtime.cxx_library(
@@ -307,6 +313,16 @@ def define_common_targets():
307313
],
308314
)
309315

316+
runtime.cxx_library(
317+
name = "vectorized_math",
318+
exported_headers = ["vectorized_math.h"],
319+
visibility = ["//executorch/..."],
320+
exported_deps = [
321+
"//executorch/runtime/core/portable_type:portable_type",
322+
"//executorch/runtime/core/exec_aten/util:scalar_type_util",
323+
],
324+
)
325+
310326
# Utility functions that can be used by operators that perform reduction
311327
for aten_mode in get_aten_mode_options():
312328
suffix = "_aten" if aten_mode else ""

kernels/portable/cpu/util/test/CMakeLists.txt

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,22 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
# @generated by test/utils/generate_gtest_cmakelists.py
8-
#
9-
# This file should be formatted with
10-
# ~~~
11-
# cmake-format -i CMakeLists.txt
12-
# ~~~
13-
# It should also be cmake-lint clean.
14-
#
15-
167
cmake_minimum_required(VERSION 3.19)
178

189
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../../..)
1910

2011
include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake)
12+
include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
2113

2214
set(_test_srcs broadcast_indexes_range_test.cpp broadcast_test.cpp
23-
reduce_test.cpp
15+
reduce_test.cpp vectorized_math_test.cpp
2416
)
2517

2618
et_cxx_test(
2719
kernels_portable_cpu_util_test SOURCES ${_test_srcs} EXTRA_LIBS
2820
portable_kernels portable_ops_lib
2921
)
22+
23+
find_package_torch_headers()
24+
target_include_directories(kernels_portable_cpu_util_test PRIVATE ${TORCH_INCLUDE_DIRS})
25+
target_compile_definitions(kernels_portable_cpu_util_test PRIVATE ET_USE_PYTORCH_HEADERS)

0 commit comments

Comments
 (0)