Skip to content

Add dpnp.linalg.solve() function #1598

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 73 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
2613795
Correct return of object type at zero copy
vlad-perevezentsev Sep 26, 2023
efdcf2a
Add tests for gh-1570
vlad-perevezentsev Sep 26, 2023
6500a6c
Add dpnp.linalg.solve() function
vlad-perevezentsev Oct 5, 2023
51663d0
Check validity of input array shapes
vlad-perevezentsev Oct 6, 2023
00b7041
Add logic for a.ndim > 2
vlad-perevezentsev Oct 6, 2023
e4e15ad
Raise value_error if coeff_matrix_nd != 2 in gesv
vlad-perevezentsev Oct 6, 2023
e762c66
Add cupy tests for dpnp.linalg.solve()
vlad-perevezentsev Oct 6, 2023
a0f76d5
Add LinAlgError exception and extend error handling for mkl::lapack::…
vlad-perevezentsev Oct 10, 2023
89f47c7
Update test_solve
vlad-perevezentsev Oct 10, 2023
159c460
Add test_solve to test scope
vlad-perevezentsev Oct 10, 2023
2ad3aa8
Fix getting nrhs to avoid CPU falling tests
vlad-perevezentsev Oct 12, 2023
ac02cf2
Add test_solve to test_sycl_queue
vlad-perevezentsev Oct 12, 2023
035d983
Add more tests for solve()
vlad-perevezentsev Oct 12, 2023
f80ba4b
Register a LinAlgError in dpnp.linalg submodule
vlad-perevezentsev Oct 13, 2023
9a3db74
Raise dpnp.linalg.LinAlgError in solve()
vlad-perevezentsev Oct 13, 2023
a8b4fec
Small changes to the docstrings
vlad-perevezentsev Oct 15, 2023
40e74ae
Merge master into impl_solve
vlad-perevezentsev Oct 15, 2023
3f88dc5
Simplify ThresholdType determination
vlad-perevezentsev Oct 16, 2023
22e8734
Small changes to the docstrings
vlad-perevezentsev Oct 17, 2023
79c60df
Merge master into impl_solve_1
vlad-perevezentsev Oct 17, 2023
87cda0b
Remove if op_count due to unreachable
vlad-perevezentsev Oct 17, 2023
76e035d
Improve test coverage
vlad-perevezentsev Oct 17, 2023
1bfc81a
Impl dtype dispatching with linalg_common_type for dpnp.linalg.solve
vlad-perevezentsev Oct 17, 2023
c7b284b
Add a new test_solve_diff_type
vlad-perevezentsev Oct 17, 2023
353b756
Merge master into impl_solve_1
vlad-perevezentsev Oct 17, 2023
e5c7626
Merge master into impl_solve_1
vlad-perevezentsev Oct 26, 2023
e99d37c
Add a common_helpers.hpp file
vlad-perevezentsev Oct 26, 2023
ec1e966
Use bool flag for sycl exception
vlad-perevezentsev Oct 26, 2023
6885aa3
Refactor memory management for ipiv in gesv_impl
vlad-perevezentsev Oct 26, 2023
56c920f
Rename linalg_common_type to _common_type and change the number of ty…
vlad-perevezentsev Oct 27, 2023
07b1418
Address the remarks
vlad-perevezentsev Oct 27, 2023
711a62f
Merge master into impl_solve_1
vlad-perevezentsev Oct 31, 2023
0336c00
Remove the use of prod to get 3d array and rename op_count to batch_size
vlad-perevezentsev Nov 2, 2023
31f6f10
gesv returns pair of events and uses dpctl.utils.keep_args_alive
vlad-perevezentsev Nov 2, 2023
d3717a6
Return the use prod to get 3d arrays
vlad-perevezentsev Nov 2, 2023
8e19740
Merge master into impl_solve_1
vlad-perevezentsev Nov 20, 2023
515df4e
Add test_solve_singular_matrix in TestSolve
vlad-perevezentsev Nov 20, 2023
f992b0b
Adress the remarks
vlad-perevezentsev Nov 20, 2023
cd21c7f
Add res_usm_type variavble and new tests in test_usm_type for dpnp.li…
vlad-perevezentsev Nov 21, 2023
5781f0c
Add skipif for test_solve_singular_matrix on cpu
vlad-perevezentsev Nov 22, 2023
0f87471
Merge master into impl_solve_1
vlad-perevezentsev Nov 22, 2023
b4eb1ad
Merge branch 'master' into impl_solve_1
antonwolfy Nov 29, 2023
cec8154
Modify _common_inexact_type and add a description for it
vlad-perevezentsev Nov 30, 2023
c3e5a0f
A small update of the desctiption of dpnp.linalg.solve() func
vlad-perevezentsev Nov 30, 2023
eb2dd4c
Use device param for default_float_type in _common_type
vlad-perevezentsev Dec 1, 2023
4780597
Simplify getting 3d array in dpnp_solve
vlad-perevezentsev Dec 4, 2023
f1b6a81
Remove unnecessary copying to F order after invoking gesv
vlad-perevezentsev Dec 4, 2023
b8f4cb9
Use get_usm_allocations instead of get_execution_queue
vlad-perevezentsev Dec 4, 2023
00ebef1
Move copying just after the memory allocation
vlad-perevezentsev Dec 4, 2023
22d4d6f
Add additional checks to gesv implementation
vlad-perevezentsev Dec 4, 2023
04f8f41
Add validation functions for array types and dimensions for linalg funcs
vlad-perevezentsev Dec 4, 2023
366be7a
Update test_solve_diff_type in test_linalg.py
vlad-perevezentsev Dec 4, 2023
08ac7fe
Address the remarks
vlad-perevezentsev Dec 5, 2023
eb6a840
Small update
vlad-perevezentsev Dec 5, 2023
61a6073
qwe
vlad-perevezentsev Dec 5, 2023
6a25e69
Merge origin/master into impl_solve_1
vlad-perevezentsev Dec 6, 2023
4ba0c7f
Merge master into impl_solve_1
vlad-perevezentsev Dec 7, 2023
eba811a
Rename assert funcs and make them external in dpnp_utils_linalg
vlad-perevezentsev Dec 7, 2023
3a9d459
Use assert_dtype_allclose for test_solve in test_sycl_queue
vlad-perevezentsev Dec 7, 2023
0de6968
Remove an unnecessary file
vlad-perevezentsev Dec 7, 2023
df72c77
Fix validation for CI
vlad-perevezentsev Dec 7, 2023
4991a81
Remove eqec_q check that will never happen
vlad-perevezentsev Dec 8, 2023
965b89a
Merge master into impl_solve_1
vlad-perevezentsev Dec 8, 2023
9b5a5a5
Set usm_type for out_v
vlad-perevezentsev Dec 11, 2023
3c7ad07
Update test_solve_singular_empty
vlad-perevezentsev Dec 11, 2023
e76b278
Merge master into impl_solve_1
vlad-perevezentsev Dec 11, 2023
3001872
Skip test_solve_singular_empty
vlad-perevezentsev Dec 11, 2023
5c3693b
Merge master into impl_solve_1
vlad-perevezentsev Dec 11, 2023
37f5400
Merge master into impl_solve_1
vlad-perevezentsev Dec 12, 2023
5392a45
Fix validation fell
vlad-perevezentsev Dec 12, 2023
6bea640
Merge master into impl_solve_1
vlad-perevezentsev Dec 12, 2023
c105570
A small update
vlad-perevezentsev Dec 12, 2023
3165397
Merge master into impl_solve_1
vlad-perevezentsev Dec 13, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/conda-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ env:
test_usm_type.py
third_party/cupy/core_tests
third_party/cupy/linalg_tests/test_product.py
third_party/cupy/linalg_tests/test_solve.py
third_party/cupy/logic_tests/test_comparison.py
third_party/cupy/logic_tests/test_truth.py
third_party/cupy/manipulation_tests/test_basic.py
Expand Down
1 change: 1 addition & 0 deletions dpnp/backend/extensions/lapack/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
set(python_module_name _lapack_impl)
set(_module_src
${CMAKE_CURRENT_SOURCE_DIR}/lapack_py.cpp
${CMAKE_CURRENT_SOURCE_DIR}/gesv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/heevd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/syevd.cpp
)
Expand Down
55 changes: 55 additions & 0 deletions dpnp/backend/extensions/lapack/common_helpers.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
//*****************************************************************************
// Copyright (c) 2023, Intel Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// - Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
// - Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
// THE POSSIBILITY OF SUCH DAMAGE.
//*****************************************************************************

#pragma once
#include <cstring>
#include <stdexcept>

namespace dpnp
{
namespace backend
{
namespace ext
{
namespace lapack
{
namespace helper
{
template <typename T>
struct value_type_of
{
using type = T;
};

template <typename T>
struct value_type_of<std::complex<T>>
{
using type = T;
};
} // namespace helper
} // namespace lapack
} // namespace ext
} // namespace backend
} // namespace dpnp
297 changes: 297 additions & 0 deletions dpnp/backend/extensions/lapack/gesv.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,297 @@
//*****************************************************************************
// Copyright (c) 2023, Intel Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// - Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
// - Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
// THE POSSIBILITY OF SUCH DAMAGE.
//*****************************************************************************

#include <pybind11/pybind11.h>

// dpctl tensor headers
#include "utils/memory_overlap.hpp"
#include "utils/type_utils.hpp"

#include "common_helpers.hpp"
#include "gesv.hpp"
#include "linalg_exceptions.hpp"
#include "types_matrix.hpp"

#include "dpnp_utils.hpp"

namespace dpnp
{
namespace backend
{
namespace ext
{
namespace lapack
{
namespace mkl_lapack = oneapi::mkl::lapack;
namespace py = pybind11;
namespace type_utils = dpctl::tensor::type_utils;

typedef sycl::event (*gesv_impl_fn_ptr_t)(sycl::queue,
const std::int64_t,
const std::int64_t,
char *,
std::int64_t,
char *,
std::int64_t,
std::vector<sycl::event> &,
const std::vector<sycl::event> &);

static gesv_impl_fn_ptr_t gesv_dispatch_vector[dpctl_td_ns::num_types];

template <typename T>
static sycl::event gesv_impl(sycl::queue exec_q,
const std::int64_t n,
const std::int64_t nrhs,
char *in_a,
std::int64_t lda,
char *in_b,
std::int64_t ldb,
std::vector<sycl::event> &host_task_events,
const std::vector<sycl::event> &depends)
{
type_utils::validate_type_for_device<T>(exec_q);

T *a = reinterpret_cast<T *>(in_a);
T *b = reinterpret_cast<T *>(in_b);

const std::int64_t scratchpad_size =
mkl_lapack::gesv_scratchpad_size<T>(exec_q, n, nrhs, lda, ldb);
T *scratchpad = nullptr;

std::int64_t *ipiv = nullptr;

std::stringstream error_msg;
std::int64_t info = 0;
bool sycl_exception_caught = false;

sycl::event gesv_event;
try {
scratchpad = sycl::malloc_device<T>(scratchpad_size, exec_q);
ipiv = sycl::malloc_device<std::int64_t>(n, exec_q);

gesv_event = mkl_lapack::gesv(
exec_q,
n, // The order of the matrix A (0 ≤ n).
nrhs, // The number of right-hand sides B (0 ≤ nrhs).
a, // Pointer to the square coefficient matrix A (n x n).
lda, // The leading dimension of a, must be at least max(1, n).
ipiv, // The pivot indices that define the permutation matrix P;
// row i of the matrix was interchanged with row ipiv(i),
// must be at least max(1, n).
b, // Pointer to the right hand side matrix B (n x nrhs).
ldb, // The leading dimension of b, must be at least max(1, n).
scratchpad, // Pointer to scratchpad memory to be used by MKL
// routine for storing intermediate results.
scratchpad_size, depends);
} catch (mkl_lapack::exception const &e) {
info = e.info();

if (info < 0) {
error_msg << "Parameter number " << -info
<< " had an illegal value.";
}
else if (info > 0) {
T host_U;
exec_q.memcpy(&host_U, &a[(info - 1) * lda + info - 1], sizeof(T))
.wait();

using ThresholdType = typename helper::value_type_of<T>::type;

const auto threshold =
std::numeric_limits<ThresholdType>::epsilon() * 100;
if (std::abs(host_U) < threshold) {
sycl::free(scratchpad, exec_q);
throw LinAlgError("The input coefficient matrix is singular.");
}
else {
error_msg << "Unexpected MKL exception caught during gesv() "
"call:\nreason: "
<< e.what() << "\ninfo: " << e.info();
}
}
else if (info == scratchpad_size && e.detail() != 0) {
error_msg
<< "Insufficient scratchpad size. Required size is at least "
<< e.detail();
}
else {
error_msg << "Unexpected MKL exception caught during gesv() "
"call:\nreason: "
<< e.what() << "\ninfo: " << e.info();
}
} catch (sycl::exception const &e) {
error_msg << "Unexpected SYCL exception caught during gesv() call:\n"
<< e.what();
sycl_exception_caught = true;
}

if (info != 0 || sycl_exception_caught) // an unexpected error occurs
{
if (scratchpad != nullptr) {
sycl::free(scratchpad, exec_q);
}
if (ipiv != nullptr) {
sycl::free(ipiv, exec_q);
}
throw std::runtime_error(error_msg.str());
}

sycl::event clean_up_event = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(gesv_event);
auto ctx = exec_q.get_context();
cgh.host_task([ctx, scratchpad, ipiv]() {
sycl::free(scratchpad, ctx);
sycl::free(ipiv, ctx);
});
});
host_task_events.push_back(clean_up_event);

return gesv_event;
}

std::pair<sycl::event, sycl::event>
gesv(sycl::queue exec_q,
dpctl::tensor::usm_ndarray coeff_matrix,
dpctl::tensor::usm_ndarray dependent_vals,
const std::vector<sycl::event> &depends)
{
const int coeff_matrix_nd = coeff_matrix.get_ndim();
const int dependent_vals_nd = dependent_vals.get_ndim();

if (coeff_matrix_nd != 2) {
throw py::value_error("The coefficient matrix has ndim=" +
std::to_string(coeff_matrix_nd) +
", but a 2-dimensional array is expected.");
}

if (dependent_vals_nd > 2) {
throw py::value_error(
"The dependent values array has ndim=" +
std::to_string(dependent_vals_nd) +
", but a 1-dimensional or a 2-dimensional array is expected.");
}

const py::ssize_t *coeff_matrix_shape = coeff_matrix.get_shape_raw();
const py::ssize_t *dependent_vals_shape = dependent_vals.get_shape_raw();

if (coeff_matrix_shape[0] != coeff_matrix_shape[1]) {
throw py::value_error("The coefficient matrix must be square,"
" but got a shape of (" +
std::to_string(coeff_matrix_shape[0]) + ", " +
std::to_string(coeff_matrix_shape[1]) + ").");
}

// check compatibility of execution queue and allocation queue
if (!dpctl::utils::queues_are_compatible(exec_q,
{coeff_matrix, dependent_vals}))
{
throw py::value_error(
"Execution queue is not compatible with allocation queues");
}

auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
if (overlap(coeff_matrix, dependent_vals)) {
throw py::value_error(
"The arrays of coefficients and dependent variables "
"are overlapping segments of memory");
}

bool is_coeff_matrix_f_contig = coeff_matrix.is_f_contiguous();
if (!is_coeff_matrix_f_contig) {
throw py::value_error("The coefficient matrix "
"must be F-contiguous");
}

bool is_dependent_vals_f_contig = dependent_vals.is_f_contiguous();
if (!is_dependent_vals_f_contig) {
throw py::value_error("The array of dependent variables "
"must be F-contiguous");
}

auto array_types = dpctl_td_ns::usm_ndarray_types();
int coeff_matrix_type_id =
array_types.typenum_to_lookup_id(coeff_matrix.get_typenum());
int dependent_vals_type_id =
array_types.typenum_to_lookup_id(dependent_vals.get_typenum());

if (coeff_matrix_type_id != dependent_vals_type_id) {
throw py::value_error("The types of the coefficient matrix and "
"dependent variables are mismatched");
}

gesv_impl_fn_ptr_t gesv_fn = gesv_dispatch_vector[coeff_matrix_type_id];
if (gesv_fn == nullptr) {
throw py::value_error(
"No gesv implementation defined for the provided type "
"of the coefficient matrix.");
}

char *coeff_matrix_data = coeff_matrix.get_data();
char *dependent_vals_data = dependent_vals.get_data();

const std::int64_t n = coeff_matrix_shape[0];
const std::int64_t m = dependent_vals_shape[0];
const std::int64_t nrhs =
(dependent_vals_nd > 1) ? dependent_vals_shape[1] : 1;

const std::int64_t lda = std::max<size_t>(1UL, n);
const std::int64_t ldb = std::max<size_t>(1UL, m);

std::vector<sycl::event> host_task_events;
sycl::event gesv_ev =
gesv_fn(exec_q, n, nrhs, coeff_matrix_data, lda, dependent_vals_data,
ldb, host_task_events, depends);

sycl::event args_ev = dpctl::utils::keep_args_alive(
exec_q, {coeff_matrix, dependent_vals}, host_task_events);

return std::make_pair(args_ev, gesv_ev);
}

template <typename fnT, typename T>
struct GesvContigFactory
{
fnT get()
{
if constexpr (types::GesvTypePairSupportFactory<T>::is_defined) {
return gesv_impl<T>;
}
else {
return nullptr;
}
}
};

void init_gesv_dispatch_vector(void)
{
dpctl_td_ns::DispatchVectorBuilder<gesv_impl_fn_ptr_t, GesvContigFactory,
dpctl_td_ns::num_types>
contig;
contig.populate_dispatch_vector(gesv_dispatch_vector);
}
} // namespace lapack
} // namespace ext
} // namespace backend
} // namespace dpnp
Loading