Skip to content

Commit 955abcd

Browse files
Update infer_shape signatures
1 parent 68d5201 commit 955abcd

File tree

3 files changed

+4
-4
lines changed

3 files changed

+4
-4
lines changed

pymc3/distributions/multivariate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -740,7 +740,7 @@ def perform(self, node, inputs, outputs):
740740
pm._log.exception("Failed to check if %s positive definite", x)
741741
raise
742742

743-
def infer_shape(self, node, shapes):
743+
def infer_shape(self, fgraph, node, shapes):
744744
return [[]]
745745

746746
def grad(self, inp, grads):

pymc3/math.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ def grad(self, inputs, gout):
359359
idx = tt.arange(gz.shape[-1])
360360
return [gz[..., idx, idx]]
361361

362-
def infer_shape(self, nodes, shapes):
362+
def infer_shape(self, fgraph, nodes, shapes):
363363
return [(shapes[0][0],) + (shapes[0][1],) * 2]
364364

365365

@@ -418,7 +418,7 @@ def grad(self, inputs, gout):
418418
]
419419
return [gout[0][slc] for slc in slices]
420420

421-
def infer_shape(self, nodes, shapes):
421+
def infer_shape(self, fgraph, nodes, shapes):
422422
first, second = zip(*shapes)
423423
return [(tt.add(*first), tt.add(*second))]
424424

pymc3/ode/ode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def perform(self, node, inputs_storage, output_storage):
210210
# simulate states and sensitivities in one forward pass
211211
output_storage[0][0], output_storage[1][0] = self._simulate(y0, theta)
212212

213-
def infer_shape(self, node, input_shapes):
213+
def infer_shape(self, fgraph, node, input_shapes):
214214
s_y0, s_theta = input_shapes
215215
output_shapes = [(self.n_times, self.n_states), (self.n_times, self.n_states, self.n_p)]
216216
return output_shapes

0 commit comments

Comments
 (0)