Skip to content

Commit 6b5f5af

Browse files
Added SBW docstring
1 parent 5a989e6 commit 6b5f5af

File tree

2 files changed

+18
-17
lines changed

2 files changed

+18
-17
lines changed

pymc/distributions/multivariate.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2188,10 +2188,8 @@ def make_node(self, rng, size, dtype, alpha, K):
21882188
alpha = at.as_tensor_variable(alpha)
21892189
K = at.as_tensor_variable(intX(K))
21902190

2191-
# if at.lt(K, 0):
2192-
# print(at.lt(K, 0).eval())
2193-
# print(K.eval() < 0)
2194-
# raise ValueError("K needs to be positive.")
2191+
if K.eval() < 0:
2192+
raise ValueError("K needs to be positive.")
21952193

21962194
if alpha.ndim > 0:
21972195
raise ValueError("The concentration parameter needs to be a scalar.")
@@ -2238,34 +2236,37 @@ def rng_fn(cls, rng, alpha, K, size):
22382236

22392237
class StickBreakingWeights(Continuous):
22402238
r"""
2241-
Likelihood of truncated stick-breaking weights. The weights are generated
2239+
Likelihood of truncated stick-breaking weights. The weights are generated from a
2240+
stick-breaking proceduce where :math:`x_k = v_k \prod_{\ell < k} (1 - v_\ell)` for
2241+
:math:`k \in \{1, \ldots, K\}` and :math
22422242
22432243
.. math:
22442244
2245-
2246-
f(\mathbf{x}|\alpha, K) =
2245+
f(\mathbf{x}|\alpha, K) =
2246+
B(1, \alpha)^{-K}x_{K+1}^\alpha \prod_{k=1}^{K+1}\left\{\sum_{j=k}^{K+1} x_j\right\}^{-1}
22472247
22482248
======== ===============================================
2249-
Support :math:`x_i \in (0, 1)` for :math:`i \in \{1, \ldots, K+1\}`
2250-
such that :math:`\sum x_i = 1`
2251-
Mean :math:`\dfrac{a_i}{\sum a_i}`
2252-
Variance :math:`\dfrac{a_i - \sum a_0}{a_0^2 (a_0 + 1)}`
2253-
where :math:`a_0 = \sum a_i`
2249+
Support :math:`x_k \in (0, 1)` for :math:`k \in \{1, \ldots, K+1\}`
2250+
such that :math:`\sum x_k = 1`
2251+
Mean :math:`\mathbb{E}[x_k] = \dfrac{1}{1 + \alpha}\left(\dfrac{\alpha}{1 + \alpha}\right)^{k - 1}`
2252+
for :math:`k \in \{1, \ldots, K\}` and `\mathbb{E}[x_{K+1}] = \left(\dfrac{\alpha}{1 + \alpha}\right)^{K}`
22542253
======== ===============================================
22552254
22562255
Parameters
22572256
----------
22582257
alpha: float
2259-
Concentration parameters (alpha > 0).
2258+
Concentration parameter (alpha > 0).
22602259
K: int
22612260
The number of "sticks" to break off from an initial one-unit stick. The length
22622261
of categories is K + 1, where the last weight is one minus the sum of all the first sticks.
22632262
22642263
References
22652264
----------
2266-
.. [1] Ishwaran James
2265+
.. [1] Ishwaran, H., & James, L. F. (2001). Gibbs sampling methods for stick-breaking priors.
2266+
Journal of the American Statistical Association, 96(453), 161-173.
22672267
2268-
.. [2] Peter Mueller
2268+
.. [2] Müller, P., Quintana, F. A., Jara, A., & Hanson, T. (2015). Bayesian nonparametric data
2269+
analysis. New York: Springer.
22692270
"""
22702271
rv_op = stickbreakingweights
22712272

@@ -2329,7 +2330,7 @@ def logp(value, alpha, K):
23292330
),
23302331
axis=-1,
23312332
)
2332-
logp += -(K - 1) * betaln(1, alpha)
2333+
logp += -K * betaln(1, alpha)
23332334
logp += alpha * at.log(value[..., -1])
23342335

23352336
logp = at.switch(

pymc/tests/test_distributions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2138,7 +2138,7 @@ def test_stickbreakingweights(self, value, alpha, K, logp):
21382138
StickBreakingWeights("sbw", alpha=alpha, K=K, transform=None)
21392139
pt = {"sbw": value}
21402140
assert_almost_equal(
2141-
model.fastlogp(pt),
2141+
model.compile_logp()(pt),
21422142
logp,
21432143
decimal=select_by_precision(float64=6, float32=2),
21442144
err_msg=str(pt),

0 commit comments

Comments
 (0)