Skip to content

Commit e885838

Browse files
authored
Dedicated kernels for in-place dpt.divide and dpt.floor_divide (#1431)
* Implements dedicated kernels for in-place division Includes floor division and true division * Adds tests for inplace division behavior * Adds a `static_assert` check to TrueDivideInplaceTypeMapFactory Checks that the result type is either the same as the third template parameter, or none Adds a comment to TrueDivideInplaceOutputType
1 parent 39e0700 commit e885838

File tree

6 files changed

+657
-23
lines changed

6 files changed

+657
-23
lines changed

dpctl/tensor/_elementwise_funcs.py

+2
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,7 @@
590590
ti._divide_result_type,
591591
ti._divide,
592592
_divide_docstring_,
593+
binary_inplace_fn=ti._divide_inplace,
593594
acceptance_fn=_acceptance_fn_divide,
594595
)
595596

@@ -720,6 +721,7 @@
720721
ti._floor_divide_result_type,
721722
ti._floor_divide,
722723
_floor_divide_docstring_,
724+
binary_inplace_fn=ti._floor_divide_inplace,
723725
)
724726

725727
# B11: ==== GREATER (x1, x2)

dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor_divide.hpp

+179-17
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,7 @@ struct FloorDivideFunctor
5757

5858
resT operator()(const argT1 &in1, const argT2 &in2) const
5959
{
60-
if constexpr (std::is_same_v<argT1, bool> &&
61-
std::is_same_v<argT2, bool>) {
62-
return (in2) ? static_cast<resT>(in1) : resT(0);
63-
}
64-
else if constexpr (std::is_integral_v<argT1> ||
65-
std::is_integral_v<argT2>) {
60+
if constexpr (std::is_integral_v<argT1> || std::is_integral_v<argT2>) {
6661
if (in2 == argT2(0)) {
6762
return resT(0);
6863
}
@@ -87,16 +82,7 @@ struct FloorDivideFunctor
8782
operator()(const sycl::vec<argT1, vec_sz> &in1,
8883
const sycl::vec<argT2, vec_sz> &in2) const
8984
{
90-
if constexpr (std::is_same_v<argT1, bool> &&
91-
std::is_same_v<argT2, bool>) {
92-
sycl::vec<resT, vec_sz> res;
93-
#pragma unroll
94-
for (int i = 0; i < vec_sz; ++i) {
95-
res[i] = (in2[i]) ? static_cast<resT>(in1[i]) : resT(0);
96-
}
97-
return res;
98-
}
99-
else if constexpr (std::is_integral_v<resT>) {
85+
if constexpr (std::is_integral_v<resT>) {
10086
sycl::vec<resT, vec_sz> res;
10187
#pragma unroll
10288
for (int i = 0; i < vec_sz; ++i) {
@@ -165,7 +151,6 @@ template <typename T1, typename T2> struct FloorDivideOutputType
165151
{
166152
using value_type = typename std::disjunction< // disjunction is C++17
167153
// feature, supported by DPC++
168-
td_ns::BinaryTypeMapResultEntry<T1, bool, T2, bool, std::int8_t>,
169154
td_ns::BinaryTypeMapResultEntry<T1,
170155
std::uint8_t,
171156
T2,
@@ -315,6 +300,183 @@ struct FloorDivideStridedFactory
315300
}
316301
};
317302

303+
template <typename argT, typename resT> struct FloorDivideInplaceFunctor
304+
{
305+
using supports_sg_loadstore = std::true_type;
306+
using supports_vec = std::true_type;
307+
308+
void operator()(resT &in1, const argT &in2) const
309+
{
310+
if constexpr (std::is_integral_v<resT>) {
311+
if (in2 == argT(0)) {
312+
in1 = 0;
313+
return;
314+
}
315+
if constexpr (std::is_signed_v<resT>) {
316+
auto tmp = in1;
317+
in1 /= in2;
318+
auto mod = tmp % in2;
319+
auto corr = (mod != 0 && l_xor(mod < 0, in2 < 0));
320+
in1 -= corr;
321+
}
322+
else {
323+
in1 /= in2;
324+
}
325+
}
326+
else {
327+
in1 /= in2;
328+
if (in1 == resT(0)) {
329+
return;
330+
}
331+
in1 = std::floor(in1);
332+
}
333+
}
334+
335+
template <int vec_sz>
336+
void operator()(sycl::vec<resT, vec_sz> &in1,
337+
const sycl::vec<argT, vec_sz> &in2) const
338+
{
339+
if constexpr (std::is_integral_v<resT>) {
340+
#pragma unroll
341+
for (int i = 0; i < vec_sz; ++i) {
342+
if (in2[i] == argT(0)) {
343+
in1[i] = 0;
344+
}
345+
else {
346+
if constexpr (std::is_signed_v<resT>) {
347+
auto tmp = in1[i];
348+
in1[i] /= in2[i];
349+
auto mod = tmp % in2[i];
350+
auto corr = (mod != 0 && l_xor(mod < 0, in2[i] < 0));
351+
in1[i] -= corr;
352+
}
353+
else {
354+
in1[i] /= in2[i];
355+
}
356+
}
357+
}
358+
}
359+
else {
360+
in1 /= in2;
361+
#pragma unroll
362+
for (int i = 0; i < vec_sz; ++i) {
363+
if (in2[i] != argT(0)) {
364+
in1[i] = std::floor(in1[i]);
365+
}
366+
}
367+
}
368+
}
369+
370+
private:
371+
bool l_xor(bool b1, bool b2) const
372+
{
373+
return (b1 != b2);
374+
}
375+
};
376+
377+
template <typename argT,
378+
typename resT,
379+
unsigned int vec_sz = 4,
380+
unsigned int n_vecs = 2>
381+
using FloorDivideInplaceContigFunctor =
382+
elementwise_common::BinaryInplaceContigFunctor<
383+
argT,
384+
resT,
385+
FloorDivideInplaceFunctor<argT, resT>,
386+
vec_sz,
387+
n_vecs>;
388+
389+
template <typename argT, typename resT, typename IndexerT>
390+
using FloorDivideInplaceStridedFunctor =
391+
elementwise_common::BinaryInplaceStridedFunctor<
392+
argT,
393+
resT,
394+
IndexerT,
395+
FloorDivideInplaceFunctor<argT, resT>>;
396+
397+
template <typename argT,
398+
typename resT,
399+
unsigned int vec_sz,
400+
unsigned int n_vecs>
401+
class floor_divide_inplace_contig_kernel;
402+
403+
template <typename argTy, typename resTy>
404+
sycl::event
405+
floor_divide_inplace_contig_impl(sycl::queue &exec_q,
406+
size_t nelems,
407+
const char *arg_p,
408+
py::ssize_t arg_offset,
409+
char *res_p,
410+
py::ssize_t res_offset,
411+
const std::vector<sycl::event> &depends = {})
412+
{
413+
return elementwise_common::binary_inplace_contig_impl<
414+
argTy, resTy, FloorDivideInplaceContigFunctor,
415+
floor_divide_inplace_contig_kernel>(exec_q, nelems, arg_p, arg_offset,
416+
res_p, res_offset, depends);
417+
}
418+
419+
template <typename fnT, typename T1, typename T2>
420+
struct FloorDivideInplaceContigFactory
421+
{
422+
fnT get()
423+
{
424+
if constexpr (std::is_same_v<
425+
typename FloorDivideOutputType<T1, T2>::value_type,
426+
void>)
427+
{
428+
fnT fn = nullptr;
429+
return fn;
430+
}
431+
else {
432+
fnT fn = floor_divide_inplace_contig_impl<T1, T2>;
433+
return fn;
434+
}
435+
}
436+
};
437+
438+
template <typename resT, typename argT, typename IndexerT>
439+
class floor_divide_inplace_strided_kernel;
440+
441+
template <typename argTy, typename resTy>
442+
sycl::event floor_divide_inplace_strided_impl(
443+
sycl::queue &exec_q,
444+
size_t nelems,
445+
int nd,
446+
const py::ssize_t *shape_and_strides,
447+
const char *arg_p,
448+
py::ssize_t arg_offset,
449+
char *res_p,
450+
py::ssize_t res_offset,
451+
const std::vector<sycl::event> &depends,
452+
const std::vector<sycl::event> &additional_depends)
453+
{
454+
return elementwise_common::binary_inplace_strided_impl<
455+
argTy, resTy, FloorDivideInplaceStridedFunctor,
456+
floor_divide_inplace_strided_kernel>(
457+
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p,
458+
res_offset, depends, additional_depends);
459+
}
460+
461+
template <typename fnT, typename T1, typename T2>
462+
struct FloorDivideInplaceStridedFactory
463+
{
464+
fnT get()
465+
{
466+
if constexpr (std::is_same_v<
467+
typename FloorDivideOutputType<T1, T2>::value_type,
468+
void>)
469+
{
470+
fnT fn = nullptr;
471+
return fn;
472+
}
473+
else {
474+
fnT fn = floor_divide_inplace_strided_impl<T1, T2>;
475+
return fn;
476+
}
477+
}
478+
};
479+
318480
} // namespace floor_divide
319481
} // namespace kernels
320482
} // namespace tensor

0 commit comments

Comments
 (0)