Skip to content

Commit a307c00

Browse files
committed
Adds sub-group work-around to copy, clip, where
Moves alignment.hpp into dpctl/tensor/libtensor/include/kernels Fixes a small typo in angle.hpp
1 parent 786bec7 commit a307c00

File tree

8 files changed

+121
-24
lines changed

8 files changed

+121
-24
lines changed

dpctl/tensor/libtensor/include/kernels/clip.hpp

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include <pybind11/pybind11.h>
3333
#include <type_traits>
3434

35+
#include "kernels/alignment.hpp"
3536
#include "utils/math_utils.hpp"
3637
#include "utils/offset_utils.hpp"
3738
#include "utils/type_dispatch.hpp"
@@ -51,6 +52,11 @@ namespace td_ns = dpctl::tensor::type_dispatch;
5152

5253
using namespace dpctl::tensor::offset_utils;
5354

55+
using dpctl::tensor::kernels::alignment_utils::
56+
disabled_sg_loadstore_wrapper_krn;
57+
using dpctl::tensor::kernels::alignment_utils::is_aligned;
58+
using dpctl::tensor::kernels::alignment_utils::required_alignment;
59+
5460
template <typename T> T clip(const T &x, const T &min, const T &max)
5561
{
5662
using dpctl::tensor::type_utils::is_complex;
@@ -73,7 +79,11 @@ template <typename T> T clip(const T &x, const T &min, const T &max)
7379
}
7480
}
7581

76-
template <typename T, int vec_sz = 4, int n_vecs = 2> class ClipContigFunctor
82+
template <typename T,
83+
int vec_sz = 4,
84+
int n_vecs = 2,
85+
bool enable_sg_loadstore = true>
86+
class ClipContigFunctor
7787
{
7888
private:
7989
size_t nelems = 0;
@@ -96,7 +106,7 @@ template <typename T, int vec_sz = 4, int n_vecs = 2> class ClipContigFunctor
96106
void operator()(sycl::nd_item<1> ndit) const
97107
{
98108
using dpctl::tensor::type_utils::is_complex;
99-
if constexpr (is_complex<T>::value) {
109+
if constexpr (is_complex<T>::value || !enable_sg_loadstore) {
100110
std::uint8_t sgSize = ndit.get_sub_group().get_local_range()[0];
101111
size_t base = ndit.get_global_linear_id();
102112

@@ -195,10 +205,30 @@ sycl::event clip_contig_impl(sycl::queue &q,
195205
const auto gws_range = sycl::range<1>(n_groups * lws);
196206
const auto lws_range = sycl::range<1>(lws);
197207

198-
cgh.parallel_for<clip_contig_kernel<T, vec_sz, n_vecs>>(
199-
sycl::nd_range<1>(gws_range, lws_range),
200-
ClipContigFunctor<T, vec_sz, n_vecs>(nelems, x_tp, min_tp, max_tp,
201-
dst_tp));
208+
if (is_aligned<required_alignment>(x_cp) &&
209+
is_aligned<required_alignment>(min_cp) &&
210+
is_aligned<required_alignment>(max_cp) &&
211+
is_aligned<required_alignment>(dst_cp))
212+
{
213+
constexpr bool enable_sg_loadstore = true;
214+
using KernelName = clip_contig_kernel<T, vec_sz, n_vecs>;
215+
216+
cgh.parallel_for<KernelName>(
217+
sycl::nd_range<1>(gws_range, lws_range),
218+
ClipContigFunctor<T, vec_sz, n_vecs, enable_sg_loadstore>(
219+
nelems, x_tp, min_tp, max_tp, dst_tp));
220+
}
221+
else {
222+
constexpr bool disable_sg_loadstore = false;
223+
using InnerKernelName = clip_contig_kernel<T, vec_sz, n_vecs>;
224+
using KernelName =
225+
disabled_sg_loadstore_wrapper_krn<InnerKernelName>;
226+
227+
cgh.parallel_for<KernelName>(
228+
sycl::nd_range<1>(gws_range, lws_range),
229+
ClipContigFunctor<T, vec_sz, n_vecs, disable_sg_loadstore>(
230+
nelems, x_tp, min_tp, max_tp, dst_tp));
231+
}
202232
});
203233

204234
return clip_ev;

dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <sycl/sycl.hpp>
3030
#include <type_traits>
3131

32+
#include "kernels/alignment.hpp"
3233
#include "utils/offset_utils.hpp"
3334
#include "utils/type_utils.hpp"
3435

@@ -44,6 +45,11 @@ namespace copy_and_cast
4445
namespace py = pybind11;
4546
using namespace dpctl::tensor::offset_utils;
4647

48+
using dpctl::tensor::kernels::alignment_utils::
49+
disabled_sg_loadstore_wrapper_krn;
50+
using dpctl::tensor::kernels::alignment_utils::is_aligned;
51+
using dpctl::tensor::kernels::alignment_utils::required_alignment;
52+
4753
template <typename srcT, typename dstT, typename IndexerT>
4854
class copy_cast_generic_kernel;
4955

@@ -200,7 +206,8 @@ template <typename srcT,
200206
typename dstT,
201207
typename CastFnT,
202208
int vec_sz = 4,
203-
int n_vecs = 2>
209+
int n_vecs = 2,
210+
bool enable_sg_loadstore = true>
204211
class ContigCopyFunctor
205212
{
206213
private:
@@ -219,7 +226,9 @@ class ContigCopyFunctor
219226
CastFnT fn{};
220227

221228
using dpctl::tensor::type_utils::is_complex;
222-
if constexpr (is_complex<srcT>::value || is_complex<dstT>::value) {
229+
if constexpr (!enable_sg_loadstore || is_complex<srcT>::value ||
230+
is_complex<dstT>::value)
231+
{
223232
std::uint8_t sgSize = ndit.get_sub_group().get_local_range()[0];
224233
size_t base = ndit.get_global_linear_id();
225234

@@ -326,10 +335,32 @@ sycl::event copy_and_cast_contig_impl(sycl::queue &q,
326335
const auto gws_range = sycl::range<1>(n_groups * lws);
327336
const auto lws_range = sycl::range<1>(lws);
328337

329-
cgh.parallel_for<copy_cast_contig_kernel<srcTy, dstTy, n_vecs, vec_sz>>(
330-
sycl::nd_range<1>(gws_range, lws_range),
331-
ContigCopyFunctor<srcTy, dstTy, Caster<srcTy, dstTy>, vec_sz,
332-
n_vecs>(nelems, src_tp, dst_tp));
338+
if (is_aligned<required_alignment>(src_cp) &&
339+
is_aligned<required_alignment>(dst_cp))
340+
{
341+
constexpr bool enable_sg_loadstore = true;
342+
using KernelName =
343+
copy_cast_contig_kernel<srcTy, dstTy, vec_sz, n_vecs>;
344+
345+
cgh.parallel_for<KernelName>(
346+
sycl::nd_range<1>(gws_range, lws_range),
347+
ContigCopyFunctor<srcTy, dstTy, Caster<srcTy, dstTy>, vec_sz,
348+
n_vecs, enable_sg_loadstore>(nelems, src_tp,
349+
dst_tp));
350+
}
351+
else {
352+
constexpr bool disable_sg_loadstore = false;
353+
using InnerKernelName =
354+
copy_cast_contig_kernel<srcTy, dstTy, vec_sz, n_vecs>;
355+
using KernelName =
356+
disabled_sg_loadstore_wrapper_krn<InnerKernelName>;
357+
358+
cgh.parallel_for<KernelName>(
359+
sycl::nd_range<1>(gws_range, lws_range),
360+
ContigCopyFunctor<srcTy, dstTy, Caster<srcTy, dstTy>, vec_sz,
361+
n_vecs, disable_sg_loadstore>(nelems, src_tp,
362+
dst_tp));
363+
}
333364
});
334365

335366
return copy_and_cast_ev;

dpctl/tensor/libtensor/include/kernels/elementwise_functions/angle.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,14 +81,14 @@ template <typename argTy,
8181
typename resTy = argTy,
8282
unsigned int vec_sz = 4,
8383
unsigned int n_vecs = 2,
84-
bool enable_sg_loadstire = true>
84+
bool enable_sg_loadstore = true>
8585
using AngleContigFunctor =
8686
elementwise_common::UnaryContigFunctor<argTy,
8787
resTy,
8888
AngleFunctor<argTy, resTy>,
8989
vec_sz,
9090
n_vecs,
91-
enable_sg_loadstire>;
91+
enable_sg_loadstore>;
9292

9393
template <typename argTy, typename resTy, typename IndexerT>
9494
using AngleStridedFunctor = elementwise_common::

dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
#include <sycl/sycl.hpp>
3030
#include <utility>
3131

32-
#include "kernels/elementwise_functions/alignment.hpp"
32+
#include "kernels/alignment.hpp"
3333
#include "utils/offset_utils.hpp"
3434

3535
namespace dpctl

dpctl/tensor/libtensor/include/kernels/elementwise_functions/common_inplace.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
#include <pybind11/pybind11.h>
3030
#include <sycl/sycl.hpp>
3131

32-
#include "kernels/elementwise_functions/alignment.hpp"
32+
#include "kernels/alignment.hpp"
3333

3434
namespace dpctl
3535
{

dpctl/tensor/libtensor/include/kernels/where.hpp

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,17 @@
2525
#pragma once
2626
#include "pybind11/numpy.h"
2727
#include "pybind11/stl.h"
28-
#include "utils/offset_utils.hpp"
29-
#include "utils/type_utils.hpp"
3028
#include <algorithm>
3129
#include <complex>
3230
#include <cstdint>
3331
#include <pybind11/pybind11.h>
3432
#include <sycl/sycl.hpp>
3533
#include <type_traits>
3634

35+
#include "kernels/alignment.hpp"
36+
#include "utils/offset_utils.hpp"
37+
#include "utils/type_utils.hpp"
38+
3739
namespace dpctl
3840
{
3941
namespace tensor
@@ -47,12 +49,21 @@ namespace py = pybind11;
4749

4850
using namespace dpctl::tensor::offset_utils;
4951

52+
using dpctl::tensor::kernels::alignment_utils::
53+
disabled_sg_loadstore_wrapper_krn;
54+
using dpctl::tensor::kernels::alignment_utils::is_aligned;
55+
using dpctl::tensor::kernels::alignment_utils::required_alignment;
56+
5057
template <typename T, typename condT, typename IndexerT>
5158
class where_strided_kernel;
5259
template <typename T, typename condT, int vec_sz, int n_vecs>
5360
class where_contig_kernel;
5461

55-
template <typename T, typename condT, int vec_sz = 4, int n_vecs = 2>
62+
template <typename T,
63+
typename condT,
64+
int vec_sz = 4,
65+
int n_vecs = 2,
66+
bool enable_sg_loadstore = true>
5667
class WhereContigFunctor
5768
{
5869
private:
@@ -76,7 +87,9 @@ class WhereContigFunctor
7687
void operator()(sycl::nd_item<1> ndit) const
7788
{
7889
using dpctl::tensor::type_utils::is_complex;
79-
if constexpr (is_complex<condT>::value || is_complex<T>::value) {
90+
if constexpr (!enable_sg_loadstore || is_complex<condT>::value ||
91+
is_complex<T>::value)
92+
{
8093
std::uint8_t sgSize = ndit.get_sub_group().get_local_range()[0];
8194
size_t base = ndit.get_global_linear_id();
8295

@@ -175,10 +188,33 @@ sycl::event where_contig_impl(sycl::queue &q,
175188
const auto gws_range = sycl::range<1>(n_groups * lws);
176189
const auto lws_range = sycl::range<1>(lws);
177190

178-
cgh.parallel_for<where_contig_kernel<T, condT, vec_sz, n_vecs>>(
179-
sycl::nd_range<1>(gws_range, lws_range),
180-
WhereContigFunctor<T, condT, vec_sz, n_vecs>(nelems, cond_tp, x1_tp,
191+
if (is_aligned<required_alignment>(cond_cp) &&
192+
is_aligned<required_alignment>(x1_cp) &&
193+
is_aligned<required_alignment>(x2_cp) &&
194+
is_aligned<required_alignment>(dst_cp))
195+
{
196+
constexpr bool enable_sg_loadstore = true;
197+
using KernelName = where_contig_kernel<T, condT, vec_sz, n_vecs>;
198+
199+
cgh.parallel_for<KernelName>(
200+
sycl::nd_range<1>(gws_range, lws_range),
201+
WhereContigFunctor<T, condT, vec_sz, n_vecs,
202+
enable_sg_loadstore>(nelems, cond_tp, x1_tp,
203+
x2_tp, dst_tp));
204+
}
205+
else {
206+
constexpr bool disable_sg_loadstore = false;
207+
using InnerKernelName =
208+
where_contig_kernel<T, condT, vec_sz, n_vecs>;
209+
using KernelName =
210+
disabled_sg_loadstore_wrapper_krn<InnerKernelName>;
211+
212+
cgh.parallel_for<KernelName>(
213+
sycl::nd_range<1>(gws_range, lws_range),
214+
WhereContigFunctor<T, condT, vec_sz, n_vecs,
215+
disable_sg_loadstore>(nelems, cond_tp, x1_tp,
181216
x2_tp, dst_tp));
217+
}
182218
});
183219

184220
return where_ev;

dpctl/tensor/libtensor/source/elementwise_functions/elementwise_functions.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
#include <vector>
3434

3535
#include "elementwise_functions_type_utils.hpp"
36-
#include "kernels/elementwise_functions/alignment.hpp"
36+
#include "kernels/alignment.hpp"
3737
#include "simplify_iteration_space.hpp"
3838
#include "utils/memory_overlap.hpp"
3939
#include "utils/offset_utils.hpp"

0 commit comments

Comments
 (0)