diff --git a/docs/api_reference.rst b/docs/api_reference.rst index 9a23d7ab9..867a6e0fc 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -29,6 +29,7 @@ Distributions :toctree: generated/ Chi + Maxwell DiscreteMarkovChain GeneralizedPoisson BetaNegativeBinomial diff --git a/pymc_experimental/distributions/__init__.py b/pymc_experimental/distributions/__init__.py index 9bd0e1a23..f06db9699 100644 --- a/pymc_experimental/distributions/__init__.py +++ b/pymc_experimental/distributions/__init__.py @@ -17,7 +17,7 @@ Experimental probability distributions for stochastic nodes in PyMC. """ -from pymc_experimental.distributions.continuous import Chi, GenExtreme +from pymc_experimental.distributions.continuous import Chi, GenExtreme, Maxwell from pymc_experimental.distributions.discrete import ( BetaNegativeBinomial, GeneralizedPoisson, @@ -36,4 +36,5 @@ "Skellam", "histogram_approximation", "Chi", + "Maxwell", ] diff --git a/pymc_experimental/distributions/continuous.py b/pymc_experimental/distributions/continuous.py index 2e957b4f2..dcf9b775f 100644 --- a/pymc_experimental/distributions/continuous.py +++ b/pymc_experimental/distributions/continuous.py @@ -28,6 +28,7 @@ from pymc.distributions.dist_math import check_parameters from pymc.distributions.distribution import Continuous from pymc.distributions.shape_utils import rv_size_is_none +from pymc.logprob.utils import CheckParameterValue from pymc.pytensorf import floatX from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.variable import TensorVariable @@ -280,3 +281,73 @@ def __new__(cls, name, nu, **kwargs): @classmethod def dist(cls, nu, **kwargs): return CustomDist.dist(nu, dist=cls.chi_dist, class_name="Chi", **kwargs) + + +class Maxwell: + R""" + The Maxwell-Boltzmann distribution + + The pdf of this distribution is + + .. math:: + + f(x \mid a) = {\displaystyle {\sqrt {\frac {2}{\pi }}}\,{\frac {x^{2}}{a^{3}}}\,\exp \left({\frac {-x^{2}}{2a^{2}}}\right)} + + Read more about it on `Wikipedia `_ + + .. plot:: + :context: close-figs + + import matplotlib.pyplot as plt + import numpy as np + import scipy.stats as st + import arviz as az + plt.style.use('arviz-darkgrid') + x = np.linspace(0, 20, 200) + for a in [1, 2, 5]: + pdf = st.maxwell.pdf(x, scale=a) + plt.plot(x, pdf, label=r'$a$ = {}'.format(a)) + plt.xlabel('x', fontsize=12) + plt.ylabel('f(x)', fontsize=12) + plt.legend(loc=1) + plt.show() + + ======== ========================================================================= + Support :math:`x \in (0, \infty)` + Mean :math:`2a \sqrt{\frac{2}{\pi}}` + Variance :math:`\frac{a^2(3 \pi - 8)}{\pi}` + ======== ========================================================================= + + Parameters + ---------- + a : tensor_like of float + Scale parameter (a > 0). + + """ + + @staticmethod + def maxwell_dist(a: TensorVariable, size: TensorVariable) -> TensorVariable: + if rv_size_is_none(size): + size = a.shape + + a = CheckParameterValue("a > 0")(a, pt.all(pt.gt(a, 0))) + + return Chi.dist(nu=3, size=size) * a + + def __new__(cls, name, a, **kwargs): + return CustomDist( + name, + a, + dist=cls.maxwell_dist, + class_name="Maxwell", + **kwargs, + ) + + @classmethod + def dist(cls, a, **kwargs): + return CustomDist.dist( + a, + dist=cls.maxwell_dist, + class_name="Maxwell", + **kwargs, + ) diff --git a/pymc_experimental/tests/distributions/test_continuous.py b/pymc_experimental/tests/distributions/test_continuous.py index 891e7ab3e..5df4aef1f 100644 --- a/pymc_experimental/tests/distributions/test_continuous.py +++ b/pymc_experimental/tests/distributions/test_continuous.py @@ -33,7 +33,7 @@ ) # the distributions to be tested -from pymc_experimental.distributions import Chi, GenExtreme +from pymc_experimental.distributions import Chi, GenExtreme, Maxwell class TestGenExtremeClass: @@ -159,3 +159,26 @@ def test_logcdf(self): {"nu": Rplus}, lambda value, nu: sp.chi.logcdf(value, df=nu), ) + + +class TestMaxwell: + """ + Wrapper class so that tests of experimental additions can be dropped into + PyMC directly on adoption. + """ + + def test_logp(self): + check_logp( + Maxwell, + Rplus, + {"a": Rplus}, + lambda value, a: sp.maxwell.logpdf(value, scale=a), + ) + + def test_logcdf(self): + check_logcdf( + Maxwell, + Rplus, + {"a": Rplus}, + lambda value, a: sp.maxwell.logcdf(value, scale=a), + )