Skip to content

Commit 0009ff3

Browse files
committed
Remove global RandomStream
1 parent 23a1636 commit 0009ff3

File tree

8 files changed

+10
-71
lines changed

8 files changed

+10
-71
lines changed

pymc/data.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from pytensor.raise_op import Assert
3030
from pytensor.scalar import Cast
3131
from pytensor.tensor.elemwise import Elemwise
32-
from pytensor.tensor.random import RandomStream
3332
from pytensor.tensor.random.basic import IntegersRV
3433
from pytensor.tensor.subtensor import AdvancedSubtensor
3534
from pytensor.tensor.type import TensorType
@@ -132,6 +131,12 @@ def __hash__(self):
132131
class MinibatchIndexRV(IntegersRV):
133132
_print_name = ("minibatch_index", r"\operatorname{minibatch\_index}")
134133

134+
# Work-around for https://github.com/pymc-devs/pytensor/issues/97
135+
def make_node(self, rng, *args, **kwargs):
136+
if rng is None:
137+
rng = pytensor.shared(np.random.default_rng())
138+
return super().make_node(rng, *args, **kwargs)
139+
135140

136141
minibatch_index = MinibatchIndexRV()
137142

@@ -184,10 +189,9 @@ def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size:
184189
>>> mdata1, mdata2 = Minibatch(data1, data2, batch_size=10)
185190
"""
186191

187-
rng = RandomStream()
188192
tensor, *tensors = tuple(map(at.as_tensor, (variable, *variables)))
189193
upper = assert_all_scalars_equal(*[t.shape[0] for t in (tensor, *tensors)])
190-
slc = rng.gen(minibatch_index, 0, upper, size=batch_size)
194+
slc = minibatch_index(0, upper, size=batch_size)
191195
for i, v in enumerate((tensor, *tensors)):
192196
if not valid_for_minibatch(v):
193197
raise ValueError(

pymc/distributions/simulator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class Simulator(Distribution):
7676
----------
7777
fn : callable
7878
Python random simulator function. Should expect the following signature
79-
``(rng, arg1, arg2, ... argn, size)``, where rng is a ``numpy.random.RandomStream()``
79+
``(rng, arg1, arg2, ... argn, size)``, where rng is a ``numpy.random.Generator``
8080
and ``size`` defines the size of the desired sample.
8181
*unnamed_params : list of TensorVariable
8282
Parameters used by the Simulator random function. Each parameter can be passed

pymc/pytensorf.py

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
from pytensor.scalar.basic import Cast
5151
from pytensor.tensor.basic import _as_tensor_variable
5252
from pytensor.tensor.elemwise import Elemwise
53-
from pytensor.tensor.random import RandomStream
5453
from pytensor.tensor.random.op import RandomVariable
5554
from pytensor.tensor.random.var import (
5655
RandomGeneratorSharedVariable,
@@ -84,8 +83,6 @@
8483
"join_nonshared_inputs",
8584
"make_shared_replacements",
8685
"generator",
87-
"set_at_rng",
88-
"at_rng",
8986
"convert_observed_data",
9087
"compile_pymc",
9188
"constant_fold",
@@ -891,49 +888,6 @@ def generator(gen, default=None):
891888
return GeneratorOp(gen, default)()
892889

893890

894-
_at_rng = RandomStream()
895-
896-
897-
def at_rng(random_seed=None):
898-
"""
899-
Get the package-level random number generator or new with specified seed.
900-
901-
Parameters
902-
----------
903-
random_seed: int
904-
If not None
905-
returns *new* pytensor random generator without replacing package global one
906-
907-
Returns
908-
-------
909-
`pytensor.tensor.random.utils.RandomStream` instance
910-
`pytensor.tensor.random.utils.RandomStream`
911-
instance passed to the most recent call of `set_at_rng`
912-
"""
913-
if random_seed is None:
914-
return _at_rng
915-
else:
916-
ret = RandomStream(random_seed)
917-
return ret
918-
919-
920-
def set_at_rng(new_rng):
921-
"""
922-
Set the package-level random number generator.
923-
924-
Parameters
925-
----------
926-
new_rng: `pytensor.tensor.random.utils.RandomStream` instance
927-
The random number generator to use.
928-
"""
929-
# pylint: disable=global-statement
930-
global _at_rng
931-
# pylint: enable=global-statement
932-
if isinstance(new_rng, int):
933-
new_rng = RandomStream(new_rng)
934-
_at_rng = new_rng
935-
936-
937891
def floatX_array(x):
938892
return floatX(np.array(x))
939893

pymc/sampling/parallel.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828

2929
from fastprogress.fastprogress import progress_bar
3030

31-
from pymc import pytensorf
3231
from pymc.blocking import DictToArrayBijection
3332
from pymc.exceptions import SamplingError
3433
from pymc.util import RandomSeed
@@ -155,7 +154,6 @@ def _recv_msg(self):
155154

156155
def _start_loop(self):
157156
np.random.seed(self._seed)
158-
pytensorf.set_at_rng(self._at_seed)
159157

160158
draw = 0
161159
tuning = True

pymc/tests/conftest.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
import pytensor
1717
import pytest
1818

19-
import pymc as pm
20-
2119

2220
@pytest.fixture(scope="function", autouse=True)
2321
def pytensor_config():
@@ -47,4 +45,3 @@ def strict_float32():
4745
def seeded_test():
4846
# TODO: use this instead of SeededTest
4947
np.random.seed(42)
50-
pm.set_at_rng(42)

pymc/tests/helpers.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,11 @@
2626
from pytensor.gradient import verify_grad as at_verify_grad
2727
from pytensor.graph import ancestors
2828
from pytensor.graph.rewriting.basic import in2out
29-
from pytensor.tensor.random import RandomStream
3029
from pytensor.tensor.random.op import RandomVariable
3130

3231
import pymc as pm
3332

34-
from pymc.pytensorf import at_rng, local_check_parameter_to_ninf_switch, set_at_rng
33+
from pymc.pytensorf import local_check_parameter_to_ninf_switch
3534
from pymc.tests.checks import close_to
3635
from pymc.tests.models import mv_simple, mv_simple_coarse
3736

@@ -46,11 +45,6 @@ def setup_class(cls):
4645

4746
def setup_method(self):
4847
nr.seed(self.random_seed)
49-
self.old_at_rng = at_rng()
50-
set_at_rng(RandomStream(self.random_seed))
51-
52-
def teardown_method(self):
53-
set_at_rng(self.old_at_rng)
5448

5549
def get_random_state(self, reset=False):
5650
if self.random_state is None or reset:

pymc/tests/test_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,7 @@ def test_pickling(self, datagen):
553553

554554
def test_gen_cloning_with_shape_change(self, datagen):
555555
gen = pm.generator(datagen)
556-
gen_r = pm.at_rng().normal(size=gen.shape).T
556+
gen_r = at.random.normal(size=gen.shape).T
557557
X = gen.dot(gen_r)
558558
res, _ = pytensor.scan(lambda x: x.sum(), X, n_steps=X.shape[0])
559559
assert res.eval().shape == (50,)

pymc/variational/inference.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -433,8 +433,6 @@ class ADVI(KLqp):
433433
model: :class:`pymc.Model`
434434
PyMC model for inference
435435
random_seed: None or int
436-
leave None to use package global RandomStream or other
437-
valid value to create instance specific one
438436
start: `dict[str, np.ndarray]` or `StartDict`
439437
starting point for inference
440438
start_sigma: `dict[str, np.ndarray]`
@@ -466,8 +464,6 @@ class FullRankADVI(KLqp):
466464
model: :class:`pymc.Model`
467465
PyMC model for inference
468466
random_seed: None or int
469-
leave None to use package global RandomStream or other
470-
valid value to create instance specific one
471467
start: `dict[str, np.ndarray]` or `StartDict`
472468
starting point for inference
473469
@@ -539,8 +535,6 @@ class SVGD(ImplicitGradient):
539535
start: `dict[str, np.ndarray]` or `StartDict`
540536
initial point for inference
541537
random_seed: None or int
542-
leave None to use package global RandomStream or other
543-
valid value to create instance specific one
544538
kwargs: other keyword arguments passed to estimator
545539
546540
References
@@ -685,8 +679,6 @@ def fit(
685679
model: :class:`Model`
686680
PyMC model for inference
687681
random_seed: None or int
688-
leave None to use package global RandomStream or other
689-
valid value to create instance specific one
690682
inf_kwargs: dict
691683
additional kwargs passed to :class:`Inference`
692684
start: `dict[str, np.ndarray]` or `StartDict`

0 commit comments

Comments
 (0)