Skip to content

Commit d49b147

Browse files
committed
elementwise_util: don't cast the result of compute_fun back to the common type
The compute function might return an entirely different type. For example, if we were applying a trigonometric function like acos to an input of type bool expecting an output of type float, we would get bad results because acos(0) = 1.57, but casting through bool would truncate that to 1. Note that we don't need the pair of ET_CHECK_MSG I removed because we already check tensor dtypes on entry to the elementwise util functions; the checks were inconvenient because we now call get_store_common_to_tensor_fn without the actual common type. ghstack-source-id: 9f73cfa ghstack-comment-id: 2735017325 Pull Request resolved: #9385
1 parent 77b624b commit d49b147

File tree

2 files changed

+24
-35
lines changed

2 files changed

+24
-35
lines changed

kernels/portable/cpu/util/dtype_util.h

Lines changed: 5 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,6 @@ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_bool_or_byte(
8686
template <typename CTYPE_COMMON, const char* op_name>
8787
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_same_as_compute(
8888
const Tensor& t) {
89-
constexpr auto common_scalar_type = CppTypeToScalarType<CTYPE_COMMON>::value;
90-
ET_CHECK_MSG(
91-
t.scalar_type() == common_scalar_type,
92-
"Unhandled dtype %s for %s",
93-
::executorch::runtime::toString(common_scalar_type),
94-
op_name);
9589
return internal::load_and_convert<CTYPE_COMMON, CTYPE_COMMON>;
9690
}
9791

@@ -179,33 +173,13 @@ get_store_common_to_tensor_fn_bool_or_byte(const Tensor& t) {
179173
template <typename CTYPE_COMMON, const char* op_name>
180174
store_common_to_tensor_fn<CTYPE_COMMON>
181175
get_store_common_to_tensor_fn_same_as_compute(const Tensor& t) {
182-
constexpr auto common_scalar_type = CppTypeToScalarType<CTYPE_COMMON>::value;
183-
ET_CHECK_MSG(
184-
t.scalar_type() == common_scalar_type,
185-
"Unhandled dtype %s for %s",
186-
::executorch::runtime::toString(common_scalar_type),
187-
op_name);
188-
return internal::convert_and_store<CTYPE_COMMON, CTYPE_COMMON>;
176+
// We already validate tensor types earlier in the process, so at
177+
// this phase, treat same_as_compute the same as our widest
178+
// SupportedTensorDtypes set.
179+
return get_store_common_to_tensor_fn_realhbf16<CTYPE_COMMON, op_name>(t);
189180
}
190181

191-
template <
192-
typename CTYPE_COMMON,
193-
const char* op_name,
194-
std::enable_if_t<std::is_same_v<CTYPE_COMMON, float>, bool> = true>
195-
store_common_to_tensor_fn<CTYPE_COMMON>
196-
get_store_common_to_tensor_fn_same_as_common(const Tensor& t) {
197-
void (*result)(CTYPE_COMMON, void*) = nullptr;
198-
ET_SWITCH_THREE_TYPES(
199-
Float, Half, BFloat16, t.scalar_type(), unused, op_name, CTYPE, [&]() {
200-
result = internal::convert_and_store<CTYPE, CTYPE_COMMON>;
201-
});
202-
return result;
203-
}
204-
205-
template <
206-
typename CTYPE_COMMON,
207-
const char* op_name,
208-
std::enable_if_t<!std::is_same_v<CTYPE_COMMON, float>, bool> = true>
182+
template <typename CTYPE_COMMON, const char* op_name>
209183
store_common_to_tensor_fn<CTYPE_COMMON>
210184
get_store_common_to_tensor_fn_same_as_common(const Tensor& t) {
211185
return get_store_common_to_tensor_fn_same_as_compute<CTYPE_COMMON, op_name>(

kernels/portable/cpu/util/elementwise_util.h

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,13 @@ inline int64_t scalar_to<int64_t>(const Scalar& s) {
5151
}
5252

5353
namespace internal {
54+
template <typename Ignore, typename T>
55+
using ignore_first_yield_second = T;
56+
57+
template <typename CTYPE_COMMON, typename Op, typename... Args>
58+
using op_call_result =
59+
std::invoke_result_t<Op, ignore_first_yield_second<Args, CTYPE_COMMON>...>;
60+
5461
template <
5562
typename CTYPE_COMMON,
5663
const char* op_name,
@@ -89,9 +96,16 @@ inline void apply_elementwise_fn(
8996
inputs.first->element_size(),
9097
})...};
9198

92-
const auto store_common_to_out =
93-
internal::get_store_common_to_tensor_fn<CTYPE_COMMON, op_name>(
94-
out, out_dtypes);
99+
// NOTE: the result of compute_fun is not necessarily CTYPE_COMMON!
100+
// For example, consider the possibility that compute_fun is a
101+
// trigonometric function like acos, the common input type is bool,
102+
// and the output type is float -- we would truncate acos(0) ~= 1.67
103+
// to just 1. Conveniently, it costs us nothing at runtime to handle
104+
// this correctly.
105+
const auto store_compute_result_to_out =
106+
internal::get_store_common_to_tensor_fn<
107+
op_call_result<CTYPE_COMMON, Op, Args...>,
108+
op_name>(out, out_dtypes);
95109
char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr());
96110
const auto out_element_size = out.element_size();
97111

@@ -114,7 +128,8 @@ inline void apply_elementwise_fn(
114128
.data_ptr[indexes[idx + 1] * input_info.element_size]);
115129
}
116130
auto result = std::apply(compute_fun, loaded_inputs);
117-
store_common_to_out(result, &data_out[indexes[0] * out_element_size]);
131+
store_compute_result_to_out(
132+
result, &data_out[indexes[0] * out_element_size]);
118133
}
119134
});
120135
}

0 commit comments

Comments
 (0)