diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor_divide.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor_divide.hpp index 268c679f00..32e97df58d 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor_divide.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor_divide.hpp @@ -57,7 +57,12 @@ struct FloorDivideFunctor resT operator()(const argT1 &in1, const argT2 &in2) { - if constexpr (std::is_integral_v || std::is_integral_v) { + if constexpr (std::is_same_v && + std::is_same_v) { + return (in2) ? static_cast(in1) : resT(0); + } + else if constexpr (std::is_integral_v || + std::is_integral_v) { if (in2 == argT2(0)) { return resT(0); } @@ -81,7 +86,16 @@ struct FloorDivideFunctor sycl::vec operator()(const sycl::vec &in1, const sycl::vec &in2) { - if constexpr (std::is_integral_v) { + if constexpr (std::is_same_v && + std::is_same_v) { + sycl::vec res; +#pragma unroll + for (int i = 0; i < vec_sz; ++i) { + res[i] = (in2[i]) ? static_cast(in1[i]) : resT(0); + } + return res; + } + else if constexpr (std::is_integral_v) { sycl::vec res; #pragma unroll for (int i = 0; i < vec_sz; ++i) {