diff --git a/pytensor/graph/op.py b/pytensor/graph/op.py index dd131cbdbc..160a65dd7a 100644 --- a/pytensor/graph/op.py +++ b/pytensor/graph/op.py @@ -246,7 +246,9 @@ def make_node(self, *inputs: Variable) -> Apply: ) return Apply(self, inputs, [o() for o in self.otypes]) - def __call__(self, *inputs: Any, **kwargs) -> Variable | list[Variable]: + def __call__( + self, *inputs: Any, name=None, return_list=False, **kwargs + ) -> Variable | list[Variable]: r"""Construct an `Apply` node using :meth:`Op.make_node` and return its outputs. This method is just a wrapper around :meth:`Op.make_node`. @@ -288,8 +290,15 @@ def __call__(self, *inputs: Any, **kwargs) -> Variable | list[Variable]: the :attr:`Op.default_output` property. """ - return_list = kwargs.pop("return_list", False) node = self.make_node(*inputs, **kwargs) + if name is not None: + if len(node.outputs) == 1: + node.outputs[0].name = name + elif self.default_output is not None: + node.outputs[self.default_output].name = name + else: + for i, n in enumerate(node.outputs): + n.name = f"{name}_{i}" if config.compute_test_value != "off": compute_test_value(node) diff --git a/tests/graph/test_op.py b/tests/graph/test_op.py index 59d81ad59e..73f612c2f5 100644 --- a/tests/graph/test_op.py +++ b/tests/graph/test_op.py @@ -232,3 +232,46 @@ def perform(self, *_): x = pt.TensorType(dtype="float64", shape=(1,))("x") assert SomeOp()(x).type == pt.dvector + + +@pytest.mark.parametrize("multi_output", [True, False]) +def test_call_name(multi_output): + def dummy_variable(name): + return Variable(MyType(thingy=None), None, None, name=name) + + x = dummy_variable("x") + + class TestCallOp(Op): + def __init__(self, default_output, multi_output): + super().__init__() + self.default_output = default_output + self.multi_output = multi_output + + def make_node(self, input): + inputs = [input] + if self.multi_output: + outputs = [input.type(), input.type()] + else: + outputs = [input.type()] + return Apply(self, inputs, outputs) + + def perform(self, node, inputs, outputs): + raise NotImplementedError() + + if multi_output: + multi_op = TestCallOp(default_output=None, multi_output=multi_output) + res = multi_op(x, name="test_name") + for i, r in enumerate(res): + assert r.name == f"test_name_{i}" + + multi_op = TestCallOp(default_output=1, multi_output=multi_output) + result = multi_op(x, name="test_name") + assert result.owner.outputs[0].name is None + assert result.name == "test_name" + else: + single_op = TestCallOp(default_output=None, multi_output=multi_output) + res_single = single_op(x, name="test_name") + assert res_single.name == "test_name" + + res_nameless = single_op(x) + assert res_nameless.name is None