@@ -966,7 +966,8 @@ def __init_group__(self, group):
966
966
if var .type .numpy_dtype .name in discrete_types :
967
967
raise ParametrizationError (f"Discrete variables are not supported by VI: { var } " )
968
968
# 3) This is the way to infer shape and dtype of the variable
969
- test_var = model_initial_point [var .tag .value_var .name ]
969
+ value_var = self .model .rvs_to_values [var ]
970
+ test_var = model_initial_point [value_var .name ]
970
971
if self .batched :
971
972
# Leave a more complicated case for future work
972
973
raise NotImplementedError ("not yet ready" )
@@ -989,10 +990,10 @@ def __init_group__(self, group):
989
990
size = test_var .size
990
991
# TODO: There was self.ordering used in other util funcitons
991
992
vr = self .input [..., start_idx :start_idx + size ].reshape (shape ).astype (dtype )
992
- vr .name = var . tag . value_var .name + "_vi_replacement"
993
- self .replacements [var . tag . value_var ] = vr
994
- self .ordering [var . tag . value_var .name ] = (
995
- var . tag . value_var .name ,
993
+ vr .name = value_var .name + "_vi_replacement"
994
+ self .replacements [value_var ] = vr
995
+ self .ordering [value_var .name ] = (
996
+ value_var .name ,
996
997
slice (start_idx , start_idx + size ),
997
998
shape ,
998
999
dtype
@@ -1599,7 +1600,7 @@ def rslice(self, name):
1599
1600
"""
1600
1601
1601
1602
def vars_names (vs ):
1602
- return {v . tag . value_var .name for v in vs }
1603
+ return {self . model . rvs_to_values [ v ] .name for v in vs }
1603
1604
1604
1605
for vars_ , random , ordering in zip (
1605
1606
self .collect ("group" ), self .symbolic_randoms , self .collect ("ordering" )
@@ -1617,7 +1618,7 @@ def vars_names(vs):
1617
1618
def sample_dict_fn (self ):
1618
1619
# TODO: this breaks
1619
1620
s = at .iscalar ()
1620
- names = [v . tag . value_var .name for v in self .model .free_RVs ]
1621
+ names = [self . model . rvs_to_values [ v ] .name for v in self .model .free_RVs ]
1621
1622
sampled = [self .rslice (name ) for name in names ]
1622
1623
sampled = self .set_size_and_deterministic (sampled , s , 0 )
1623
1624
sample_fn = aesara .function ([s ], sampled )
0 commit comments