Skip to content

Commit 645044a

Browse files
Instantiate atomic reduction templates for min/max ops for double/float types
Added entries for float and double types to TypePairSupportDataForCompReductionAtomic as spotted by @ndgrigorian in the PR review. Also moved comments around.
1 parent 41ec378 commit 645044a

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

dpctl/tensor/libtensor/include/kernels/reductions.hpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2247,11 +2247,10 @@ template <typename argTy, typename outTy>
22472247
struct TypePairSupportDataForCompReductionAtomic
22482248
{
22492249

2250-
/* value if true a kernel for <argTy, outTy> must be instantiated, false
2250+
/* value is true if a kernel for <argTy, outTy> must be instantiated, false
22512251
* otherwise */
2252-
static constexpr bool is_defined = std::disjunction< // disjunction is C++17
2253-
// feature, supported
2254-
// by DPC++
2252+
// disjunction is C++17 feature, supported by DPC++
2253+
static constexpr bool is_defined = std::disjunction<
22552254
// input int32
22562255
td_ns::TypePairDefinedEntry<argTy, std::int32_t, outTy, std::int32_t>,
22572256
// input uint32
@@ -2260,6 +2259,10 @@ struct TypePairSupportDataForCompReductionAtomic
22602259
td_ns::TypePairDefinedEntry<argTy, std::int64_t, outTy, std::int64_t>,
22612260
// input uint64
22622261
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, outTy, std::uint64_t>,
2262+
// input float
2263+
td_ns::TypePairDefinedEntry<argTy, float, outTy, float>,
2264+
// input double
2265+
td_ns::TypePairDefinedEntry<argTy, double, outTy, double>,
22632266
// fall-through
22642267
td_ns::NotDefinedEntry>::is_defined;
22652268
};
@@ -2268,19 +2271,17 @@ template <typename argTy, typename outTy>
22682271
struct TypePairSupportDataForCompReductionTemps
22692272
{
22702273

2271-
static constexpr bool is_defined = std::disjunction< // disjunction is C++17
2272-
// feature, supported
2273-
// by DPC++ input bool
2274+
// disjunction is C++17 feature, supported by DPC++
2275+
static constexpr bool is_defined = std::disjunction<
2276+
// input bool
22742277
td_ns::TypePairDefinedEntry<argTy, bool, outTy, bool>,
22752278
// input int8_t
22762279
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, std::int8_t>,
2277-
22782280
// input uint8_t
22792281
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::uint8_t>,
22802282

22812283
// input int16_t
22822284
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, std::int16_t>,
2283-
22842285
// input uint16_t
22852286
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, std::uint16_t>,
22862287

0 commit comments

Comments
 (0)