Skip to content

Refactor MiniBatch and stop using deprecated MRG sampler #6304

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
392 changes: 59 additions & 333 deletions pymc/data.py

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions pymc/distributions/logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ def _get_scaling(total_size: TOTAL_SIZE, shape, ndim: int) -> TensorVariable:

def _check_no_rvs(logp_terms: Sequence[TensorVariable]):
# Raise if there are unexpected RandomVariables in the logp graph
# Only SimulatorRVs are allowed
# Only SimulatorRVs MinibatchIndexRVs are allowed
from pymc.data import MinibatchIndexRV
from pymc.distributions.simulator import SimulatorRV

unexpected_rv_nodes = [
Expand All @@ -111,7 +112,7 @@ def _check_no_rvs(logp_terms: Sequence[TensorVariable]):
if (
node.owner
and isinstance(node.owner.op, RandomVariable)
and not isinstance(node.owner.op, SimulatorRV)
and not isinstance(node.owner.op, (SimulatorRV, MinibatchIndexRV))
)
]
if unexpected_rv_nodes:
Expand Down
5 changes: 3 additions & 2 deletions pymc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from pytensor.tensor.var import TensorConstant, TensorVariable

from pymc.blocking import DictToArrayBijection, RaveledVars
from pymc.data import GenTensorVariable, Minibatch
from pymc.data import GenTensorVariable, is_minibatch
from pymc.distributions.logprob import _joint_logp
from pymc.distributions.transforms import _default_transform
from pymc.exceptions import (
Expand Down Expand Up @@ -1329,14 +1329,15 @@ def register_rv(
else:
if (
isinstance(observed, Variable)
and not isinstance(observed, (GenTensorVariable, Minibatch))
and not isinstance(observed, GenTensorVariable)
and observed.owner is not None
# The only PyTensor operation we allow on observed data is type casting
# Although we could allow for any graph that does not depend on other RVs
and not (
isinstance(observed.owner.op, Elemwise)
and isinstance(observed.owner.op.scalar_op, Cast)
)
and not is_minibatch(observed)
):
raise TypeError(
"Variables that depend on other nodes cannot be used for observed data."
Expand Down
2 changes: 1 addition & 1 deletion pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@
)
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
from pytensor.sandbox.rng_mrg import MRG_RandomStream as RandomStream
from pytensor.scalar.basic import Cast
from pytensor.tensor.basic import _as_tensor_variable
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random import RandomStream
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.var import (
RandomGeneratorSharedVariable,
Expand Down
5 changes: 4 additions & 1 deletion pymc/tests/distributions/test_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,10 +778,13 @@ def test_moyal_logp(self):
reason="PyMC underflows earlier than scipy on float32",
)
def test_moyal_logcdf(self):
# SciPy has new (?) precision issues at {mu=-2.1, sigma=0.5, x=2.1}
# We circumvent it by skipping sigma=0.5:
rplusbig = Domain([0, 0.9, 0.99, 1, 1.5, 2, 20, np.inf])
check_logcdf(
pm.Moyal,
R,
{"mu": R, "sigma": Rplusbig},
{"mu": R, "sigma": rplusbig},
lambda value, mu, sigma: floatX(st.moyal.logcdf(value, mu, sigma)),
)
if pytensor.config.floatX == "float32":
Expand Down
2 changes: 1 addition & 1 deletion pymc/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from pytensor.gradient import verify_grad as at_verify_grad
from pytensor.graph import ancestors
from pytensor.graph.rewriting.basic import in2out
from pytensor.sandbox.rng_mrg import MRG_RandomStream as RandomStream
from pytensor.tensor.random import RandomStream
from pytensor.tensor.random.op import RandomVariable

import pymc as pm
Expand Down
83 changes: 28 additions & 55 deletions pymc/tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import pymc as pm

from pymc.data import is_minibatch
from pymc.pytensorf import GeneratorOp, floatX
from pymc.tests.helpers import SeededTest, select_by_precision

Expand Down Expand Up @@ -696,15 +697,10 @@ def test_common_errors(self):

def test_mixed1(self):
with pm.Model():
data = np.random.rand(10, 20, 30, 40, 50)
mb = pm.Minibatch(data, [2, None, 20, Ellipsis, 10])
pm.Normal("n", observed=mb, total_size=(10, None, 30, Ellipsis, 50))

def test_mixed2(self):
with pm.Model():
data = np.random.rand(10, 20, 30, 40, 50)
mb = pm.Minibatch(data, [2, None, 20])
pm.Normal("n", observed=mb, total_size=(10, None, 30))
data = np.random.rand(10, 20)
mb = pm.Minibatch(data, batch_size=5)
v = pm.Normal("n", observed=mb, total_size=10)
assert pm.logp(v, 1) is not None, "Check index is allowed in graph"

def test_free_rv(self):
with pm.Model() as model4:
Expand All @@ -719,51 +715,28 @@ def test_free_rv(self):

@pytest.mark.usefixtures("strict_float32")
class TestMinibatch:
data = np.random.rand(30, 10, 40, 10, 50)
data = np.random.rand(30, 10)

def test_1d(self):
mb = pm.Minibatch(self.data, 20)
assert mb.eval().shape == (20, 10, 40, 10, 50)

def test_2d(self):
mb = pm.Minibatch(self.data, [(10, 42), (4, 42)])
assert mb.eval().shape == (10, 4, 40, 10, 50)

@pytest.mark.parametrize(
"batch_size, expected",
[
([(10, 42), None, (4, 42)], (10, 10, 4, 10, 50)),
([(10, 42), Ellipsis, (4, 42)], (10, 10, 40, 10, 4)),
([(10, 42), None, Ellipsis, (4, 42)], (10, 10, 40, 10, 4)),
([10, None, Ellipsis, (4, 42)], (10, 10, 40, 10, 4)),
],
)
def test_special_batch_size(self, batch_size, expected):
mb = pm.Minibatch(self.data, batch_size)
assert mb.eval().shape == expected

def test_cloning_available(self):
gop = pm.Minibatch(np.arange(100), 1)
res = gop**2
shared = pytensor.shared(np.array([10]))
res1 = pytensor.clone_replace(res, {gop: shared})
f = pytensor.function([], res1)
assert f() == np.array([100])

def test_align(self):
m = pm.Minibatch(np.arange(1000), 1, random_seed=1)
n = pm.Minibatch(np.arange(1000), 1, random_seed=1)
f = pytensor.function([], [m, n])
n.eval() # not aligned
a, b = zip(*(f() for _ in range(1000)))
assert a != b
pm.align_minibatches()
a, b = zip(*(f() for _ in range(1000)))
assert a == b
n.eval() # not aligned
pm.align_minibatches([m])
a, b = zip(*(f() for _ in range(1000)))
assert a != b
pm.align_minibatches([m, n])
a, b = zip(*(f() for _ in range(1000)))
assert a == b
mb = pm.Minibatch(self.data, batch_size=20)
assert is_minibatch(mb)
assert mb.eval().shape == (20, 10)

def test_allowed(self):
mb = pm.Minibatch(at.as_tensor(self.data).astype(int), batch_size=20)
assert is_minibatch(mb)

def test_not_allowed(self):
with pytest.raises(ValueError, match="not valid for Minibatch"):
mb = pm.Minibatch(at.as_tensor(self.data) * 2, batch_size=20)

def test_not_allowed2(self):
with pytest.raises(ValueError, match="not valid for Minibatch"):
mb = pm.Minibatch(self.data, at.as_tensor(self.data) * 2, batch_size=20)

def test_assert(self):
with pytest.raises(
AssertionError, match=r"All variables shape\[0\] in Minibatch should be equal"
):
d1, d2 = pm.Minibatch(self.data, self.data[::2], batch_size=20)
d1.eval()
12 changes: 12 additions & 0 deletions pymc/tests/variational/test_approximations.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,15 @@ def test_elbo_beta_kl(aux_total_size):
np.testing.assert_allclose(
elbo_via_total_size_scaled.eval(), elbo_via_beta_kl.eval(), rtol=0, atol=1e-1
)


def test_seeding_advi_fit():
with pm.Model():
x = pm.Normal("x", 0, 10, initval="prior")
approx1 = pm.fit(
random_seed=42, n=10, method="advi", obj_optimizer=pm.adagrad_window, progressbar=False
)
approx2 = pm.fit(
random_seed=42, n=10, method="advi", obj_optimizer=pm.adagrad_window, progressbar=False
)
np.testing.assert_allclose(approx1.mean.eval(), approx2.mean.eval())
32 changes: 20 additions & 12 deletions pymc/tests/variational/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import io
import operator

import cloudpickle
import numpy as np
import pytensor
import pytensor.tensor as at
Expand All @@ -26,6 +27,7 @@

from pymc.pytensorf import intX
from pymc.variational.inference import ADVI, ASVGD, SVGD, FullRankADVI
from pymc.variational.opvi import NotImplementedInference

pytestmark = pytest.mark.usefixtures("strict_float32", "seeded_test")

Expand Down Expand Up @@ -60,7 +62,7 @@ def simple_model_data(use_minibatch):
d = n / sigma**2 + 1 / sigma0**2
mu_post = (n * np.mean(data) / sigma**2 + mu0 / sigma0**2) / d
if use_minibatch:
data = pm.Minibatch(data)
data = pm.Minibatch(data, batch_size=128)
return dict(
n=n,
data=data,
Expand Down Expand Up @@ -118,7 +120,7 @@ def init_(**kw):
@pytest.fixture(scope="function")
def inference(inference_spec, simple_model):
with simple_model:
return inference_spec()
return inference_spec(random_seed=42)


@pytest.fixture(scope="function")
Expand All @@ -129,7 +131,7 @@ def fit_kwargs(inference, use_minibatch):
obj_optimizer=pm.adagrad_window(learning_rate=0.01, n_win=50), n=12000
),
(FullRankADVI, "full"): dict(
obj_optimizer=pm.adagrad_window(learning_rate=0.01, n_win=50), n=6000
obj_optimizer=pm.adagrad_window(learning_rate=0.015, n_win=50), n=6000
),
(FullRankADVI, "mini"): dict(
obj_optimizer=pm.adagrad_window(learning_rate=0.007, n_win=50), n=12000
Expand All @@ -149,6 +151,8 @@ def fit_kwargs(inference, use_minibatch):
inference.approx.scale_cost_to_minibatch = False
else:
key = "full"
if (type(inference), key) in {(SVGD, "mini"), (ASVGD, "mini")}:
pytest.skip("Not Implemented Inference")
return _select[(type(inference), key)]


Expand Down Expand Up @@ -179,7 +183,10 @@ def test_fit_start(inference_spec, simple_model):

with simple_model:
inference = inference_spec(**kw)
trace = inference.fit(n=0).sample(10000)
try:
trace = inference.fit(n=0).sample(10000)
except NotImplementedInference as e:
pytest.skip(str(e))
np.testing.assert_allclose(np.mean(trace.posterior["mu"]), mu_init, rtol=0.05)
if has_start_sigma:
np.testing.assert_allclose(np.std(trace.posterior["mu"]), mu_sigma_init, rtol=0.05)
Expand Down Expand Up @@ -218,6 +225,8 @@ def test_fit_fn_text(method, kwargs, error):


def test_profile(inference):
if type(inference) in {SVGD, ASVGD}:
pytest.skip("Not Implemented Inference")
inference.run_profiling(n=100).summary()


Expand All @@ -239,8 +248,7 @@ def binomial_model_inference(binomial_model, inference_spec):

@pytest.mark.xfail("pytensor.config.warn_float64 == 'raise'", reason="too strict float32")
def test_replacements(binomial_model_inference):
d = at.bscalar()
d.tag.test_value = 1
d = pytensor.shared(1)
approx = binomial_model_inference.approx
p = approx.model.p
p_t = p**3
Expand All @@ -252,7 +260,7 @@ def test_replacements(binomial_model_inference):
), "p should be replaced"
if pytensor.config.compute_test_value != "off":
assert p_s.tag.test_value.shape == p_t.tag.test_value.shape
sampled = [p_s.eval() for _ in range(100)]
sampled = [pm.draw(p_s) for _ in range(100)]
assert any(map(operator.ne, sampled[1:], sampled[:-1])) # stochastic
p_z = approx.sample_node(p_t, deterministic=False, size=10)
assert p_z.shape.eval() == (10,)
Expand All @@ -264,15 +272,17 @@ def test_replacements(binomial_model_inference):

try:
p_d = approx.sample_node(p_t, deterministic=True)
sampled = [p_d.eval() for _ in range(100)]
sampled = [pm.draw(p_d) for _ in range(100)]
assert all(map(operator.eq, sampled[1:], sampled[:-1])) # deterministic
except opvi.NotImplementedInference:
pass

p_r = approx.sample_node(p_t, deterministic=d)
sampled = [p_r.eval({d: 1}) for _ in range(100)]
d.set_value(1)
sampled = [pm.draw(p_r) for _ in range(100)]
assert all(map(operator.eq, sampled[1:], sampled[:-1])) # deterministic
sampled = [p_r.eval({d: 0}) for _ in range(100)]
d.set_value(0)
sampled = [pm.draw(p_r) for _ in range(100)]
assert any(map(operator.ne, sampled[1:], sampled[:-1])) # stochastic


Expand Down Expand Up @@ -325,8 +335,6 @@ def test_var_replacement():


def test_clear_cache():
import cloudpickle

with pm.Model():
pm.Normal("n", 0, 1)
inference = ADVI()
Expand Down
18 changes: 13 additions & 5 deletions pymc/variational/approximations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import numpy as np
import pytensor

Expand All @@ -24,11 +25,13 @@

from pymc.blocking import DictToArrayBijection
from pymc.distributions.dist_math import rho2sigma
from pymc.pytensorf import makeiter
from pymc.variational import opvi
from pymc.variational.opvi import (
Approximation,
Group,
NotImplementedInference,
_known_scan_ignored_inputs,
node_property,
)

Expand Down Expand Up @@ -248,9 +251,12 @@ def randidx(self, size=None):
pass
else:
size = tuple(np.atleast_1d(size))
return self._rng.uniform(
size=size, low=pm.floatX(0), high=pm.floatX(self.histogram.shape[0]) - pm.floatX(1e-16)
).astype("int32")
return at.random.integers(
size=size,
low=0,
high=self.histogram.shape[0],
rng=pytensor.shared(np.random.default_rng()),
)

def _new_initial(self, size, deterministic, more_replacements=None):
pytensor_condition_is_here = isinstance(deterministic, Variable)
Expand Down Expand Up @@ -383,8 +389,10 @@ def evaluate_over_trace(self, node):
"""
node = self.to_flat_input(node)

def sample(post, node):
def sample(post, *_):
return pytensor.clone_replace(node, {self.input: post})

nodes, _ = pytensor.scan(sample, self.histogram, non_sequences=[node])
nodes, _ = pytensor.scan(
sample, self.histogram, non_sequences=_known_scan_ignored_inputs(makeiter(node))
)
return nodes
4 changes: 3 additions & 1 deletion pymc/variational/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,9 @@ def __init__(self, approx=None, estimator=KSD, kernel=test_functions.rbf, **kwar
"is often **underestimated** when using temperature = 1."
)
if approx is None:
approx = FullRank(model=kwargs.pop("model", None))
approx = FullRank(
model=kwargs.pop("model", None), random_seed=kwargs.pop("random_seed", None)
)
super().__init__(estimator=estimator, approx=approx, kernel=kernel, **kwargs)

def fit(
Expand Down
Loading