diff --git a/kernels/portable/cpu/util/broadcast_indexes_range.h b/kernels/portable/cpu/util/broadcast_indexes_range.h index c623fdb4c31..5fa50d8d212 100644 --- a/kernels/portable/cpu/util/broadcast_indexes_range.h +++ b/kernels/portable/cpu/util/broadcast_indexes_range.h @@ -14,6 +14,7 @@ #include #include +#include #include #include @@ -78,7 +79,9 @@ class BroadcastIndexesIterator { // You might wonder what happens if output_shape_[ii] == 0. In // that case, output.numel() would be 0, and thus we would have // begin() == end() and no iteration. - if ET_UNLIKELY (delinearized_output_index_[ii] == output_shape_[ii] - 1) { + if ET_UNLIKELY ( + static_cast(delinearized_output_index_[ii]) == + output_shape_[ii] - 1) { const auto old_delinearized_output_index_item = delinearized_output_index_[ii]; delinearized_output_index_[ii] = 0; @@ -104,11 +107,42 @@ class BroadcastIndexesIterator { return it; } + BroadcastIndexesIterator& operator+=(difference_type n) { + if (n <= 3) { + std::advance(*this, n); + return *this; + } + + output_index() += n; + delinearize_index( + output_index(), + output_shape_, + delinearized_output_index_.data(), + delinearized_output_index_.size()); + for (const auto ii : c10::irange(1, kNumInputs + 1)) { + current_indexes_[ii] = 0; + for (const auto jj : c10::irange(output_dim_)) { + current_indexes_[ii] += delinearized_output_index_[jj] * + effective_input_broadcast_strides_[ii - 1][jj]; + } + } + return *this; + } + + BroadcastIndexesIterator operator+(difference_type n) { + auto it = *this; + it += n; + return it; + } + difference_type operator-(const BroadcastIndexesIterator& rhs) const { return difference_type(output_index() - rhs.output_index()); } private: + using ShapeType = + std::array; + ssize_t output_index() const { return current_indexes_[0]; } @@ -117,11 +151,10 @@ class BroadcastIndexesIterator { return current_indexes_[0]; } - std::array - effective_input_broadcast_stride(const Tensor& output, const Tensor& t) - const { - std::array - result = {0}; + ShapeType effective_input_broadcast_stride( + const Tensor& output, + const Tensor& t) const { + ShapeType result = {0}; ET_CHECK_MSG( t.dim() <= output.dim(), "input to broadcasting op should have dim at most output dim, but %d > %d!", @@ -146,8 +179,6 @@ class BroadcastIndexesIterator { // The 0th entry is the current linear index into the output, // followed by kNumInputs input indexes. std::array current_indexes_ = {0}; - using ShapeType = std:: - array; ShapeType delinearized_output_index_ = {0}; ssize_t output_dim_; ArrayRef output_shape_; diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h index 09db5f7180d..a5bcd6ff98b 100644 --- a/kernels/portable/cpu/util/elementwise_util.h +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -14,6 +14,9 @@ #include #include +#include +#include + namespace torch { namespace executor { namespace native { @@ -46,38 +49,94 @@ inline int64_t scalar_to(const Scalar& s) { : s.to(); } -template -inline void apply_unitensor_elementwise_fn( +namespace internal { +template < + typename CTYPE_COMMON, + const char* op_name, + typename Op, + typename... Args> +inline void apply_elementwise_fn( const Op& compute_fun, KernelRuntimeContext& ctx, - const Tensor& a, - SupportedTensorDtypes a_dtypes, const Tensor& out, - SupportedTensorDtypes out_dtypes) { + SupportedTensorDtypes out_dtypes, + Args... inputs) { + static_assert( + (std::is_same_v> && + ...)); + constexpr auto kNumInputs = sizeof...(inputs); constexpr auto compute_type = CppTypeToScalarType::value; - + const auto check_input_dtype = [](auto input, auto compute_type) { + return internal::check_tensor_dtype( + *input.first, input.second, compute_type); + }; ET_KERNEL_CHECK( ctx, - (internal::check_tensor_dtype(a, a_dtypes, compute_type) && - internal::check_tensor_dtype(out, out_dtypes, compute_type)), + (check_input_dtype(inputs, compute_type) && ...) && + internal::check_tensor_dtype(out, out_dtypes, compute_type), InvalidArgument, ); - const auto load_a_to_common = - internal::get_load_to_common_fn(a, a_dtypes); + bool any_is_broadcasted = false; + if constexpr (kNumInputs > 1) { + any_is_broadcasted = (!out.sizes().equals(inputs.first->sizes()) || ...); + } + + struct InputInfo { + load_to_common_fn load_to_common; + const char* data_ptr; + ssize_t element_size; + }; + std::array inputs_info = {(InputInfo{ + internal::get_load_to_common_fn( + *inputs.first, inputs.second), + reinterpret_cast(inputs.first->const_data_ptr()), + inputs.first->element_size(), + })...}; + const auto store_common_to_out = internal::get_store_common_to_tensor_fn( out, out_dtypes); - const char* const data_a = reinterpret_cast(a.const_data_ptr()); - const auto a_element_size = a.element_size(); - const auto out_element_size = out.element_size(); char* const data_out = reinterpret_cast(out.mutable_data_ptr()); + const auto out_element_size = out.element_size(); - auto out_numel = out.numel(); - for (const auto i : c10::irange(out_numel)) { - auto result = compute_fun(load_a_to_common(&data_a[i * a_element_size])); - store_common_to_out(result, &data_out[i * out_element_size]); + if (any_is_broadcasted) { + for (const auto& indexes : + BroadcastIndexesRange(out, (*inputs.first)...)) { + std::array loaded_inputs; + for (const auto idx : c10::irange(kNumInputs)) { + const auto& input_info = inputs_info[idx]; + loaded_inputs[idx] = input_info.load_to_common( + &input_info.data_ptr[indexes[idx + 1] * input_info.element_size]); + } + auto result = std::apply(compute_fun, loaded_inputs); + store_common_to_out(result, &data_out[indexes[0] * out_element_size]); + } + } else { + for (const auto i : c10::irange(out.numel())) { + std::array loaded_inputs; + for (const auto idx : c10::irange(kNumInputs)) { + const auto& input_info = inputs_info[idx]; + loaded_inputs[idx] = input_info.load_to_common( + &input_info.data_ptr[i * input_info.element_size]); + } + auto result = std::apply(compute_fun, loaded_inputs); + store_common_to_out(result, &data_out[i * out_element_size]); + } } } +} // namespace internal + +template +inline void apply_unitensor_elementwise_fn( + const Op& compute_fun, + KernelRuntimeContext& ctx, + const Tensor& a, + SupportedTensorDtypes a_dtypes, + const Tensor& out, + SupportedTensorDtypes out_dtypes) { + internal::apply_elementwise_fn( + compute_fun, ctx, out, out_dtypes, std::make_pair(&a, a_dtypes)); +} /** * Useful for bi-tensor elementwise operators. For each element of the inputs, @@ -94,53 +153,13 @@ inline void apply_bitensor_elementwise_fn( SupportedTensorDtypes b_dtypes, const Tensor& out, SupportedTensorDtypes out_dtypes) { - constexpr auto compute_type = CppTypeToScalarType::value; - - ET_KERNEL_CHECK( + internal::apply_elementwise_fn( + compute_fun, ctx, - (internal::check_tensor_dtype(a, a_dtypes, compute_type) && - internal::check_tensor_dtype(b, b_dtypes, compute_type) && - internal::check_tensor_dtype(out, out_dtypes, compute_type)), - InvalidArgument, ); - - const bool a_is_broadcasted = !out.sizes().equals(a.sizes()); - const bool b_is_broadcasted = !out.sizes().equals(b.sizes()); - const bool any_is_broadcasted = (a_is_broadcasted || b_is_broadcasted); - - const auto load_a_to_common = - internal::get_load_to_common_fn(a, a_dtypes); - const auto load_b_to_common = - internal::get_load_to_common_fn(b, b_dtypes); - const auto store_common_to_out = - internal::get_store_common_to_tensor_fn( - out, out_dtypes); - const char* const data_a = reinterpret_cast(a.const_data_ptr()); - const char* const data_b = reinterpret_cast(b.const_data_ptr()); - const auto a_element_size = a.element_size(); - const auto b_element_size = b.element_size(); - const auto out_element_size = out.element_size(); - char* const data_out = reinterpret_cast(out.mutable_data_ptr()); - - auto out_numel = out.numel(); - if (any_is_broadcasted) { - for (const auto [out_index, a_index, b_index] : - BroadcastIndexesRange<2>(out, a, b)) { - auto result = compute_fun( - load_a_to_common(&data_a[a_index * a_element_size]), - load_b_to_common(&data_b[b_index * b_element_size])); - store_common_to_out(result, &data_out[out_index * out_element_size]); - } - } else { - for (const auto i : c10::irange(out_numel)) { - size_t a_linear_index = i; - size_t b_linear_index = i; - - auto result = compute_fun( - load_a_to_common(&data_a[a_linear_index * a_element_size]), - load_b_to_common(&data_b[b_linear_index * b_element_size])); - store_common_to_out(result, &data_out[i * out_element_size]); - } - } + out, + out_dtypes, + std::make_pair(&a, a_dtypes), + std::make_pair(&b, b_dtypes)); } /** @@ -175,63 +194,14 @@ inline void apply_tritensor_elementwise_fn( SupportedTensorDtypes c_dtypes, const Tensor& out, SupportedTensorDtypes out_dtypes) { - constexpr auto compute_type = CppTypeToScalarType::value; - - ET_KERNEL_CHECK( + internal::apply_elementwise_fn( + compute_fun, ctx, - (internal::check_tensor_dtype(a, a_dtypes, compute_type) && - internal::check_tensor_dtype(b, b_dtypes, compute_type) && - internal::check_tensor_dtype(c, c_dtypes, compute_type) && - internal::check_tensor_dtype(out, out_dtypes, compute_type)), - InvalidArgument, ); - - const bool a_is_broadcasted = !out.sizes().equals(a.sizes()); - const bool b_is_broadcasted = !out.sizes().equals(b.sizes()); - const bool c_is_broadcasted = !out.sizes().equals(c.sizes()); - const bool any_is_broadcasted = - (a_is_broadcasted || b_is_broadcasted || c_is_broadcasted); - - const auto load_a_to_common = - internal::get_load_to_common_fn(a, a_dtypes); - const auto load_b_to_common = - internal::get_load_to_common_fn(b, b_dtypes); - const auto load_c_to_common = - internal::get_load_to_common_fn(c, c_dtypes); - const auto store_common_to_out = - internal::get_store_common_to_tensor_fn( - out, out_dtypes); - const char* const data_a = reinterpret_cast(a.const_data_ptr()); - const char* const data_b = reinterpret_cast(b.const_data_ptr()); - const char* const data_c = reinterpret_cast(c.const_data_ptr()); - const auto a_element_size = a.element_size(); - const auto b_element_size = b.element_size(); - const auto c_element_size = c.element_size(); - const auto out_element_size = out.element_size(); - char* const data_out = reinterpret_cast(out.mutable_data_ptr()); - - auto out_numel = out.numel(); - if (any_is_broadcasted) { - for (const auto [out_index, a_index, b_index, c_index] : - BroadcastIndexesRange<3>(out, a, b, c)) { - auto result = compute_fun( - load_a_to_common(&data_a[a_index * a_element_size]), - load_b_to_common(&data_b[b_index * b_element_size]), - load_c_to_common(&data_c[c_index * c_element_size])); - store_common_to_out(result, &data_out[out_index * out_element_size]); - } - } else { - for (const auto i : c10::irange(out_numel)) { - size_t a_linear_index = i; - size_t b_linear_index = i; - size_t c_linear_index = i; - - auto result = compute_fun( - load_a_to_common(&data_a[a_linear_index * a_element_size]), - load_b_to_common(&data_b[b_linear_index * b_element_size]), - load_c_to_common(&data_c[c_linear_index * c_element_size])); - store_common_to_out(result, &data_out[i * out_element_size]); - } - } + out, + out_dtypes, + std::make_pair(&a, a_dtypes), + std::make_pair(&b, b_dtypes), + std::make_pair(&c, c_dtypes)); } inline ScalarType get_compute_type(ScalarType& common_type) { diff --git a/kernels/portable/cpu/util/test/broadcast_indexes_range_test.cpp b/kernels/portable/cpu/util/test/broadcast_indexes_range_test.cpp index f147958558d..519cd9fe9f9 100644 --- a/kernels/portable/cpu/util/test/broadcast_indexes_range_test.cpp +++ b/kernels/portable/cpu/util/test/broadcast_indexes_range_test.cpp @@ -68,6 +68,15 @@ TEST(BroadcastIndexesRangeTest, ScalarBroadcastToOneD) { EXPECT_EQ(expected, actual); } +template +void test_operator_plus(const Range& range) { + size_t idx = 0; + for (const auto indexes : range) { + EXPECT_EQ(*(range.begin() + idx), indexes); + idx++; + } +} + // [1] -> [H, W] // [W] -> [H, W] // [1, 1] -> [H, W] @@ -87,14 +96,15 @@ TEST(BroadcastIndexesRangeTest, OneAndTwoDExhaustive) { Tensor in_not_broadcast = tf.zeros({3, 4}); - auto actual = range_to_vec(BroadcastIndexesRange<6>( + const auto range = BroadcastIndexesRange<6>( out, in_0d_scalar, in_1d_scalar, in_2d_scalar, in_row, in_col, - in_not_broadcast)); + in_not_broadcast); + auto actual = range_to_vec(range); decltype(actual) expected = { {0, 0, 0, 0, 0, 0, 0}, {1, 0, 0, 0, 1, 0, 1}, @@ -110,6 +120,8 @@ TEST(BroadcastIndexesRangeTest, OneAndTwoDExhaustive) { {11, 0, 0, 0, 3, 2, 11}, }; EXPECT_EQ(expected, actual); + + test_operator_plus(range); } // Make sure nothing is thrown off by a size-1 dim in the output: @@ -138,20 +150,20 @@ TEST(BroadcastIndexesRangeTest, OneAndTwoDWith1InOutputShapeExhaustive) { Tensor in_col = tf.zeros({H, 1}); size_t idx = 0; + const auto range_row = BroadcastIndexesRange<5>( + out_row, + in_0d_scalar, + in_1d_scalar, + in_2d_scalar, + in_row, + in_leading_one_row); for (const auto [out_idx, in_0d_idx, in_1d_idx, in_2d_idx, in_row_idx, - in_leading_one_row_idx] : - BroadcastIndexesRange<5>( - out_row, - in_0d_scalar, - in_1d_scalar, - in_2d_scalar, - in_row, - in_leading_one_row)) { + in_leading_one_row_idx] : range_row) { EXPECT_EQ(out_idx, idx++); EXPECT_EQ(in_0d_idx, 0); EXPECT_EQ(in_1d_idx, 0); @@ -160,16 +172,21 @@ TEST(BroadcastIndexesRangeTest, OneAndTwoDWith1InOutputShapeExhaustive) { EXPECT_EQ(in_leading_one_row_idx, out_idx); } + test_operator_plus(range_row); + idx = 0; + const auto range_col = BroadcastIndexesRange<4>( + out_col, in_0d_scalar, in_1d_scalar, in_2d_scalar, in_col); for (const auto [out_idx, in_0d_idx, in_1d_idx, in_2d_idx, in_col_idx] : - BroadcastIndexesRange<4>( - out_col, in_0d_scalar, in_1d_scalar, in_2d_scalar, in_col)) { + range_col) { EXPECT_EQ(out_idx, idx++); EXPECT_EQ(in_0d_idx, 0); EXPECT_EQ(in_1d_idx, 0); EXPECT_EQ(in_2d_idx, 0); EXPECT_EQ(in_col_idx, out_idx); } + + test_operator_plus(range_col); } // [1, 1, 1] -> [C, H, W] @@ -197,16 +214,17 @@ TEST(BroadcastIndexesRangeTest, ThreeDBroadcasting) { // take the opportunity to mutation test against delinearize_index // and linearize_access_indexes. int idx = 0; - for (const auto indexes : BroadcastIndexesRange<8>( - out, - input_tensors[0], - input_tensors[1], - input_tensors[2], - input_tensors[3], - input_tensors[4], - input_tensors[5], - input_tensors[6], - input_tensors[7])) { + const auto range = BroadcastIndexesRange<8>( + out, + input_tensors[0], + input_tensors[1], + input_tensors[2], + input_tensors[3], + input_tensors[4], + input_tensors[5], + input_tensors[6], + input_tensors[7]); + for (const auto indexes : range) { const auto out_idx = indexes[0]; EXPECT_EQ(out_idx, idx++); size_t out_indexes[executorch::runtime::kTensorDimensionLimit]; @@ -219,6 +237,7 @@ TEST(BroadcastIndexesRangeTest, ThreeDBroadcasting) { out_indexes, out.dim(), input_tensors[tensor_idx])); } } + test_operator_plus(range); } // 4-D should generalize, but we will go ahead and test: @@ -235,8 +254,9 @@ void four_d_broadcasting_test() { // take the opportunity to mutation test against delinearize_index // and linearize_access_indexes. int idx = 0; - for (const auto [out_idx, in_cw_idx, in_nh_idx] : - BroadcastIndexesRange<2>(out, in_broadcast_cw, in_broadcast_nh)) { + const auto range = + BroadcastIndexesRange<2>(out, in_broadcast_cw, in_broadcast_nh); + for (const auto [out_idx, in_cw_idx, in_nh_idx] : range) { EXPECT_EQ(out_idx, idx++); size_t out_indexes[executorch::runtime::kTensorDimensionLimit]; delinearize_index( @@ -248,6 +268,8 @@ void four_d_broadcasting_test() { in_nh_idx, linearize_access_indexes(out_indexes, out.dim(), in_broadcast_nh)); } + + test_operator_plus(range); } TEST(BroadcastIndexesRangeTest, FourDBroadcasting) {