From 423a178c1b7e6a26142b4001340cf2a3a076c2c8 Mon Sep 17 00:00:00 2001 From: swolchok Date: Fri, 13 Jun 2025 13:17:30 -0700 Subject: [PATCH 1/2] Define ET_HAS_EXCEPTIONS macro Summary: To support passing ET_USE_PYTORCH_HEADERS only when exceptions are enabled. Differential Revision: D76470039 --- runtime/platform/compiler.h | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/runtime/platform/compiler.h b/runtime/platform/compiler.h index 7467d5c1e04..864b76e2050 100644 --- a/runtime/platform/compiler.h +++ b/runtime/platform/compiler.h @@ -171,6 +171,14 @@ using ssize_t = ptrdiff_t; #endif +#ifdef __EXCEPTIONS +#define ET_HAS_EXCEPTIONS 1 +#elif defined(_HAS_EXCEPTIONS) && _HAS_EXCEPTIONS +#define ET_HAS_EXCEPTIONS 1 +#else +#define ET_HAS_EXCEPTIONS 0 +#endif + // DEPRECATED: Use the non-underscore-prefixed versions instead. // TODO(T199005537): Remove these once all users have stopped using them. #define __ET_DEPRECATED ET_DEPRECATED From 735f2148c049813d7e455ab761d4aa41ed064cf2 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Fri, 13 Jun 2025 13:17:30 -0700 Subject: [PATCH 2/2] Reapply "Add vectorized_math.h (#11204)", "Add optimized_portable_kernels test (#11205)", and "Add vectorization in elementwise_util (#9432)" Summary: Stack was reverted due to internal CI failures. Reapplying as an exported internal diff so that we make sure to catch any more of those. New fixes: - straightforward op_sub build fixes - s/EXPECT_EQ/EXPECT_FLOAT_EQ/ in vectorized_math_test - define ET_USE_PYTORCH_HEADERS to detect whether exceptions are enabled, and use `#if defined(...) && ...` instead of `#ifdef` to check the macro so that we don't use PyTorch headers if exceptions are disabled. (otherwise, we might have problems with e.g. TORCH_CHECK) Original summary for #11204: Set of math functions that work on both scalars and at::vec::Vectorized, to be used in #9432. Original summary for #11205: Make sure we test the optimized versions of portable kernels even if they are shadowed by optimized implementations. Intended to support #9432. Original summary for #9432: This is a first cut at #9241 . In this PR I've vectorized a small initial set of ops: atan2, clamp, fmod_Scalar, maximum, minimum, mul, pow, and sigmoid. In addition, the following ops should have gotten vectorized automatically because they already used generic lambdas: add, div, rsub, sub. I've left covering ops that use the `unary_ufunc_*` utilities in [pattern.h](https://github.com/pytorch/executorch/blob/main/kernels/portable/cpu/pattern/pattern.h) for a follow-up push, because pattern.h and elementwise_util need some work before we can migrate pattern.h's utilities to be backed by elementwise_util. This PR adds an interesting testing problem: in theory, *all* operators might need test cases long enough to tickle vectorization, because we might accidentally vectorize ops unexpectedly and break their lambdas due to anticipated differences in semantics. I address this issue by using Vectorized for the scalar prologue/epilogue in debug mode (we run tests in both debug and release) so that we can detect broken lambdas. I additionally intentionally introduced a bug in the vectorized path in elementwise_util and manually verified that we saw test failures for each vectorized op called out above. Differential Revision: D76467389 *** fix ET_USE_PYTORCH_HEADERS --- .lintrunner.toml | 4 + kernels/optimized/CMakeLists.txt | 2 +- kernels/portable/CMakeLists.txt | 11 +- kernels/portable/cpu/op_add.cpp | 12 +- kernels/portable/cpu/op_atan2.cpp | 2 +- kernels/portable/cpu/op_clamp.cpp | 5 +- kernels/portable/cpu/op_elu.cpp | 3 +- kernels/portable/cpu/op_fmod.cpp | 8 +- kernels/portable/cpu/op_maximum.cpp | 2 +- kernels/portable/cpu/op_minimum.cpp | 3 +- kernels/portable/cpu/op_mul.cpp | 4 +- kernels/portable/cpu/op_native_dropout.cpp | 10 +- kernels/portable/cpu/op_pow.cpp | 23 ++- kernels/portable/cpu/op_sigmoid.cpp | 7 +- kernels/portable/cpu/op_sub.cpp | 7 +- kernels/portable/cpu/op_where.cpp | 6 +- kernels/portable/cpu/util/elementwise_util.h | 139 +++++++++++++++- kernels/portable/cpu/util/math_util.h | 30 ++++ kernels/portable/cpu/util/targets.bzl | 16 ++ kernels/portable/cpu/util/test/CMakeLists.txt | 16 +- kernels/portable/cpu/util/test/targets.bzl | 11 ++ .../cpu/util/test/vectorized_math_test.cpp | 95 +++++++++++ kernels/portable/cpu/util/vectorized_math.h | 148 ++++++++++++++++++ kernels/test/CMakeLists.txt | 32 +++- kernels/test/op_atan2_test.cpp | 33 ++++ kernels/test/op_clamp_test.cpp | 34 ++++ kernels/test/op_fmod_test.cpp | 31 ++++ kernels/test/op_maximum_test.cpp | 14 ++ kernels/test/op_minimum_test.cpp | 14 ++ kernels/test/op_mul_test.cpp | 6 +- kernels/test/op_pow_test.cpp | 13 ++ kernels/test/op_sigmoid_test.cpp | 4 + .../core/portable_type/c10/c10/targets.bzl | 10 +- runtime/core/portable_type/targets.bzl | 1 + test/utils/OSSTestConfig.json | 12 -- 35 files changed, 694 insertions(+), 74 deletions(-) create mode 100644 kernels/portable/cpu/util/test/vectorized_math_test.cpp create mode 100644 kernels/portable/cpu/util/vectorized_math.h diff --git a/.lintrunner.toml b/.lintrunner.toml index 8912e65d66d..4c881940155 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -271,6 +271,10 @@ exclude_patterns = [ 'examples/**', 'exir/verification/bindings.cpp', 'extension/**', + # Uses properly-gated (ET_USE_PYTORCH_HEADERS) ATen include. + 'kernels/portable/cpu/util/elementwise_util.h', + 'kernels/portable/cpu/util/math_util.h', + 'kernels/portable/cpu/util/vectorized_math.h', 'kernels/optimized/**', 'runtime/core/exec_aten/**', # Want to be able to keep c10 in sync with PyTorch core. diff --git a/kernels/optimized/CMakeLists.txt b/kernels/optimized/CMakeLists.txt index 7b8ebd58f13..8dfc9e0f734 100644 --- a/kernels/optimized/CMakeLists.txt +++ b/kernels/optimized/CMakeLists.txt @@ -60,7 +60,7 @@ message("Generated files ${gen_command_sources}") list(TRANSFORM _optimized_kernels__srcs PREPEND "${EXECUTORCH_ROOT}/") add_library(optimized_kernels ${_optimized_kernels__srcs}) target_include_directories(optimized_kernels PRIVATE ${TORCH_INCLUDE_DIRS} "${EXECUTORCH_ROOT}/third-party/pocketfft") -target_compile_definitions(optimized_kernels PRIVATE ET_USE_PYTORCH_HEADERS) +target_compile_definitions(optimized_kernels PRIVATE "ET_USE_PYTORCH_HEADERS=ET_HAS_EXCEPTIONS") target_link_libraries( optimized_kernels PUBLIC executorch_core cpublas extension_threadpool kernels_util_all_deps ) diff --git a/kernels/portable/CMakeLists.txt b/kernels/portable/CMakeLists.txt index d301ea564f6..15aaece750e 100644 --- a/kernels/portable/CMakeLists.txt +++ b/kernels/portable/CMakeLists.txt @@ -68,9 +68,16 @@ if(EXECUTORCH_BUILD_PTHREADPOOL AND EXECUTORCH_BUILD_KERNELS_OPTIMIZED) target_link_libraries(optimized_portable_kernels PUBLIC extension_threadpool) target_compile_options(optimized_portable_kernels PUBLIC ${_common_compile_options}) target_include_directories(optimized_portable_kernels PRIVATE ${TORCH_INCLUDE_DIRS}) - target_compile_definitions(optimized_portable_kernels PRIVATE ET_USE_PYTORCH_HEADERS) + target_compile_definitions(optimized_portable_kernels PRIVATE "ET_USE_PYTORCH_HEADERS=ET_HAS_EXCEPTIONS") + gen_selected_ops(LIB_NAME "optimized_portable_ops_lib" OPS_SCHEMA_YAML "${_yaml}") + generate_bindings_for_kernels( + LIB_NAME "optimized_portable_ops_lib" FUNCTIONS_YAML "${_yaml}" + ) + gen_operators_lib( + LIB_NAME "optimized_portable_ops_lib" KERNEL_LIBS optimized_portable_kernels DEPS executorch_core + ) install( - TARGETS optimized_portable_kernels + TARGETS optimized_portable_kernels optimized_portable_ops_lib DESTINATION lib ) endif() diff --git a/kernels/portable/cpu/op_add.cpp b/kernels/portable/cpu/op_add.cpp index 555341b3447..83642c4864d 100644 --- a/kernels/portable/cpu/op_add.cpp +++ b/kernels/portable/cpu/op_add.cpp @@ -102,14 +102,18 @@ Tensor& add_scalar_out( static constexpr const char op_name[] = "add.Scalar_out"; ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + CTYPE_COMPUTE val_b = utils::scalar_to(b); + CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); + auto val_alpha_times_b = val_alpha * val_b; utils::apply_unitensor_elementwise_fn< CTYPE_COMPUTE, op_name, utils::SupportedTensorDtypes::SAME_AS_COMMON>( - [b, alpha](const auto val_a) { - CTYPE_COMPUTE val_b = utils::scalar_to(b); - CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); - return val_a + val_alpha * val_b; + [val_alpha_times_b](const auto val_a) { + // Cast here supports vectorization; either it does nothing + // or it casts from CTYPE_COMPUTE to + // Vectorized. + return val_a + decltype(val_a)(val_alpha_times_b); }, ctx, a, diff --git a/kernels/portable/cpu/op_atan2.cpp b/kernels/portable/cpu/op_atan2.cpp index 33d66cf2ad7..5390eb52820 100644 --- a/kernels/portable/cpu/op_atan2.cpp +++ b/kernels/portable/cpu/op_atan2.cpp @@ -60,7 +60,7 @@ Tensor& atan2_out( op_name, utils::SupportedTensorDtypes::FLOATHBF16>( [](const auto val_a, const auto val_b) { - return std::atan2(val_a, val_b); + return executorch::math::atan2(val_a, val_b); }, ctx, a, diff --git a/kernels/portable/cpu/op_clamp.cpp b/kernels/portable/cpu/op_clamp.cpp index d7d9fab2f59..af082a18e78 100644 --- a/kernels/portable/cpu/op_clamp.cpp +++ b/kernels/portable/cpu/op_clamp.cpp @@ -138,9 +138,8 @@ Tensor& clamp_out( CTYPE_COMPUTE, op_name, utils::SupportedTensorDtypes::SAME_AS_COMMON>( - [has_min, min_opt, has_max, max_opt](const CTYPE_COMPUTE val_in) { - // TODO: rewrite this to be vectorization-capable. - CTYPE_COMPUTE val_out = val_in; + [has_min, min_opt, has_max, max_opt](const auto val_in) { + auto val_out = val_in; if (has_min) { val_out = utils::max_override( val_out, utils::scalar_to(min_opt.value())); diff --git a/kernels/portable/cpu/op_elu.cpp b/kernels/portable/cpu/op_elu.cpp index d6533642860..d7477717a3a 100644 --- a/kernels/portable/cpu/op_elu.cpp +++ b/kernels/portable/cpu/op_elu.cpp @@ -48,8 +48,7 @@ Tensor& elu_out( CTYPE, op_name, utils::SupportedTensorDtypes::SAME_AS_COMMON>( - [negcoef, math_scale, math_input_scale](const auto x) { - // TODO: rewrite this to be vectorization-capable. + [negcoef, math_scale, math_input_scale](const CTYPE x) { return MathT(x) <= MathT(0) ? std::expm1(MathT(x) * math_input_scale) * negcoef : MathT(x) * math_scale; diff --git a/kernels/portable/cpu/op_fmod.cpp b/kernels/portable/cpu/op_fmod.cpp index 96a971b166a..40bb4a5e94c 100644 --- a/kernels/portable/cpu/op_fmod.cpp +++ b/kernels/portable/cpu/op_fmod.cpp @@ -61,7 +61,7 @@ Tensor& fmod_Tensor_out( utils::SupportedTensorDtypes::REALHBF16>( [&div_by_zero_error]( const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { - // TODO: rewrite this to be vectorization-capable. + // TODO: rewrite this to be vectorization-capable? CTYPE_COMPUTE value = 0; if (is_integral_type::value) { if (val_b == 0) { @@ -138,10 +138,8 @@ Tensor& fmod_Scalar_out( CTYPE_COMPUTE, op_name, utils::SupportedTensorDtypes::REALHBF16>( - [val_b](const CTYPE_COMPUTE val_a) { - // TODO: rewrite this to be vectorization-capable. - CTYPE_COMPUTE value = std::fmod(val_a, val_b); - return value; + [val_b](const auto val_a) { + return executorch::math::fmod(val_a, (decltype(val_a))val_b); }, ctx, a, diff --git a/kernels/portable/cpu/op_maximum.cpp b/kernels/portable/cpu/op_maximum.cpp index 3a84095a4df..c7979e40d7c 100644 --- a/kernels/portable/cpu/op_maximum.cpp +++ b/kernels/portable/cpu/op_maximum.cpp @@ -49,7 +49,7 @@ Tensor& maximum_out( CTYPE_COMPUTE, op_name, utils::SupportedTensorDtypes::REALHBBF16>( - [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + [](const auto val_a, const auto val_b) { return utils::max_override(val_a, val_b); }, ctx, diff --git a/kernels/portable/cpu/op_minimum.cpp b/kernels/portable/cpu/op_minimum.cpp index 5c0e79eb9bb..1bac23187d8 100644 --- a/kernels/portable/cpu/op_minimum.cpp +++ b/kernels/portable/cpu/op_minimum.cpp @@ -49,8 +49,7 @@ Tensor& minimum_out( CTYPE_COMPUTE, op_name, utils::SupportedTensorDtypes::REALHBBF16>( - [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { - // TODO: rewrite this to be vectorization-capable. + [](const auto val_a, const auto val_b) { return utils::min_override(val_a, val_b); }, ctx, diff --git a/kernels/portable/cpu/op_mul.cpp b/kernels/portable/cpu/op_mul.cpp index ba16ddc075a..6d4f30106ca 100644 --- a/kernels/portable/cpu/op_mul.cpp +++ b/kernels/portable/cpu/op_mul.cpp @@ -72,9 +72,7 @@ Tensor& mul_out( CTYPE_COMPUTE, op_name, utils::SupportedTensorDtypes::REALHBBF16>( - [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { - return val_a * val_b; - }, + [](const auto val_a, const auto val_b) { return val_a * val_b; }, ctx, a, utils::SupportedTensorDtypes::REALHBBF16, diff --git a/kernels/portable/cpu/op_native_dropout.cpp b/kernels/portable/cpu/op_native_dropout.cpp index 1c4d177e8ed..8dafd9e0512 100644 --- a/kernels/portable/cpu/op_native_dropout.cpp +++ b/kernels/portable/cpu/op_native_dropout.cpp @@ -57,8 +57,11 @@ std::tuple native_dropout_out( } ET_SWITCH_FLOATHBF16_TYPES( input.scalar_type(), ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_bitensor_elementwise_fn( - [](const auto val, const auto mask_val) { + utils::apply_bitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + utils::SupportedTensorDtypes::SAME_AS_COMMON>( + [](const CTYPE_COMPUTE val, const CTYPE_COMPUTE mask_val) { if (!mask_val) { return static_cast(0); } @@ -70,8 +73,7 @@ std::tuple native_dropout_out( mask, // TODO: should really be just BOOL utils::SupportedTensorDtypes::BOOL_OR_BYTE, - out, - utils::SupportedTensorDtypes::SAME_AS_COMMON); + out); }); } else if (input.numel() > 0) { std::memcpy(out.mutable_data_ptr(), input.data_ptr(), input.nbytes()); diff --git a/kernels/portable/cpu/op_pow.cpp b/kernels/portable/cpu/op_pow.cpp index 4d2673cb72d..aaf934b9adf 100644 --- a/kernels/portable/cpu/op_pow.cpp +++ b/kernels/portable/cpu/op_pow.cpp @@ -57,9 +57,8 @@ Tensor& pow_Tensor_Tensor_out( CTYPE_COMPUTE, op_name, utils::SupportedTensorDtypes::REALHBF16>( - [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { - // TODO: rewrite this to be vectorization-capable. - return std::pow(val_a, val_b); + [](const auto val_a, const auto val_b) { + return executorch::math::pow(val_a, val_b); }, ctx, a, @@ -111,8 +110,13 @@ Tensor& pow_Tensor_Scalar_out( CTYPE_COMPUTE, op_name, utils::SupportedTensorDtypes::REALHBF16>( - // TODO: rewrite this to be vectorization-capable. - [val_b](const CTYPE_COMPUTE val_a) { return std::pow(val_a, val_b); }, + // Casting val_b here supports vectorization; it does + // nothing if we are not vectorizing (casts to + // CTYPE_COMPUTE) and casts to a vectorized type + // otherwise. + [val_b](const auto val_a) { + return executorch::math::pow(val_a, decltype(val_a)(val_b)); + }, ctx, a, utils::SupportedTensorDtypes::REALHBBF16, @@ -161,8 +165,13 @@ Tensor& pow_Scalar_out( CTYPE_COMPUTE, op_name, utils::SupportedTensorDtypes::REALHBF16>( - // TODO: rewrite this to be vectorization-capable. - [val_a](const CTYPE_COMPUTE val_b) { return std::pow(val_a, val_b); }, + // Casting val_a here supports vectorization; it does + // nothing if we are not vectorizing (casts to + // CTYPE_COMPUTE) and casts to a vectorized type + // otherwise. + [val_a](const auto val_b) { + return executorch::math::pow(decltype(val_b)(val_a), val_b); + }, ctx, b, utils::SupportedTensorDtypes::REALHBBF16, diff --git a/kernels/portable/cpu/op_sigmoid.cpp b/kernels/portable/cpu/op_sigmoid.cpp index acb743a2db6..a1eb03c1869 100644 --- a/kernels/portable/cpu/op_sigmoid.cpp +++ b/kernels/portable/cpu/op_sigmoid.cpp @@ -49,10 +49,9 @@ Tensor& sigmoid_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { CTYPE_COMPUTE, op_name, utils::SupportedTensorDtypes::FLOATHBF16>( - [](const auto val_in) -> CTYPE_COMPUTE { - // TODO: rewrite this to be vectorization-capable - CTYPE_COMPUTE out_val = static_cast(1.0) / - (static_cast(1.0) + exp(-val_in)); + [](const auto val_in) { + const auto one = static_cast(1.0); + auto out_val = one / (one + executorch::math::exp(-val_in)); return out_val; }, ctx, diff --git a/kernels/portable/cpu/op_sub.cpp b/kernels/portable/cpu/op_sub.cpp index aa90df8dee4..b914c411303 100644 --- a/kernels/portable/cpu/op_sub.cpp +++ b/kernels/portable/cpu/op_sub.cpp @@ -61,7 +61,7 @@ Tensor& sub_out( op_name, utils::SupportedTensorDtypes::REALHBF16>( [val_alpha](const auto val_a, const auto val_b) { - return val_a - val_alpha * val_b; + return val_a - (decltype(val_b))(val_alpha)*val_b; }, ctx, a, @@ -112,12 +112,13 @@ Tensor& sub_scalar_out( ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { const CTYPE_COMPUTE val_b = utils::scalar_to(b); const CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); + const auto val_alpha_times_b = val_alpha * val_b; utils::apply_unitensor_elementwise_fn< CTYPE_COMPUTE, op_name, utils::SupportedTensorDtypes::SAME_AS_COMMON>( - [val_b, val_alpha](const auto val_a) { - return val_a - val_alpha * val_b; + [val_alpha_times_b](const auto val_a) { + return val_a - (decltype(val_a))(val_alpha_times_b); }, ctx, a, diff --git a/kernels/portable/cpu/op_where.cpp b/kernels/portable/cpu/op_where.cpp index 692e296ee00..b1eb4ff442c 100644 --- a/kernels/portable/cpu/op_where.cpp +++ b/kernels/portable/cpu/op_where.cpp @@ -47,9 +47,9 @@ Tensor& where_out( CTYPE_COMPUTE, op_name, utils::SupportedTensorDtypes::SAME_AS_COMMON>( - [](const auto val_a, const auto val_b, const auto val_c) { - return val_c ? val_a : val_b; - }, + [](const CTYPE_COMPUTE val_a, + const CTYPE_COMPUTE val_b, + const CTYPE_COMPUTE val_c) { return val_c ? val_a : val_b; }, ctx, a, utils::SupportedTensorDtypes::REALHBBF16, diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h index d07250f1d66..948da50fdd4 100644 --- a/kernels/portable/cpu/util/elementwise_util.h +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -12,9 +12,14 @@ #include #include #include +#include // Make vectorization support easy for clients. #include #include +#if defined(ET_USE_PYTORCH_HEADERS) && ET_USE_PYTORCH_HEADERS +#include +#endif // ET_USE_PYTORCH_HEADERS + #include #include @@ -60,6 +65,38 @@ struct SupportNoncontiguousInputTensors { explicit SupportNoncontiguousInputTensors() = default; }; +template +using ignore_first_yield_second = T; + +#if defined(ET_USE_PYTORCH_HEADERS) && ET_USE_PYTORCH_HEADERS +// Can I call a function of type Op with sizeof...(Args) arguments of type +// at::vec::Vectorized? +// +// See [NOTE: Generic lambdas] below for requirements on Op. +template +constexpr bool can_use_vectorized() { + using Vec = at::vec::Vectorized; + // NOTE: if we start building optimized kernels on platforms that + // ATen Vectorized doesn't support well, we will want to add a way + // to check that Vectorized actually does something on our target + // platform. For now, I see no concrete need for that. + if constexpr (std::is_invocable_v< + Op, + ignore_first_yield_second...>) { + // For bool, we will get a false positive if we rely on only the + // is_invocable_v check above because at::vec::Vectorized is + // implicitly convertible to a pointer, which makes it implicitly + // convertible to bool (which was 15 minutes of fun to debug). Also + // just seems like good hygiene to make sure we get the Vectorized + // we're expecting. + return std::is_same_v< + std::invoke_result_t...>, + Vec>; + } + return false; +} +#endif // ET_USE_PYTORCH_HEADERS + template < typename CTYPE_COMPUTE, typename CTYPE_OUT, @@ -71,8 +108,90 @@ inline void dtype_specialized_elementwise_fn_impl( KernelRuntimeContext& ctx, const Tensor& out, Args... inputs) { + static_assert( + (std::is_same_v> && + ...)); constexpr auto kNumInputs = sizeof...(inputs); - ET_DCHECK(((inputs.first->element_size() == sizeof(CTYPE_COMPUTE)) && ...)); + // All inputs must be of type CTYPE_COMPUTE. + ET_DCHECK( + ((inputs.first->scalar_type() == + CppTypeToScalarType::value) && + ...)); + +#if defined(ET_USE_PYTORCH_HEADERS) && ET_USE_PYTORCH_HEADERS + if constexpr (can_use_vectorized()) { + const bool any_is_broadcasted = + !(torch::executor::internal::sizes_match_ignoring_leading_1s( + inputs.first->sizes(), out.sizes()) && + ...); + if (!any_is_broadcasted) { + using Vec = at::vec::Vectorized; + ::executorch::extension::parallel_for( + 0, + out.numel(), + ::executorch::extension::internal::GRAIN_SIZE, + [&](const auto begin, const auto end) { + std::array inputs_data_ptrs = { + inputs.first->template const_data_ptr()...}; + + CTYPE_OUT* const data_out = out.mutable_data_ptr(); + + const auto vectorized_begin = + begin + (Vec::size() - begin % Vec::size()) % Vec::size(); + const auto vectorized_end = end - (end % Vec::size()); + // Scalar prologue. + for (const auto idx : c10::irange(begin, vectorized_begin)) { + // In debug mode, always use Vectorized so that even + // small-sized tests will test whether using Vectorized broke our + // lambda. +#ifndef NDEBUG + std::array loaded_inputs; +#else // NDEBUG + std::array loaded_inputs; +#endif // NDEBUG + for (const auto input_idx : c10::irange(kNumInputs)) { + loaded_inputs[input_idx] = inputs_data_ptrs[input_idx][idx]; + } +#ifndef NDEBUG + std::apply(compute_fun, loaded_inputs).store(&data_out[idx], 1); +#else // NDEBUG + data_out[idx] = std::apply(compute_fun, loaded_inputs); +#endif // NDEBUG + } + + // Main vectorized loop. + for (auto idx = vectorized_begin; idx < vectorized_end; + idx += Vec::size()) { + std::array loaded_vec_inputs; + for (const auto input_idx : c10::irange(kNumInputs)) { + loaded_vec_inputs[input_idx] = + Vec::loadu(&inputs_data_ptrs[input_idx][idx]); + } + auto result_vec = std::apply(compute_fun, loaded_vec_inputs); + result_vec.store(&data_out[idx]); + } + + // Scalar epilogue. + for (const auto idx : c10::irange(vectorized_end, end)) { +#ifndef NDEBUG + std::array loaded_inputs; +#else // NDEBUG + std::array loaded_inputs; +#endif // NDEBUG + for (const auto input_idx : c10::irange(kNumInputs)) { + loaded_inputs[input_idx] = inputs_data_ptrs[input_idx][idx]; + } +#ifndef NDEBUG + std::apply(compute_fun, loaded_inputs).store(&data_out[idx], 1); +#else // NDEBUG + data_out[idx] = std::apply(compute_fun, loaded_inputs); +#endif // NDEBUG + } + }); + return; + } + } +#endif // ET_USE_PYTORCH_HEADERS ::executorch::extension::parallel_for( 0, @@ -262,6 +381,19 @@ inline void apply_unitensor_elementwise_fn( compute_fun, ctx, out, out_dtypes, std::make_pair(&a, a_dtypes)); } +/** + * Useful for unary elementwise operators. For each element of the + * input, call Op and write to the corresponding element of the + * output. Tensor broadcasting is applied wherever it is required. + * + * [NOTE: Generic lambdas]: If Op is a *generic* lambda (i.e., one with `auto` + * parameters; normal lambdas are fine), it must fulfill one of the + * following conditions. Either: + * 1) It must in fact compile when passed at::vec::Vectorized, or + * 2) It must be actively SFINAE-friendly, as per the C++17 examples in + * https://stackoverflow.com/questions/76525790/detecting-if-a-generic-lambda-with-certain-arguments-is-invocable + * . + */ template < typename CTYPE_COMPUTE, const char* op_name, @@ -327,6 +459,8 @@ inline void apply_bitensor_elementwise_fn( * Useful for bi-tensor elementwise operators. For each element of the inputs, * perform a computation and write to the corresponding element of the output. * Tensor broadcasting is applied wherever it is required. + * See [NOTE: Generic lambdas] if you want to pass a generic lambda for + * compute_fun. */ template < typename CTYPE_COMPUTE, @@ -423,6 +557,9 @@ inline void apply_tritensor_elementwise_fn( * * static constexpr const char op_name[] = "my_op"; * apply_ternary_elementwise_fn. + * + * See [NOTE: Generic lambdas] if you want to pass a generic lambda for + * compute_fun. */ template < typename CTYPE_COMPUTE, diff --git a/kernels/portable/cpu/util/math_util.h b/kernels/portable/cpu/util/math_util.h index 2ba068da18e..2c4828b9e6e 100644 --- a/kernels/portable/cpu/util/math_util.h +++ b/kernels/portable/cpu/util/math_util.h @@ -8,6 +8,10 @@ #pragma once +#if defined(ET_USE_PYTORCH_HEADERS) && ET_USE_PYTORCH_HEADERS +#include +#endif + namespace torch { namespace executor { namespace native { @@ -138,6 +142,32 @@ T max_override(T a, T b) { return b; } +#if defined(ET_USE_PYTORCH_HEADERS) && ET_USE_PYTORCH_HEADERS +template +at::vec::Vectorized min_override( + at::vec::Vectorized a, + at::vec::Vectorized b) { + return at::vec::minimum(a, b); +} + +template +at::vec::Vectorized min_override(at::vec::Vectorized a, T b) { + return min_override(a, at::vec::Vectorized(b)); +} + +template +at::vec::Vectorized max_override( + at::vec::Vectorized a, + at::vec::Vectorized b) { + return at::vec::maximum(a, b); +} + +template +at::vec::Vectorized max_override(at::vec::Vectorized a, T b) { + return max_override(a, at::vec::Vectorized(b)); +} + +#endif /** * There is a slight difference in how std::fmod works compared to how ATen * determines remainders: diff --git a/kernels/portable/cpu/util/targets.bzl b/kernels/portable/cpu/util/targets.bzl index abf3f22c00b..65a0c9fc47a 100644 --- a/kernels/portable/cpu/util/targets.bzl +++ b/kernels/portable/cpu/util/targets.bzl @@ -33,6 +33,7 @@ def define_common_targets(): "//executorch/kernels/portable/cpu/util:slice_util", "//executorch/kernels/portable/cpu/util:elementwise_util", "//executorch/kernels/portable/cpu/util:upsample_util", + "//executorch/kernels/portable/cpu/util:vectorized_math", ], visibility = ["//executorch/...", "@EXECUTORCH_CLIENTS"], ) @@ -110,6 +111,8 @@ def define_common_targets(): ":broadcast_indexes_range", ":broadcast_util", ":dtype_util", + ":vectorized_math", + "//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch", "//executorch/runtime/kernel:kernel_runtime_context", "//executorch/extension/threadpool:threadpool", ], @@ -260,6 +263,9 @@ def define_common_targets(): srcs = [], exported_headers = ["math_util.h"], visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/quantized/..."], + exported_deps = [ + "//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch", + ], ) runtime.cxx_library( @@ -307,6 +313,16 @@ def define_common_targets(): ], ) + runtime.cxx_library( + name = "vectorized_math", + exported_headers = ["vectorized_math.h"], + visibility = ["//executorch/..."], + exported_deps = [ + "//executorch/runtime/core/portable_type:portable_type", + "//executorch/runtime/core/exec_aten/util:scalar_type_util", + ], + ) + # Utility functions that can be used by operators that perform reduction for aten_mode in get_aten_mode_options(): suffix = "_aten" if aten_mode else "" diff --git a/kernels/portable/cpu/util/test/CMakeLists.txt b/kernels/portable/cpu/util/test/CMakeLists.txt index d95b3a81b5c..41bfea54020 100644 --- a/kernels/portable/cpu/util/test/CMakeLists.txt +++ b/kernels/portable/cpu/util/test/CMakeLists.txt @@ -4,26 +4,22 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# @generated by test/utils/generate_gtest_cmakelists.py -# -# This file should be formatted with -# ~~~ -# cmake-format -i CMakeLists.txt -# ~~~ -# It should also be cmake-lint clean. -# - cmake_minimum_required(VERSION 3.19) set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../../..) include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake) +include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) set(_test_srcs broadcast_indexes_range_test.cpp broadcast_test.cpp - reduce_test.cpp + reduce_test.cpp vectorized_math_test.cpp ) et_cxx_test( kernels_portable_cpu_util_test SOURCES ${_test_srcs} EXTRA_LIBS portable_kernels portable_ops_lib ) + +find_package_torch_headers() +target_include_directories(kernels_portable_cpu_util_test PRIVATE ${TORCH_INCLUDE_DIRS}) +target_compile_definitions(kernels_portable_cpu_util_test PRIVATE ET_USE_PYTORCH_HEADERS) diff --git a/kernels/portable/cpu/util/test/targets.bzl b/kernels/portable/cpu/util/test/targets.bzl index 178eb25a79b..69ca7d0dd22 100644 --- a/kernels/portable/cpu/util/test/targets.bzl +++ b/kernels/portable/cpu/util/test/targets.bzl @@ -32,3 +32,14 @@ def define_common_targets(): "//executorch/kernels/portable/cpu/util:reduce_util", ], ) + + # this test requires ET_USE_PYTORCH_HEADERS, which doesn't work in OSS Buck. + if not runtime.is_oss: + runtime.cxx_test( + name = "vectorized_math_test", + srcs = ["vectorized_math_test.cpp"], + deps = [ + "//executorch/kernels/portable/cpu/util:vectorized_math", + "//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch", + ], + ) diff --git a/kernels/portable/cpu/util/test/vectorized_math_test.cpp b/kernels/portable/cpu/util/test/vectorized_math_test.cpp new file mode 100644 index 00000000000..69115985fc1 --- /dev/null +++ b/kernels/portable/cpu/util/test/vectorized_math_test.cpp @@ -0,0 +1,95 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +#include + +#ifndef ET_USE_PYTORCH_HEADERS +#error "This test requires ET_USE_PYTORCH_HEADERS!" +#endif // ET_USE_PYTORCH_HEADERS + +TEST(VectorizedMathTest, BasicUnary) { + __at_align__ float result_floats[at::vec::Vectorized::size()] = {0}; + const auto x_vec = at::vec::Vectorized::arange(0, 1); + const auto result_vec = executorch::math::exp(x_vec); + result_vec.store(result_floats); + for (const auto ii : c10::irange(at::vec::Vectorized::size())) { + EXPECT_FLOAT_EQ(result_floats[ii], std::exp(ii)); + } +} + +namespace { +template +void test_unary_t_to_float() { + __at_align__ float result_floats[at::vec::Vectorized::size()] = {0}; + const auto x_vec = at::vec::Vectorized::arange(0, 1); + const auto result_vec = executorch::math::exp(x_vec); + static_assert(decltype(result_vec)::size() >= at::vec::Vectorized::size()); + result_vec.store(result_floats, at::vec::Vectorized::size()); + for (const auto ii : c10::irange(at::vec::Vectorized::size())) { + EXPECT_FLOAT_EQ(result_floats[ii], std::exp((float)ii)) << ii; + } +} + +} // namespace + +TEST(VectorizedMathTest, UnaryInt16ToFloat) { + test_unary_t_to_float(); +} + +TEST(VectorizedMathTest, UnaryInt32ToFloat) { + test_unary_t_to_float(); +} + +TEST(VectorizedMathTest, UnaryInt64ToFloat) { + test_unary_t_to_float(); +} + +TEST(VectorizedMathTest, BasicBinary) { + __at_align__ float result_floats[at::vec::Vectorized::size()] = {0}; + const auto x_vec = at::vec::Vectorized::arange(0, 1); + const auto y_vec = at::vec::Vectorized(2); + const auto result_vec = executorch::math::pow(x_vec, y_vec); + result_vec.store(result_floats); + for (const auto ii : c10::irange(at::vec::Vectorized::size())) { + EXPECT_FLOAT_EQ(result_floats[ii], std::pow((float)ii, 2.0f)); + } +} + +namespace { +template +void test_binary_t_to_float() { + __at_align__ float result_floats[at::vec::Vectorized::size()] = {0}; + const auto x_vec = at::vec::Vectorized::arange(0, 1); + const auto y_vec = at::vec::Vectorized(2); + const auto result_vec = executorch::math::pow(x_vec, y_vec); + static_assert(decltype(result_vec)::size() >= at::vec::Vectorized::size()); + result_vec.store(result_floats, at::vec::Vectorized::size()); + for (const auto ii : c10::irange(at::vec::Vectorized::size())) { + EXPECT_EQ(result_floats[ii], std::pow((float)ii, 2.0f)) << ii; + } +} + +TEST(VectorizedMathTest, BinaryInt16ToFloat) { + test_binary_t_to_float(); +} + +TEST(VectorizedMathTest, BinaryInt32ToFloat) { + test_binary_t_to_float(); +} + +TEST(VectorizedMathTest, BinaryInt64ToFloat) { + test_binary_t_to_float(); +} + +} // namespace diff --git a/kernels/portable/cpu/util/vectorized_math.h b/kernels/portable/cpu/util/vectorized_math.h new file mode 100644 index 00000000000..e67e862ef62 --- /dev/null +++ b/kernels/portable/cpu/util/vectorized_math.h @@ -0,0 +1,148 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#if defined(ET_USE_PYTORCH_HEADERS) && ET_USE_PYTORCH_HEADERS +#include +#endif // ET_USE_PYTORCH_HEADERS + +#include +#include + +#if defined(ET_USE_PYTORCH_HEADERS) && ET_USE_PYTORCH_HEADERS +namespace executorch { +inline namespace math { +namespace internal { +template +auto convert_to_vectorized_n_of_float(at::vec::Vectorized vec) { + static constexpr auto float_vec_size = at::vec::Vectorized::size(); + static constexpr auto t_vec_size = at::vec::Vectorized::size(); + static constexpr auto result_size = + t_vec_size < float_vec_size ? 1 : t_vec_size / float_vec_size; + static_assert(result_size >= 1); + return at::vec::convert( + at::vec::VectorizedN(vec)); +} +} // namespace internal +} // namespace math +} // namespace executorch +#endif // ET_USE_PYTORCH_HEADERS + +#define _ET_INTERNAL_STD_MATH_FUNC(name) \ + namespace executorch { \ + inline namespace math { \ + using std::name; \ + } \ + } // namespace executorch + +#if defined(ET_USE_PYTORCH_HEADERS) && ET_USE_PYTORCH_HEADERS +/** + * Internal-usage macro for making a vectorized variant of a unary + * function available in the executorch::math namespace. + */ +#define ET_INTERNAL_VECTORIZED_FLOAT_UNARY_FUNC(func_name) \ + namespace executorch { \ + inline namespace math { \ + template \ + auto func_name(at::vec::Vectorized vec) { \ + if constexpr (!::executorch::runtime::is_floating_point::value) { \ + return internal::convert_to_vectorized_n_of_float(vec).func_name(); \ + } else { \ + return vec.func_name(); \ + } \ + } \ + } \ + } + +#define ET_INTERNAL_VECTORIZED_FLOAT_BINARY_FUNC(func_name) \ + namespace executorch { \ + inline namespace math { \ + template \ + auto func_name(at::vec::Vectorized vec0, at::vec::Vectorized vec1) { \ + if constexpr (!::executorch::runtime::is_floating_point::value) { \ + const auto vec_float0 = \ + internal::convert_to_vectorized_n_of_float(vec0); \ + const auto vec_float1 = \ + internal::convert_to_vectorized_n_of_float(vec1); \ + return vec_float0.func_name(vec_float1); \ + } else { \ + return vec0.func_name(vec1); \ + } \ + } \ + } \ + } + +/** + * Internal-usage macro for making a C++ standard library + * floating-point function and a vectorized variant of it available in + * the c10::math namespace. Should be used with functions where the + * corresponding operator is a "float op" in TensorIterator parlance + * (i.e., uses something like build_borrowing_binary_float_op()), + * because it converts non-floating-point arguments to floating point. + */ +#define ET_INTERNAL_VECTORIZED_STD_FLOAT_UNARY_FUNC(func_name) \ + _ET_INTERNAL_STD_MATH_FUNC(func_name) \ + ET_INTERNAL_VECTORIZED_FLOAT_UNARY_FUNC(func_name) + +#define ET_INTERNAL_VECTORIZED_STD_FLOAT_BINARY_FUNC(func_name) \ + _ET_INTERNAL_STD_MATH_FUNC(func_name) \ + ET_INTERNAL_VECTORIZED_FLOAT_BINARY_FUNC(func_name) + +#else // ET_USE_PYTORCH_HEADERS +#define ET_INTERNAL_VECTORIZED_STD_FLOAT_UNARY_FUNC(name) \ + _ET_INTERNAL_STD_MATH_FUNC(name) +#define ET_INTERNAL_VECTORIZED_STD_FLOAT_BINARY_FUNC(name) \ + _ET_INTERNAL_STD_MATH_FUNC(name) +#endif // ET_USE_PYTORCH_HEADERS + +// To simplify client code, we provide coverage for a bunch of float ops (the +// same ones listed in ATen vml.h) here. +ET_INTERNAL_VECTORIZED_STD_FLOAT_UNARY_FUNC(abs) +ET_INTERNAL_VECTORIZED_STD_FLOAT_UNARY_FUNC(acos) +ET_INTERNAL_VECTORIZED_STD_FLOAT_UNARY_FUNC(asin) +ET_INTERNAL_VECTORIZED_STD_FLOAT_UNARY_FUNC(atan) +ET_INTERNAL_VECTORIZED_STD_FLOAT_UNARY_FUNC(ceil) +ET_INTERNAL_VECTORIZED_STD_FLOAT_UNARY_FUNC(cos) +ET_INTERNAL_VECTORIZED_STD_FLOAT_UNARY_FUNC(cosh) +ET_INTERNAL_VECTORIZED_STD_FLOAT_UNARY_FUNC(erf) +ET_INTERNAL_VECTORIZED_STD_FLOAT_UNARY_FUNC(erfc) +ET_INTERNAL_VECTORIZED_STD_FLOAT_UNARY_FUNC(exp) +ET_INTERNAL_VECTORIZED_STD_FLOAT_UNARY_FUNC(expm1) +ET_INTERNAL_VECTORIZED_STD_FLOAT_UNARY_FUNC(floor) +ET_INTERNAL_VECTORIZED_STD_FLOAT_UNARY_FUNC(log) +ET_INTERNAL_VECTORIZED_STD_FLOAT_UNARY_FUNC(log10) +ET_INTERNAL_VECTORIZED_STD_FLOAT_UNARY_FUNC(log1p) +ET_INTERNAL_VECTORIZED_STD_FLOAT_UNARY_FUNC(log2) +ET_INTERNAL_VECTORIZED_STD_FLOAT_UNARY_FUNC(sin) +ET_INTERNAL_VECTORIZED_STD_FLOAT_UNARY_FUNC(sinh) +ET_INTERNAL_VECTORIZED_STD_FLOAT_UNARY_FUNC(sqrt) +ET_INTERNAL_VECTORIZED_STD_FLOAT_UNARY_FUNC(round) +ET_INTERNAL_VECTORIZED_STD_FLOAT_UNARY_FUNC(tan) +ET_INTERNAL_VECTORIZED_STD_FLOAT_UNARY_FUNC(tanh) +ET_INTERNAL_VECTORIZED_STD_FLOAT_UNARY_FUNC(trunc) +ET_INTERNAL_VECTORIZED_STD_FLOAT_UNARY_FUNC(lgamma) + +#if defined(ET_USE_PYTORCH_HEADERS) && ET_USE_PYTORCH_HEADERS +ET_INTERNAL_VECTORIZED_FLOAT_BINARY_FUNC(rsqrt) +#endif // ET_USE_PYTORCH_HEADERS + +namespace executorch { +inline namespace math { +template >> +T rsqrt(T x) { + return T(1) / std::sqrt(x); +} +} // namespace math +} // namespace executorch + +ET_INTERNAL_VECTORIZED_STD_FLOAT_BINARY_FUNC(atan2) +ET_INTERNAL_VECTORIZED_STD_FLOAT_BINARY_FUNC(fmod) +ET_INTERNAL_VECTORIZED_STD_FLOAT_BINARY_FUNC(pow) diff --git a/kernels/test/CMakeLists.txt b/kernels/test/CMakeLists.txt index 4f174b5a652..f5997a1ee3f 100644 --- a/kernels/test/CMakeLists.txt +++ b/kernels/test/CMakeLists.txt @@ -17,7 +17,7 @@ set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..) include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake) -set(_kernels portable optimized quantized) +set(_kernels portable optimized_portable optimized quantized) foreach(kernel ${_kernels}) set(_wrapper_dir "${CMAKE_CURRENT_BINARY_DIR}/include/${kernel}/executorch/kernels/test" @@ -37,13 +37,17 @@ foreach(kernel ${_kernels}) VERBATIM ) + set(_supported_features_kernel ${kernel}) + if(${kernel} STREQUAL "optimized_portable") + set(_supported_features_kernel "portable") + endif() add_custom_command( OUTPUT "${_wrapper_dir}/supported_features.cpp" "${_wrapper_dir}/supported_features.h" COMMAND mkdir -p ${_wrapper_dir} COMMAND ${PYTHON_EXECUTABLE} kernels/test/gen_supported_features.py - kernels/${kernel}/test/supported_features_def.yaml > + kernels/${_supported_features_kernel}/test/supported_features_def.yaml > ${_wrapper_dir}/supported_features.cpp COMMAND ${PYTHON_EXECUTABLE} kernels/test/gen_supported_features.py @@ -57,6 +61,11 @@ foreach(kernel ${_kernels}) set(_kernel_ops_lib "optimized_native_cpu_ops_lib") set(_kernel_ops_lib_path "${CMAKE_CURRENT_BINARY_DIR}/../../configurations/optimized_native_cpu_ops_lib" + ) + elseif(${kernel} STREQUAL "optimized_portable") + set(_kernel_ops_lib "${kernel}_ops_lib") + set(_kernel_ops_lib_path + "${CMAKE_CURRENT_BINARY_DIR}/../../kernels/portable/${kernel}_ops_lib" ) else() set(_kernel_ops_lib "${kernel}_ops_lib") @@ -88,6 +97,9 @@ add_custom_target( "${CMAKE_CURRENT_BINARY_DIR}/include/optimized/executorch/kernels/test/FunctionHeaderWrapper.h" "${CMAKE_CURRENT_BINARY_DIR}/include/optimized/executorch/kernels/test/supported_features.h" "${CMAKE_CURRENT_BINARY_DIR}/include/optimized/executorch/kernels/test/supported_features.cpp" + "${CMAKE_CURRENT_BINARY_DIR}/include/optimized_portable/executorch/kernels/test/FunctionHeaderWrapper.h" + "${CMAKE_CURRENT_BINARY_DIR}/include/optimized_portable/executorch/kernels/test/supported_features.h" + "${CMAKE_CURRENT_BINARY_DIR}/include/optimized_portable/executorch/kernels/test/supported_features.cpp" "${CMAKE_CURRENT_BINARY_DIR}/include/quantized/executorch/kernels/test/FunctionHeaderWrapper.h" "${CMAKE_CURRENT_BINARY_DIR}/include/quantized/executorch/kernels/test/supported_features.h" "${CMAKE_CURRENT_BINARY_DIR}/include/quantized/executorch/kernels/test/supported_features.cpp" @@ -299,6 +311,22 @@ set(_optimized_kernels_test_sources if(TARGET optimized_portable_kernels) list(APPEND _optimized_kernels_test_sources ${all_test_sources}) list(REMOVE_DUPLICATES _optimized_kernels_test_sources) + + # Make sure that we still test optimized versions of portable + # kernels even if they would currently be shadowed by specific + # optimized implementations. + et_cxx_test( + optimized_portable_kernels_test + SOURCES + ${all_test_sources} + ${CMAKE_CURRENT_BINARY_DIR}/include/optimized_portable/executorch/kernels/test/supported_features.cpp + EXTRA_LIBS + optimized_portable_kernels + ) + add_dependencies(optimized_portable_kernels_test generate_wrapper) + target_include_directories( + optimized_portable_kernels_test PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/include/optimized_portable" + ) endif() et_cxx_test( diff --git a/kernels/test/op_atan2_test.cpp b/kernels/test/op_atan2_test.cpp index 436826e2b6d..ae19ef687bc 100644 --- a/kernels/test/op_atan2_test.cpp +++ b/kernels/test/op_atan2_test.cpp @@ -46,3 +46,36 @@ TEST(OpAtan2OutTest, SmokeTest) { op_atan2_out(self, other, out); EXPECT_TENSOR_CLOSE(out, out_expected); } + +TEST(OpAtan2OutTest, SmokeTestNoBroadcastingSameDtype) { + TensorFactory tfDouble; + + std::vector a(18); + std::iota(a.begin(), a.end(), -8); + std::vector b(18, 2.0); + Tensor self = tfDouble.make({18}, a); + Tensor other = tfDouble.make({18}, b); + Tensor out = tfDouble.zeros({18}); + Tensor out_expected = tfDouble.make( + {18}, + {-1.3258176636680326, + -1.2924966677897853, + -1.2490457723982544, + -1.1902899496825317, + -1.1071487177940904, + -0.9827937232473291, + -0.7853981633974483, + -0.4636476090008061, + 0.0000000000000000, + 0.4636476090008061, + 0.7853981633974483, + 0.9827937232473291, + 1.1071487177940904, + 1.1902899496825317, + 1.2490457723982544, + 1.2924966677897853, + 1.3258176636680326, + 1.3521273809209546}); + op_atan2_out(self, other, out); + EXPECT_TENSOR_CLOSE(out, out_expected); +} diff --git a/kernels/test/op_clamp_test.cpp b/kernels/test/op_clamp_test.cpp index 68a1c6a1997..81138fc8a55 100644 --- a/kernels/test/op_clamp_test.cpp +++ b/kernels/test/op_clamp_test.cpp @@ -31,6 +31,15 @@ using torch::executor::testing::TensorFactory; using OptScalar = std::optional; +namespace { +template +std::vector arange(T stop) { + std::vector result(stop); + std::iota(result.begin(), result.end(), 0); + return result; +} +} // namespace + class OpClampOutTest : public OperatorTest { protected: Tensor& op_clamp_out( @@ -114,6 +123,31 @@ class OpClampOutTest : public OperatorTest { // Should set all elements to max. {6, 6, 6, 6}, // expected_data }, + { + std::string(__func__) + ": Simple clamp larger data", + {18}, // sizes + arange::ctype>(18), // input_data + OptScalar(1), // min + OptScalar(6), // max + {1, + 1, + 2, + 3, + 4, + 5, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6}, // expected_data + }, }; run_test_cases(test_cases); diff --git a/kernels/test/op_fmod_test.cpp b/kernels/test/op_fmod_test.cpp index fa7cc4b63f7..3227a01a17a 100644 --- a/kernels/test/op_fmod_test.cpp +++ b/kernels/test/op_fmod_test.cpp @@ -45,3 +45,34 @@ TEST_F(OpFmodTest, SmokeTest) { op_fmod_tensor_out(self, other, out); EXPECT_TENSOR_CLOSE(out, out_expected); } + +TEST_F(OpFmodTest, ScalarSmokeTest) { + TensorFactory tfFloat; + std::vector a(18); + std::iota(a.begin(), a.end(), -8); + Tensor self = tfFloat.make({18}, a); + Scalar other = 3; + Tensor out = tfFloat.zeros({18}); + Tensor out_expected = tfFloat.make( + {18}, + {-2., + -1., + -0., + -2., + -1., + -0., + -2., + -1., + 0., + 1., + 2., + 0., + 1., + 2., + 0., + 1., + 2., + 0.}); + op_fmod_scalar_out(self, other, out); + EXPECT_TENSOR_CLOSE(out, out_expected); +} diff --git a/kernels/test/op_maximum_test.cpp b/kernels/test/op_maximum_test.cpp index faa18fa56cd..c32cf571ff3 100644 --- a/kernels/test/op_maximum_test.cpp +++ b/kernels/test/op_maximum_test.cpp @@ -37,3 +37,17 @@ TEST(OpMaximumOutTest, SmokeTest) { op_maximum_out(self, other, out); EXPECT_TENSOR_CLOSE(out, out_expected); } + +TEST(OpMaximumOutTest, SmokeTestLarger) { + TensorFactory tfFloat; + + std::vector a(18); + std::iota(a.begin(), a.end(), -8); + Tensor self = tfFloat.make({18}, a); + Tensor other = tfFloat.full({18}, 4); + Tensor out = tfFloat.zeros({18}); + Tensor out_expected = tfFloat.make( + {18}, {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 6, 7, 8, 9}); + op_maximum_out(self, other, out); + EXPECT_TENSOR_CLOSE(out, out_expected); +} diff --git a/kernels/test/op_minimum_test.cpp b/kernels/test/op_minimum_test.cpp index 686e1feee64..9c256963943 100644 --- a/kernels/test/op_minimum_test.cpp +++ b/kernels/test/op_minimum_test.cpp @@ -266,3 +266,17 @@ TEST_F(OpMinimumOutTest, DynamicShapeUnbound) { op_minimum_out(x, y, out); EXPECT_TENSOR_EQ(out, expected); } + +TEST_F(OpMinimumOutTest, SmokeTestLarger) { + TensorFactory tfFloat; + + std::vector a(18); + std::iota(a.begin(), a.end(), -8); + Tensor self = tfFloat.make({18}, a); + Tensor other = tfFloat.full({18}, 4); + Tensor out = tfFloat.zeros({18}); + Tensor out_expected = tfFloat.make( + {18}, {-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 4, 4, 4, 4, 4}); + op_minimum_out(self, other, out); + EXPECT_TENSOR_CLOSE(out, out_expected); +} diff --git a/kernels/test/op_mul_test.cpp b/kernels/test/op_mul_test.cpp index e109193e227..2d2f2872b99 100644 --- a/kernels/test/op_mul_test.cpp +++ b/kernels/test/op_mul_test.cpp @@ -30,7 +30,7 @@ class OpMulOutTest : public OperatorTest { return torch::executor::aten::mul_outf(context_, self, other, out); } - // Common testing for multipling two integer Tensors + // Common testing for multiplying two integer Tensors template void test_mul() { TensorFactory tf_a; @@ -54,6 +54,10 @@ class OpMulOutTest : public OperatorTest { tf_b.make(sizes, /*data=*/{1, 2, 4, 8}), out); EXPECT_TENSOR_EQ(out, tf_out.make(sizes, /*data=*/{1, 4, 16, 64})); + + out = tf_out.zeros({18}); + op_mul_out(tf_a.full({18}, 4), tf_b.full({18}, 2), out); + EXPECT_TENSOR_EQ(out, tf_out.full({18}, 8)); } template diff --git a/kernels/test/op_pow_test.cpp b/kernels/test/op_pow_test.cpp index f9234a748b9..25d0f97526c 100644 --- a/kernels/test/op_pow_test.cpp +++ b/kernels/test/op_pow_test.cpp @@ -54,6 +54,19 @@ TEST_F(OpPowTest, TensorTensorSanityCheck) { EXPECT_TENSOR_EQ(out, tf.make({2, 2}, {16, 16, 16, 16})); } +TEST_F(OpPowTest, TensorTensorSanityCheckLargerNoBroadcasting) { + TensorFactory tf; + Tensor self = tf.full({18}, 2); + Tensor exp = tf.full({18}, 4); + Tensor out = tf.zeros({18}); + Tensor out_expected = tf.full({18}, 16); + + Tensor ret = op_pow_tensor_tensor_out(self, exp, out); + + EXPECT_TENSOR_EQ(out, ret); + EXPECT_TENSOR_EQ(out_expected, out); +} + TEST_F(OpPowTest, TensorTensorSanityCheck2) { TensorFactory tf1; TensorFactory tf2; diff --git a/kernels/test/op_sigmoid_test.cpp b/kernels/test/op_sigmoid_test.cpp index 550cebda315..1e3499ba451 100644 --- a/kernels/test/op_sigmoid_test.cpp +++ b/kernels/test/op_sigmoid_test.cpp @@ -44,6 +44,10 @@ class OpSigmoidOutTest : public OperatorTest { EXPECT_TENSOR_CLOSE( out, tf_out.make(sizes, /*data=*/{0.731059, 0.880797, 0.982014, 0.999665})); + + out = tf_out.zeros({18}); + op_sigmoid_out(tf.full({18}, 2), out); + EXPECT_TENSOR_CLOSE(out, tf_out.full({18}, 0.880797)); } // Unhandled output dtypes. diff --git a/runtime/core/portable_type/c10/c10/targets.bzl b/runtime/core/portable_type/c10/c10/targets.bzl index 827a63d2cef..995cbcd4dfa 100644 --- a/runtime/core/portable_type/c10/c10/targets.bzl +++ b/runtime/core/portable_type/c10/c10/targets.bzl @@ -53,7 +53,11 @@ def define_common_targets(): runtime.cxx_library( name = "aten_headers_for_executorch", srcs = [], - visibility = ["//executorch/kernels/optimized/...", "@EXECUTORCH_CLIENTS"], + visibility = [ + "//executorch/kernels/optimized/...", + "//executorch/kernels/portable/cpu/util/...", + "@EXECUTORCH_CLIENTS", + ], exported_deps = select({ "DEFAULT": [], "ovr_config//cpu:arm64": [ @@ -77,7 +81,7 @@ def define_common_targets(): # -Wmacro-redefined, and we only care about getting # reasonable vectorization and Sleef support. "-DCPU_CAPABILITY_AVX2", - "-DET_USE_PYTORCH_HEADERS", + "-DET_USE_PYTORCH_HEADERS=ET_HAS_EXCEPTIONS", "-DHAVE_AVX2_CPU_DEFINITION", "-DSTANDALONE_TORCH_HEADER", ] + get_sleef_preprocessor_flags(), @@ -91,5 +95,5 @@ def define_common_targets(): # linker failure. "ovr_config//cpu:arm64": get_sleef_preprocessor_flags(), "DEFAULT": [], - }) + ["-DSTANDALONE_TORCH_HEADER"] + ([] if runtime.is_oss else ["-DET_USE_PYTORCH_HEADERS"]), + }) + ["-DSTANDALONE_TORCH_HEADER"] + ([] if runtime.is_oss else ["-DET_USE_PYTORCH_HEADERS=ET_HAS_EXCEPTIONS"]), ) diff --git a/runtime/core/portable_type/targets.bzl b/runtime/core/portable_type/targets.bzl index 41bc6050524..5b6e67fa213 100644 --- a/runtime/core/portable_type/targets.bzl +++ b/runtime/core/portable_type/targets.bzl @@ -26,6 +26,7 @@ def define_common_targets(): visibility = [ "//executorch/backends/...", "//executorch/extension/fb/dynamic_shim/...", + "//executorch/kernels/portable/cpu/...", "//executorch/runtime/core/exec_aten/...", "//executorch/runtime/core/portable_type/test/...", ], diff --git a/test/utils/OSSTestConfig.json b/test/utils/OSSTestConfig.json index 2cfc4b8a995..182d0bfd58a 100644 --- a/test/utils/OSSTestConfig.json +++ b/test/utils/OSSTestConfig.json @@ -68,18 +68,6 @@ "extension_threadpool" ] }, - { - "directory": "kernels/portable/cpu/util/test", - "sources": [ - "broadcast_indexes_range_test.cpp", - "broadcast_test.cpp", - "reduce_test.cpp" - ], - "additional_libs": [ - "portable_kernels", - "portable_ops_lib" - ] - }, { "directory": "runtime/core/portable_type/test", "sources": [