Skip to content

Commit 1950d5b

Browse files
committed
Add dpnp.linalg.eigh() function
1 parent 2dfa804 commit 1950d5b

18 files changed

+1230
-12
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# CMake build and local install directory
22
_skbuild
3+
build
34
build_cython
45
dpnp.egg-info
56

dpnp/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ endfunction()
4747

4848
build_dpnp_cython_ext_with_backend(dparray ${CMAKE_CURRENT_SOURCE_DIR}/dparray.pyx dpnp)
4949
add_subdirectory(backend)
50+
add_subdirectory(backend/extensions/lapack)
5051

5152
add_subdirectory(dpnp_algo)
5253
add_subdirectory(dpnp_utils)

dpnp/backend/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ add_library(dpnp_backend_library INTERFACE IMPORTED GLOBAL)
109109
target_include_directories(dpnp_backend_library BEFORE INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}/include ${CMAKE_CURRENT_SOURCE_DIR}/src)
110110
target_link_libraries(dpnp_backend_library INTERFACE ${_trgt})
111111

112-
if(DPNP_BACKEND_TESTS)
112+
if (DPNP_BACKEND_TESTS)
113113
add_subdirectory(tests)
114114
endif()
115115

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# *****************************************************************************
2+
# Copyright (c) 2016-2023, Intel Corporation
3+
# All rights reserved.
4+
#
5+
# Redistribution and use in source and binary forms, with or without
6+
# modification, are permitted provided that the following conditions are met:
7+
# - Redistributions of source code must retain the above copyright notice,
8+
# this list of conditions and the following disclaimer.
9+
# - Redistributions in binary form must reproduce the above copyright notice,
10+
# this list of conditions and the following disclaimer in the documentation
11+
# and/or other materials provided with the distribution.
12+
#
13+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
# THE POSSIBILITY OF SUCH DAMAGE.
24+
# *****************************************************************************
25+
26+
27+
set(python_module_name _lapack_impl)
28+
pybind11_add_module(${python_module_name} MODULE
29+
lapack_py.cpp
30+
heevd.cpp
31+
syevd.cpp
32+
)
33+
34+
if (WIN32)
35+
if (${CMAKE_VERSION} VERSION_LESS "3.23")
36+
# this is a work-around for target_link_options inserting option after -link option, cause
37+
# linker to ignore it.
38+
set(CMAKE_CXX_LINK_FLAGS "${CMAKE_CXX_LINK_FLAGS} -fsycl-device-code-split=per_kernel")
39+
endif()
40+
endif()
41+
42+
set_target_properties(${python_module_name} PROPERTIES CMAKE_POSITION_INDEPENDENT_CODE ON)
43+
44+
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../include)
45+
target_include_directories(${python_module_name} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../src)
46+
47+
target_include_directories(${python_module_name} PUBLIC ${Dpctl_INCLUDE_DIRS})
48+
49+
if (WIN32)
50+
target_compile_options(${python_module_name} PRIVATE
51+
/clang:-fno-approx-func
52+
/clang:-fno-finite-math-only
53+
)
54+
else()
55+
target_compile_options(${python_module_name} PRIVATE
56+
-fno-approx-func
57+
-fno-finite-math-only
58+
)
59+
endif()
60+
61+
target_link_options(${python_module_name} PUBLIC -fsycl-device-code-split=per_kernel)
62+
if (UNIX)
63+
# this option is support on Linux only
64+
target_link_options(${python_module_name} PUBLIC -fsycl-link-huge-device-code)
65+
endif()
66+
67+
if (DPNP_GENERATE_COVERAGE)
68+
target_link_options(${python_module_name} PRIVATE -fprofile-instr-generate -fcoverage-mapping)
69+
endif()
70+
71+
target_link_libraries(${python_module_name} PUBLIC MKL::MKL_DPCPP)
72+
73+
install(TARGETS ${python_module_name}
74+
DESTINATION "dpnp/backend/extensions/lapack"
75+
)
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2023, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
27+
#include <pybind11/pybind11.h>
28+
29+
#include "heevd.hpp"
30+
31+
#include "dpnp_utils.hpp"
32+
33+
34+
namespace dpnp
35+
{
36+
namespace backend
37+
{
38+
namespace ext
39+
{
40+
namespace lapack
41+
{
42+
43+
namespace mkl_lapack = oneapi::mkl::lapack;
44+
namespace py = pybind11;
45+
46+
template <typename T, typename RealT>
47+
static inline sycl::event call_heevd(sycl::queue exec_q,
48+
const oneapi::mkl::job jobz,
49+
const oneapi::mkl::uplo upper_lower,
50+
const std::int64_t n,
51+
T* a,
52+
RealT* w,
53+
std::vector<sycl::event> &host_task_events,
54+
const std::vector<sycl::event>& depends)
55+
{
56+
validate_type_for_device<T>(exec_q);
57+
validate_type_for_device<RealT>(exec_q);
58+
59+
const std::int64_t lda = std::max<size_t>(1UL, n);
60+
const std::int64_t scratchpad_size = mkl_lapack::heevd_scratchpad_size<T>(exec_q, jobz, upper_lower, n, lda);
61+
T* scratchpad = nullptr;
62+
63+
std::stringstream error_msg;
64+
std::int64_t info = 0;
65+
66+
sycl::event heevd_event;
67+
try
68+
{
69+
scratchpad = sycl::malloc_device<T>(scratchpad_size, exec_q);
70+
71+
heevd_event = mkl_lapack::heevd(
72+
exec_q,
73+
jobz, // 'jobz == job::vec' means eigenvalues and eigenvectors are computed.
74+
upper_lower, // 'upper_lower == job::upper' means the upper triangular part of A, or the lower triangular otherwise
75+
n, // The order of the matrix A (0 <= n)
76+
a, // Pointer to A, size (lda, *), where the 2nd dimension, must be at least max(1, n)
77+
// If 'jobz == job::vec', then on exit it will contain the eigenvectors of A
78+
lda, // The leading dimension of a, must be at least max(1, n)
79+
w, // Pointer to array of size at least n, it will contain the eigenvalues of A in ascending order
80+
scratchpad, // Pointer to scratchpad memory to be used by MKL routine for storing intermediate results
81+
scratchpad_size,
82+
depends);
83+
}
84+
catch (mkl_lapack::exception const& e)
85+
{
86+
error_msg << "Unexpected MKL exception caught during heevd() call:\nreason: " << e.what()
87+
<< "\ninfo: " << e.info();
88+
info = e.info();
89+
}
90+
catch (sycl::exception const& e)
91+
{
92+
error_msg << "Unexpected SYCL exception caught during heevd() call:\n" << e.what();
93+
info = -1;
94+
}
95+
96+
if (info != 0) // an unexected error occurs
97+
{
98+
if (scratchpad != nullptr)
99+
{
100+
sycl::free(scratchpad, exec_q);
101+
}
102+
throw std::runtime_error(error_msg.str());
103+
}
104+
105+
sycl::event clean_up_event = exec_q.submit([&](sycl::handler& cgh) {
106+
cgh.depends_on(heevd_event);
107+
auto ctx = exec_q.get_context();
108+
cgh.host_task([ctx, scratchpad]() { sycl::free(scratchpad, ctx); });
109+
});
110+
host_task_events.push_back(clean_up_event);
111+
return heevd_event;
112+
}
113+
114+
std::pair<sycl::event, sycl::event> heevd(sycl::queue exec_q,
115+
const std::int8_t jobz,
116+
const std::int8_t upper_lower,
117+
dpctl::tensor::usm_ndarray eig_vecs,
118+
dpctl::tensor::usm_ndarray eig_vals,
119+
const std::vector<sycl::event>& depends)
120+
{
121+
const int eig_vecs_nd = eig_vecs.get_ndim();
122+
const int eig_vals_nd = eig_vals.get_ndim();
123+
124+
if (eig_vecs_nd != 2)
125+
{
126+
throw py::value_error("Unexpected ndim=" + std::to_string(eig_vecs_nd) +
127+
" of an output array with eigenvectors");
128+
}
129+
else if (eig_vals_nd != 1)
130+
{
131+
throw py::value_error("Unexpected ndim=" + std::to_string(eig_vals_nd) +
132+
" of an output array with eigenvalues");
133+
}
134+
135+
const py::ssize_t* eig_vecs_shape = eig_vecs.get_shape_raw();
136+
const py::ssize_t* eig_vals_shape = eig_vals.get_shape_raw();
137+
138+
if (eig_vecs_shape[0] != eig_vecs_shape[1])
139+
{
140+
throw py::value_error("Output array with eigenvectors with be square");
141+
}
142+
else if (eig_vecs_shape[0] != eig_vals_shape[0])
143+
{
144+
throw py::value_error("Eigenvectors and eigenvalues have different shapes");
145+
}
146+
147+
size_t src_nelems(1);
148+
149+
for (int i = 0; i < eig_vecs_nd; ++i)
150+
{
151+
src_nelems *= static_cast<size_t>(eig_vecs_shape[i]);
152+
}
153+
154+
if (src_nelems == 0)
155+
{
156+
// nothing to do
157+
return std::make_pair(sycl::event(), sycl::event());
158+
}
159+
160+
// check compatibility of execution queue and allocation queue
161+
if (!dpctl::utils::queues_are_compatible(exec_q, {eig_vecs, eig_vals}))
162+
{
163+
throw py::value_error("Execution queue is not compatible with allocation queues");
164+
}
165+
166+
// check that arrays do not overlap, and concurrent access is safe.
167+
// TODO: need to be exposed by DPCTL headers
168+
// auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
169+
// if (overlap(eig_vecs, eig_vals))
170+
// {
171+
// throw py::value_error("Arrays index overlapping segments of memory");
172+
// }
173+
174+
int eig_vecs_typenum = eig_vecs.get_typenum();
175+
int eig_vals_typenum = eig_vals.get_typenum();
176+
auto const& dpctl_capi = dpctl::detail::dpctl_capi::get();
177+
178+
sycl::event heevd_ev;
179+
std::vector<sycl::event> host_task_events;
180+
181+
const std::int64_t n = eig_vecs_shape[0];
182+
const oneapi::mkl::job jobz_val = static_cast<oneapi::mkl::job>(jobz);
183+
const oneapi::mkl::uplo uplo_val = static_cast<oneapi::mkl::uplo>(upper_lower);
184+
185+
if ((eig_vecs_typenum == dpctl_capi.UAR_CDOUBLE_) && (eig_vals_typenum == dpctl_capi.UAR_DOUBLE_))
186+
{
187+
std::complex<double>* a = reinterpret_cast<std::complex<double>*>(eig_vecs.get_data());
188+
double* w = reinterpret_cast<double*>(eig_vals.get_data());
189+
190+
heevd_ev = call_heevd(exec_q, jobz_val, uplo_val, n, a, w, host_task_events, depends);
191+
}
192+
else if ((eig_vecs_typenum == dpctl_capi.UAR_CFLOAT_) && (eig_vals_typenum == dpctl_capi.UAR_FLOAT_))
193+
{
194+
std::complex<float>* a = reinterpret_cast<std::complex<float>*>(eig_vecs.get_data());
195+
float* w = reinterpret_cast<float*>(eig_vals.get_data());
196+
197+
heevd_ev = call_heevd(exec_q, jobz_val, uplo_val, n, a, w, host_task_events, depends);
198+
}
199+
else
200+
{
201+
throw py::value_error("Unexpected types of either eigenvectors or eigenvalues");
202+
}
203+
204+
sycl::event args_ev = dpctl::utils::keep_args_alive(exec_q, {eig_vecs, eig_vals}, host_task_events);
205+
return std::make_pair(args_ev, heevd_ev);
206+
}
207+
}
208+
}
209+
}
210+
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2023, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <CL/sycl.hpp>
29+
#include <oneapi/mkl.hpp>
30+
31+
#include <dpctl4pybind11.hpp>
32+
33+
34+
namespace dpnp
35+
{
36+
namespace backend
37+
{
38+
namespace ext
39+
{
40+
namespace lapack
41+
{
42+
extern std::pair<sycl::event, sycl::event> heevd(sycl::queue exec_q,
43+
const std::int8_t jobz,
44+
const std::int8_t upper_lower,
45+
dpctl::tensor::usm_ndarray eig_vecs,
46+
dpctl::tensor::usm_ndarray eig_vals,
47+
const std::vector<sycl::event>& depends);
48+
}
49+
}
50+
}
51+
}

0 commit comments

Comments
 (0)