Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 2 additions & 12 deletions kernels/portable/cpu/op_argmin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

#include <executorch/kernels/portable/cpu/util/reduce_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>
#include <executorch/runtime/kernel/thread_parallel_interface.h>
#include <executorch/runtime/platform/assert.h>

namespace torch {
Expand Down Expand Up @@ -48,17 +47,8 @@ Tensor& argmin_out(
ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, "argmin.out", CTYPE, [&] {
long* out_data = out.mutable_data_ptr<long>();

// REVIEW: this is the parallelization strategy ATen uses
// specifically when the reduction is along the last dimension and
// that dimension is contiguous. Is there any particular reason we
// shouldn't just always use this strategy since we aren't
// otherwise capable of parallelizing reductions?
const int64_t reduction_size = get_reduced_dim_product(in, dim);
const auto grain_size = std::max(
static_cast<int64_t>(1),
executorch::extension::internal::GRAIN_SIZE / reduction_size);
const bool success = executorch::extension::parallel_for(
0, out.numel(), grain_size, [&](const auto begin, const auto end) {
const bool success = parallel_for_each_reduce_over_dim_output_index(
in, dim, out, [&](const auto begin, const auto end) {
for (const auto out_ix : c10::irange(begin, end)) {
std::tuple<CTYPE, long> acc = reduce_over_dim<CTYPE>(
[](CTYPE v, long ix, CTYPE acc_val, long acc_ix) {
Expand Down
19 changes: 19 additions & 0 deletions kernels/portable/cpu/util/reduce_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
#include <executorch/runtime/kernel/thread_parallel_interface.h>
#include <cstring>
#include <tuple>

Expand Down Expand Up @@ -811,5 +812,23 @@ bool check_prod_out_args(

#endif

/**
* parallel_for wrapper for reductions that call reduce_over_dim or
* map_reduce_over_dim for each output element. Automatically
* calculates appropriate grain size.
*/
template <typename Func>
[[nodiscard]] bool parallel_for_each_reduce_over_dim_output_index(
const Tensor& in,
executorch::aten::optional<int64_t> dim,
const Tensor& out,
const Func& func) {
const int64_t reduction_size = get_reduced_dim_product(in, dim);
const auto grain_size = std::max(
static_cast<int64_t>(1),
executorch::extension::internal::GRAIN_SIZE / reduction_size);
return executorch::extension::parallel_for(0, out.numel(), grain_size, func);
}

} // namespace executor
} // namespace torch
3 changes: 3 additions & 0 deletions kernels/portable/cpu/util/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,9 @@ def define_common_targets():
"//executorch/runtime/kernel:kernel_includes{}".format(suffix),
"//executorch/runtime/core/exec_aten/util:tensor_util{}".format(suffix),
],
exported_deps = [
"//executorch/runtime/kernel:thread_parallel_interface",
],
exported_preprocessor_flags = ["-DUSE_ATEN_LIB"] if aten_mode else [],
visibility = [
"//executorch/extension/llm/custom_ops/...",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,6 @@ ATEN_OPS = (
name = "op_argmin",
deps = [
"//executorch/kernels/portable/cpu/util:reduce_util",
"//executorch/runtime/kernel:thread_parallel_interface",
],
),
op_target(
Expand Down
Loading