Closed
Description
Description
Originally reported here: pymc-devs/nutpie#163
Here is a smaller reproducer:
import pytensor
import pytensor.tensor as pt
import numpy as np
a = pytensor.scalar.get_scalar_type("float64")()
loop = pytensor.scalar.ScalarLoop([a], [pytensor.scalar.add(a, a)])
x = pt.tensor("x", shape=(3,))
elem = pt.elemwise.Elemwise(loop)(3, x)
elem.eval({x: np.ones(3)})
# Returns array([8., 8., 8.])
# But compiling fails:
func = pytensor.function([x], elem, mode="NUMBA")
File ~/git/pymc-labs/red-cities/.pixi/envs/default/lib/python3.12/site-packages/pytensor/link/numba/dispatch/elemwise.py:357, in numba_funcify_Elemwise(op, node, **kwargs)
355 if not isinstance(op.scalar_op, Composite):
356 scalar_inputs = [scalar(dtype=input.dtype) for input in node.inputs]
--> 357 scalar_node = op.scalar_op.make_node(*scalar_inputs)
359 scalar_op_fn = numba_funcify(
360 op.scalar_op,
361 node=scalar_node,
(...)
364 **kwargs,
365 )
367 nin = len(node.inputs)
File ~/git/pymc-labs/red-cities/.pixi/envs/default/lib/python3.12/site-packages/pytensor/scalar/loop.py:180, in ScalarLoop.make_node(self, n_steps, *inputs)
178 cloned_constant = cloned_inputs[len(cloned_update) :]
179 # This will fail if the cloned init have a different dtype than the cloned_update
--> 180 op = ScalarLoop(
181 init=cloned_init,
182 update=cloned_update,
183 constant=cloned_constant,
184 until=cloned_until,
185 name=self.name,
186 )
187 node = op.make_node(n_steps, *inputs)
188 return node
File ~/git/pymc-labs/red-cities/.pixi/envs/default/lib/python3.12/site-packages/pytensor/scalar/loop.py:69, in ScalarLoop.__init__(self, init, update, constant, until, name)
66 inputs, outputs = clone([*init, *constant], update)
68 self.is_while = bool(until)
---> 69 self.inputs, self.outputs = self._cleanup_graph(inputs, outputs)
70 self._validate_updates(self.inputs, self.outputs)
72 self.inputs_type = tuple(input.type for input in self.inputs)
File ~/git/pymc-labs/red-cities/.pixi/envs/default/lib/python3.12/site-packages/pytensor/scalar/basic.py:3992, in ScalarInnerGraphOp._cleanup_graph(self, inputs, outputs)
3990 for node in fgraph.apply_nodes:
3991 if not isinstance(node.op, ScalarOp):
-> 3992 raise TypeError(
3993 f"The fgraph of {self.__class__.__name__} must be exclusively "
3994 "composed of scalar operations."
3995 )
3997 # Run MergeOptimization to avoid duplicated nodes
3998 MergeOptimizer().rewrite(fgraph)
TypeError: The fgraph of ScalarLoop must be exclusively composed of scalar operations.
The debugger shows that it tries to create a ScalarLoop Node with inputs that are Rank 0 Tensors instead of scalars:
> /home/adr/git/pymc-labs/red-cities/.pixi/envs/default/lib/python3.12/site-packages/pytensor/link/numba/dispatch/elemwise.py(357)numba_funcify_Elemwise()
355 if not isinstance(op.scalar_op, Composite):
356 scalar_inputs = [scalar(dtype=input.dtype) for input in node.inputs]
--> 357 scalar_node = op.scalar_op.make_node(*scalar_inputs)
358
359 scalar_op_fn = numba_funcify(
ipdb> p scalar_inputs
[<Scalar(int8, shape=())>, <Scalar(float64, shape=())>]
ipdb> p scalar_inputs[0].type
TensorType(int8, shape=())
ipdb> exit
Metadata
Metadata
Assignees
Labels
No labels