Skip to content

Commit 22072c6

Browse files
authored
Merge pull request #1222 from bstatcomp/feature/issue-1221-opencl-prim-mdivide-left-tri
Feature/issue 1221 OpenCL implementation of primitive mdivide_left_tri
2 parents 9e0bd89 + ae393cd commit 22072c6

File tree

12 files changed

+361
-86
lines changed

12 files changed

+361
-86
lines changed

stan/math/opencl/cholesky_decompose.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#include <stan/math/opencl/kernels/cholesky_decompose.hpp>
77
#include <stan/math/opencl/multiply.hpp>
88
#include <stan/math/opencl/multiply_transpose.hpp>
9-
#include <stan/math/opencl/lower_tri_inverse.hpp>
9+
#include <stan/math/opencl/tri_inverse.hpp>
1010
#include <stan/math/opencl/transpose.hpp>
1111
#include <stan/math/opencl/subtract.hpp>
1212
#include <stan/math/opencl/err/check_diagonal_zeros.hpp>
@@ -78,7 +78,7 @@ inline void cholesky_decompose(matrix_cl& A) {
7878
// and copies the resulting submatrix to the lower left hand corner of A
7979
matrix_cl L_21
8080
= opencl::multiply<TriangularViewCL::Entire, TriangularViewCL::Upper>(
81-
A_21, transpose(lower_triangular_inverse(A_11)));
81+
A_21, transpose(tri_inverse<TriangularViewCL::Lower>(A_11)));
8282
A.sub_block(L_21, 0, 0, block, 0, block_subset, block);
8383
matrix_cl A_22(block_subset, block_subset);
8484
A_22.sub_block(A, block, block, 0, 0, block_subset, block_subset);

stan/math/opencl/kernels/diag_inv.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ static const char* diag_inv_kernel_code = STRINGIFY(
3636
* @param rows The number of rows for A.
3737
* @note Code is a <code>const char*</code> held in
3838
* <code>diag_inv_kernel_code.</code>
39-
* Used in math/opencl/lower_tri_inverse.hpp.
39+
* Used in math/opencl/tri_inverse.hpp.
4040
* This kernel uses the helper macros available in helpers.cl.
4141
*/
4242
__kernel void diag_inv(__global double* A, __global double* tmp_inv,

stan/math/opencl/kernels/inv_lower_tri_multiply.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ static const char* inv_lower_tri_multiply_kernel_code = STRINGIFY(
3939
* @param rows The number of rows in a single matrix of the batch
4040
* @note Code is a <code>const char*</code> held in
4141
* <code>inv_lower_tri_multiply_kernel_code.</code>
42-
* Used in math/opencl/lower_tri_inverse.hpp.
42+
* Used in math/opencl/tri_inverse.hpp.
4343
* This kernel uses the helper macros available in helpers.cl.
4444
*/
4545
__kernel void inv_lower_tri_multiply(__global double* A,

stan/math/opencl/kernels/neg_rect_lower_tri_multiply.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ static const char* neg_rect_lower_tri_multiply_kernel_code = STRINGIFY(
3333
* @param rows The number of rows in a single matrix of the batch
3434
* @note Code is a <code>const char*</code> held in
3535
* neg_rect_lower_tri_multiply_kernel_code
36-
* Used in math/opencl/lower_tri_inverse.hpp.
36+
* Used in math/opencl/tri_inverse.hpp.
3737
* This kernel uses the helper macros available in helpers.cl.
3838
*/
3939
__kernel void neg_rect_lower_tri_multiply(

stan/math/opencl/opencl.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#include <stan/math/opencl/cholesky_decompose.hpp>
1010
#include <stan/math/opencl/diagonal_multiply.hpp>
1111
#include <stan/math/opencl/identity.hpp>
12-
#include <stan/math/opencl/lower_tri_inverse.hpp>
12+
#include <stan/math/opencl/tri_inverse.hpp>
1313
#include <stan/math/opencl/matrix_cl.hpp>
1414
#include <stan/math/opencl/multiply.hpp>
1515
#include <stan/math/opencl/multiply_transpose.hpp>

stan/math/opencl/opencl_context.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,9 @@ class opencl_context_base {
195195
int cholesky_rev_block_partition = 8;
196196
// used in math/opencl/multiply
197197
int multiply_split_upper_limit = 2000000;
198+
// used in math/prim/mat/fun/mdivide_left_tri
199+
// and math/rev/mat/fun/mdivide_left_tri
200+
int tri_inverse_size_worth_transfer = 100;
198201
} tuning_opts_;
199202

200203
static opencl_context_base& getInstance() {

stan/math/opencl/lower_tri_inverse.hpp renamed to stan/math/opencl/tri_inverse.hpp

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
#ifndef STAN_MATH_OPENCL_LOWER_TRI_INVERSE_HPP
2-
#define STAN_MATH_OPENCL_LOWER_TRI_INVERSE_HPP
1+
#ifndef STAN_MATH_OPENCL_TRI_INVERSE_HPP
2+
#define STAN_MATH_OPENCL_TRI_INVERSE_HPP
33

44
#ifdef STAN_OPENCL
55
#include <stan/math/opencl/matrix_cl.hpp>
6+
#include <stan/math/opencl/constants.hpp>
67
#include <stan/math/opencl/kernels/diag_inv.hpp>
78
#include <stan/math/opencl/kernels/inv_lower_tri_multiply.hpp>
89
#include <stan/math/opencl/kernels/neg_rect_lower_tri_multiply.hpp>
910
#include <stan/math/opencl/err/check_opencl.hpp>
10-
11+
#include <stan/math/opencl/transpose.hpp>
1112
#include <stan/math/opencl/identity.hpp>
1213
#include <stan/math/opencl/err/check_square.hpp>
1314
#include <stan/math/opencl/sub_block.hpp>
@@ -19,22 +20,26 @@
1920
namespace stan {
2021
namespace math {
2122
/**
22-
* Computes the inverse of the lower triangular matrix
23+
* Computes the inverse of a triangular matrix
2324
*
2425
* For a full guide to how this works and fits into Cholesky decompositions,
2526
* see the reference report
2627
* <a href="https://github.com/SteveBronder/stancon2018/blob/master/report.pdf">
2728
* here</a> and kernel doc
2829
* <a href="https://github.com/stan-dev/math/wiki/GPU-Kernels">here</a>.
2930
*
31+
* @tparam triangular_view the triangularity of the input matrix
3032
* @param A matrix on the OpenCL device
3133
* @return the inverse of A
3234
*
3335
* @throw <code>std::invalid_argument</code> if the matrix
3436
* is not square
3537
*/
36-
inline matrix_cl lower_triangular_inverse(const matrix_cl& A) {
37-
check_square("lower_triangular_inverse (OpenCL)", "A", A);
38+
template <TriangularViewCL triangular_view>
39+
inline matrix_cl tri_inverse(const matrix_cl& A) {
40+
static_assert(triangular_view != TriangularViewCL::Entire,
41+
"tri_inverse(OpenCL) only supports triangular input matrices");
42+
check_square("tri_inverse (OpenCL)", "A", A);
3843

3944
int thread_block_2D_dim = 32;
4045
int max_1D_thread_block_size = opencl_context.max_thread_block_size();
@@ -69,7 +74,9 @@ inline matrix_cl lower_triangular_inverse(const matrix_cl& A) {
6974
zero_mat.zeros<stan::math::TriangularViewCL::Entire>();
7075
temp.zeros<stan::math::TriangularViewCL::Entire>();
7176
inv_padded.zeros<stan::math::TriangularViewCL::Entire>();
72-
77+
if (triangular_view == TriangularViewCL::Upper) {
78+
inv_mat = transpose(inv_mat);
79+
}
7380
int work_per_thread
7481
= opencl_kernels::inv_lower_tri_multiply.make_functor.get_opts().at(
7582
"WORK_PER_THREAD");
@@ -95,6 +102,9 @@ inline matrix_cl lower_triangular_inverse(const matrix_cl& A) {
95102
inv_padded.zeros<stan::math::TriangularViewCL::Upper>();
96103
if (parts == 1) {
97104
inv_mat.sub_block(inv_padded, 0, 0, 0, 0, inv_mat.rows(), inv_mat.rows());
105+
if (triangular_view == TriangularViewCL::Upper) {
106+
inv_mat = transpose(inv_mat);
107+
}
98108
return inv_mat;
99109
}
100110
parts = ceil(parts / 2.0);
@@ -132,7 +142,10 @@ inline matrix_cl lower_triangular_inverse(const matrix_cl& A) {
132142
inv_padded.zeros<stan::math::TriangularViewCL::Upper>();
133143
}
134144
// un-pad and return
135-
inv_mat.sub_block(inv_padded, 0, 0, 0, 0, A.rows(), A.rows());
145+
inv_mat.sub_block(inv_padded, 0, 0, 0, 0, inv_mat.rows(), inv_mat.rows());
146+
if (triangular_view == TriangularViewCL::Upper) {
147+
inv_mat = transpose(inv_mat);
148+
}
136149
return inv_mat;
137150
}
138151
} // namespace math

stan/math/prim/mat/fun/mdivide_left_tri.hpp

Lines changed: 100 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,27 @@
66
#include <stan/math/prim/mat/fun/promote_common.hpp>
77
#include <stan/math/prim/mat/err/check_multiplicable.hpp>
88
#include <stan/math/prim/mat/err/check_square.hpp>
9-
9+
#ifdef STAN_OPENCL
10+
#include <stan/math/opencl/opencl_context.hpp>
11+
#include <stan/math/opencl/multiply.hpp>
12+
#include <stan/math/opencl/tri_inverse.hpp>
13+
#include <stan/math/opencl/transpose.hpp>
14+
#include <stan/math/opencl/copy.hpp>
15+
#endif
1016
namespace stan {
1117
namespace math {
1218

1319
/**
14-
* Returns the solution of the system Ax=b when A is triangular
15-
* @param A Triangular matrix. Specify upper or lower with TriView
16-
* being Eigen::Upper or Eigen::Lower.
20+
* Returns the solution of the system Ax=b when A is triangular.
21+
* @tparam TriView Specifies whether A is upper (Eigen::Upper)
22+
* or lower triangular (Eigen::Lower).
23+
* @tparam T1 type of elements in A
24+
* @tparam T2 type of elements in b
25+
* @tparam R1 number of rows in A
26+
* @tparam C1 number of columns in A
27+
* @tparam R2 number of rows in b
28+
* @tparam C2 number of columns in b
29+
* @param A Triangular matrix.
1730
* @param b Right hand side matrix or vector.
1831
* @return x = A^-1 b, solution of the linear system.
1932
* @throws std::domain_error if A is not square or the rows of b don't
@@ -36,8 +49,10 @@ mdivide_left_tri(const Eigen::Matrix<T1, R1, C1> &A,
3649

3750
/**
3851
* Returns the solution of the system Ax=b when A is triangular and b=I.
39-
* @param A Triangular matrix. Specify upper or lower with TriView
40-
* being Eigen::Upper or Eigen::Lower.
52+
* @tparam T type of elements in A
53+
* @tparam R1 number of rows in A
54+
* @tparam C1 number of columns in A
55+
* @param A Triangular matrix.
4156
* @return x = A^-1 .
4257
* @throws std::domain_error if A is not square
4358
*/
@@ -52,6 +67,85 @@ inline Eigen::Matrix<T, R1, C1> mdivide_left_tri(
5267
return b;
5368
}
5469

70+
/**
71+
* Returns the solution of the system Ax=b when A is triangular
72+
* and A and b are matrices of doubles.
73+
* @tparam TriView Specifies whether A is upper (Eigen::Upper)
74+
* or lower triangular (Eigen::Lower).
75+
* @tparam R1 number of rows in A
76+
* @tparam C1 number of columns in A
77+
* @tparam R2 number of rows in b
78+
* @tparam C2 number of columns in b
79+
* @param A Triangular matrix.
80+
* @param b Right hand side matrix or vector.
81+
* @return x = A^-1 b, solution of the linear system.
82+
* @throws std::domain_error if A is not square or the rows of b don't
83+
* match the size of A.
84+
*/
85+
template <int TriView, int R1, int C1, int R2, int C2>
86+
inline Eigen::Matrix<double, R1, C2> mdivide_left_tri(
87+
const Eigen::Matrix<double, R1, C1> &A,
88+
const Eigen::Matrix<double, R2, C2> &b) {
89+
check_square("mdivide_left_tri", "A", A);
90+
check_multiplicable("mdivide_left_tri", "A", A, "b", b);
91+
#ifdef STAN_OPENCL
92+
if (A.rows()
93+
>= opencl_context.tuning_opts().tri_inverse_size_worth_transfer) {
94+
matrix_cl A_cl(A);
95+
matrix_cl b_cl(b);
96+
matrix_cl A_inv_cl(A.rows(), A.cols());
97+
if (TriView == Eigen::Lower) {
98+
A_inv_cl = tri_inverse<TriangularViewCL::Lower>(A_cl);
99+
} else {
100+
A_inv_cl = tri_inverse<TriangularViewCL::Upper>(A_cl);
101+
}
102+
matrix_cl C_cl = A_inv_cl * b_cl;
103+
return from_matrix_cl(C_cl);
104+
} else {
105+
#endif
106+
return A.template triangularView<TriView>().solve(b);
107+
#ifdef STAN_OPENCL
108+
}
109+
#endif
110+
}
111+
112+
/**
113+
* Returns the solution of the system Ax=b when A is triangular, b=I and
114+
* both are matrices of doubles.
115+
* @tparam TriView Specifies whether A is upper (Eigen::Upper)
116+
* or lower triangular (Eigen::Lower).
117+
* @tparam R1 number of rows in A
118+
* @tparam C1 number of columns in A
119+
* @param A Triangular matrix.
120+
* @return x = A^-1 .
121+
* @throws std::domain_error if A is not square
122+
*/
123+
template <int TriView, int R1, int C1>
124+
inline Eigen::Matrix<double, R1, C1> mdivide_left_tri(
125+
const Eigen::Matrix<double, R1, C1> &A) {
126+
check_square("mdivide_left_tri", "A", A);
127+
const int n = A.rows();
128+
#ifdef STAN_OPENCL
129+
if (A.rows()
130+
>= opencl_context.tuning_opts().tri_inverse_size_worth_transfer) {
131+
matrix_cl A_cl(A);
132+
if (TriView == Eigen::Lower) {
133+
A_cl = tri_inverse<TriangularViewCL::Lower>(A_cl);
134+
} else {
135+
A_cl = tri_inverse<TriangularViewCL::Upper>(A_cl);
136+
}
137+
return from_matrix_cl(A_cl);
138+
} else {
139+
#endif
140+
Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> b;
141+
b.setIdentity(n, n);
142+
A.template triangularView<TriView>().solveInPlace(b);
143+
return b;
144+
#ifdef STAN_OPENCL
145+
}
146+
#endif
147+
}
148+
55149
} // namespace math
56150
} // namespace stan
57151
#endif

stan/math/rev/mat/fun/cholesky_decompose.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ class cholesky_opencl : public vari {
299299
L_adj = opencl::multiply<TriangularViewCL::Upper, TriangularViewCL::Entire>(
300300
transpose(L), L_adj);
301301
L_adj.triangular_transpose<TriangularMapCL::LowerToUpper>();
302-
L = transpose(lower_triangular_inverse(L));
302+
L = transpose(tri_inverse<TriangularViewCL::Lower>(L));
303303
L_adj = L
304304
* transpose(opencl::multiply<TriangularViewCL::Upper,
305305
TriangularViewCL::Entire>(L, L_adj));
@@ -360,7 +360,7 @@ class cholesky_opencl : public vari {
360360

361361
C_adj
362362
= opencl::multiply<TriangularViewCL::Entire, TriangularViewCL::Lower>(
363-
C_adj, lower_triangular_inverse(D));
363+
C_adj, tri_inverse<TriangularViewCL::Lower>(D));
364364
B_adj = B_adj - C_adj * R;
365365
D_adj = D_adj - transpose(C_adj) * C;
366366

test/unit/math/opencl/lower_tri_inverse_test.cpp

Lines changed: 0 additions & 64 deletions
This file was deleted.

0 commit comments

Comments
 (0)