Skip to content

Commit 9fe1283

Browse files
committed
Implement unary_ufunc functions using elementwise_util
One less set of independent implementations to worry about going forward (e.g., we don't have to vectorize these separately from elementwise_util and they get all benefits of elementwise_util). ghstack-source-id: ebf059a ghstack-comment-id: 2735017402 Pull Request resolved: #9386
1 parent d49b147 commit 9fe1283

File tree

6 files changed

+95
-33
lines changed

6 files changed

+95
-33
lines changed

kernels/portable/cpu/pattern/targets.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def define_common_targets():
6060
compiler_flags = ["-Wno-missing-prototypes"],
6161
deps = [
6262
"//executorch/kernels/portable/cpu/util:broadcast_util",
63-
"//executorch/kernels/portable/cpu/util:functional_util",
63+
"//executorch/kernels/portable/cpu/util:elementwise_util",
6464
"//executorch/runtime/kernel:kernel_includes",
6565
],
6666
visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/optimized/cpu/..."],

kernels/portable/cpu/pattern/unary_ufunc_realh.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
*/
88

99
#include <executorch/kernels/portable/cpu/pattern/pattern.h>
10-
#include <executorch/kernels/portable/cpu/util/functional_util.h>
10+
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
1111
#include <executorch/runtime/kernel/kernel_includes.h>
1212

1313
namespace torch {
@@ -36,12 +36,19 @@ Tensor& unary_ufunc_realh(
3636
ET_KERNEL_CHECK(
3737
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
3838

39-
ET_SWITCH_REALH_TYPES(in.scalar_type(), ctx, __func__, CTYPE, [&] {
40-
apply_unary_map_fn(
39+
// TODO: this is broken for dtype_selective_build: this was
40+
// __func__, which isn't the operator name.
41+
// @lint-ignore CLANGTIDY facebook-hte-CArray
42+
static constexpr const char op_name[] = "unary_ufunc_realh";
43+
44+
ET_SWITCH_REALH_TYPES(in.scalar_type(), ctx, op_name, CTYPE, [&] {
45+
utils::apply_unitensor_elementwise_fn<CTYPE, op_name>(
4146
[fn](const CTYPE val_in) { return static_cast<CTYPE>(fn(val_in)); },
42-
in.const_data_ptr<CTYPE>(),
43-
out.mutable_data_ptr<CTYPE>(),
44-
in.numel());
47+
ctx,
48+
in,
49+
utils::SupportedTensorDtypes::REALH,
50+
out,
51+
utils::SupportedTensorDtypes::SAME_AS_COMMON);
4552
});
4653

4754
return out;

kernels/portable/cpu/pattern/unary_ufunc_realhb_to_bool.cpp

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
*/
88

99
#include <executorch/kernels/portable/cpu/pattern/pattern.h>
10-
#include <executorch/kernels/portable/cpu/util/functional_util.h>
10+
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
1111
#include <executorch/runtime/kernel/kernel_includes.h>
1212

1313
namespace torch {
@@ -30,25 +30,23 @@ Tensor& unary_ufunc_realhb_to_bool(
3030
out,
3131
"Failed to resize output tensor.");
3232

33-
ET_KERNEL_CHECK_MSG(
34-
ctx,
35-
out.scalar_type() == executorch::aten::ScalarType::Bool,
36-
InvalidArgument,
37-
out,
38-
"Expected out tensor to have dtype Bool, but got %" PRId8 " instead.",
39-
static_cast<int8_t>(out.scalar_type()));
40-
4133
ET_KERNEL_CHECK(
4234
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
4335

4436
const auto in_type = in.scalar_type();
4537

46-
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, __func__, CTYPE_IN, [&] {
47-
apply_unary_map_fn(
38+
// TODO: this is broken for dtype_selective_build: this was
39+
// __func__, which isn't the operator name.
40+
// @lint-ignore CLANGTIDY facebook-hte-CArray
41+
static constexpr const char op_name[] = "unary_ufunc_realhb_to_bool";
42+
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, op_name, CTYPE_IN, [&] {
43+
utils::apply_unitensor_elementwise_fn<CTYPE_IN, op_name>(
4844
[fn](const CTYPE_IN val_in) { return fn(val_in); },
49-
in.const_data_ptr<CTYPE_IN>(),
50-
out.mutable_data_ptr<bool>(),
51-
in.numel());
45+
ctx,
46+
in,
47+
utils::SupportedTensorDtypes::REALHBBF16,
48+
out,
49+
utils::SupportedTensorDtypes::BOOL);
5250
});
5351

5452
return out;

kernels/portable/cpu/pattern/unary_ufunc_realhbbf16_to_floathbf16.cpp

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
*/
88

99
#include <executorch/kernels/portable/cpu/pattern/pattern.h>
10-
#include <executorch/kernels/portable/cpu/util/functional_util.h>
10+
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
1111
#include <executorch/runtime/kernel/kernel_includes.h>
1212

1313
namespace torch {
@@ -38,17 +38,20 @@ Tensor& unary_ufunc_realhbbf16_to_floathbf16(
3838
const auto in_type = in.scalar_type();
3939
const auto out_type = out.scalar_type();
4040

41-
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, __func__, CTYPE_IN, [&] {
42-
ET_SWITCH_FLOATHBF16_TYPES(out_type, ctx, __func__, CTYPE_OUT, [&] {
43-
apply_unary_map_fn(
44-
[fn](const CTYPE_IN val_in) {
45-
CTYPE_OUT xi = static_cast<CTYPE_OUT>(val_in);
46-
return static_cast<CTYPE_OUT>(fn(xi));
47-
},
48-
in.const_data_ptr<CTYPE_IN>(),
49-
out.mutable_data_ptr<CTYPE_OUT>(),
50-
in.numel());
51-
});
41+
// TODO: this is broken for dtype_selective_build: this was
42+
// __func__, which isn't the operator name.
43+
// @lint-ignore CLANGTIDY facebook-hte-CArray
44+
static constexpr const char op_name[] =
45+
"unary_ufunc_realhbbf16_to_floathbf16";
46+
47+
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, op_name, CTYPE_IN, [&] {
48+
utils::apply_unitensor_elementwise_fn<CTYPE_IN, op_name>(
49+
[fn](const CTYPE_IN val_in) { return fn(val_in); },
50+
ctx,
51+
in,
52+
utils::SupportedTensorDtypes::REALHBBF16,
53+
out,
54+
utils::SupportedTensorDtypes::FLOATHBF16);
5255
});
5356

5457
return out;

kernels/portable/cpu/util/dtype_util.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,14 @@ bool check_tensor_dtype(
2323
return executorch::runtime::tensor_is_realhbbf16_type(t);
2424
case SupportedTensorDtypes::REALHBF16:
2525
return executorch::runtime::tensor_is_realhbf16_type(t);
26+
case SupportedTensorDtypes::REALH:
27+
return executorch::runtime::tensor_is_realh_type(t);
2628
case SupportedTensorDtypes::FLOATHBF16:
2729
return executorch::runtime::tensor_is_floating_type(t);
2830
case SupportedTensorDtypes::INTB:
2931
return executorch::runtime::tensor_is_integral_type(t, true);
32+
case SupportedTensorDtypes::BOOL:
33+
return executorch::runtime::tensor_is_type(t, ScalarType::Bool);
3034
case SupportedTensorDtypes::BOOL_OR_BYTE:
3135
return (executorch::runtime::tensor_is_type(
3236
t, ScalarType::Bool, ScalarType::Byte));

kernels/portable/cpu/util/dtype_util.h

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,15 @@ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_realhbf16(
5151
return result;
5252
}
5353

54+
template <typename CTYPE_COMMON, const char* op_name>
55+
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_realh(const Tensor& t) {
56+
CTYPE_COMMON (*result)(const void*) = nullptr;
57+
ET_SWITCH_REALH_TYPES(t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
58+
result = internal::load_and_convert<CTYPE_COMMON, TENSOR_CTYPE>;
59+
});
60+
return result;
61+
}
62+
5463
template <typename CTYPE_COMMON, const char* op_name>
5564
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_floathbf16(
5665
const Tensor& t) {
@@ -72,6 +81,16 @@ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_intb(const Tensor& t) {
7281
return result;
7382
}
7483

84+
template <typename CTYPE_COMMON, const char* op_name>
85+
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_bool(const Tensor& t) {
86+
ET_CHECK_MSG(
87+
t.scalar_type() == ScalarType::Bool,
88+
"Unhandled dtype %s for %s",
89+
::executorch::runtime::toString(t.scalar_type()),
90+
op_name);
91+
return internal::load_and_convert<CTYPE_COMMON, bool>;
92+
}
93+
7594
template <typename CTYPE_COMMON, const char* op_name>
7695
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_bool_or_byte(
7796
const Tensor& t) {
@@ -137,6 +156,16 @@ store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn_realhbf16(
137156
return result;
138157
}
139158

159+
template <typename CTYPE_COMMON, const char* op_name>
160+
store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn_realh(
161+
const Tensor& t) {
162+
void (*result)(CTYPE_COMMON, void*) = nullptr;
163+
ET_SWITCH_REALH_TYPES(t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
164+
result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON>;
165+
});
166+
return result;
167+
}
168+
140169
template <typename CTYPE_COMMON, const char* op_name>
141170
store_common_to_tensor_fn<CTYPE_COMMON>
142171
get_store_common_to_tensor_fn_floathbf16(const Tensor& t) {
@@ -159,6 +188,17 @@ store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn_intb(
159188
return result;
160189
}
161190

191+
template <typename CTYPE_COMMON, const char* op_name>
192+
store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn_bool(
193+
const Tensor& t) {
194+
ET_CHECK_MSG(
195+
t.scalar_type() == ScalarType::Bool,
196+
"Unhandled dtype %s for %s",
197+
::executorch::runtime::toString(t.scalar_type()),
198+
op_name);
199+
return internal::convert_and_store<bool, CTYPE_COMMON>;
200+
}
201+
162202
template <typename CTYPE_COMMON, const char* op_name>
163203
store_common_to_tensor_fn<CTYPE_COMMON>
164204
get_store_common_to_tensor_fn_bool_or_byte(const Tensor& t) {
@@ -191,8 +231,10 @@ get_store_common_to_tensor_fn_same_as_common(const Tensor& t) {
191231
enum class SupportedTensorDtypes {
192232
REALHBBF16,
193233
REALHBF16,
234+
REALH,
194235
FLOATHBF16,
195236
INTB,
237+
BOOL,
196238
BOOL_OR_BYTE,
197239
SAME_AS_COMPUTE,
198240
SAME_AS_COMMON,
@@ -209,10 +251,14 @@ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn(
209251
return get_load_to_common_fn_realhbbf16<CTYPE_COMMON, op_name>(t);
210252
case SupportedTensorDtypes::REALHBF16:
211253
return get_load_to_common_fn_realhbf16<CTYPE_COMMON, op_name>(t);
254+
case SupportedTensorDtypes::REALH:
255+
return get_load_to_common_fn_realh<CTYPE_COMMON, op_name>(t);
212256
case SupportedTensorDtypes::FLOATHBF16:
213257
return get_load_to_common_fn_realhbf16<CTYPE_COMMON, op_name>(t);
214258
case SupportedTensorDtypes::INTB:
215259
return get_load_to_common_fn_intb<CTYPE_COMMON, op_name>(t);
260+
case SupportedTensorDtypes::BOOL:
261+
return get_load_to_common_fn_bool<CTYPE_COMMON, op_name>(t);
216262
case SupportedTensorDtypes::BOOL_OR_BYTE:
217263
return get_load_to_common_fn_bool_or_byte<CTYPE_COMMON, op_name>(t);
218264
case SupportedTensorDtypes::SAME_AS_COMPUTE:
@@ -233,10 +279,14 @@ store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn(
233279
return get_store_common_to_tensor_fn_realhbbf16<CTYPE_COMMON, op_name>(t);
234280
case SupportedTensorDtypes::REALHBF16:
235281
return get_store_common_to_tensor_fn_realhbf16<CTYPE_COMMON, op_name>(t);
282+
case SupportedTensorDtypes::REALH:
283+
return get_store_common_to_tensor_fn_realh<CTYPE_COMMON, op_name>(t);
236284
case SupportedTensorDtypes::FLOATHBF16:
237285
return get_store_common_to_tensor_fn_floathbf16<CTYPE_COMMON, op_name>(t);
238286
case SupportedTensorDtypes::INTB:
239287
return get_store_common_to_tensor_fn_intb<CTYPE_COMMON, op_name>(t);
288+
case SupportedTensorDtypes::BOOL:
289+
return get_store_common_to_tensor_fn_bool<CTYPE_COMMON, op_name>(t);
240290
case SupportedTensorDtypes::BOOL_OR_BYTE:
241291
return get_store_common_to_tensor_fn_bool_or_byte<CTYPE_COMMON, op_name>(
242292
t);

0 commit comments

Comments
 (0)