Skip to content

Commit 44ee51a

Browse files
committed
Update
[ghstack-poisoned]
1 parent e49080d commit 44ee51a

File tree

2 files changed

+54
-0
lines changed

2 files changed

+54
-0
lines changed

kernels/portable/cpu/util/dtype_util.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,14 @@ bool check_tensor_dtype(
2323
return executorch::runtime::tensor_is_realhbbf16_type(t);
2424
case SupportedTensorDtypes::REALHBF16:
2525
return executorch::runtime::tensor_is_realhbf16_type(t);
26+
case SupportedTensorDtypes::REALH:
27+
return executorch::runtime::tensor_is_realh_type(t);
2628
case SupportedTensorDtypes::FLOATHBF16:
2729
return executorch::runtime::tensor_is_floating_type(t);
2830
case SupportedTensorDtypes::INTB:
2931
return executorch::runtime::tensor_is_integral_type(t, true);
32+
case SupportedTensorDtypes::BOOL:
33+
return executorch::runtime::tensor_is_type(t, ScalarType::Bool);
3034
case SupportedTensorDtypes::BOOL_OR_BYTE:
3135
return (executorch::runtime::tensor_is_type(
3236
t, ScalarType::Bool, ScalarType::Byte));

kernels/portable/cpu/util/dtype_util.h

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,15 @@ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_realhbf16(
5151
return result;
5252
}
5353

54+
template <typename CTYPE_COMMON, const char* op_name>
55+
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_realh(const Tensor& t) {
56+
CTYPE_COMMON (*result)(const void*) = nullptr;
57+
ET_SWITCH_REALH_TYPES(t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
58+
result = internal::load_and_convert<CTYPE_COMMON, TENSOR_CTYPE>;
59+
});
60+
return result;
61+
}
62+
5463
template <typename CTYPE_COMMON, const char* op_name>
5564
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_floathbf16(
5665
const Tensor& t) {
@@ -72,6 +81,16 @@ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_intb(const Tensor& t) {
7281
return result;
7382
}
7483

84+
template <typename CTYPE_COMMON, const char* op_name>
85+
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_bool(const Tensor& t) {
86+
ET_CHECK_MSG(
87+
t.scalar_type() == ScalarType::Bool,
88+
"Unhandled dtype %s for %s",
89+
::executorch::runtime::toString(t.scalar_type()),
90+
op_name);
91+
return internal::load_and_convert<CTYPE_COMMON, bool>;
92+
}
93+
7594
template <typename CTYPE_COMMON, const char* op_name>
7695
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_bool_or_byte(
7796
const Tensor& t) {
@@ -137,6 +156,16 @@ store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn_realhbf16(
137156
return result;
138157
}
139158

159+
template <typename CTYPE_COMMON, const char* op_name>
160+
store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn_realh(
161+
const Tensor& t) {
162+
void (*result)(CTYPE_COMMON, void*) = nullptr;
163+
ET_SWITCH_REALH_TYPES(t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
164+
result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON>;
165+
});
166+
return result;
167+
}
168+
140169
template <typename CTYPE_COMMON, const char* op_name>
141170
store_common_to_tensor_fn<CTYPE_COMMON>
142171
get_store_common_to_tensor_fn_floathbf16(const Tensor& t) {
@@ -159,6 +188,17 @@ store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn_intb(
159188
return result;
160189
}
161190

191+
template <typename CTYPE_COMMON, const char* op_name>
192+
store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn_bool(
193+
const Tensor& t) {
194+
ET_CHECK_MSG(
195+
t.scalar_type() == ScalarType::Bool,
196+
"Unhandled dtype %s for %s",
197+
::executorch::runtime::toString(t.scalar_type()),
198+
op_name);
199+
return internal::convert_and_store<bool, CTYPE_COMMON>;
200+
}
201+
162202
template <typename CTYPE_COMMON, const char* op_name>
163203
store_common_to_tensor_fn<CTYPE_COMMON>
164204
get_store_common_to_tensor_fn_bool_or_byte(const Tensor& t) {
@@ -191,8 +231,10 @@ get_store_common_to_tensor_fn_same_as_common(const Tensor& t) {
191231
enum class SupportedTensorDtypes {
192232
REALHBBF16,
193233
REALHBF16,
234+
REALH,
194235
FLOATHBF16,
195236
INTB,
237+
BOOL,
196238
BOOL_OR_BYTE,
197239
SAME_AS_COMPUTE,
198240
SAME_AS_COMMON,
@@ -209,10 +251,14 @@ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn(
209251
return get_load_to_common_fn_realhbbf16<CTYPE_COMMON, op_name>(t);
210252
case SupportedTensorDtypes::REALHBF16:
211253
return get_load_to_common_fn_realhbf16<CTYPE_COMMON, op_name>(t);
254+
case SupportedTensorDtypes::REALH:
255+
return get_load_to_common_fn_realh<CTYPE_COMMON, op_name>(t);
212256
case SupportedTensorDtypes::FLOATHBF16:
213257
return get_load_to_common_fn_realhbf16<CTYPE_COMMON, op_name>(t);
214258
case SupportedTensorDtypes::INTB:
215259
return get_load_to_common_fn_intb<CTYPE_COMMON, op_name>(t);
260+
case SupportedTensorDtypes::BOOL:
261+
return get_load_to_common_fn_bool<CTYPE_COMMON, op_name>(t);
216262
case SupportedTensorDtypes::BOOL_OR_BYTE:
217263
return get_load_to_common_fn_bool_or_byte<CTYPE_COMMON, op_name>(t);
218264
case SupportedTensorDtypes::SAME_AS_COMPUTE:
@@ -233,10 +279,14 @@ store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn(
233279
return get_store_common_to_tensor_fn_realhbbf16<CTYPE_COMMON, op_name>(t);
234280
case SupportedTensorDtypes::REALHBF16:
235281
return get_store_common_to_tensor_fn_realhbf16<CTYPE_COMMON, op_name>(t);
282+
case SupportedTensorDtypes::REALH:
283+
return get_store_common_to_tensor_fn_realh<CTYPE_COMMON, op_name>(t);
236284
case SupportedTensorDtypes::FLOATHBF16:
237285
return get_store_common_to_tensor_fn_floathbf16<CTYPE_COMMON, op_name>(t);
238286
case SupportedTensorDtypes::INTB:
239287
return get_store_common_to_tensor_fn_intb<CTYPE_COMMON, op_name>(t);
288+
case SupportedTensorDtypes::BOOL:
289+
return get_store_common_to_tensor_fn_bool<CTYPE_COMMON, op_name>(t);
240290
case SupportedTensorDtypes::BOOL_OR_BYTE:
241291
return get_store_common_to_tensor_fn_bool_or_byte<CTYPE_COMMON, op_name>(
242292
t);

0 commit comments

Comments
 (0)