Skip to content

Commit e7fc039

Browse files
Merge pull request #1179 from IntelPython/refactoring/dpctl-tensor-type-dispatch-namespace
Refactoring/dpctl tensor type dispatch namespace
2 parents 4e41318 + 0ab2223 commit e7fc039

12 files changed

+87
-104
lines changed

dpctl/tensor/libtensor/include/kernels/boolean_advanced_indexing.hpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -763,10 +763,6 @@ sycl::event masked_place_some_slices_strided_impl(
763763
return comp_ev;
764764
}
765765

766-
static masked_place_all_slices_strided_impl_fn_ptr_t
767-
masked_place_all_slices_strided_impl_dispatch_vector
768-
[dpctl::tensor::detail::num_types];
769-
770766
template <typename fnT, typename T> struct MaskPlaceAllSlicesStridedFactory
771767
{
772768
fnT get()
@@ -776,10 +772,6 @@ template <typename fnT, typename T> struct MaskPlaceAllSlicesStridedFactory
776772
}
777773
};
778774

779-
static masked_place_some_slices_strided_impl_fn_ptr_t
780-
masked_place_some_slices_strided_impl_dispatch_vector
781-
[dpctl::tensor::detail::num_types];
782-
783775
template <typename fnT, typename T> struct MaskPlaceSomeSlicesStridedFactory
784776
{
785777
fnT get()

dpctl/tensor/libtensor/include/utils/type_dispatch.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ namespace dpctl
3333
namespace tensor
3434
{
3535

36-
namespace detail
36+
namespace type_dispatch
3737
{
3838

3939
enum class typenum_t : int
@@ -164,7 +164,7 @@ struct usm_ndarray_types
164164

165165
int typenum_to_lookup_id(int typenum) const
166166
{
167-
using typenum_t = dpctl::tensor::detail::typenum_t;
167+
using typenum_t = ::dpctl::tensor::type_dispatch::typenum_t;
168168
auto const &api = ::dpctl::detail::dpctl_capi::get();
169169

170170
if (typenum == api.UAR_DOUBLE_) {
@@ -250,7 +250,7 @@ struct usm_ndarray_types
250250
}
251251
};
252252

253-
} // namespace detail
253+
} // namespace type_dispatch
254254

255255
} // namespace tensor
256256
} // namespace dpctl

dpctl/tensor/libtensor/source/boolean_advanced_indexing.cpp

Lines changed: 30 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -93,27 +93,27 @@ void _split_iteration_space(const shT &shape_vec,
9393

9494
// Computation of positions of masked elements
9595

96+
namespace td_ns = dpctl::tensor::type_dispatch;
97+
9698
using dpctl::tensor::kernels::indexing::mask_positions_contig_impl_fn_ptr_t;
9799
static mask_positions_contig_impl_fn_ptr_t
98-
mask_positions_contig_dispatch_vector[dpctl::tensor::detail::num_types];
100+
mask_positions_contig_dispatch_vector[td_ns::num_types];
99101

100102
using dpctl::tensor::kernels::indexing::mask_positions_strided_impl_fn_ptr_t;
101103
static mask_positions_strided_impl_fn_ptr_t
102-
mask_positions_strided_dispatch_vector[dpctl::tensor::detail::num_types];
104+
mask_positions_strided_dispatch_vector[td_ns::num_types];
103105

104106
void populate_mask_positions_dispatch_vectors(void)
105107
{
106108
using dpctl::tensor::kernels::indexing::MaskPositionsContigFactory;
107-
dpctl::tensor::detail::DispatchVectorBuilder<
108-
mask_positions_contig_impl_fn_ptr_t, MaskPositionsContigFactory,
109-
dpctl::tensor::detail::num_types>
109+
td_ns::DispatchVectorBuilder<mask_positions_contig_impl_fn_ptr_t,
110+
MaskPositionsContigFactory, td_ns::num_types>
110111
dvb1;
111112
dvb1.populate_dispatch_vector(mask_positions_contig_dispatch_vector);
112113

113114
using dpctl::tensor::kernels::indexing::MaskPositionsStridedFactory;
114-
dpctl::tensor::detail::DispatchVectorBuilder<
115-
mask_positions_strided_impl_fn_ptr_t, MaskPositionsStridedFactory,
116-
dpctl::tensor::detail::num_types>
115+
td_ns::DispatchVectorBuilder<mask_positions_strided_impl_fn_ptr_t,
116+
MaskPositionsStridedFactory, td_ns::num_types>
117117
dvb2;
118118
dvb2.populate_dispatch_vector(mask_positions_strided_dispatch_vector);
119119

@@ -158,14 +158,13 @@ size_t py_mask_positions(dpctl::tensor::usm_ndarray mask,
158158
const char *mask_data = mask.get_data();
159159
char *cumsum_data = cumsum.get_data();
160160

161-
auto const &array_types = dpctl::tensor::detail::usm_ndarray_types();
161+
auto const &array_types = td_ns::usm_ndarray_types();
162162

163163
int mask_typeid = array_types.typenum_to_lookup_id(mask_typenum);
164164
int cumsum_typeid = array_types.typenum_to_lookup_id(cumsum_typenum);
165165

166166
// cumsum must be int64_t only
167-
constexpr int int64_typeid =
168-
static_cast<int>(dpctl::tensor::detail::typenum_t::INT64);
167+
constexpr int int64_typeid = static_cast<int>(td_ns::typenum_t::INT64);
169168
if (cumsum_typeid != int64_typeid) {
170169
throw py::value_error(
171170
"Cumulative sum array must have int64 data-type.");
@@ -244,30 +243,28 @@ using dpctl::tensor::kernels::indexing::
244243
masked_extract_all_slices_strided_impl_fn_ptr_t;
245244

246245
static masked_extract_all_slices_strided_impl_fn_ptr_t
247-
masked_extract_all_slices_strided_impl_dispatch_vector
248-
[dpctl::tensor::detail::num_types];
246+
masked_extract_all_slices_strided_impl_dispatch_vector[td_ns::num_types];
249247

250248
using dpctl::tensor::kernels::indexing::
251249
masked_extract_some_slices_strided_impl_fn_ptr_t;
252250

253251
static masked_extract_some_slices_strided_impl_fn_ptr_t
254-
masked_extract_some_slices_strided_impl_dispatch_vector
255-
[dpctl::tensor::detail::num_types];
252+
masked_extract_some_slices_strided_impl_dispatch_vector[td_ns::num_types];
256253

257254
void populate_masked_extract_dispatch_vectors(void)
258255
{
259256
using dpctl::tensor::kernels::indexing::MaskExtractAllSlicesStridedFactory;
260-
dpctl::tensor::detail::DispatchVectorBuilder<
257+
td_ns::DispatchVectorBuilder<
261258
masked_extract_all_slices_strided_impl_fn_ptr_t,
262-
MaskExtractAllSlicesStridedFactory, dpctl::tensor::detail::num_types>
259+
MaskExtractAllSlicesStridedFactory, td_ns::num_types>
263260
dvb1;
264261
dvb1.populate_dispatch_vector(
265262
masked_extract_all_slices_strided_impl_dispatch_vector);
266263

267264
using dpctl::tensor::kernels::indexing::MaskExtractSomeSlicesStridedFactory;
268-
dpctl::tensor::detail::DispatchVectorBuilder<
265+
td_ns::DispatchVectorBuilder<
269266
masked_extract_some_slices_strided_impl_fn_ptr_t,
270-
MaskExtractSomeSlicesStridedFactory, dpctl::tensor::detail::num_types>
267+
MaskExtractSomeSlicesStridedFactory, td_ns::num_types>
271268
dvb2;
272269
dvb2.populate_dispatch_vector(
273270
masked_extract_some_slices_strided_impl_dispatch_vector);
@@ -359,13 +356,12 @@ py_extract(dpctl::tensor::usm_ndarray src,
359356
int dst_typenum = dst.get_typenum();
360357
int cumsum_typenum = cumsum.get_typenum();
361358

362-
auto const &array_types = dpctl::tensor::detail::usm_ndarray_types();
359+
auto const &array_types = td_ns::usm_ndarray_types();
363360
int src_typeid = array_types.typenum_to_lookup_id(src_typenum);
364361
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
365362
int cumsum_typeid = array_types.typenum_to_lookup_id(cumsum_typenum);
366363

367-
constexpr int int64_typeid =
368-
static_cast<int>(dpctl::tensor::detail::typenum_t::INT64);
364+
constexpr int int64_typeid = static_cast<int>(td_ns::typenum_t::INT64);
369365
if (cumsum_typeid != int64_typeid) {
370366
throw py::value_error(
371367
"Unexact data type of cumsum array, expecting 'int64'");
@@ -557,30 +553,28 @@ using dpctl::tensor::kernels::indexing::
557553
masked_place_all_slices_strided_impl_fn_ptr_t;
558554

559555
static masked_place_all_slices_strided_impl_fn_ptr_t
560-
masked_place_all_slices_strided_impl_dispatch_vector
561-
[dpctl::tensor::detail::num_types];
556+
masked_place_all_slices_strided_impl_dispatch_vector[td_ns::num_types];
562557

563558
using dpctl::tensor::kernels::indexing::
564559
masked_place_some_slices_strided_impl_fn_ptr_t;
565560

566561
static masked_place_some_slices_strided_impl_fn_ptr_t
567-
masked_place_some_slices_strided_impl_dispatch_vector
568-
[dpctl::tensor::detail::num_types];
562+
masked_place_some_slices_strided_impl_dispatch_vector[td_ns::num_types];
569563

570564
void populate_masked_place_dispatch_vectors(void)
571565
{
572566
using dpctl::tensor::kernels::indexing::MaskPlaceAllSlicesStridedFactory;
573-
dpctl::tensor::detail::DispatchVectorBuilder<
574-
masked_place_all_slices_strided_impl_fn_ptr_t,
575-
MaskPlaceAllSlicesStridedFactory, dpctl::tensor::detail::num_types>
567+
td_ns::DispatchVectorBuilder<masked_place_all_slices_strided_impl_fn_ptr_t,
568+
MaskPlaceAllSlicesStridedFactory,
569+
td_ns::num_types>
576570
dvb1;
577571
dvb1.populate_dispatch_vector(
578572
masked_place_all_slices_strided_impl_dispatch_vector);
579573

580574
using dpctl::tensor::kernels::indexing::MaskPlaceSomeSlicesStridedFactory;
581-
dpctl::tensor::detail::DispatchVectorBuilder<
582-
masked_place_some_slices_strided_impl_fn_ptr_t,
583-
MaskPlaceSomeSlicesStridedFactory, dpctl::tensor::detail::num_types>
575+
td_ns::DispatchVectorBuilder<masked_place_some_slices_strided_impl_fn_ptr_t,
576+
MaskPlaceSomeSlicesStridedFactory,
577+
td_ns::num_types>
584578
dvb2;
585579
dvb2.populate_dispatch_vector(
586580
masked_place_some_slices_strided_impl_dispatch_vector);
@@ -673,13 +667,12 @@ py_place(dpctl::tensor::usm_ndarray dst,
673667
int rhs_typenum = rhs.get_typenum();
674668
int cumsum_typenum = cumsum.get_typenum();
675669

676-
auto const &array_types = dpctl::tensor::detail::usm_ndarray_types();
670+
auto const &array_types = td_ns::usm_ndarray_types();
677671
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
678672
int rhs_typeid = array_types.typenum_to_lookup_id(rhs_typenum);
679673
int cumsum_typeid = array_types.typenum_to_lookup_id(cumsum_typenum);
680674

681-
constexpr int int64_typeid =
682-
static_cast<int>(dpctl::tensor::detail::typenum_t::INT64);
675+
constexpr int int64_typeid = static_cast<int>(td_ns::typenum_t::INT64);
683676
if (cumsum_typeid != int64_typeid) {
684677
throw py::value_error(
685678
"Unexact data type of cumsum array, expecting 'int64'");
@@ -913,15 +906,14 @@ std::pair<sycl::event, sycl::event> py_nonzero(
913906
py::ssize_t nz_elems = indexes_shape[1];
914907

915908
int indexes_typenum = indexes.get_typenum();
916-
auto const &array_types = dpctl::tensor::detail::usm_ndarray_types();
909+
auto const &array_types = td_ns::usm_ndarray_types();
917910
int indexes_typeid = array_types.typenum_to_lookup_id(indexes_typenum);
918911

919912
int cumsum_typenum = cumsum.get_typenum();
920913
int cumsum_typeid = array_types.typenum_to_lookup_id(cumsum_typenum);
921914

922915
// cumsum must be int64_t only
923-
constexpr int int64_typeid =
924-
static_cast<int>(dpctl::tensor::detail::typenum_t::INT64);
916+
constexpr int int64_typeid = static_cast<int>(td_ns::typenum_t::INT64);
925917
if (cumsum_typeid != int64_typeid || indexes_typeid != int64_typeid) {
926918
throw py::value_error(
927919
"Cumulative sum array and index array must have int64 data-type");

dpctl/tensor/libtensor/source/copy_and_cast_usm_to_usm.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,18 +49,18 @@ namespace tensor
4949
namespace py_internal
5050
{
5151

52-
namespace _ns = dpctl::tensor::detail;
52+
namespace td_ns = dpctl::tensor::type_dispatch;
5353

5454
using dpctl::tensor::kernels::copy_and_cast::copy_and_cast_1d_fn_ptr_t;
5555
using dpctl::tensor::kernels::copy_and_cast::copy_and_cast_contig_fn_ptr_t;
5656
using dpctl::tensor::kernels::copy_and_cast::copy_and_cast_generic_fn_ptr_t;
5757

5858
static copy_and_cast_generic_fn_ptr_t
59-
copy_and_cast_generic_dispatch_table[_ns::num_types][_ns::num_types];
59+
copy_and_cast_generic_dispatch_table[td_ns::num_types][td_ns::num_types];
6060
static copy_and_cast_1d_fn_ptr_t
61-
copy_and_cast_1d_dispatch_table[_ns::num_types][_ns::num_types];
61+
copy_and_cast_1d_dispatch_table[td_ns::num_types][td_ns::num_types];
6262
static copy_and_cast_contig_fn_ptr_t
63-
copy_and_cast_contig_dispatch_table[_ns::num_types][_ns::num_types];
63+
copy_and_cast_contig_dispatch_table[td_ns::num_types][td_ns::num_types];
6464

6565
namespace py = pybind11;
6666

@@ -121,7 +121,7 @@ copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src,
121121
int src_typenum = src.get_typenum();
122122
int dst_typenum = dst.get_typenum();
123123

124-
auto array_types = dpctl::tensor::detail::usm_ndarray_types();
124+
auto array_types = td_ns::usm_ndarray_types();
125125
int src_type_id = array_types.typenum_to_lookup_id(src_typenum);
126126
int dst_type_id = array_types.typenum_to_lookup_id(dst_typenum);
127127

@@ -277,7 +277,7 @@ copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src,
277277

278278
void init_copy_and_cast_usm_to_usm_dispatch_tables(void)
279279
{
280-
using namespace dpctl::tensor::detail;
280+
using namespace td_ns;
281281

282282
using dpctl::tensor::kernels::copy_and_cast::CopyAndCastContigFactory;
283283
DispatchTableBuilder<copy_and_cast_contig_fn_ptr_t,

dpctl/tensor/libtensor/source/copy_for_reshape.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,14 @@ namespace tensor
3939
namespace py_internal
4040
{
4141

42-
namespace _ns = dpctl::tensor::detail;
42+
namespace td_ns = dpctl::tensor::type_dispatch;
4343

4444
using dpctl::tensor::kernels::copy_and_cast::copy_for_reshape_fn_ptr_t;
4545
using dpctl::utils::keep_args_alive;
4646

4747
// define static vector
4848
static copy_for_reshape_fn_ptr_t
49-
copy_for_reshape_generic_dispatch_vector[_ns::num_types];
49+
copy_for_reshape_generic_dispatch_vector[td_ns::num_types];
5050

5151
/*
5252
* Copies src into dst (same data type) of different shapes by using flat
@@ -121,7 +121,7 @@ copy_usm_ndarray_for_reshape(dpctl::tensor::usm_ndarray src,
121121
int src_nd = src.get_ndim();
122122
int dst_nd = dst.get_ndim();
123123

124-
auto array_types = dpctl::tensor::detail::usm_ndarray_types();
124+
auto array_types = td_ns::usm_ndarray_types();
125125
int type_id = array_types.typenum_to_lookup_id(src_typenum);
126126

127127
auto fn = copy_for_reshape_generic_dispatch_vector[type_id];
@@ -172,7 +172,7 @@ copy_usm_ndarray_for_reshape(dpctl::tensor::usm_ndarray src,
172172

173173
void init_copy_for_reshape_dispatch_vectors(void)
174174
{
175-
using namespace dpctl::tensor::detail;
175+
using namespace td_ns;
176176
using dpctl::tensor::kernels::copy_and_cast::CopyForReshapeGenericFactory;
177177

178178
DispatchVectorBuilder<copy_for_reshape_fn_ptr_t,

dpctl/tensor/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
#include "simplify_iteration_space.hpp"
3737

3838
namespace py = pybind11;
39-
namespace _ns = dpctl::tensor::detail;
39+
namespace td_ns = dpctl::tensor::type_dispatch;
4040

4141
namespace dpctl
4242
{
@@ -49,8 +49,8 @@ using dpctl::tensor::kernels::copy_and_cast::
4949
copy_and_cast_from_host_blocking_fn_ptr_t;
5050

5151
static copy_and_cast_from_host_blocking_fn_ptr_t
52-
copy_and_cast_from_host_blocking_dispatch_table[_ns::num_types]
53-
[_ns::num_types];
52+
copy_and_cast_from_host_blocking_dispatch_table[td_ns::num_types]
53+
[td_ns::num_types];
5454

5555
void copy_numpy_ndarray_into_usm_ndarray(
5656
py::array npy_src,
@@ -111,7 +111,7 @@ void copy_numpy_ndarray_into_usm_ndarray(
111111
py::detail::array_descriptor_proxy(npy_src.dtype().ptr())->type_num;
112112
int dst_typenum = dst.get_typenum();
113113

114-
auto array_types = dpctl::tensor::detail::usm_ndarray_types();
114+
auto array_types = td_ns::usm_ndarray_types();
115115
int src_type_id = array_types.typenum_to_lookup_id(src_typenum);
116116
int dst_type_id = array_types.typenum_to_lookup_id(dst_typenum);
117117

@@ -239,11 +239,11 @@ void copy_numpy_ndarray_into_usm_ndarray(
239239

240240
void init_copy_numpy_ndarray_into_usm_ndarray_dispatch_tables(void)
241241
{
242-
using namespace dpctl::tensor::detail;
242+
using namespace td_ns;
243243
using dpctl::tensor::kernels::copy_and_cast::CopyAndCastFromHostFactory;
244244

245245
DispatchTableBuilder<copy_and_cast_from_host_blocking_fn_ptr_t,
246-
CopyAndCastFromHostFactory, _ns::num_types>
246+
CopyAndCastFromHostFactory, num_types>
247247
dtb_copy_from_numpy;
248248

249249
dtb_copy_from_numpy.populate_dispatch_table(

dpctl/tensor/libtensor/source/eye_ctor.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
#include "utils/type_dispatch.hpp"
3535

3636
namespace py = pybind11;
37-
namespace _ns = dpctl::tensor::detail;
37+
namespace td_ns = dpctl::tensor::type_dispatch;
3838

3939
namespace dpctl
4040
{
@@ -46,7 +46,7 @@ namespace py_internal
4646
using dpctl::utils::keep_args_alive;
4747

4848
using dpctl::tensor::kernels::constructors::eye_fn_ptr_t;
49-
static eye_fn_ptr_t eye_dispatch_vector[_ns::num_types];
49+
static eye_fn_ptr_t eye_dispatch_vector[td_ns::num_types];
5050

5151
std::pair<sycl::event, sycl::event>
5252
usm_ndarray_eye(py::ssize_t k,
@@ -66,7 +66,7 @@ usm_ndarray_eye(py::ssize_t k,
6666
"allocation queue");
6767
}
6868

69-
auto array_types = dpctl::tensor::detail::usm_ndarray_types();
69+
auto array_types = td_ns::usm_ndarray_types();
7070
int dst_typenum = dst.get_typenum();
7171
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
7272

@@ -118,7 +118,7 @@ usm_ndarray_eye(py::ssize_t k,
118118

119119
void init_eye_ctor_dispatch_vectors(void)
120120
{
121-
using namespace dpctl::tensor::detail;
121+
using namespace td_ns;
122122
using dpctl::tensor::kernels::constructors::EyeFactory;
123123

124124
DispatchVectorBuilder<eye_fn_ptr_t, EyeFactory, num_types> dvb;

0 commit comments

Comments
 (0)