|
53 | 53 | from pytensor.tensor.random.utils import normalize_size_param |
54 | 54 | from pytensor.tensor.variable import TensorConstant, TensorVariable |
55 | 55 |
|
| 56 | +from pymc.distributions.custom import CustomDist |
56 | 57 | from pymc.logprob.abstract import _logprob_helper |
57 | 58 | from pymc.logprob.basic import TensorLike, icdf |
58 | 59 | from pymc.pytensorf import normalize_rng_param |
@@ -92,7 +93,7 @@ def polyagamma_cdf(*args, **kwargs): |
92 | 93 | from pymc.distributions.distribution import DIST_PARAMETER_TYPES, Continuous, SymbolicRandomVariable |
93 | 94 | from pymc.distributions.shape_utils import implicit_size_from_params, rv_size_is_none |
94 | 95 | from pymc.distributions.transforms import _default_transform |
95 | | -from pymc.math import invlogit, logdiffexp, logit |
| 96 | +from pymc.math import invlogit, logdiffexp |
96 | 97 |
|
97 | 98 | __all__ = [ |
98 | 99 | "AsymmetricLaplace", |
@@ -3603,28 +3604,7 @@ def icdf(value, mu, s): |
3603 | 3604 | ) |
3604 | 3605 |
|
3605 | 3606 |
|
3606 | | -class LogitNormalRV(SymbolicRandomVariable): |
3607 | | - name = "logit_normal" |
3608 | | - extended_signature = "[rng],[size],(),()->[rng],()" |
3609 | | - _print_name = ("LogitNormal", "\\operatorname{LogitNormal}") |
3610 | | - |
3611 | | - @classmethod |
3612 | | - def rv_op(cls, mu, sigma, *, size=None, rng=None): |
3613 | | - mu = pt.as_tensor(mu) |
3614 | | - sigma = pt.as_tensor(sigma) |
3615 | | - rng = normalize_rng_param(rng) |
3616 | | - size = normalize_size_param(size) |
3617 | | - |
3618 | | - next_rng, normal_draws = normal(loc=mu, scale=sigma, size=size, rng=rng).owner.outputs |
3619 | | - draws = pt.expit(normal_draws) |
3620 | | - |
3621 | | - return cls( |
3622 | | - inputs=[rng, size, mu, sigma], |
3623 | | - outputs=[next_rng, draws], |
3624 | | - )(rng, size, mu, sigma) |
3625 | | - |
3626 | | - |
3627 | | -class LogitNormal(UnitContinuous): |
| 3607 | +class LogitNormal: |
3628 | 3608 | r""" |
3629 | 3609 | Logit-Normal distribution. |
3630 | 3610 |
|
@@ -3672,37 +3652,26 @@ class LogitNormal(UnitContinuous): |
3672 | 3652 | Defaults to 1. |
3673 | 3653 | """ |
3674 | 3654 |
|
3675 | | - rv_type = LogitNormalRV |
3676 | | - rv_op = LogitNormalRV.rv_op |
| 3655 | + @staticmethod |
| 3656 | + def logitnormal_dist(mu, sigma, size): |
| 3657 | + return invlogit(Normal.dist(mu=mu, sigma=sigma, size=size)) |
3677 | 3658 |
|
3678 | | - @classmethod |
3679 | | - def dist(cls, mu=0, sigma=None, tau=None, **kwargs): |
| 3659 | + def __new__(cls, name, mu=0, sigma=None, tau=None, **kwargs): |
3680 | 3660 | _, sigma = get_tau_sigma(tau=tau, sigma=sigma) |
3681 | | - return super().dist([mu, sigma], **kwargs) |
3682 | | - |
3683 | | - def support_point(rv, size, mu, sigma): |
3684 | | - median, _ = pt.broadcast_arrays(invlogit(mu), sigma) |
3685 | | - if not rv_size_is_none(size): |
3686 | | - median = pt.full(size, median) |
3687 | | - return median |
3688 | | - |
3689 | | - def logp(value, mu, sigma): |
3690 | | - tau, _ = get_tau_sigma(sigma=sigma) |
3691 | | - |
3692 | | - res = pt.switch( |
3693 | | - pt.or_(pt.le(value, 0), pt.ge(value, 1)), |
3694 | | - -np.inf, |
3695 | | - ( |
3696 | | - -0.5 * tau * (logit(value) - mu) ** 2 |
3697 | | - + 0.5 * pt.log(tau / (2.0 * np.pi)) |
3698 | | - - pt.log(value * (1 - value)) |
3699 | | - ), |
| 3661 | + return CustomDist( |
| 3662 | + name, |
| 3663 | + mu, |
| 3664 | + sigma, |
| 3665 | + dist=cls.logitnormal_dist, |
| 3666 | + class_name="LogitNormal", |
| 3667 | + **kwargs, |
3700 | 3668 | ) |
3701 | 3669 |
|
3702 | | - return check_parameters( |
3703 | | - res, |
3704 | | - tau > 0, |
3705 | | - msg="tau > 0", |
| 3670 | + @classmethod |
| 3671 | + def dist(cls, mu=0, sigma=None, tau=None, **kwargs): |
| 3672 | + _, sigma = get_tau_sigma(tau=tau, sigma=sigma) |
| 3673 | + return CustomDist.dist( |
| 3674 | + mu, sigma, dist=cls.logitnormal_dist, class_name="LogitNormal", **kwargs |
3706 | 3675 | ) |
3707 | 3676 |
|
3708 | 3677 |
|
|
0 commit comments