Skip to content

Commit 116be9f

Browse files
committed
Adds more supported types to arithmetic reductions
Permits `float` accumulation type with 64 bit integer and unsigned integer inouts to prevent unnecessary copies on devices that don't support double precision
1 parent 7c500ec commit 116be9f

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

dpctl/tensor/libtensor/include/kernels/reductions.hpp

+4
Original file line numberDiff line numberDiff line change
@@ -2806,10 +2806,12 @@ struct TypePairSupportDataForSumReductionTemps
28062806

28072807
// input int64_t
28082808
td_ns::TypePairDefinedEntry<argTy, std::int64_t, outTy, std::int64_t>,
2809+
td_ns::TypePairDefinedEntry<argTy, std::int64_t, outTy, float>,
28092810
td_ns::TypePairDefinedEntry<argTy, std::int64_t, outTy, double>,
28102811

28112812
// input uint64_t
28122813
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, outTy, std::uint64_t>,
2814+
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, outTy, float>,
28132815
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, outTy, double>,
28142816

28152817
// input half
@@ -3077,10 +3079,12 @@ struct TypePairSupportDataForProductReductionTemps
30773079

30783080
// input int64_t
30793081
td_ns::TypePairDefinedEntry<argTy, std::int64_t, outTy, std::int64_t>,
3082+
td_ns::TypePairDefinedEntry<argTy, std::int64_t, outTy, float>,
30803083
td_ns::TypePairDefinedEntry<argTy, std::int64_t, outTy, double>,
30813084

30823085
// input uint32_t
30833086
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, outTy, std::uint64_t>,
3087+
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, outTy, float>,
30843088
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, outTy, double>,
30853089

30863090
// input half

0 commit comments

Comments
 (0)