@@ -173,27 +173,13 @@ get_store_common_to_tensor_fn_bool_or_byte(const Tensor& t) {
173
173
template <typename CTYPE_COMMON, const char * op_name>
174
174
store_common_to_tensor_fn<CTYPE_COMMON>
175
175
get_store_common_to_tensor_fn_same_as_compute (const Tensor& t) {
176
- return internal::convert_and_store<CTYPE_COMMON, CTYPE_COMMON>;
176
+ // We already validate tensor types earlier in the process, so at
177
+ // this phase, treat same_as_compute the same as our widest
178
+ // SupportedTensorDtypes set.
179
+ return get_store_common_to_tensor_fn_realhbf16<CTYPE_COMMON, op_name>(t);
177
180
}
178
181
179
- template <
180
- typename CTYPE_COMMON,
181
- const char * op_name,
182
- std::enable_if_t <std::is_same_v<CTYPE_COMMON, float >, bool > = true >
183
- store_common_to_tensor_fn<CTYPE_COMMON>
184
- get_store_common_to_tensor_fn_same_as_common (const Tensor& t) {
185
- void (*result)(CTYPE_COMMON, void *) = nullptr ;
186
- ET_SWITCH_THREE_TYPES (
187
- Float, Half, BFloat16, t.scalar_type (), unused, op_name, CTYPE, [&]() {
188
- result = internal::convert_and_store<CTYPE, CTYPE_COMMON>;
189
- });
190
- return result;
191
- }
192
-
193
- template <
194
- typename CTYPE_COMMON,
195
- const char * op_name,
196
- std::enable_if_t <!std::is_same_v<CTYPE_COMMON, float >, bool > = true >
182
+ template <typename CTYPE_COMMON, const char * op_name>
197
183
store_common_to_tensor_fn<CTYPE_COMMON>
198
184
get_store_common_to_tensor_fn_same_as_common (const Tensor& t) {
199
185
return get_store_common_to_tensor_fn_same_as_compute<CTYPE_COMMON, op_name>(
0 commit comments