1717
1818from theano .gof .op import get_test_value
1919from theano .gof .graph import Apply , inputs as tt_inputs
20+ from theano .scan_module .scan_op import Scan
21+ from theano .scan_module .scan_utils import clone
2022
2123from .random_variables import (
2224 observed ,
@@ -96,6 +98,122 @@ def convert_rv_to_dist(node, obs):
9698 return dist_type (rv .name , shape = shape , observed = obs , ** dist_params )
9799
98100
101+ @dispatch (tt .TensorVariable , object )
102+ def logp (var , obs ):
103+
104+ node = var .owner
105+
106+ if hasattr (node , "fgraph" ) and hasattr (node .fgraph , "shape_feature" ):
107+ shape = list (node .fgraph .shape_feature .shape_tuple (var ))
108+ else :
109+ shape = list (var .shape )
110+
111+ for i , s in enumerate (shape ):
112+ try :
113+ shape [i ] = tt .get_scalar_constant_value (s )
114+ except tt .NotScalarConstantError :
115+ shape [i ] = s .tag .test_value
116+
117+ logp_fn = _logp_fn (node .op , node , shape )
118+ return logp_fn (obs )
119+
120+
121+ @dispatch (RandomVariable , Apply , object )
122+ def _logp_fn (op , node , shape = None ):
123+ dist_type , dist_params = _convert_rv_to_dist (op , node )
124+ if shape is not None :
125+ dist_params ["shape" ] = shape
126+ res = dist_type .dist (** dist_params )
127+ # Add extra information to the PyMC3 `Distribution` object
128+ res .dist_params = dist_params
129+ res .ndim_supp = op .ndim_supp
130+ # TODO: Need to maintain the order of these so that they correspond with
131+ # the `Distribution`'s parameters
132+ res .ndims_params = op .ndims_params
133+ return res .logp
134+
135+
136+ @_logp_fn .register (Scan , Apply , object )
137+ def _logp_fn_Scan (op , scan_node , shape = None ):
138+
139+ scan_inner_inputs = scan_node .op .inputs
140+ scan_inner_outputs = scan_node .op .outputs
141+
142+ def create_obs_var (i , x ):
143+ obs = x .type ()
144+ obs .name = f"{ x .name or x .owner .op .name } _obs_{ i } "
145+ if hasattr (x .tag , "test_value" ):
146+ obs .tag .test_value = x .tag .test_value
147+ return obs
148+
149+ rv_outs = [
150+ (i , x , create_obs_var (i , x ))
151+ for i , x in enumerate (scan_inner_outputs )
152+ if x .owner and isinstance (x .owner .op , RandomVariable )
153+ ]
154+ rv_inner_out_idx , rv_out_vars , rv_out_obs = zip (* rv_outs )
155+ # rv_outer_out_idx = [scan_node.op.get_oinp_iinp_iout_oout_mappings()['outer_out_from_inner_out'][i] for i in rv_inner_out_idx]
156+ # rv_outer_outputs = [scan_node.outputs[i] for i in rv_outer_out_idx]
157+
158+ logp_inner_outputs = [clone (logp (rv , obs )) for i , rv , obs in rv_outs ]
159+ assert all (o in tt .gof .graph .inputs (logp_inner_outputs ) for o in rv_out_obs )
160+
161+ logp_inner_outputs_inputs = tt .gof .graph .inputs (logp_inner_outputs )
162+ rv_relevant_inner_input_idx , rv_relevant_inner_inputs = zip (
163+ * [(n , i ) for n , i in enumerate (scan_inner_inputs ) if i in logp_inner_outputs_inputs ]
164+ )
165+ logp_inner_inputs = list (rv_out_obs ) + list (rv_relevant_inner_inputs )
166+
167+ # We need to create outer-inputs that represent arrays of observations
168+ # for each random variable.
169+ # To do that, we're going to use each random variable's outer-output term,
170+ # since they necessarily have the same shape and type as the observations
171+ # arrays.
172+
173+ # Just like we did for the inner-inputs, we need to get only the outer-inputs
174+ # that are relevant to the new logp graphs.
175+ # We can do that by removing the irrelevant outer-inputs using the known relevant inner-inputs
176+ removed_inner_inputs = set (range (len (scan_inner_inputs ))) - set (rv_relevant_inner_input_idx )
177+ old_in_out_mappings = scan_node .op .get_oinp_iinp_iout_oout_mappings ()
178+ rv_removed_outer_input_idx = [
179+ old_in_out_mappings ["outer_inp_from_inner_inp" ][i ] for i in removed_inner_inputs
180+ ]
181+ rv_removed_outer_inputs = [scan_node .inputs [i ] for i in rv_removed_outer_input_idx ]
182+
183+ rv_relevant_outer_inputs = [r for r in scan_node .inputs if r not in rv_removed_outer_inputs ]
184+
185+ # Now, we can create a new op with our new inner-graph inputs and outputs.
186+ # Also, since our inner graph has new placeholder terms representing
187+ # an observed value for each random variable, we need to update the
188+ # "info" `dict`.
189+ logp_info = scan_node .op .info .copy ()
190+ logp_info ["tap_array" ] = []
191+ logp_info ["n_seqs" ] += len (rv_out_obs )
192+ logp_info ["n_mit_mot" ] = 0
193+ logp_info ["n_mit_mot_outs" ] = 0
194+ logp_info ["mit_mot_out_slices" ] = []
195+ logp_info ["n_mit_sot" ] = 0
196+ logp_info ["n_sit_sot" ] = 0
197+ logp_info ["n_shared_outs" ] = 0
198+ logp_info ["n_nit_sot" ] += len (rv_out_obs ) - 1
199+ logp_info ["name" ] = None
200+ logp_info ["strict" ] = True
201+
202+ # These are the tensor variables corresponding to each random variable's
203+ # array of observations.
204+ def logp_fn (* obs ):
205+ logp_obs_outer_inputs = list (obs ) # [r.clone() for r in rv_outer_outputs]
206+ logp_outer_inputs = (
207+ [rv_relevant_outer_inputs [0 ]] + logp_obs_outer_inputs + rv_relevant_outer_inputs [1 :]
208+ )
209+ logp_op = Scan (logp_inner_inputs , logp_inner_outputs , logp_info )
210+ scan_logp = logp_op (* logp_outer_inputs )
211+ return scan_logp
212+
213+ # logp_fn = OpFromGraph(logp_obs_outer_inputs, [scan_logp])
214+ return logp_fn
215+
216+
99217@dispatch (pm .Uniform , object )
100218def convert_dist_to_rv (dist , rng ):
101219 size = dist .shape .astype (int )[UniformRV .ndim_supp :]
@@ -467,7 +585,7 @@ def model_graph(pymc_model, output_vars=None, rand_state=None, attach_memo=True)
467585 return model_fg
468586
469587
470- def graph_model (graph , * model_args , ** model_kwargs ):
588+ def graph_model (graph , * model_args , generate_names = False , ** model_kwargs ):
471589 """Create a PyMC3 model from a Theano graph with `RandomVariable` nodes."""
472590 model = pm .Model (* model_args , ** model_kwargs )
473591
@@ -478,13 +596,14 @@ def graph_model(graph, *model_args, **model_kwargs):
478596 nodes = [n for n in fgraph .toposort () if isinstance (n .op , RandomVariable )]
479597 rv_replacements = {}
480598
599+ node_id = 0
600+
481601 for node in nodes :
482602
483603 obs = get_rv_observation (node )
484604
485605 if obs is not None :
486606 obs = obs .inputs [0 ]
487-
488607 obs = tt_get_values (obs )
489608
490609 old_rv_var = node .default_output ()
@@ -500,6 +619,12 @@ def graph_model(graph, *model_args, **model_kwargs):
500619 if op != node
501620 )
502621
622+ if generate_names and rv_var .name is None :
623+ node_name = "{}_{}" .format (node .op .name , node_id )
624+ # warn("Name {} generated for node {}.".format(node, node_name))
625+ node_id += 1
626+ rv_var .name = node_name
627+
503628 with model :
504629 rv = convert_rv_to_dist (node , obs )
505630
0 commit comments