@@ -937,32 +937,43 @@ def recompute_initial_point(self) -> Dict[str, np.ndarray]:
937
937
Returns
938
938
-------
939
939
initial_point : dict
940
- Maps free variable names to transformed, numeric initial values.
940
+ Maps transformed free variable names to transformed, numeric initial values.
941
941
"""
942
- self ._initial_point_cache = Point (list (self .initial_values .items ()), model = self )
942
+ numeric_initvals = {}
943
+ # The entries in `initial_values` are already in topological order and can be evaluated one by one.
944
+ for rv_value , initval in self .initial_values .items ():
945
+ rv_var = self .values_to_rvs [rv_value ]
946
+ transform = getattr (rv_value .tag , "transform" , None )
947
+ if isinstance (initval , np .ndarray ) and transform is None :
948
+ # Only untransformed, numeric initvals can be taken as they are.
949
+ numeric_initvals [rv_value ] = initval
950
+ else :
951
+ # Evaluate initvals that are None, symbolic or need to be transformed.
952
+ # They can depend on other initvals from higher up in the graph,
953
+ # which are therefore fed to the evaluation as "givens".
954
+ test_value = getattr (rv_var .tag , "test_value" , None )
955
+ numeric_initvals [rv_value ] = self ._eval_initval (
956
+ rv_var , initval , test_value , transform , given = numeric_initvals
957
+ )
958
+
959
+ # Cache the evaluation results for next time.
960
+ self ._initial_point_cache = Point (list (numeric_initvals .items ()), model = self )
943
961
return self ._initial_point_cache
944
962
945
963
@property
946
- def initial_values (self ) -> Dict [TensorVariable , np .ndarray ]:
947
- """Maps transformed variables to initial values .
964
+ def initial_values (self ) -> Dict [TensorVariable , Optional [ Union [ np .ndarray , Variable ]] ]:
965
+ """Maps transformed variables to initial value placeholders .
948
966
949
967
⚠ The keys are NOT the objects returned by, `pm.Normal(...)`.
950
- For a name-based dictionary use the `initial_point` property .
968
+ For a name-based dictionary use the `get_initial_point()` method .
951
969
"""
952
970
return self ._initial_values
953
971
954
972
def set_initval (self , rv_var , initval ):
955
973
if initval is not None :
956
974
initval = rv_var .type .filter (initval )
957
975
958
- test_value = getattr (rv_var .tag , "test_value" , None )
959
-
960
976
rv_value_var = self .rvs_to_values [rv_var ]
961
- transform = getattr (rv_value_var .tag , "transform" , None )
962
-
963
- if initval is None or transform :
964
- initval = self ._eval_initval (rv_var , initval , test_value , transform )
965
-
966
977
self .initial_values [rv_value_var ] = initval
967
978
968
979
def _eval_initval (
@@ -971,6 +982,7 @@ def _eval_initval(
971
982
initval : Optional [Variable ],
972
983
test_value : Optional [np .ndarray ],
973
984
transform : Optional [Transform ],
985
+ given : Optional [Dict [TensorVariable , np .ndarray ]] = None ,
974
986
) -> np .ndarray :
975
987
"""Sample/evaluate an initial value using the existing initial values,
976
988
and with the least effect on the RNGs involved (i.e. no in-placing).
@@ -989,6 +1001,8 @@ def _eval_initval(
989
1001
transform : optional, Transform
990
1002
A transformation associated with the random variable.
991
1003
Transformations are automatically applied to initial values.
1004
+ given : optional, dict
1005
+ Numeric initial values to be used for givens instead of `self.initial_values`.
992
1006
993
1007
Returns
994
1008
-------
@@ -999,6 +1013,9 @@ def _eval_initval(
999
1013
opt_qry = mode .provided_optimizer .excluding ("random_make_inplace" )
1000
1014
mode = Mode (linker = mode .linker , optimizer = opt_qry )
1001
1015
1016
+ if given is None :
1017
+ given = self .initial_values
1018
+
1002
1019
if transform :
1003
1020
if initval is not None :
1004
1021
value = initval
@@ -1015,9 +1032,7 @@ def initval_to_rvval(value_var, value):
1015
1032
else :
1016
1033
return initval
1017
1034
1018
- givens = {
1019
- self .values_to_rvs [k ]: initval_to_rvval (k , v ) for k , v in self .initial_values .items ()
1020
- }
1035
+ givens = {self .values_to_rvs [k ]: initval_to_rvval (k , v ) for k , v in given .items ()}
1021
1036
initval_fn = aesara .function ([], rv_var , mode = mode , givens = givens , on_unused_input = "ignore" )
1022
1037
try :
1023
1038
initval = initval_fn ()
0 commit comments