@@ -57,12 +57,7 @@ struct FloorDivideFunctor
57
57
58
58
resT operator ()(const argT1 &in1, const argT2 &in2) const
59
59
{
60
- if constexpr (std::is_same_v<argT1, bool > &&
61
- std::is_same_v<argT2, bool >) {
62
- return (in2) ? static_cast <resT>(in1) : resT (0 );
63
- }
64
- else if constexpr (std::is_integral_v<argT1> ||
65
- std::is_integral_v<argT2>) {
60
+ if constexpr (std::is_integral_v<argT1> || std::is_integral_v<argT2>) {
66
61
if (in2 == argT2 (0 )) {
67
62
return resT (0 );
68
63
}
@@ -87,16 +82,7 @@ struct FloorDivideFunctor
87
82
operator ()(const sycl::vec<argT1, vec_sz> &in1,
88
83
const sycl::vec<argT2, vec_sz> &in2) const
89
84
{
90
- if constexpr (std::is_same_v<argT1, bool > &&
91
- std::is_same_v<argT2, bool >) {
92
- sycl::vec<resT, vec_sz> res;
93
- #pragma unroll
94
- for (int i = 0 ; i < vec_sz; ++i) {
95
- res[i] = (in2[i]) ? static_cast <resT>(in1[i]) : resT (0 );
96
- }
97
- return res;
98
- }
99
- else if constexpr (std::is_integral_v<resT>) {
85
+ if constexpr (std::is_integral_v<resT>) {
100
86
sycl::vec<resT, vec_sz> res;
101
87
#pragma unroll
102
88
for (int i = 0 ; i < vec_sz; ++i) {
@@ -165,7 +151,6 @@ template <typename T1, typename T2> struct FloorDivideOutputType
165
151
{
166
152
using value_type = typename std::disjunction< // disjunction is C++17
167
153
// feature, supported by DPC++
168
- td_ns::BinaryTypeMapResultEntry<T1, bool , T2, bool , std::int8_t >,
169
154
td_ns::BinaryTypeMapResultEntry<T1,
170
155
std::uint8_t ,
171
156
T2,
@@ -315,6 +300,183 @@ struct FloorDivideStridedFactory
315
300
}
316
301
};
317
302
303
+ template <typename argT, typename resT> struct FloorDivideInplaceFunctor
304
+ {
305
+ using supports_sg_loadstore = std::true_type;
306
+ using supports_vec = std::true_type;
307
+
308
+ void operator ()(resT &in1, const argT &in2) const
309
+ {
310
+ if constexpr (std::is_integral_v<resT>) {
311
+ if (in2 == argT (0 )) {
312
+ in1 = 0 ;
313
+ return ;
314
+ }
315
+ if constexpr (std::is_signed_v<resT>) {
316
+ auto tmp = in1;
317
+ in1 /= in2;
318
+ auto mod = tmp % in2;
319
+ auto corr = (mod != 0 && l_xor (mod < 0 , in2 < 0 ));
320
+ in1 -= corr;
321
+ }
322
+ else {
323
+ in1 /= in2;
324
+ }
325
+ }
326
+ else {
327
+ in1 /= in2;
328
+ if (in1 == resT (0 )) {
329
+ return ;
330
+ }
331
+ in1 = std::floor (in1);
332
+ }
333
+ }
334
+
335
+ template <int vec_sz>
336
+ void operator ()(sycl::vec<resT, vec_sz> &in1,
337
+ const sycl::vec<argT, vec_sz> &in2) const
338
+ {
339
+ if constexpr (std::is_integral_v<resT>) {
340
+ #pragma unroll
341
+ for (int i = 0 ; i < vec_sz; ++i) {
342
+ if (in2[i] == argT (0 )) {
343
+ in1[i] = 0 ;
344
+ }
345
+ else {
346
+ if constexpr (std::is_signed_v<resT>) {
347
+ auto tmp = in1[i];
348
+ in1[i] /= in2[i];
349
+ auto mod = tmp % in2[i];
350
+ auto corr = (mod != 0 && l_xor (mod < 0 , in2[i] < 0 ));
351
+ in1[i] -= corr;
352
+ }
353
+ else {
354
+ in1[i] /= in2[i];
355
+ }
356
+ }
357
+ }
358
+ }
359
+ else {
360
+ in1 /= in2;
361
+ #pragma unroll
362
+ for (int i = 0 ; i < vec_sz; ++i) {
363
+ if (in2[i] != argT (0 )) {
364
+ in1[i] = std::floor (in1[i]);
365
+ }
366
+ }
367
+ }
368
+ }
369
+
370
+ private:
371
+ bool l_xor (bool b1, bool b2) const
372
+ {
373
+ return (b1 != b2);
374
+ }
375
+ };
376
+
377
+ template <typename argT,
378
+ typename resT,
379
+ unsigned int vec_sz = 4 ,
380
+ unsigned int n_vecs = 2 >
381
+ using FloorDivideInplaceContigFunctor =
382
+ elementwise_common::BinaryInplaceContigFunctor<
383
+ argT,
384
+ resT,
385
+ FloorDivideInplaceFunctor<argT, resT>,
386
+ vec_sz,
387
+ n_vecs>;
388
+
389
+ template <typename argT, typename resT, typename IndexerT>
390
+ using FloorDivideInplaceStridedFunctor =
391
+ elementwise_common::BinaryInplaceStridedFunctor<
392
+ argT,
393
+ resT,
394
+ IndexerT,
395
+ FloorDivideInplaceFunctor<argT, resT>>;
396
+
397
+ template <typename argT,
398
+ typename resT,
399
+ unsigned int vec_sz,
400
+ unsigned int n_vecs>
401
+ class floor_divide_inplace_contig_kernel ;
402
+
403
+ template <typename argTy, typename resTy>
404
+ sycl::event
405
+ floor_divide_inplace_contig_impl (sycl::queue &exec_q,
406
+ size_t nelems,
407
+ const char *arg_p,
408
+ py::ssize_t arg_offset,
409
+ char *res_p,
410
+ py::ssize_t res_offset,
411
+ const std::vector<sycl::event> &depends = {})
412
+ {
413
+ return elementwise_common::binary_inplace_contig_impl<
414
+ argTy, resTy, FloorDivideInplaceContigFunctor,
415
+ floor_divide_inplace_contig_kernel>(exec_q, nelems, arg_p, arg_offset,
416
+ res_p, res_offset, depends);
417
+ }
418
+
419
+ template <typename fnT, typename T1, typename T2>
420
+ struct FloorDivideInplaceContigFactory
421
+ {
422
+ fnT get ()
423
+ {
424
+ if constexpr (std::is_same_v<
425
+ typename FloorDivideOutputType<T1, T2>::value_type,
426
+ void >)
427
+ {
428
+ fnT fn = nullptr ;
429
+ return fn;
430
+ }
431
+ else {
432
+ fnT fn = floor_divide_inplace_contig_impl<T1, T2>;
433
+ return fn;
434
+ }
435
+ }
436
+ };
437
+
438
+ template <typename resT, typename argT, typename IndexerT>
439
+ class floor_divide_inplace_strided_kernel ;
440
+
441
+ template <typename argTy, typename resTy>
442
+ sycl::event floor_divide_inplace_strided_impl (
443
+ sycl::queue &exec_q,
444
+ size_t nelems,
445
+ int nd,
446
+ const py::ssize_t *shape_and_strides,
447
+ const char *arg_p,
448
+ py::ssize_t arg_offset,
449
+ char *res_p,
450
+ py::ssize_t res_offset,
451
+ const std::vector<sycl::event> &depends,
452
+ const std::vector<sycl::event> &additional_depends)
453
+ {
454
+ return elementwise_common::binary_inplace_strided_impl<
455
+ argTy, resTy, FloorDivideInplaceStridedFunctor,
456
+ floor_divide_inplace_strided_kernel>(
457
+ exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p,
458
+ res_offset, depends, additional_depends);
459
+ }
460
+
461
+ template <typename fnT, typename T1, typename T2>
462
+ struct FloorDivideInplaceStridedFactory
463
+ {
464
+ fnT get ()
465
+ {
466
+ if constexpr (std::is_same_v<
467
+ typename FloorDivideOutputType<T1, T2>::value_type,
468
+ void >)
469
+ {
470
+ fnT fn = nullptr ;
471
+ return fn;
472
+ }
473
+ else {
474
+ fnT fn = floor_divide_inplace_strided_impl<T1, T2>;
475
+ return fn;
476
+ }
477
+ }
478
+ };
479
+
318
480
} // namespace floor_divide
319
481
} // namespace kernels
320
482
} // namespace tensor
0 commit comments