@@ -2188,10 +2188,8 @@ def make_node(self, rng, size, dtype, alpha, K):
2188
2188
alpha = at .as_tensor_variable (alpha )
2189
2189
K = at .as_tensor_variable (intX (K ))
2190
2190
2191
- # if at.lt(K, 0):
2192
- # print(at.lt(K, 0).eval())
2193
- # print(K.eval() < 0)
2194
- # raise ValueError("K needs to be positive.")
2191
+ if K .eval () < 0 :
2192
+ raise ValueError ("K needs to be positive." )
2195
2193
2196
2194
if alpha .ndim > 0 :
2197
2195
raise ValueError ("The concentration parameter needs to be a scalar." )
@@ -2238,34 +2236,37 @@ def rng_fn(cls, rng, alpha, K, size):
2238
2236
2239
2237
class StickBreakingWeights (Continuous ):
2240
2238
r"""
2241
- Likelihood of truncated stick-breaking weights. The weights are generated
2239
+ Likelihood of truncated stick-breaking weights. The weights are generated from a
2240
+ stick-breaking proceduce where :math:`x_k = v_k \prod_{\ell < k} (1 - v_\ell)` for
2241
+ :math:`k \in \{1, \ldots, K\}` and :math
2242
2242
2243
2243
.. math:
2244
2244
2245
-
2246
- f(\mathbf{x}| \alpha, K) =
2245
+ f(\mathbf{x}|\alpha, K) =
2246
+ B(1, \alpha)^{-K}x_{K+1}^ \alpha \prod_{k=1}^{K+1}\left\{\sum_{j=k}^{K+1} x_j\right\}^{-1}
2247
2247
2248
2248
======== ===============================================
2249
- Support :math:`x_i \in (0, 1)` for :math:`i \in \{1, \ldots, K+1\}`
2250
- such that :math:`\sum x_i = 1`
2251
- Mean :math:`\dfrac{a_i}{\sum a_i}`
2252
- Variance :math:`\dfrac{a_i - \sum a_0}{a_0^2 (a_0 + 1)}`
2253
- where :math:`a_0 = \sum a_i`
2249
+ Support :math:`x_k \in (0, 1)` for :math:`k \in \{1, \ldots, K+1\}`
2250
+ such that :math:`\sum x_k = 1`
2251
+ Mean :math:`\mathbb{E}[x_k] = \dfrac{1}{1 + \alpha}\left(\dfrac{\alpha}{1 + \alpha}\right)^{k - 1}`
2252
+ for :math:`k \in \{1, \ldots, K\}` and `\mathbb{E}[x_{K+1}] = \left(\dfrac{\alpha}{1 + \alpha}\right)^{K}`
2254
2253
======== ===============================================
2255
2254
2256
2255
Parameters
2257
2256
----------
2258
2257
alpha: float
2259
- Concentration parameters (alpha > 0).
2258
+ Concentration parameter (alpha > 0).
2260
2259
K: int
2261
2260
The number of "sticks" to break off from an initial one-unit stick. The length
2262
2261
of categories is K + 1, where the last weight is one minus the sum of all the first sticks.
2263
2262
2264
2263
References
2265
2264
----------
2266
- .. [1] Ishwaran James
2265
+ .. [1] Ishwaran, H., & James, L. F. (2001). Gibbs sampling methods for stick-breaking priors.
2266
+ Journal of the American Statistical Association, 96(453), 161-173.
2267
2267
2268
- .. [2] Peter Mueller
2268
+ .. [2] Müller, P., Quintana, F. A., Jara, A., & Hanson, T. (2015). Bayesian nonparametric data
2269
+ analysis. New York: Springer.
2269
2270
"""
2270
2271
rv_op = stickbreakingweights
2271
2272
@@ -2329,7 +2330,7 @@ def logp(value, alpha, K):
2329
2330
),
2330
2331
axis = - 1 ,
2331
2332
)
2332
- logp += - ( K - 1 ) * betaln (1 , alpha )
2333
+ logp += - K * betaln (1 , alpha )
2333
2334
logp += alpha * at .log (value [..., - 1 ])
2334
2335
2335
2336
logp = at .switch (
0 commit comments