Skip to content

Commit 97d6c89

Browse files
committed
Merge branch 'develop' into feature/faster-ad-tls-v6
2 parents d013e55 + 22072c6 commit 97d6c89

34 files changed

+2685
-116
lines changed

stan/math/fwd/mat.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <stan/math/fwd/core.hpp>
55
#include <stan/math/fwd/scal/meta/is_fvar.hpp>
66
#include <stan/math/fwd/scal/meta/partials_type.hpp>
7+
#include <stan/math/fwd/mat/meta/operands_and_partials.hpp>
78

89
#include <stan/math/fwd/mat/vectorize/apply_scalar_unary.hpp>
910
#include <stan/math/prim/mat.hpp>
@@ -47,6 +48,4 @@
4748
#include <stan/math/fwd/mat/functor/hessian.hpp>
4849
#include <stan/math/fwd/mat/functor/jacobian.hpp>
4950

50-
#include <stan/math/fwd/mat/meta/operands_and_partials.hpp>
51-
5251
#endif

stan/math/fwd/mat/meta/operands_and_partials.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
#ifndef STAN_MATH_FWD_MAT_META_OPERANDS_AND_PARTIALS_HPP
22
#define STAN_MATH_FWD_MAT_META_OPERANDS_AND_PARTIALS_HPP
33

4-
#include <stan/math/fwd/scal/meta/operands_and_partials.hpp>
5-
#include <stan/math/prim/scal/meta/broadcast_array.hpp>
64
#include <stan/math/prim/mat/fun/Eigen.hpp>
5+
#include <stan/math/prim/arr/meta/length.hpp>
6+
#include <stan/math/prim/scal/meta/broadcast_array.hpp>
7+
#include <stan/math/fwd/scal/meta/operands_and_partials.hpp>
78
#include <vector>
89

910
namespace stan {

stan/math/opencl/cholesky_decompose.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#include <stan/math/opencl/kernels/cholesky_decompose.hpp>
77
#include <stan/math/opencl/multiply.hpp>
88
#include <stan/math/opencl/multiply_transpose.hpp>
9-
#include <stan/math/opencl/lower_tri_inverse.hpp>
9+
#include <stan/math/opencl/tri_inverse.hpp>
1010
#include <stan/math/opencl/transpose.hpp>
1111
#include <stan/math/opencl/subtract.hpp>
1212
#include <stan/math/opencl/err/check_diagonal_zeros.hpp>
@@ -78,7 +78,7 @@ inline void cholesky_decompose(matrix_cl& A) {
7878
// and copies the resulting submatrix to the lower left hand corner of A
7979
matrix_cl L_21
8080
= opencl::multiply<TriangularViewCL::Entire, TriangularViewCL::Upper>(
81-
A_21, transpose(lower_triangular_inverse(A_11)));
81+
A_21, transpose(tri_inverse<TriangularViewCL::Lower>(A_11)));
8282
A.sub_block(L_21, 0, 0, block, 0, block_subset, block);
8383
matrix_cl A_22(block_subset, block_subset);
8484
A_22.sub_block(A, block, block, 0, 0, block_subset, block_subset);

stan/math/opencl/kernels/diag_inv.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ static const char* diag_inv_kernel_code = STRINGIFY(
3636
* @param rows The number of rows for A.
3737
* @note Code is a <code>const char*</code> held in
3838
* <code>diag_inv_kernel_code.</code>
39-
* Used in math/opencl/lower_tri_inverse.hpp.
39+
* Used in math/opencl/tri_inverse.hpp.
4040
* This kernel uses the helper macros available in helpers.cl.
4141
*/
4242
__kernel void diag_inv(__global double* A, __global double* tmp_inv,

stan/math/opencl/kernels/inv_lower_tri_multiply.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ static const char* inv_lower_tri_multiply_kernel_code = STRINGIFY(
3939
* @param rows The number of rows in a single matrix of the batch
4040
* @note Code is a <code>const char*</code> held in
4141
* <code>inv_lower_tri_multiply_kernel_code.</code>
42-
* Used in math/opencl/lower_tri_inverse.hpp.
42+
* Used in math/opencl/tri_inverse.hpp.
4343
* This kernel uses the helper macros available in helpers.cl.
4444
*/
4545
__kernel void inv_lower_tri_multiply(__global double* A,

stan/math/opencl/kernels/neg_rect_lower_tri_multiply.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ static const char* neg_rect_lower_tri_multiply_kernel_code = STRINGIFY(
3333
* @param rows The number of rows in a single matrix of the batch
3434
* @note Code is a <code>const char*</code> held in
3535
* neg_rect_lower_tri_multiply_kernel_code
36-
* Used in math/opencl/lower_tri_inverse.hpp.
36+
* Used in math/opencl/tri_inverse.hpp.
3737
* This kernel uses the helper macros available in helpers.cl.
3838
*/
3939
__kernel void neg_rect_lower_tri_multiply(

stan/math/opencl/opencl.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#include <stan/math/opencl/cholesky_decompose.hpp>
1010
#include <stan/math/opencl/diagonal_multiply.hpp>
1111
#include <stan/math/opencl/identity.hpp>
12-
#include <stan/math/opencl/lower_tri_inverse.hpp>
12+
#include <stan/math/opencl/tri_inverse.hpp>
1313
#include <stan/math/opencl/matrix_cl.hpp>
1414
#include <stan/math/opencl/multiply.hpp>
1515
#include <stan/math/opencl/multiply_transpose.hpp>

stan/math/opencl/opencl_context.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,9 @@ class opencl_context_base {
195195
int cholesky_rev_block_partition = 8;
196196
// used in math/opencl/multiply
197197
int multiply_split_upper_limit = 2000000;
198+
// used in math/prim/mat/fun/mdivide_left_tri
199+
// and math/rev/mat/fun/mdivide_left_tri
200+
int tri_inverse_size_worth_transfer = 100;
198201
} tuning_opts_;
199202

200203
static opencl_context_base& getInstance() {

stan/math/opencl/lower_tri_inverse.hpp renamed to stan/math/opencl/tri_inverse.hpp

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
#ifndef STAN_MATH_OPENCL_LOWER_TRI_INVERSE_HPP
2-
#define STAN_MATH_OPENCL_LOWER_TRI_INVERSE_HPP
1+
#ifndef STAN_MATH_OPENCL_TRI_INVERSE_HPP
2+
#define STAN_MATH_OPENCL_TRI_INVERSE_HPP
33

44
#ifdef STAN_OPENCL
55
#include <stan/math/opencl/matrix_cl.hpp>
6+
#include <stan/math/opencl/constants.hpp>
67
#include <stan/math/opencl/kernels/diag_inv.hpp>
78
#include <stan/math/opencl/kernels/inv_lower_tri_multiply.hpp>
89
#include <stan/math/opencl/kernels/neg_rect_lower_tri_multiply.hpp>
910
#include <stan/math/opencl/err/check_opencl.hpp>
10-
11+
#include <stan/math/opencl/transpose.hpp>
1112
#include <stan/math/opencl/identity.hpp>
1213
#include <stan/math/opencl/err/check_square.hpp>
1314
#include <stan/math/opencl/sub_block.hpp>
@@ -19,22 +20,26 @@
1920
namespace stan {
2021
namespace math {
2122
/**
22-
* Computes the inverse of the lower triangular matrix
23+
* Computes the inverse of a triangular matrix
2324
*
2425
* For a full guide to how this works and fits into Cholesky decompositions,
2526
* see the reference report
2627
* <a href="https://github.com/SteveBronder/stancon2018/blob/master/report.pdf">
2728
* here</a> and kernel doc
2829
* <a href="https://github.com/stan-dev/math/wiki/GPU-Kernels">here</a>.
2930
*
31+
* @tparam triangular_view the triangularity of the input matrix
3032
* @param A matrix on the OpenCL device
3133
* @return the inverse of A
3234
*
3335
* @throw <code>std::invalid_argument</code> if the matrix
3436
* is not square
3537
*/
36-
inline matrix_cl lower_triangular_inverse(const matrix_cl& A) {
37-
check_square("lower_triangular_inverse (OpenCL)", "A", A);
38+
template <TriangularViewCL triangular_view>
39+
inline matrix_cl tri_inverse(const matrix_cl& A) {
40+
static_assert(triangular_view != TriangularViewCL::Entire,
41+
"tri_inverse(OpenCL) only supports triangular input matrices");
42+
check_square("tri_inverse (OpenCL)", "A", A);
3843

3944
int thread_block_2D_dim = 32;
4045
int max_1D_thread_block_size = opencl_context.max_thread_block_size();
@@ -69,7 +74,9 @@ inline matrix_cl lower_triangular_inverse(const matrix_cl& A) {
6974
zero_mat.zeros<stan::math::TriangularViewCL::Entire>();
7075
temp.zeros<stan::math::TriangularViewCL::Entire>();
7176
inv_padded.zeros<stan::math::TriangularViewCL::Entire>();
72-
77+
if (triangular_view == TriangularViewCL::Upper) {
78+
inv_mat = transpose(inv_mat);
79+
}
7380
int work_per_thread
7481
= opencl_kernels::inv_lower_tri_multiply.make_functor.get_opts().at(
7582
"WORK_PER_THREAD");
@@ -95,6 +102,9 @@ inline matrix_cl lower_triangular_inverse(const matrix_cl& A) {
95102
inv_padded.zeros<stan::math::TriangularViewCL::Upper>();
96103
if (parts == 1) {
97104
inv_mat.sub_block(inv_padded, 0, 0, 0, 0, inv_mat.rows(), inv_mat.rows());
105+
if (triangular_view == TriangularViewCL::Upper) {
106+
inv_mat = transpose(inv_mat);
107+
}
98108
return inv_mat;
99109
}
100110
parts = ceil(parts / 2.0);
@@ -132,7 +142,10 @@ inline matrix_cl lower_triangular_inverse(const matrix_cl& A) {
132142
inv_padded.zeros<stan::math::TriangularViewCL::Upper>();
133143
}
134144
// un-pad and return
135-
inv_mat.sub_block(inv_padded, 0, 0, 0, 0, A.rows(), A.rows());
145+
inv_mat.sub_block(inv_padded, 0, 0, 0, 0, inv_mat.rows(), inv_mat.rows());
146+
if (triangular_view == TriangularViewCL::Upper) {
147+
inv_mat = transpose(inv_mat);
148+
}
136149
return inv_mat;
137150
}
138151
} // namespace math

stan/math/prim/mat.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@
281281
#include <stan/math/prim/mat/functor/map_rect_reduce.hpp>
282282
#include <stan/math/prim/mat/prob/bernoulli_logit_glm_log.hpp>
283283
#include <stan/math/prim/mat/prob/bernoulli_logit_glm_lpmf.hpp>
284+
#include <stan/math/prim/mat/prob/bernoulli_logit_glm_rng.hpp>
284285
#include <stan/math/prim/mat/prob/categorical_log.hpp>
285286
#include <stan/math/prim/mat/prob/categorical_logit_log.hpp>
286287
#include <stan/math/prim/mat/prob/categorical_logit_lpmf.hpp>
@@ -292,6 +293,7 @@
292293
#include <stan/math/prim/mat/prob/dirichlet_rng.hpp>
293294
#include <stan/math/prim/mat/prob/gaussian_dlm_obs_log.hpp>
294295
#include <stan/math/prim/mat/prob/gaussian_dlm_obs_lpdf.hpp>
296+
#include <stan/math/prim/mat/prob/gaussian_dlm_obs_rng.hpp>
295297
#include <stan/math/prim/mat/prob/inv_wishart_log.hpp>
296298
#include <stan/math/prim/mat/prob/inv_wishart_lpdf.hpp>
297299
#include <stan/math/prim/mat/prob/inv_wishart_rng.hpp>
@@ -305,6 +307,7 @@
305307
#include <stan/math/prim/mat/prob/lkj_cov_lpdf.hpp>
306308
#include <stan/math/prim/mat/prob/matrix_normal_prec_log.hpp>
307309
#include <stan/math/prim/mat/prob/matrix_normal_prec_lpdf.hpp>
310+
#include <stan/math/prim/mat/prob/matrix_normal_prec_rng.hpp>
308311
#include <stan/math/prim/mat/prob/multi_gp_cholesky_log.hpp>
309312
#include <stan/math/prim/mat/prob/multi_gp_cholesky_lpdf.hpp>
310313
#include <stan/math/prim/mat/prob/multi_gp_log.hpp>
@@ -316,6 +319,7 @@
316319
#include <stan/math/prim/mat/prob/multi_normal_lpdf.hpp>
317320
#include <stan/math/prim/mat/prob/multi_normal_prec_log.hpp>
318321
#include <stan/math/prim/mat/prob/multi_normal_prec_lpdf.hpp>
322+
#include <stan/math/prim/mat/prob/multi_normal_prec_rng.hpp>
319323
#include <stan/math/prim/mat/prob/multi_normal_rng.hpp>
320324
#include <stan/math/prim/mat/prob/multi_student_t_log.hpp>
321325
#include <stan/math/prim/mat/prob/multi_student_t_lpdf.hpp>

0 commit comments

Comments
 (0)