Closed
Description
Reported in https://discourse.pymc.io/t/shape-issue-with-custom-logp-in-densitydist/13608/5?u=ricardov94
import numpy as np
import pytensor
import pytensor.tensor as pt
x = pt.scalar("x")
neg_const = pt.as_tensor(np.full((10,), -0.1))
out = neg_const + x
pytensor.function([x], out)
<<!! BUG IN FGRAPH.REPLACE OR A LISTENER !!>> <class 'TypeError'> The type of the replacement (Vector(float64, shape=(1,))) must be compatible with the type of the original Variable (Vector(float64, shape=(10,))). local_add_neg_to_sub
ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: local_add_neg_to_sub
ERROR (pytensor.graph.rewriting.basic): node: Add([-0.1 -0.1 ... -0.1 -0.1], ExpandDims{axis=0}.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
File "/home/ricardo/Documents/Projects/pytensor/pytensor/graph/rewriting/basic.py", line 1968, in process_node
fgraph.replace_all_validate_remove( # type: ignore
File "/home/ricardo/Documents/Projects/pytensor/pytensor/graph/features.py", line 626, in replace_all_validate_remove
chk = fgraph.replace_all_validate(replacements, reason=reason, **kwargs)
File "/home/ricardo/Documents/Projects/pytensor/pytensor/graph/features.py", line 571, in replace_all_validate
fgraph.replace(r, new_r, reason=reason, verbose=False, **kwargs)
File "/home/ricardo/Documents/Projects/pytensor/pytensor/graph/fg.py", line 508, in replace
self.change_node_input(
File "/home/ricardo/Documents/Projects/pytensor/pytensor/graph/fg.py", line 419, in change_node_input
raise TypeError(
TypeError: The type of the replacement (Vector(float64, shape=(1,))) must be compatible with the type of the original Variable (Vector(float64, shape=(10,))).
pytensor/pytensor/tensor/rewriting/math.py
Lines 1897 to 1922 in c52154d
The helper get_constant
actually calls get_unique_constant_value
under the hood, which in this case returns a single scalar, not the filled homogenous vector. The rewrite itself is also expecting only a scalar constant, otherwise it should be using np.all