Skip to content

Commit 40c1b1b

Browse files
committed
Update
[ghstack-poisoned]
1 parent def7ed4 commit 40c1b1b

File tree

1 file changed

+5
-19
lines changed

1 file changed

+5
-19
lines changed

kernels/portable/cpu/util/dtype_util.h

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -173,27 +173,13 @@ get_store_common_to_tensor_fn_bool_or_byte(const Tensor& t) {
173173
template <typename CTYPE_COMMON, const char* op_name>
174174
store_common_to_tensor_fn<CTYPE_COMMON>
175175
get_store_common_to_tensor_fn_same_as_compute(const Tensor& t) {
176-
return internal::convert_and_store<CTYPE_COMMON, CTYPE_COMMON>;
176+
// We already validate tensor types earlier in the process, so at
177+
// this phase, treat same_as_compute the same as our widest
178+
// SupportedTensorDtypes set.
179+
return get_store_common_to_tensor_fn_realhbf16<CTYPE_COMMON, op_name>(t);
177180
}
178181

179-
template <
180-
typename CTYPE_COMMON,
181-
const char* op_name,
182-
std::enable_if_t<std::is_same_v<CTYPE_COMMON, float>, bool> = true>
183-
store_common_to_tensor_fn<CTYPE_COMMON>
184-
get_store_common_to_tensor_fn_same_as_common(const Tensor& t) {
185-
void (*result)(CTYPE_COMMON, void*) = nullptr;
186-
ET_SWITCH_THREE_TYPES(
187-
Float, Half, BFloat16, t.scalar_type(), unused, op_name, CTYPE, [&]() {
188-
result = internal::convert_and_store<CTYPE, CTYPE_COMMON>;
189-
});
190-
return result;
191-
}
192-
193-
template <
194-
typename CTYPE_COMMON,
195-
const char* op_name,
196-
std::enable_if_t<!std::is_same_v<CTYPE_COMMON, float>, bool> = true>
182+
template <typename CTYPE_COMMON, const char* op_name>
197183
store_common_to_tensor_fn<CTYPE_COMMON>
198184
get_store_common_to_tensor_fn_same_as_common(const Tensor& t) {
199185
return get_store_common_to_tensor_fn_same_as_compute<CTYPE_COMMON, op_name>(

0 commit comments

Comments
 (0)