Skip to content

simplify backend implementation of dpnp.kaiser #2472

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -176,16 +176,17 @@ std::pair<sycl::event, sycl::event>

dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);

int src_nd = src.get_ndim();
const int src_nd = src.get_ndim();
if (src_nd != dst.get_ndim()) {
throw py::value_error("Array dimensions are not the same.");
}

const py::ssize_t *src_shape = src.get_shape_raw();
const py::ssize_t *dst_shape = dst.get_shape_raw();

std::size_t nelems = src.get_size();
bool shapes_equal = std::equal(src_shape, src_shape + src_nd, dst_shape);
const std::size_t nelems = src.get_size();
const bool shapes_equal =
std::equal(src_shape, src_shape + src_nd, dst_shape);
if (!shapes_equal) {
throw py::value_error("Array shapes are not the same.");
}
Expand All @@ -209,14 +210,14 @@ std::pair<sycl::event, sycl::event>
char *dst_data = dst.get_data();

// handle contiguous inputs
bool is_src_c_contig = src.is_c_contiguous();
bool is_src_f_contig = src.is_f_contiguous();
const bool is_src_c_contig = src.is_c_contiguous();
const bool is_src_f_contig = src.is_f_contiguous();

bool is_dst_c_contig = dst.is_c_contiguous();
bool is_dst_f_contig = dst.is_f_contiguous();
const bool is_dst_c_contig = dst.is_c_contiguous();
const bool is_dst_f_contig = dst.is_f_contiguous();

bool both_c_contig = (is_src_c_contig && is_dst_c_contig);
bool both_f_contig = (is_src_f_contig && is_dst_f_contig);
const bool both_c_contig = (is_src_c_contig && is_dst_c_contig);
const bool both_f_contig = (is_src_f_contig && is_dst_f_contig);

if (both_c_contig || both_f_contig) {
auto contig_fn = nan_to_num_contig_dispatch_vector[src_typeid];
Expand Down
8 changes: 4 additions & 4 deletions dpnp/backend/extensions/window/bartlett.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, RES, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
Expand All @@ -35,18 +35,18 @@ template <typename T>
class BartlettFunctor
{
private:
T *data = nullptr;
T *res = nullptr;
const std::size_t N;

public:
BartlettFunctor(T *data, const std::size_t N) : data(data), N(N) {}
BartlettFunctor(T *res, const std::size_t N) : res(res), N(N) {}

void operator()(sycl::id<1> id) const
{
const auto i = id.get(0);

const T alpha = (N - 1) / T(2);
data[i] = T(1) - sycl::fabs(i - alpha) / alpha;
res[i] = T(1) - sycl::fabs(i - alpha) / alpha;
}
};

Expand Down
8 changes: 4 additions & 4 deletions dpnp/backend/extensions/window/blackman.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,19 @@ template <typename T>
class BlackmanFunctor
{
private:
T *data = nullptr;
T *res = nullptr;
const std::size_t N;

public:
BlackmanFunctor(T *data, const std::size_t N) : data(data), N(N) {}
BlackmanFunctor(T *res, const std::size_t N) : res(res), N(N) {}

void operator()(sycl::id<1> id) const
{
const auto i = id.get(0);

const T alpha = T(2) * i / (N - 1);
data[i] = T(0.42) - T(0.5) * sycl::cospi(alpha) +
T(0.08) * sycl::cospi(T(2) * alpha);
res[i] = T(0.42) - T(0.5) * sycl::cospi(alpha) +
T(0.08) * sycl::cospi(T(2) * alpha);
}
};

Expand Down
16 changes: 8 additions & 8 deletions dpnp/backend/extensions/window/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,16 @@ typedef sycl::event (*window_fn_ptr_t)(sycl::queue &,
const std::vector<sycl::event> &);

template <typename T, template <typename> class Functor>
sycl::event window_impl(sycl::queue &q,
sycl::event window_impl(sycl::queue &exec_q,
char *result,
const std::size_t nelems,
const std::vector<sycl::event> &depends)
{
dpctl::tensor::type_utils::validate_type_for_device<T>(q);
dpctl::tensor::type_utils::validate_type_for_device<T>(exec_q);

T *res = reinterpret_cast<T *>(result);

sycl::event window_ev = q.submit([&](sycl::handler &cgh) {
sycl::event window_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);

using WindowKernel = Functor<T>;
Expand All @@ -75,7 +75,7 @@ std::tuple<size_t, char *, funcPtrT>
{
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(result);

int nd = result.get_ndim();
const int nd = result.get_ndim();
if (nd != 1) {
throw py::value_error("Array should be 1d");
}
Expand All @@ -87,17 +87,17 @@ std::tuple<size_t, char *, funcPtrT>

const bool is_result_c_contig = result.is_c_contiguous();
if (!is_result_c_contig) {
throw py::value_error("The result input array is not c-contiguous.");
throw py::value_error("The result array is not c-contiguous.");
}

size_t nelems = result.get_size();
const std::size_t nelems = result.get_size();
if (nelems == 0) {
return std::make_tuple(nelems, nullptr, nullptr);
}

int result_typenum = result.get_typenum();
const int result_typenum = result.get_typenum();
auto array_types = dpctl_td_ns::usm_ndarray_types();
int result_type_id = array_types.typenum_to_lookup_id(result_typenum);
const int result_type_id = array_types.typenum_to_lookup_id(result_typenum);
funcPtrT fn = window_dispatch_vector[result_type_id];

if (fn == nullptr) {
Expand Down
6 changes: 3 additions & 3 deletions dpnp/backend/extensions/window/hamming.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,17 @@ template <typename T>
class HammingFunctor
{
private:
T *data = nullptr;
T *res = nullptr;
const std::size_t N;

public:
HammingFunctor(T *data, const std::size_t N) : data(data), N(N) {}
HammingFunctor(T *res, const std::size_t N) : res(res), N(N) {}

void operator()(sycl::id<1> id) const
{
const auto i = id.get(0);

data[i] = T(0.54) - T(0.46) * sycl::cospi(T(2) * i / (N - 1));
res[i] = T(0.54) - T(0.46) * sycl::cospi(T(2) * i / (N - 1));
}
};

Expand Down
6 changes: 3 additions & 3 deletions dpnp/backend/extensions/window/hanning.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,17 @@ template <typename T>
class HanningFunctor
{
private:
T *data = nullptr;
T *res = nullptr;
const std::size_t N;

public:
HanningFunctor(T *data, const std::size_t N) : data(data), N(N) {}
HanningFunctor(T *res, const std::size_t N) : res(res), N(N) {}

void operator()(sycl::id<1> id) const
{
const auto i = id.get(0);

data[i] = T(0.5) - T(0.5) * sycl::cospi(T(2) * i / (N - 1));
res[i] = T(0.5) - T(0.5) * sycl::cospi(T(2) * i / (N - 1));
}
};

Expand Down
26 changes: 13 additions & 13 deletions dpnp/backend/extensions/window/kaiser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@ template <typename T>
class KaiserFunctor
{
private:
T *data = nullptr;
T *res = nullptr;
const std::size_t N;
const T beta;

public:
KaiserFunctor(T *data, const std::size_t N, const T beta)
: data(data), N(N), beta(beta)
KaiserFunctor(T *res, const std::size_t N, const T beta)
: res(res), N(N), beta(beta)
{
}

Expand All @@ -67,27 +67,27 @@ class KaiserFunctor
const auto i = id.get(0);
const T alpha = (N - 1) / T(2);
const T tmp = (i - alpha) / alpha;
data[i] = cyl_bessel_i0(beta * sycl::sqrt(1 - tmp * tmp)) /
cyl_bessel_i0(beta);
res[i] = cyl_bessel_i0(beta * sycl::sqrt(1 - tmp * tmp)) /
cyl_bessel_i0(beta);
}
};

template <typename T, template <typename> class Functor>
sycl::event kaiser_impl(sycl::queue &q,
template <typename T>
sycl::event kaiser_impl(sycl::queue &exec_q,
char *result,
const std::size_t nelems,
const py::object &py_beta,
const std::vector<sycl::event> &depends)
{
dpctl::tensor::type_utils::validate_type_for_device<T>(q);
dpctl::tensor::type_utils::validate_type_for_device<T>(exec_q);

T *res = reinterpret_cast<T *>(result);
const T beta = py::cast<const T>(py_beta);

sycl::event kaiser_ev = q.submit([&](sycl::handler &cgh) {
sycl::event kaiser_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);

using KaiserKernel = Functor<T>;
using KaiserKernel = KaiserFunctor<T>;
cgh.parallel_for<KaiserKernel>(sycl::range<1>(nelems),
KaiserKernel(res, nelems, beta));
});
Expand All @@ -101,7 +101,7 @@ struct KaiserFactory
fnT get()
{
if constexpr (std::is_floating_point_v<T>) {
return kaiser_impl<T, KaiserFunctor>;
return kaiser_impl<T>;
}
else {
return nullptr;
Expand All @@ -115,15 +115,15 @@ std::pair<sycl::event, sycl::event>
const dpctl::tensor::usm_ndarray &result,
const std::vector<sycl::event> &depends)
{
auto [nelems, result_typeless_ptr, fn] =
auto [nelems, result_typeless_ptr, kaiser_fn] =
window_fn<kaiser_fn_ptr_t>(exec_q, result, kaiser_dispatch_vector);

if (nelems == 0) {
return std::make_pair(sycl::event{}, sycl::event{});
}

sycl::event kaiser_ev =
fn(exec_q, result_typeless_ptr, nelems, py_beta, depends);
kaiser_fn(exec_q, result_typeless_ptr, nelems, py_beta, depends);
sycl::event args_ev =
dpctl::utils::keep_args_alive(exec_q, {result}, {kaiser_ev});

Expand Down
Loading