1- #ifndef STAN_MATH_OPENCL_LOWER_TRI_INVERSE_HPP
2- #define STAN_MATH_OPENCL_LOWER_TRI_INVERSE_HPP
1+ #ifndef STAN_MATH_OPENCL_TRI_INVERSE_HPP
2+ #define STAN_MATH_OPENCL_TRI_INVERSE_HPP
33
44#ifdef STAN_OPENCL
55#include < stan/math/opencl/matrix_cl.hpp>
6+ #include < stan/math/opencl/constants.hpp>
67#include < stan/math/opencl/kernels/diag_inv.hpp>
78#include < stan/math/opencl/kernels/inv_lower_tri_multiply.hpp>
89#include < stan/math/opencl/kernels/neg_rect_lower_tri_multiply.hpp>
910#include < stan/math/opencl/err/check_opencl.hpp>
10-
11+ # include < stan/math/opencl/transpose.hpp >
1112#include < stan/math/opencl/identity.hpp>
1213#include < stan/math/opencl/err/check_square.hpp>
1314#include < stan/math/opencl/sub_block.hpp>
1920namespace stan {
2021namespace math {
2122/* *
22- * Computes the inverse of the lower triangular matrix
23+ * Computes the inverse of a triangular matrix
2324 *
2425 * For a full guide to how this works and fits into Cholesky decompositions,
2526 * see the reference report
2627 * <a href="https://github.com/SteveBronder/stancon2018/blob/master/report.pdf">
2728 * here</a> and kernel doc
2829 * <a href="https://github.com/stan-dev/math/wiki/GPU-Kernels">here</a>.
2930 *
31+ * @tparam triangular_view the triangularity of the input matrix
3032 * @param A matrix on the OpenCL device
3133 * @return the inverse of A
3234 *
3335 * @throw <code>std::invalid_argument</code> if the matrix
3436 * is not square
3537 */
36- inline matrix_cl lower_triangular_inverse (const matrix_cl& A) {
37- check_square (" lower_triangular_inverse (OpenCL)" , " A" , A);
38+ template <TriangularViewCL triangular_view>
39+ inline matrix_cl tri_inverse (const matrix_cl& A) {
40+ static_assert (triangular_view != TriangularViewCL::Entire,
41+ " tri_inverse(OpenCL) only supports triangular input matrices" );
42+ check_square (" tri_inverse (OpenCL)" , " A" , A);
3843
3944 int thread_block_2D_dim = 32 ;
4045 int max_1D_thread_block_size = opencl_context.max_thread_block_size ();
@@ -69,7 +74,9 @@ inline matrix_cl lower_triangular_inverse(const matrix_cl& A) {
6974 zero_mat.zeros <stan::math::TriangularViewCL::Entire>();
7075 temp.zeros <stan::math::TriangularViewCL::Entire>();
7176 inv_padded.zeros <stan::math::TriangularViewCL::Entire>();
72-
77+ if (triangular_view == TriangularViewCL::Upper) {
78+ inv_mat = transpose (inv_mat);
79+ }
7380 int work_per_thread
7481 = opencl_kernels::inv_lower_tri_multiply.make_functor .get_opts ().at (
7582 " WORK_PER_THREAD" );
@@ -95,6 +102,9 @@ inline matrix_cl lower_triangular_inverse(const matrix_cl& A) {
95102 inv_padded.zeros <stan::math::TriangularViewCL::Upper>();
96103 if (parts == 1 ) {
97104 inv_mat.sub_block (inv_padded, 0 , 0 , 0 , 0 , inv_mat.rows (), inv_mat.rows ());
105+ if (triangular_view == TriangularViewCL::Upper) {
106+ inv_mat = transpose (inv_mat);
107+ }
98108 return inv_mat;
99109 }
100110 parts = ceil (parts / 2.0 );
@@ -132,7 +142,10 @@ inline matrix_cl lower_triangular_inverse(const matrix_cl& A) {
132142 inv_padded.zeros <stan::math::TriangularViewCL::Upper>();
133143 }
134144 // un-pad and return
135- inv_mat.sub_block (inv_padded, 0 , 0 , 0 , 0 , A.rows (), A.rows ());
145+ inv_mat.sub_block (inv_padded, 0 , 0 , 0 , 0 , inv_mat.rows (), inv_mat.rows ());
146+ if (triangular_view == TriangularViewCL::Upper) {
147+ inv_mat = transpose (inv_mat);
148+ }
136149 return inv_mat;
137150}
138151} // namespace math
0 commit comments