@@ -51,6 +51,15 @@ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_realhbf16(
51
51
return result;
52
52
}
53
53
54
+ template <typename CTYPE_COMMON, const char * op_name>
55
+ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_realh (const Tensor& t) {
56
+ CTYPE_COMMON (*result)(const void *) = nullptr ;
57
+ ET_SWITCH_REALH_TYPES (t.scalar_type (), unused, op_name, TENSOR_CTYPE, [&]() {
58
+ result = internal::load_and_convert<CTYPE_COMMON, TENSOR_CTYPE>;
59
+ });
60
+ return result;
61
+ }
62
+
54
63
template <typename CTYPE_COMMON, const char * op_name>
55
64
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_floathbf16 (
56
65
const Tensor& t) {
@@ -72,6 +81,16 @@ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_intb(const Tensor& t) {
72
81
return result;
73
82
}
74
83
84
+ template <typename CTYPE_COMMON, const char * op_name>
85
+ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_bool (const Tensor& t) {
86
+ ET_CHECK_MSG (
87
+ t.scalar_type () == ScalarType::Bool,
88
+ " Unhandled dtype %s for %s" ,
89
+ ::executorch::runtime::toString (t.scalar_type()),
90
+ op_name);
91
+ return internal::load_and_convert<CTYPE_COMMON, bool >;
92
+ }
93
+
75
94
template <typename CTYPE_COMMON, const char * op_name>
76
95
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_bool_or_byte (
77
96
const Tensor& t) {
@@ -137,6 +156,16 @@ store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn_realhbf16(
137
156
return result;
138
157
}
139
158
159
+ template <typename CTYPE_COMMON, const char * op_name>
160
+ store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn_realh (
161
+ const Tensor& t) {
162
+ void (*result)(CTYPE_COMMON, void *) = nullptr ;
163
+ ET_SWITCH_REALH_TYPES (t.scalar_type (), unused, op_name, TENSOR_CTYPE, [&]() {
164
+ result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON>;
165
+ });
166
+ return result;
167
+ }
168
+
140
169
template <typename CTYPE_COMMON, const char * op_name>
141
170
store_common_to_tensor_fn<CTYPE_COMMON>
142
171
get_store_common_to_tensor_fn_floathbf16 (const Tensor& t) {
@@ -159,6 +188,17 @@ store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn_intb(
159
188
return result;
160
189
}
161
190
191
+ template <typename CTYPE_COMMON, const char * op_name>
192
+ store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn_bool (
193
+ const Tensor& t) {
194
+ ET_CHECK_MSG (
195
+ t.scalar_type () == ScalarType::Bool,
196
+ " Unhandled dtype %s for %s" ,
197
+ ::executorch::runtime::toString (t.scalar_type()),
198
+ op_name);
199
+ return internal::convert_and_store<bool , CTYPE_COMMON>;
200
+ }
201
+
162
202
template <typename CTYPE_COMMON, const char * op_name>
163
203
store_common_to_tensor_fn<CTYPE_COMMON>
164
204
get_store_common_to_tensor_fn_bool_or_byte (const Tensor& t) {
@@ -191,8 +231,10 @@ get_store_common_to_tensor_fn_same_as_common(const Tensor& t) {
191
231
enum class SupportedTensorDtypes {
192
232
REALHBBF16,
193
233
REALHBF16,
234
+ REALH,
194
235
FLOATHBF16,
195
236
INTB,
237
+ BOOL,
196
238
BOOL_OR_BYTE,
197
239
SAME_AS_COMPUTE,
198
240
SAME_AS_COMMON,
@@ -209,10 +251,14 @@ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn(
209
251
return get_load_to_common_fn_realhbbf16<CTYPE_COMMON, op_name>(t);
210
252
case SupportedTensorDtypes::REALHBF16:
211
253
return get_load_to_common_fn_realhbf16<CTYPE_COMMON, op_name>(t);
254
+ case SupportedTensorDtypes::REALH:
255
+ return get_load_to_common_fn_realh<CTYPE_COMMON, op_name>(t);
212
256
case SupportedTensorDtypes::FLOATHBF16:
213
257
return get_load_to_common_fn_realhbf16<CTYPE_COMMON, op_name>(t);
214
258
case SupportedTensorDtypes::INTB:
215
259
return get_load_to_common_fn_intb<CTYPE_COMMON, op_name>(t);
260
+ case SupportedTensorDtypes::BOOL:
261
+ return get_load_to_common_fn_bool<CTYPE_COMMON, op_name>(t);
216
262
case SupportedTensorDtypes::BOOL_OR_BYTE:
217
263
return get_load_to_common_fn_bool_or_byte<CTYPE_COMMON, op_name>(t);
218
264
case SupportedTensorDtypes::SAME_AS_COMPUTE:
@@ -233,10 +279,14 @@ store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn(
233
279
return get_store_common_to_tensor_fn_realhbbf16<CTYPE_COMMON, op_name>(t);
234
280
case SupportedTensorDtypes::REALHBF16:
235
281
return get_store_common_to_tensor_fn_realhbf16<CTYPE_COMMON, op_name>(t);
282
+ case SupportedTensorDtypes::REALH:
283
+ return get_store_common_to_tensor_fn_realh<CTYPE_COMMON, op_name>(t);
236
284
case SupportedTensorDtypes::FLOATHBF16:
237
285
return get_store_common_to_tensor_fn_floathbf16<CTYPE_COMMON, op_name>(t);
238
286
case SupportedTensorDtypes::INTB:
239
287
return get_store_common_to_tensor_fn_intb<CTYPE_COMMON, op_name>(t);
288
+ case SupportedTensorDtypes::BOOL:
289
+ return get_store_common_to_tensor_fn_bool<CTYPE_COMMON, op_name>(t);
240
290
case SupportedTensorDtypes::BOOL_OR_BYTE:
241
291
return get_store_common_to_tensor_fn_bool_or_byte<CTYPE_COMMON, op_name>(
242
292
t);
0 commit comments