Skip to content

Commit 573162a

Browse files
committed
use rvs_to_values from the model in opi.py
1 parent 232c884 commit 573162a

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

pymc3/variational/opvi.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -966,7 +966,8 @@ def __init_group__(self, group):
966966
if var.type.numpy_dtype.name in discrete_types:
967967
raise ParametrizationError(f"Discrete variables are not supported by VI: {var}")
968968
# 3) This is the way to infer shape and dtype of the variable
969-
test_var = model_initial_point[var.tag.value_var.name]
969+
value_var = self.model.rvs_to_values[var]
970+
test_var = model_initial_point[value_var.name]
970971
if self.batched:
971972
# Leave a more complicated case for future work
972973
raise NotImplementedError("not yet ready")
@@ -989,10 +990,10 @@ def __init_group__(self, group):
989990
size = test_var.size
990991
# TODO: There was self.ordering used in other util funcitons
991992
vr = self.input[..., start_idx:start_idx+size].reshape(shape).astype(dtype)
992-
vr.name = var.tag.value_var.name + "_vi_replacement"
993-
self.replacements[var.tag.value_var] = vr
994-
self.ordering[var.tag.value_var.name] = (
995-
var.tag.value_var.name,
993+
vr.name = value_var.name + "_vi_replacement"
994+
self.replacements[value_var] = vr
995+
self.ordering[value_var.name] = (
996+
value_var.name,
996997
slice(start_idx, start_idx+size),
997998
shape,
998999
dtype
@@ -1599,7 +1600,7 @@ def rslice(self, name):
15991600
"""
16001601

16011602
def vars_names(vs):
1602-
return {v.tag.value_var.name for v in vs}
1603+
return {self.model.rvs_to_values[v].name for v in vs}
16031604

16041605
for vars_, random, ordering in zip(
16051606
self.collect("group"), self.symbolic_randoms, self.collect("ordering")
@@ -1617,7 +1618,7 @@ def vars_names(vs):
16171618
def sample_dict_fn(self):
16181619
# TODO: this breaks
16191620
s = at.iscalar()
1620-
names = [v.tag.value_var.name for v in self.model.free_RVs]
1621+
names = [self.model.rvs_to_values[v].name for v in self.model.free_RVs]
16211622
sampled = [self.rslice(name) for name in names]
16221623
sampled = self.set_size_and_deterministic(sampled, s, 0)
16231624
sample_fn = aesara.function([s], sampled)

0 commit comments

Comments
 (0)