diff --git a/kernels/portable/cpu/op_argmin.cpp b/kernels/portable/cpu/op_argmin.cpp index b0816596e4e..87e90de4c04 100644 --- a/kernels/portable/cpu/op_argmin.cpp +++ b/kernels/portable/cpu/op_argmin.cpp @@ -12,6 +12,7 @@ #include #include +#include #include namespace torch { @@ -47,8 +48,17 @@ Tensor& argmin_out( ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, "argmin.out", CTYPE, [&] { long* out_data = out.mutable_data_ptr(); - const bool success = parallel_for_each_reduce_over_dim_output_index( - in, dim, out, [&](const auto begin, const auto end) { + // REVIEW: this is the parallelization strategy ATen uses + // specifically when the reduction is along the last dimension and + // that dimension is contiguous. Is there any particular reason we + // shouldn't just always use this strategy since we aren't + // otherwise capable of parallelizing reductions? + const int64_t reduction_size = get_reduced_dim_product(in, dim); + const auto grain_size = std::max( + static_cast(1), + executorch::extension::internal::GRAIN_SIZE / reduction_size); + const bool success = executorch::extension::parallel_for( + 0, out.numel(), grain_size, [&](const auto begin, const auto end) { for (const auto out_ix : c10::irange(begin, end)) { std::tuple acc = reduce_over_dim( [](CTYPE v, long ix, CTYPE acc_val, long acc_ix) { diff --git a/kernels/portable/cpu/util/reduce_util.h b/kernels/portable/cpu/util/reduce_util.h index 6e109299920..2160d9810ae 100644 --- a/kernels/portable/cpu/util/reduce_util.h +++ b/kernels/portable/cpu/util/reduce_util.h @@ -10,7 +10,6 @@ #include #include -#include #include #include @@ -812,23 +811,5 @@ bool check_prod_out_args( #endif -/** - * parallel_for wrapper for reductions that call reduce_over_dim or - * map_reduce_over_dim for each output element. Automatically - * calculates appropriate grain size. - */ -template -[[nodiscard]] bool parallel_for_each_reduce_over_dim_output_index( - const Tensor& in, - optional dim, - const Tensor& out, - const Func& func) { - const int64_t reduction_size = get_reduced_dim_product(in, dim); - const auto grain_size = std::max( - static_cast(1), - executorch::extension::internal::GRAIN_SIZE / reduction_size); - return executorch::extension::parallel_for(0, out.numel(), grain_size, func); -} - } // namespace executor } // namespace torch diff --git a/kernels/portable/cpu/util/targets.bzl b/kernels/portable/cpu/util/targets.bzl index 95fd1734d8e..db8202a920a 100644 --- a/kernels/portable/cpu/util/targets.bzl +++ b/kernels/portable/cpu/util/targets.bzl @@ -314,9 +314,6 @@ def define_common_targets(): "//executorch/runtime/kernel:kernel_includes{}".format(suffix), "//executorch/runtime/core/exec_aten/util:tensor_util{}".format(suffix), ], - exported_deps = [ - "//executorch/runtime/kernel:thread_parallel_interface", - ], exported_preprocessor_flags = ["-DUSE_ATEN_LIB"] if aten_mode else [], visibility = [ "//executorch/extension/llm/custom_ops/...", diff --git a/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl index b56413b92f4..dd48da64c30 100644 --- a/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -284,6 +284,7 @@ ATEN_OPS = ( name = "op_argmin", deps = [ "//executorch/kernels/portable/cpu/util:reduce_util", + "//executorch/runtime/kernel:thread_parallel_interface", ], ), op_target(