Skip to content

Commit 3c87e7d

Browse files
michaelosthegetwiecki
authored andcommitted
Evaluate initial values lazily
Related to #4924
1 parent 28f2d43 commit 3c87e7d

File tree

3 files changed

+64
-19
lines changed

3 files changed

+64
-19
lines changed

pymc/model.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -937,32 +937,43 @@ def recompute_initial_point(self) -> Dict[str, np.ndarray]:
937937
Returns
938938
-------
939939
initial_point : dict
940-
Maps free variable names to transformed, numeric initial values.
940+
Maps transformed free variable names to transformed, numeric initial values.
941941
"""
942-
self._initial_point_cache = Point(list(self.initial_values.items()), model=self)
942+
numeric_initvals = {}
943+
# The entries in `initial_values` are already in topological order and can be evaluated one by one.
944+
for rv_value, initval in self.initial_values.items():
945+
rv_var = self.values_to_rvs[rv_value]
946+
transform = getattr(rv_value.tag, "transform", None)
947+
if isinstance(initval, np.ndarray) and transform is None:
948+
# Only untransformed, numeric initvals can be taken as they are.
949+
numeric_initvals[rv_value] = initval
950+
else:
951+
# Evaluate initvals that are None, symbolic or need to be transformed.
952+
# They can depend on other initvals from higher up in the graph,
953+
# which are therefore fed to the evaluation as "givens".
954+
test_value = getattr(rv_var.tag, "test_value", None)
955+
numeric_initvals[rv_value] = self._eval_initval(
956+
rv_var, initval, test_value, transform, given=numeric_initvals
957+
)
958+
959+
# Cache the evaluation results for next time.
960+
self._initial_point_cache = Point(list(numeric_initvals.items()), model=self)
943961
return self._initial_point_cache
944962

945963
@property
946-
def initial_values(self) -> Dict[TensorVariable, np.ndarray]:
947-
"""Maps transformed variables to initial values.
964+
def initial_values(self) -> Dict[TensorVariable, Optional[Union[np.ndarray, Variable]]]:
965+
"""Maps transformed variables to initial value placeholders.
948966
949967
⚠ The keys are NOT the objects returned by, `pm.Normal(...)`.
950-
For a name-based dictionary use the `initial_point` property.
968+
For a name-based dictionary use the `get_initial_point()` method.
951969
"""
952970
return self._initial_values
953971

954972
def set_initval(self, rv_var, initval):
955973
if initval is not None:
956974
initval = rv_var.type.filter(initval)
957975

958-
test_value = getattr(rv_var.tag, "test_value", None)
959-
960976
rv_value_var = self.rvs_to_values[rv_var]
961-
transform = getattr(rv_value_var.tag, "transform", None)
962-
963-
if initval is None or transform:
964-
initval = self._eval_initval(rv_var, initval, test_value, transform)
965-
966977
self.initial_values[rv_value_var] = initval
967978

968979
def _eval_initval(
@@ -971,6 +982,7 @@ def _eval_initval(
971982
initval: Optional[Variable],
972983
test_value: Optional[np.ndarray],
973984
transform: Optional[Transform],
985+
given: Optional[Dict[TensorVariable, np.ndarray]] = None,
974986
) -> np.ndarray:
975987
"""Sample/evaluate an initial value using the existing initial values,
976988
and with the least effect on the RNGs involved (i.e. no in-placing).
@@ -989,6 +1001,8 @@ def _eval_initval(
9891001
transform : optional, Transform
9901002
A transformation associated with the random variable.
9911003
Transformations are automatically applied to initial values.
1004+
given : optional, dict
1005+
Numeric initial values to be used for givens instead of `self.initial_values`.
9921006
9931007
Returns
9941008
-------
@@ -999,6 +1013,9 @@ def _eval_initval(
9991013
opt_qry = mode.provided_optimizer.excluding("random_make_inplace")
10001014
mode = Mode(linker=mode.linker, optimizer=opt_qry)
10011015

1016+
if given is None:
1017+
given = self.initial_values
1018+
10021019
if transform:
10031020
if initval is not None:
10041021
value = initval
@@ -1015,9 +1032,7 @@ def initval_to_rvval(value_var, value):
10151032
else:
10161033
return initval
10171034

1018-
givens = {
1019-
self.values_to_rvs[k]: initval_to_rvval(k, v) for k, v in self.initial_values.items()
1020-
}
1035+
givens = {self.values_to_rvs[k]: initval_to_rvval(k, v) for k, v in given.items()}
10211036
initval_fn = aesara.function([], rv_var, mode=mode, givens=givens, on_unused_input="ignore")
10221037
try:
10231038
initval = initval_fn()

pymc/tests/test_initvals.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import aesara
1415
import aesara.tensor as at
1516
import numpy as np
1617
import pytest
@@ -38,7 +39,8 @@ def test_new_warnings(self):
3839
with pm.Model() as pmodel:
3940
with pytest.warns(DeprecationWarning, match="`testval` argument is deprecated"):
4041
rv = pm.Uniform("u", 0, 1, testval=0.75)
41-
assert pmodel.initial_values[rv.tag.value_var] == transform_fwd(rv, 0.75)
42+
initial_point = pmodel.recompute_initial_point()
43+
assert initial_point["u_interval__"] == transform_fwd(rv, 0.75)
4244
assert not hasattr(rv.tag, "test_value")
4345
pass
4446

@@ -83,6 +85,33 @@ def test_falls_back_to_test_value(self):
8385
assert iv == 0.6
8486
pass
8587

88+
def test_dependent_initvals(self):
89+
with pm.Model() as pmodel:
90+
L = pm.Uniform("L", 0, 1, initval=0.5)
91+
B = pm.Uniform("B", lower=L, upper=2, initval=1.25)
92+
ip = pmodel.recompute_initial_point()
93+
assert ip["L_interval__"] == 0
94+
assert ip["B_interval__"] == 0
95+
96+
# Modify initval of L and re-evaluate
97+
pmodel.initial_values[pmodel.rvs_to_values[L]] = 0.9
98+
ip = pmodel.recompute_initial_point()
99+
assert ip["B_interval__"] < 0
100+
pass
101+
102+
def test_initval_resizing(self):
103+
with pm.Model() as pmodel:
104+
data = aesara.shared(np.arange(4))
105+
rv = pm.Uniform("u", lower=data, upper=10)
106+
107+
ip = pmodel.recompute_initial_point()
108+
assert np.shape(ip["u_interval__"]) == (4,)
109+
110+
data.set_value(np.arange(5))
111+
ip = pmodel.recompute_initial_point()
112+
assert np.shape(ip["u_interval__"]) == (5,)
113+
pass
114+
86115

87116
class TestSpecialDistributions:
88117
def test_automatically_assigned_test_values(self):

pymc/tests/test_model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,8 @@ def test_initial_point():
515515

516516
assert model.rvs_to_values[a] in model.initial_values
517517
assert model.rvs_to_values[x] in model.initial_values
518-
assert model.initial_values[b_value_var] == b_initval_trans
518+
assert model.initial_values[b_value_var] == b_initval
519+
assert model.recompute_initial_point()["b_interval__"] == b_initval_trans
519520
assert model.initial_values[model.rvs_to_values[y]] == y_initval
520521

521522

@@ -662,8 +663,8 @@ def test_set_initval():
662663
value = pm.NegativeBinomial("value", mu=mu, alpha=alpha)
663664

664665
assert np.array_equal(model.initial_values[model.rvs_to_values[mu]], np.array([[100.0]]))
665-
np.testing.assert_almost_equal(model.initial_values[model.rvs_to_values[alpha]], np.log(100))
666-
assert 50 < model.initial_values[model.rvs_to_values[value]] < 150
666+
np.testing.assert_array_equal(model.initial_values[model.rvs_to_values[alpha]], np.array(100))
667+
assert model.initial_values[model.rvs_to_values[value]] is None
667668

668669
# `Flat` cannot be sampled, so let's make sure that doesn't break initial
669670
# value computations

0 commit comments

Comments
 (0)