diff --git a/docs/source/contributing/developer_guide_implementing_distribution.md b/docs/source/contributing/developer_guide_implementing_distribution.md index 85ceee6f90..c615fb355e 100644 --- a/docs/source/contributing/developer_guide_implementing_distribution.md +++ b/docs/source/contributing/developer_guide_implementing_distribution.md @@ -4,7 +4,8 @@ This guide provides an overview on how to implement a distribution for version 4 It is designed for developers who wish to add a new distribution to the library. Users will not be aware of all this complexity and should instead make use of helper methods such as (TODO). -PyMC {class}`~pymc.distributions.Distribution` build on top of Aesara's {class}`~aesara.tensor.random.op.RandomVariable`, and implement `logp` and `logcdf` methods as well as other initialization and validation helpers, most notably `shape/dims`, alternative parametrizations, and default `transforms`. +PyMC {class}`~pymc.distributions.Distribution` builds on top of Aesara's {class}`~aesara.tensor.random.op.RandomVariable`, and implements `logp`, `logcdf` and `get_moment` methods as well as other initialization and validation helpers. +Most notably `shape/dims` kwargs, alternative parametrizations, and default `transforms`. Here is a summary check-list of the steps needed to implement a new distribution. Each section will be expanded below: @@ -12,7 +13,7 @@ Each section will be expanded below: 1. Creating a new `RandomVariable` `Op` 1. Implementing the corresponding `Distribution` class 1. Adding tests for the new `RandomVariable` -1. Adding tests for the `logp` / `logcdf` methods +1. Adding tests for `logp` / `logcdf` and `get_moment` methods 1. Documenting the new `Distribution`. This guide does not attempt to explain the rationale behind the `Distributions` current implementation, and details are provided only insofar as they help to implement new "standard" distributions. @@ -118,7 +119,7 @@ After implementing the new `RandomVariable` `Op`, it's time to make use of it in PyMC 4.x works in a very {term}`functional ` way, and the `distribution` classes are there mostly to facilitate porting the `PyMC3` v3.x code to the new `PyMC` v4.x version, add PyMC API features and keep related methods organized together. In practice, they take care of: -1. Linking ({term}`Dispatching`) a rv_op class with the corresponding logp and logcdf methods. +1. Linking ({term}`Dispatching`) a rv_op class with the corresponding `get_moment`, `logp` and `logcdf` methods. 1. Defining a standard transformation (for continuous distributions) that converts a bounded variable domain (e.g., positive line) to an unbounded domain (i.e., the real line), which many samplers prefer. 1. Validating the parametrization of a distribution and converting non-symbolic inputs (i.e., numeric literals or numpy arrays) to symbolic variables. 1. Converting multiple alternative parametrizations to the standard parametrization that the `RandomVariable` is defined in terms of. @@ -153,6 +154,14 @@ class Blah(PositiveContinuous): # the rv_op needs in order to be instantiated return super().dist([param1, param2], **kwargs) + # get_moment returns a symbolic expression for the stable moment from which to start sampling + # the variable, given the implicit `rv`, `size` and `param1` ... `paramN` + def get_moment(rv, size, param1, param2): + moment, _ = at.broadcast_arrays(param1, param2) + if not rv_size_is_none(size): + moment = at.full(size, moment) + return moment + # Logp returns a symbolic expression for the logp evaluation of the variable # given the `value` of the variable and the parameters `param1` ... `paramN` def logp(value, param1, param2): @@ -188,27 +197,34 @@ Some notes: overriding `__new__`. 1. As mentioned above, `PyMC` v4.x works in a very {term}`functional ` way, and all the information that is needed in the `logp` and `logcdf` methods is expected to be "carried" via the `RandomVariable` inputs. You may pass numerical arguments that are not strictly needed for the `rng_fn` method but are used in the `logp` and `logcdf` methods. Just keep in mind whether this affects the correct shape inference behavior of the `RandomVariable`. If specialized non-numeric information is needed you might need to define your custom`_logp` and `_logcdf` {term}`Dispatching` functions, but this should be done as a last resort. 1. The `logcdf` method is not a requirement, but it's a nice plus! +1. Currently only one moment is supported in the `get_moment` method, and probably the "higher-order" one is the most useful (that is `mean` > `median` > `mode`)... You might need to truncate the moment if you are dealing with a discrete distribution. +1. When creating the `get_moment` method, we have to be careful with `size != None` and broadcast properly when some parameters that are not used in the moment may nevertheless inform about the shape of the distribution. E.g. `pm.Normal.dist(mu=0, sigma=np.arange(1, 6))` returns a moment of `[mu, mu, mu, mu, mu]`. For a quick check that things are working you can try the following: ```python import pymc as pm +from pymc.distributions.distribution import get_moment -# pm.blah = pm.Uniform in this example -blah = pm.Blah.dist([0, 0], [1, 2]) +# pm.blah = pm.Normal in this example +blah = pm.blah.dist(mu = 0, sigma = 1) # Test that the returned blah_op is still working fine blah.eval() -# array([0.62778803, 1.95165513]) +# array(-1.01397228) + +# Test the get_moment method +get_moment(blah).eval() +# array(0.) -# Test the logp -pm.logp(blah, [1.5, 1.5]).eval() -# array([ -inf, -0.69314718]) +# Test the logp method +pm.logp(blah, [-0.5, 1.5]).eval() +# array([-1.04393853, -2.04393853]) -# Test the logcdf -pm.logcdf(blah, [1.5, 1.5]).eval() -# array([ 0. , -0.28768207]) +# Test the logcdf method +pm.logcdf(blah, [-0.5, 1.5]).eval() +# array([-1.17591177, -0.06914345]) ``` ## 3. Adding tests for the new `RandomVariable` @@ -350,8 +366,40 @@ def test_blah_logcdf(self): ``` +## 5. Adding tests for the `get_moment` method + +Tests for the `get_moment` method are contained in `pymc/tests/test_distributions_moments.py`, and make use of the function `assert_moment_is_expected` +which checks if: +1. Moments return the `expected` values +1. Moments have the expected size and shape + +```python + +import pytest +from pymc.distributions import Blah + +@pytest.mark.parametrize( + "param1, param2, size, expected", + [ + (0, 1, None, 0), + (0, np.ones(5), None, np.zeros(5)), + (np.arange(5), 1, None, np.arange(5)), + (np.arange(5), np.arange(1, 6), (2, 5), np.full((2, 5), np.arange(5))), + ], +) +def test_blah_moment(param1, param2, size, expected): + with Model() as model: + Blah("x", param1=param1, param2=param2, size=size) + assert_moment_is_expected(model, expected) + +``` + +Here are some details worth keeping in mind: + +1. In the case where you have to manually broadcast the parameters with each other it's important to add test conditions that would fail if you were not to do that. A straightforward way to do this is to make the used parameter a scalar, the unused one(s) a vector (one at a time) and size `None`. +1. In other words, make sure to test different combinations of size and broadcasting to cover these cases. -## 5. Documenting the new `Distribution` +## 6. Documenting the new `Distribution` New distributions should have a rich docstring, following the same format as that of previously implemented distributions. It generally looks something like this: