66#include < stan/math/prim/mat/fun/promote_common.hpp>
77#include < stan/math/prim/mat/err/check_multiplicable.hpp>
88#include < stan/math/prim/mat/err/check_square.hpp>
9-
9+ #ifdef STAN_OPENCL
10+ #include < stan/math/opencl/opencl_context.hpp>
11+ #include < stan/math/opencl/multiply.hpp>
12+ #include < stan/math/opencl/tri_inverse.hpp>
13+ #include < stan/math/opencl/transpose.hpp>
14+ #include < stan/math/opencl/copy.hpp>
15+ #endif
1016namespace stan {
1117namespace math {
1218
1319/* *
14- * Returns the solution of the system Ax=b when A is triangular
15- * @param A Triangular matrix. Specify upper or lower with TriView
16- * being Eigen::Upper or Eigen::Lower.
20+ * Returns the solution of the system Ax=b when A is triangular.
21+ * @tparam TriView Specifies whether A is upper (Eigen::Upper)
22+ * or lower triangular (Eigen::Lower).
23+ * @tparam T1 type of elements in A
24+ * @tparam T2 type of elements in b
25+ * @tparam R1 number of rows in A
26+ * @tparam C1 number of columns in A
27+ * @tparam R2 number of rows in b
28+ * @tparam C2 number of columns in b
29+ * @param A Triangular matrix.
1730 * @param b Right hand side matrix or vector.
1831 * @return x = A^-1 b, solution of the linear system.
1932 * @throws std::domain_error if A is not square or the rows of b don't
@@ -36,8 +49,10 @@ mdivide_left_tri(const Eigen::Matrix<T1, R1, C1> &A,
3649
3750/* *
3851 * Returns the solution of the system Ax=b when A is triangular and b=I.
39- * @param A Triangular matrix. Specify upper or lower with TriView
40- * being Eigen::Upper or Eigen::Lower.
52+ * @tparam T type of elements in A
53+ * @tparam R1 number of rows in A
54+ * @tparam C1 number of columns in A
55+ * @param A Triangular matrix.
4156 * @return x = A^-1 .
4257 * @throws std::domain_error if A is not square
4358 */
@@ -52,6 +67,85 @@ inline Eigen::Matrix<T, R1, C1> mdivide_left_tri(
5267 return b;
5368}
5469
70+ /* *
71+ * Returns the solution of the system Ax=b when A is triangular
72+ * and A and b are matrices of doubles.
73+ * @tparam TriView Specifies whether A is upper (Eigen::Upper)
74+ * or lower triangular (Eigen::Lower).
75+ * @tparam R1 number of rows in A
76+ * @tparam C1 number of columns in A
77+ * @tparam R2 number of rows in b
78+ * @tparam C2 number of columns in b
79+ * @param A Triangular matrix.
80+ * @param b Right hand side matrix or vector.
81+ * @return x = A^-1 b, solution of the linear system.
82+ * @throws std::domain_error if A is not square or the rows of b don't
83+ * match the size of A.
84+ */
85+ template <int TriView, int R1, int C1, int R2, int C2>
86+ inline Eigen::Matrix<double , R1, C2> mdivide_left_tri (
87+ const Eigen::Matrix<double , R1, C1> &A,
88+ const Eigen::Matrix<double , R2, C2> &b) {
89+ check_square (" mdivide_left_tri" , " A" , A);
90+ check_multiplicable (" mdivide_left_tri" , " A" , A, " b" , b);
91+ #ifdef STAN_OPENCL
92+ if (A.rows ()
93+ >= opencl_context.tuning_opts ().tri_inverse_size_worth_transfer ) {
94+ matrix_cl A_cl (A);
95+ matrix_cl b_cl (b);
96+ matrix_cl A_inv_cl (A.rows (), A.cols ());
97+ if (TriView == Eigen::Lower) {
98+ A_inv_cl = tri_inverse<TriangularViewCL::Lower>(A_cl);
99+ } else {
100+ A_inv_cl = tri_inverse<TriangularViewCL::Upper>(A_cl);
101+ }
102+ matrix_cl C_cl = A_inv_cl * b_cl;
103+ return from_matrix_cl (C_cl);
104+ } else {
105+ #endif
106+ return A.template triangularView <TriView>().solve (b);
107+ #ifdef STAN_OPENCL
108+ }
109+ #endif
110+ }
111+
112+ /* *
113+ * Returns the solution of the system Ax=b when A is triangular, b=I and
114+ * both are matrices of doubles.
115+ * @tparam TriView Specifies whether A is upper (Eigen::Upper)
116+ * or lower triangular (Eigen::Lower).
117+ * @tparam R1 number of rows in A
118+ * @tparam C1 number of columns in A
119+ * @param A Triangular matrix.
120+ * @return x = A^-1 .
121+ * @throws std::domain_error if A is not square
122+ */
123+ template <int TriView, int R1, int C1>
124+ inline Eigen::Matrix<double , R1, C1> mdivide_left_tri (
125+ const Eigen::Matrix<double , R1, C1> &A) {
126+ check_square (" mdivide_left_tri" , " A" , A);
127+ const int n = A.rows ();
128+ #ifdef STAN_OPENCL
129+ if (A.rows ()
130+ >= opencl_context.tuning_opts ().tri_inverse_size_worth_transfer ) {
131+ matrix_cl A_cl (A);
132+ if (TriView == Eigen::Lower) {
133+ A_cl = tri_inverse<TriangularViewCL::Lower>(A_cl);
134+ } else {
135+ A_cl = tri_inverse<TriangularViewCL::Upper>(A_cl);
136+ }
137+ return from_matrix_cl (A_cl);
138+ } else {
139+ #endif
140+ Eigen::Matrix<double , Eigen::Dynamic, Eigen::Dynamic> b;
141+ b.setIdentity (n, n);
142+ A.template triangularView <TriView>().solveInPlace (b);
143+ return b;
144+ #ifdef STAN_OPENCL
145+ }
146+ #endif
147+ }
148+
55149} // namespace math
56150} // namespace stan
57151#endif
0 commit comments