Skip to content

Numba dispatch of Elemwise of ScalarLoop is broken #1130

Closed
@aseyboldt

Description

@aseyboldt

Description

Originally reported here: pymc-devs/nutpie#163

Here is a smaller reproducer:

import pytensor
import pytensor.tensor as pt
import numpy as np

a = pytensor.scalar.get_scalar_type("float64")()
loop = pytensor.scalar.ScalarLoop([a], [pytensor.scalar.add(a, a)])
x = pt.tensor("x", shape=(3,))
elem = pt.elemwise.Elemwise(loop)(3, x)
elem.eval({x: np.ones(3)})
# Returns array([8., 8., 8.])

# But compiling fails:
func = pytensor.function([x], elem, mode="NUMBA")
File ~/git/pymc-labs/red-cities/.pixi/envs/default/lib/python3.12/site-packages/pytensor/link/numba/dispatch/elemwise.py:357, in numba_funcify_Elemwise(op, node, **kwargs)
    355 if not isinstance(op.scalar_op, Composite):
    356     scalar_inputs = [scalar(dtype=input.dtype) for input in node.inputs]
--> 357     scalar_node = op.scalar_op.make_node(*scalar_inputs)
    359 scalar_op_fn = numba_funcify(
    360     op.scalar_op,
    361     node=scalar_node,
   (...)
    364     **kwargs,
    365 )
    367 nin = len(node.inputs)

File ~/git/pymc-labs/red-cities/.pixi/envs/default/lib/python3.12/site-packages/pytensor/scalar/loop.py:180, in ScalarLoop.make_node(self, n_steps, *inputs)
    178 cloned_constant = cloned_inputs[len(cloned_update) :]
    179 # This will fail if the cloned init have a different dtype than the cloned_update
--> 180 op = ScalarLoop(
    181     init=cloned_init,
    182     update=cloned_update,
    183     constant=cloned_constant,
    184     until=cloned_until,
    185     name=self.name,
    186 )
    187 node = op.make_node(n_steps, *inputs)
    188 return node

File ~/git/pymc-labs/red-cities/.pixi/envs/default/lib/python3.12/site-packages/pytensor/scalar/loop.py:69, in ScalarLoop.__init__(self, init, update, constant, until, name)
     66     inputs, outputs = clone([*init, *constant], update)
     68 self.is_while = bool(until)
---> 69 self.inputs, self.outputs = self._cleanup_graph(inputs, outputs)
     70 self._validate_updates(self.inputs, self.outputs)
     72 self.inputs_type = tuple(input.type for input in self.inputs)

File ~/git/pymc-labs/red-cities/.pixi/envs/default/lib/python3.12/site-packages/pytensor/scalar/basic.py:3992, in ScalarInnerGraphOp._cleanup_graph(self, inputs, outputs)
   3990 for node in fgraph.apply_nodes:
   3991     if not isinstance(node.op, ScalarOp):
-> 3992         raise TypeError(
   3993             f"The fgraph of {self.__class__.__name__} must be exclusively "
   3994             "composed of scalar operations."
   3995         )
   3997 # Run MergeOptimization to avoid duplicated nodes
   3998 MergeOptimizer().rewrite(fgraph)

TypeError: The fgraph of ScalarLoop must be exclusively composed of scalar operations.

The debugger shows that it tries to create a ScalarLoop Node with inputs that are Rank 0 Tensors instead of scalars:

> /home/adr/git/pymc-labs/red-cities/.pixi/envs/default/lib/python3.12/site-packages/pytensor/link/numba/dispatch/elemwise.py(357)numba_funcify_Elemwise()
    355     if not isinstance(op.scalar_op, Composite):
    356         scalar_inputs = [scalar(dtype=input.dtype) for input in node.inputs]
--> 357         scalar_node = op.scalar_op.make_node(*scalar_inputs)
    358 
    359     scalar_op_fn = numba_funcify(

ipdb>  p scalar_inputs
[<Scalar(int8, shape=())>, <Scalar(float64, shape=())>]
ipdb>  p scalar_inputs[0].type
TensorType(int8, shape=())
ipdb>  exit

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions