diff --git a/dpctl/tensor/libtensor/include/kernels/boolean_advanced_indexing.hpp b/dpctl/tensor/libtensor/include/kernels/boolean_advanced_indexing.hpp index f7fb262550..62621ecf12 100644 --- a/dpctl/tensor/libtensor/include/kernels/boolean_advanced_indexing.hpp +++ b/dpctl/tensor/libtensor/include/kernels/boolean_advanced_indexing.hpp @@ -763,10 +763,6 @@ sycl::event masked_place_some_slices_strided_impl( return comp_ev; } -static masked_place_all_slices_strided_impl_fn_ptr_t - masked_place_all_slices_strided_impl_dispatch_vector - [dpctl::tensor::detail::num_types]; - template struct MaskPlaceAllSlicesStridedFactory { fnT get() @@ -776,10 +772,6 @@ template struct MaskPlaceAllSlicesStridedFactory } }; -static masked_place_some_slices_strided_impl_fn_ptr_t - masked_place_some_slices_strided_impl_dispatch_vector - [dpctl::tensor::detail::num_types]; - template struct MaskPlaceSomeSlicesStridedFactory { fnT get() diff --git a/dpctl/tensor/libtensor/include/utils/type_dispatch.hpp b/dpctl/tensor/libtensor/include/utils/type_dispatch.hpp index 7c0ed60ef0..5d7d6b8a8c 100644 --- a/dpctl/tensor/libtensor/include/utils/type_dispatch.hpp +++ b/dpctl/tensor/libtensor/include/utils/type_dispatch.hpp @@ -33,7 +33,7 @@ namespace dpctl namespace tensor { -namespace detail +namespace type_dispatch { enum class typenum_t : int @@ -164,7 +164,7 @@ struct usm_ndarray_types int typenum_to_lookup_id(int typenum) const { - using typenum_t = dpctl::tensor::detail::typenum_t; + using typenum_t = ::dpctl::tensor::type_dispatch::typenum_t; auto const &api = ::dpctl::detail::dpctl_capi::get(); if (typenum == api.UAR_DOUBLE_) { @@ -250,7 +250,7 @@ struct usm_ndarray_types } }; -} // namespace detail +} // namespace type_dispatch } // namespace tensor } // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/boolean_advanced_indexing.cpp b/dpctl/tensor/libtensor/source/boolean_advanced_indexing.cpp index 168e29acff..216c1102ab 100644 --- a/dpctl/tensor/libtensor/source/boolean_advanced_indexing.cpp +++ b/dpctl/tensor/libtensor/source/boolean_advanced_indexing.cpp @@ -93,27 +93,27 @@ void _split_iteration_space(const shT &shape_vec, // Computation of positions of masked elements +namespace td_ns = dpctl::tensor::type_dispatch; + using dpctl::tensor::kernels::indexing::mask_positions_contig_impl_fn_ptr_t; static mask_positions_contig_impl_fn_ptr_t - mask_positions_contig_dispatch_vector[dpctl::tensor::detail::num_types]; + mask_positions_contig_dispatch_vector[td_ns::num_types]; using dpctl::tensor::kernels::indexing::mask_positions_strided_impl_fn_ptr_t; static mask_positions_strided_impl_fn_ptr_t - mask_positions_strided_dispatch_vector[dpctl::tensor::detail::num_types]; + mask_positions_strided_dispatch_vector[td_ns::num_types]; void populate_mask_positions_dispatch_vectors(void) { using dpctl::tensor::kernels::indexing::MaskPositionsContigFactory; - dpctl::tensor::detail::DispatchVectorBuilder< - mask_positions_contig_impl_fn_ptr_t, MaskPositionsContigFactory, - dpctl::tensor::detail::num_types> + td_ns::DispatchVectorBuilder dvb1; dvb1.populate_dispatch_vector(mask_positions_contig_dispatch_vector); using dpctl::tensor::kernels::indexing::MaskPositionsStridedFactory; - dpctl::tensor::detail::DispatchVectorBuilder< - mask_positions_strided_impl_fn_ptr_t, MaskPositionsStridedFactory, - dpctl::tensor::detail::num_types> + td_ns::DispatchVectorBuilder dvb2; dvb2.populate_dispatch_vector(mask_positions_strided_dispatch_vector); @@ -158,14 +158,13 @@ size_t py_mask_positions(dpctl::tensor::usm_ndarray mask, const char *mask_data = mask.get_data(); char *cumsum_data = cumsum.get_data(); - auto const &array_types = dpctl::tensor::detail::usm_ndarray_types(); + auto const &array_types = td_ns::usm_ndarray_types(); int mask_typeid = array_types.typenum_to_lookup_id(mask_typenum); int cumsum_typeid = array_types.typenum_to_lookup_id(cumsum_typenum); // cumsum must be int64_t only - constexpr int int64_typeid = - static_cast(dpctl::tensor::detail::typenum_t::INT64); + constexpr int int64_typeid = static_cast(td_ns::typenum_t::INT64); if (cumsum_typeid != int64_typeid) { throw py::value_error( "Cumulative sum array must have int64 data-type."); @@ -244,30 +243,28 @@ using dpctl::tensor::kernels::indexing:: masked_extract_all_slices_strided_impl_fn_ptr_t; static masked_extract_all_slices_strided_impl_fn_ptr_t - masked_extract_all_slices_strided_impl_dispatch_vector - [dpctl::tensor::detail::num_types]; + masked_extract_all_slices_strided_impl_dispatch_vector[td_ns::num_types]; using dpctl::tensor::kernels::indexing:: masked_extract_some_slices_strided_impl_fn_ptr_t; static masked_extract_some_slices_strided_impl_fn_ptr_t - masked_extract_some_slices_strided_impl_dispatch_vector - [dpctl::tensor::detail::num_types]; + masked_extract_some_slices_strided_impl_dispatch_vector[td_ns::num_types]; void populate_masked_extract_dispatch_vectors(void) { using dpctl::tensor::kernels::indexing::MaskExtractAllSlicesStridedFactory; - dpctl::tensor::detail::DispatchVectorBuilder< + td_ns::DispatchVectorBuilder< masked_extract_all_slices_strided_impl_fn_ptr_t, - MaskExtractAllSlicesStridedFactory, dpctl::tensor::detail::num_types> + MaskExtractAllSlicesStridedFactory, td_ns::num_types> dvb1; dvb1.populate_dispatch_vector( masked_extract_all_slices_strided_impl_dispatch_vector); using dpctl::tensor::kernels::indexing::MaskExtractSomeSlicesStridedFactory; - dpctl::tensor::detail::DispatchVectorBuilder< + td_ns::DispatchVectorBuilder< masked_extract_some_slices_strided_impl_fn_ptr_t, - MaskExtractSomeSlicesStridedFactory, dpctl::tensor::detail::num_types> + MaskExtractSomeSlicesStridedFactory, td_ns::num_types> dvb2; dvb2.populate_dispatch_vector( masked_extract_some_slices_strided_impl_dispatch_vector); @@ -359,13 +356,12 @@ py_extract(dpctl::tensor::usm_ndarray src, int dst_typenum = dst.get_typenum(); int cumsum_typenum = cumsum.get_typenum(); - auto const &array_types = dpctl::tensor::detail::usm_ndarray_types(); + auto const &array_types = td_ns::usm_ndarray_types(); int src_typeid = array_types.typenum_to_lookup_id(src_typenum); int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); int cumsum_typeid = array_types.typenum_to_lookup_id(cumsum_typenum); - constexpr int int64_typeid = - static_cast(dpctl::tensor::detail::typenum_t::INT64); + constexpr int int64_typeid = static_cast(td_ns::typenum_t::INT64); if (cumsum_typeid != int64_typeid) { throw py::value_error( "Unexact data type of cumsum array, expecting 'int64'"); @@ -557,30 +553,28 @@ using dpctl::tensor::kernels::indexing:: masked_place_all_slices_strided_impl_fn_ptr_t; static masked_place_all_slices_strided_impl_fn_ptr_t - masked_place_all_slices_strided_impl_dispatch_vector - [dpctl::tensor::detail::num_types]; + masked_place_all_slices_strided_impl_dispatch_vector[td_ns::num_types]; using dpctl::tensor::kernels::indexing:: masked_place_some_slices_strided_impl_fn_ptr_t; static masked_place_some_slices_strided_impl_fn_ptr_t - masked_place_some_slices_strided_impl_dispatch_vector - [dpctl::tensor::detail::num_types]; + masked_place_some_slices_strided_impl_dispatch_vector[td_ns::num_types]; void populate_masked_place_dispatch_vectors(void) { using dpctl::tensor::kernels::indexing::MaskPlaceAllSlicesStridedFactory; - dpctl::tensor::detail::DispatchVectorBuilder< - masked_place_all_slices_strided_impl_fn_ptr_t, - MaskPlaceAllSlicesStridedFactory, dpctl::tensor::detail::num_types> + td_ns::DispatchVectorBuilder dvb1; dvb1.populate_dispatch_vector( masked_place_all_slices_strided_impl_dispatch_vector); using dpctl::tensor::kernels::indexing::MaskPlaceSomeSlicesStridedFactory; - dpctl::tensor::detail::DispatchVectorBuilder< - masked_place_some_slices_strided_impl_fn_ptr_t, - MaskPlaceSomeSlicesStridedFactory, dpctl::tensor::detail::num_types> + td_ns::DispatchVectorBuilder dvb2; dvb2.populate_dispatch_vector( masked_place_some_slices_strided_impl_dispatch_vector); @@ -673,13 +667,12 @@ py_place(dpctl::tensor::usm_ndarray dst, int rhs_typenum = rhs.get_typenum(); int cumsum_typenum = cumsum.get_typenum(); - auto const &array_types = dpctl::tensor::detail::usm_ndarray_types(); + auto const &array_types = td_ns::usm_ndarray_types(); int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); int rhs_typeid = array_types.typenum_to_lookup_id(rhs_typenum); int cumsum_typeid = array_types.typenum_to_lookup_id(cumsum_typenum); - constexpr int int64_typeid = - static_cast(dpctl::tensor::detail::typenum_t::INT64); + constexpr int int64_typeid = static_cast(td_ns::typenum_t::INT64); if (cumsum_typeid != int64_typeid) { throw py::value_error( "Unexact data type of cumsum array, expecting 'int64'"); @@ -913,15 +906,14 @@ std::pair py_nonzero( py::ssize_t nz_elems = indexes_shape[1]; int indexes_typenum = indexes.get_typenum(); - auto const &array_types = dpctl::tensor::detail::usm_ndarray_types(); + auto const &array_types = td_ns::usm_ndarray_types(); int indexes_typeid = array_types.typenum_to_lookup_id(indexes_typenum); int cumsum_typenum = cumsum.get_typenum(); int cumsum_typeid = array_types.typenum_to_lookup_id(cumsum_typenum); // cumsum must be int64_t only - constexpr int int64_typeid = - static_cast(dpctl::tensor::detail::typenum_t::INT64); + constexpr int int64_typeid = static_cast(td_ns::typenum_t::INT64); if (cumsum_typeid != int64_typeid || indexes_typeid != int64_typeid) { throw py::value_error( "Cumulative sum array and index array must have int64 data-type"); diff --git a/dpctl/tensor/libtensor/source/copy_and_cast_usm_to_usm.cpp b/dpctl/tensor/libtensor/source/copy_and_cast_usm_to_usm.cpp index 1edbb52dfb..db92e2a18e 100644 --- a/dpctl/tensor/libtensor/source/copy_and_cast_usm_to_usm.cpp +++ b/dpctl/tensor/libtensor/source/copy_and_cast_usm_to_usm.cpp @@ -49,18 +49,18 @@ namespace tensor namespace py_internal { -namespace _ns = dpctl::tensor::detail; +namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::kernels::copy_and_cast::copy_and_cast_1d_fn_ptr_t; using dpctl::tensor::kernels::copy_and_cast::copy_and_cast_contig_fn_ptr_t; using dpctl::tensor::kernels::copy_and_cast::copy_and_cast_generic_fn_ptr_t; static copy_and_cast_generic_fn_ptr_t - copy_and_cast_generic_dispatch_table[_ns::num_types][_ns::num_types]; + copy_and_cast_generic_dispatch_table[td_ns::num_types][td_ns::num_types]; static copy_and_cast_1d_fn_ptr_t - copy_and_cast_1d_dispatch_table[_ns::num_types][_ns::num_types]; + copy_and_cast_1d_dispatch_table[td_ns::num_types][td_ns::num_types]; static copy_and_cast_contig_fn_ptr_t - copy_and_cast_contig_dispatch_table[_ns::num_types][_ns::num_types]; + copy_and_cast_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; namespace py = pybind11; @@ -121,7 +121,7 @@ copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src, int src_typenum = src.get_typenum(); int dst_typenum = dst.get_typenum(); - auto array_types = dpctl::tensor::detail::usm_ndarray_types(); + auto array_types = td_ns::usm_ndarray_types(); int src_type_id = array_types.typenum_to_lookup_id(src_typenum); int dst_type_id = array_types.typenum_to_lookup_id(dst_typenum); @@ -277,7 +277,7 @@ copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src, void init_copy_and_cast_usm_to_usm_dispatch_tables(void) { - using namespace dpctl::tensor::detail; + using namespace td_ns; using dpctl::tensor::kernels::copy_and_cast::CopyAndCastContigFactory; DispatchTableBuildertype_num; int dst_typenum = dst.get_typenum(); - auto array_types = dpctl::tensor::detail::usm_ndarray_types(); + auto array_types = td_ns::usm_ndarray_types(); int src_type_id = array_types.typenum_to_lookup_id(src_typenum); int dst_type_id = array_types.typenum_to_lookup_id(dst_typenum); @@ -239,11 +239,11 @@ void copy_numpy_ndarray_into_usm_ndarray( void init_copy_numpy_ndarray_into_usm_ndarray_dispatch_tables(void) { - using namespace dpctl::tensor::detail; + using namespace td_ns; using dpctl::tensor::kernels::copy_and_cast::CopyAndCastFromHostFactory; DispatchTableBuilder + CopyAndCastFromHostFactory, num_types> dtb_copy_from_numpy; dtb_copy_from_numpy.populate_dispatch_table( diff --git a/dpctl/tensor/libtensor/source/eye_ctor.cpp b/dpctl/tensor/libtensor/source/eye_ctor.cpp index 867e862633..f04518bc48 100644 --- a/dpctl/tensor/libtensor/source/eye_ctor.cpp +++ b/dpctl/tensor/libtensor/source/eye_ctor.cpp @@ -34,7 +34,7 @@ #include "utils/type_dispatch.hpp" namespace py = pybind11; -namespace _ns = dpctl::tensor::detail; +namespace td_ns = dpctl::tensor::type_dispatch; namespace dpctl { @@ -46,7 +46,7 @@ namespace py_internal using dpctl::utils::keep_args_alive; using dpctl::tensor::kernels::constructors::eye_fn_ptr_t; -static eye_fn_ptr_t eye_dispatch_vector[_ns::num_types]; +static eye_fn_ptr_t eye_dispatch_vector[td_ns::num_types]; std::pair usm_ndarray_eye(py::ssize_t k, @@ -66,7 +66,7 @@ usm_ndarray_eye(py::ssize_t k, "allocation queue"); } - auto array_types = dpctl::tensor::detail::usm_ndarray_types(); + auto array_types = td_ns::usm_ndarray_types(); int dst_typenum = dst.get_typenum(); int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); @@ -118,7 +118,7 @@ usm_ndarray_eye(py::ssize_t k, void init_eye_ctor_dispatch_vectors(void) { - using namespace dpctl::tensor::detail; + using namespace td_ns; using dpctl::tensor::kernels::constructors::EyeFactory; DispatchVectorBuilder dvb; diff --git a/dpctl/tensor/libtensor/source/full_ctor.cpp b/dpctl/tensor/libtensor/source/full_ctor.cpp index 5239f283a9..f4b8ae5f42 100644 --- a/dpctl/tensor/libtensor/source/full_ctor.cpp +++ b/dpctl/tensor/libtensor/source/full_ctor.cpp @@ -37,7 +37,7 @@ #include "full_ctor.hpp" namespace py = pybind11; -namespace _ns = dpctl::tensor::detail; +namespace td_ns = dpctl::tensor::type_dispatch; namespace dpctl { @@ -51,7 +51,7 @@ using dpctl::utils::keep_args_alive; using dpctl::tensor::kernels::constructors::full_contig_fn_ptr_t; -static full_contig_fn_ptr_t full_contig_dispatch_vector[_ns::num_types]; +static full_contig_fn_ptr_t full_contig_dispatch_vector[td_ns::num_types]; std::pair usm_ndarray_full(py::object py_value, @@ -73,7 +73,7 @@ usm_ndarray_full(py::object py_value, "Execution queue is not compatible with the allocation queue"); } - auto array_types = dpctl::tensor::detail::usm_ndarray_types(); + auto array_types = td_ns::usm_ndarray_types(); int dst_typenum = dst.get_typenum(); int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); @@ -99,7 +99,7 @@ usm_ndarray_full(py::object py_value, void init_full_ctor_dispatch_vectors(void) { - using namespace dpctl::tensor::detail; + using namespace td_ns; using dpctl::tensor::kernels::constructors::FullContigFactory; DispatchVectorBuilder diff --git a/dpctl/tensor/libtensor/source/integer_advanced_indexing.cpp b/dpctl/tensor/libtensor/source/integer_advanced_indexing.cpp index ee6e741b7a..1039820014 100644 --- a/dpctl/tensor/libtensor/source/integer_advanced_indexing.cpp +++ b/dpctl/tensor/libtensor/source/integer_advanced_indexing.cpp @@ -51,16 +51,16 @@ namespace tensor namespace py_internal { -namespace _ns = dpctl::tensor::detail; +namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::kernels::indexing::put_fn_ptr_t; using dpctl::tensor::kernels::indexing::take_fn_ptr_t; -static take_fn_ptr_t take_dispatch_table[INDEXING_MODES][_ns::num_types] - [_ns::num_types]; +static take_fn_ptr_t take_dispatch_table[INDEXING_MODES][td_ns::num_types] + [td_ns::num_types]; -static put_fn_ptr_t put_dispatch_table[INDEXING_MODES][_ns::num_types] - [_ns::num_types]; +static put_fn_ptr_t put_dispatch_table[INDEXING_MODES][td_ns::num_types] + [td_ns::num_types]; namespace py = pybind11; @@ -324,7 +324,7 @@ usm_ndarray_take(dpctl::tensor::usm_ndarray src, int src_typenum = src.get_typenum(); int dst_typenum = dst.get_typenum(); - auto array_types = dpctl::tensor::detail::usm_ndarray_types(); + auto array_types = td_ns::usm_ndarray_types(); int src_type_id = array_types.typenum_to_lookup_id(src_typenum); int dst_type_id = array_types.typenum_to_lookup_id(dst_typenum); @@ -653,7 +653,7 @@ usm_ndarray_put(dpctl::tensor::usm_ndarray dst, int dst_typenum = dst.get_typenum(); int val_typenum = val.get_typenum(); - auto array_types = dpctl::tensor::detail::usm_ndarray_types(); + auto array_types = td_ns::usm_ndarray_types(); int dst_type_id = array_types.typenum_to_lookup_id(dst_typenum); int val_type_id = array_types.typenum_to_lookup_id(val_typenum); @@ -859,7 +859,7 @@ usm_ndarray_put(dpctl::tensor::usm_ndarray dst, void init_advanced_indexing_dispatch_tables(void) { - using namespace dpctl::tensor::detail; + using namespace td_ns; using dpctl::tensor::kernels::indexing::TakeClipFactory; DispatchTableBuilder diff --git a/dpctl/tensor/libtensor/source/linear_sequences.cpp b/dpctl/tensor/libtensor/source/linear_sequences.cpp index 9384a73198..9b17581b8e 100644 --- a/dpctl/tensor/libtensor/source/linear_sequences.cpp +++ b/dpctl/tensor/libtensor/source/linear_sequences.cpp @@ -37,7 +37,7 @@ #include "linear_sequences.hpp" namespace py = pybind11; -namespace _ns = dpctl::tensor::detail; +namespace td_ns = dpctl::tensor::type_dispatch; namespace dpctl { @@ -50,12 +50,12 @@ using dpctl::utils::keep_args_alive; using dpctl::tensor::kernels::constructors::lin_space_step_fn_ptr_t; -static lin_space_step_fn_ptr_t lin_space_step_dispatch_vector[_ns::num_types]; +static lin_space_step_fn_ptr_t lin_space_step_dispatch_vector[td_ns::num_types]; using dpctl::tensor::kernels::constructors::lin_space_affine_fn_ptr_t; static lin_space_affine_fn_ptr_t - lin_space_affine_dispatch_vector[_ns::num_types]; + lin_space_affine_dispatch_vector[td_ns::num_types]; std::pair usm_ndarray_linear_sequence_step(py::object start, @@ -82,7 +82,7 @@ usm_ndarray_linear_sequence_step(py::object start, "Execution queue is not compatible with the allocation queue"); } - auto array_types = dpctl::tensor::detail::usm_ndarray_types(); + auto array_types = td_ns::usm_ndarray_types(); int dst_typenum = dst.get_typenum(); int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); @@ -130,7 +130,7 @@ usm_ndarray_linear_sequence_affine(py::object start, "Execution queue context is not the same as allocation context"); } - auto array_types = dpctl::tensor::detail::usm_ndarray_types(); + auto array_types = td_ns::usm_ndarray_types(); int dst_typenum = dst.get_typenum(); int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); @@ -155,7 +155,7 @@ usm_ndarray_linear_sequence_affine(py::object start, void init_linear_sequences_dispatch_vectors(void) { - using namespace dpctl::tensor::detail; + using namespace td_ns; using dpctl::tensor::kernels::constructors::LinSpaceAffineFactory; using dpctl::tensor::kernels::constructors::LinSpaceStepFactory; diff --git a/dpctl/tensor/libtensor/source/triul_ctor.cpp b/dpctl/tensor/libtensor/source/triul_ctor.cpp index 686cc77032..47fad15698 100644 --- a/dpctl/tensor/libtensor/source/triul_ctor.cpp +++ b/dpctl/tensor/libtensor/source/triul_ctor.cpp @@ -35,7 +35,7 @@ #include "utils/type_dispatch.hpp" namespace py = pybind11; -namespace _ns = dpctl::tensor::detail; +namespace td_ns = dpctl::tensor::type_dispatch; namespace dpctl { @@ -48,8 +48,8 @@ using dpctl::utils::keep_args_alive; using dpctl::tensor::kernels::constructors::tri_fn_ptr_t; -static tri_fn_ptr_t tril_generic_dispatch_vector[_ns::num_types]; -static tri_fn_ptr_t triu_generic_dispatch_vector[_ns::num_types]; +static tri_fn_ptr_t tril_generic_dispatch_vector[td_ns::num_types]; +static tri_fn_ptr_t triu_generic_dispatch_vector[td_ns::num_types]; std::pair usm_ndarray_triul(sycl::queue exec_q, @@ -100,7 +100,7 @@ usm_ndarray_triul(sycl::queue exec_q, throw py::value_error("Arrays index overlapping segments of memory"); } - auto array_types = dpctl::tensor::detail::usm_ndarray_types(); + auto array_types = td_ns::usm_ndarray_types(); int src_typenum = src.get_typenum(); int dst_typenum = dst.get_typenum(); @@ -219,7 +219,7 @@ usm_ndarray_triul(sycl::queue exec_q, void init_triul_ctor_dispatch_vectors(void) { - using namespace dpctl::tensor::detail; + using namespace td_ns; using dpctl::tensor::kernels::constructors::TrilGenericFactory; using dpctl::tensor::kernels::constructors::TriuGenericFactory; diff --git a/dpctl/tensor/libtensor/source/where.cpp b/dpctl/tensor/libtensor/source/where.cpp index 3122eb31d8..b3843844bd 100644 --- a/dpctl/tensor/libtensor/source/where.cpp +++ b/dpctl/tensor/libtensor/source/where.cpp @@ -46,15 +46,15 @@ namespace tensor namespace py_internal { -namespace _ns = dpctl::tensor::detail; +namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::kernels::search::where_contig_impl_fn_ptr_t; using dpctl::tensor::kernels::search::where_strided_impl_fn_ptr_t; -static where_contig_impl_fn_ptr_t where_contig_dispatch_table[_ns::num_types] - [_ns::num_types]; -static where_strided_impl_fn_ptr_t where_strided_dispatch_table[_ns::num_types] - [_ns::num_types]; +static where_contig_impl_fn_ptr_t where_contig_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static where_strided_impl_fn_ptr_t + where_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; using dpctl::utils::keep_args_alive; @@ -120,7 +120,7 @@ py_where(dpctl::tensor::usm_ndarray condition, int cond_typenum = condition.get_typenum(); int dst_typenum = dst.get_typenum(); - auto const &array_types = dpctl::tensor::detail::usm_ndarray_types(); + auto const &array_types = td_ns::usm_ndarray_types(); int cond_typeid = array_types.typenum_to_lookup_id(cond_typenum); int x1_typeid = array_types.typenum_to_lookup_id(x1_typenum); int x2_typeid = array_types.typenum_to_lookup_id(x2_typenum); @@ -249,17 +249,16 @@ py_where(dpctl::tensor::usm_ndarray condition, void init_where_dispatch_tables(void) { + using namespace td_ns; using dpctl::tensor::kernels::search::WhereContigFactory; - dpctl::tensor::detail::DispatchTableBuilder< - where_contig_impl_fn_ptr_t, WhereContigFactory, - dpctl::tensor::detail::num_types> + DispatchTableBuilder dtb1; dtb1.populate_dispatch_table(where_contig_dispatch_table); using dpctl::tensor::kernels::search::WhereStridedFactory; - dpctl::tensor::detail::DispatchTableBuilder< - where_strided_impl_fn_ptr_t, WhereStridedFactory, - dpctl::tensor::detail::num_types> + DispatchTableBuilder dtb2; dtb2.populate_dispatch_table(where_strided_dispatch_table); }