-
Notifications
You must be signed in to change notification settings - Fork 135
Add name kwarg to Op.__call__ #693
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
pytensor/graph/op.py
Outdated
node = self.make_node(*inputs, **kwargs) | ||
if isinstance(node.outputs, list): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will always be a list.
More importantly I think we should only assign name to the default output (if there's a single output, that's the default).
If there's no default output we could call them "f{name}_{i}"
perhaps? What do you think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If my understanding is correct
If node.outputs
has a length of more than 1, then they should be called f"{name}_{i}"
Here i
would be the default_output
(which is specified in Multi-output Ops)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not quite what I was thinking. Multi-output nodes can have a default output, which is the only thing users usually see. We should pass name directly to either to the default output of a multi-output node or the only output of a single-output node.
But for multi-output nodes without a default output we shouldn't. For those I suggest adding a numerical suffix, 0 for the first, 1 for the second, and so on
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small tweaks/suggestion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also we need tests for multiple output with default output and not.
There are some dummy outputs we use to test basic functionality that doesn't require finding real ops.
Should be somewhere in tests/graph/....py
pytensor/graph/op.py
Outdated
if len(node.outputs) == 1: | ||
node.outputs[0].name = name | ||
else: | ||
for i, n in enumerate(node.outputs): | ||
n.name = f"{name}_{i}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if len(node.outputs) == 1: | |
node.outputs[0].name = name | |
else: | |
for i, n in enumerate(node.outputs): | |
n.name = f"{name}_{i}" | |
if len(node.outputs) == 1: | |
node.outputs[0].name = name | |
elif self.default_output is not None: | |
node.outputs[self.default_output].name = name | |
else: | |
for i, n in enumerate(node.outputs): | |
n.name = f"{name}_{i}" |
Also scan seems to have a special logic for the name kwarg (see failing tests). Have to check what's going on with that |
For scan some tests were failing because they were assigned names like |
tests/graph/test_op.py
Outdated
for i, r in enumerate(res): | ||
assert r.name == f"{op_name}_{i}" | ||
|
||
z = pt.add(x, y, name=op_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't use pt.add. Use a single output dummy Op
tests/graph/test_op.py
Outdated
return Apply(self, list(inputs), outputs) | ||
|
||
def perform(self, node, inputs, outputs): | ||
outputs[0] = pt.matrix() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is wrong but not needed. You can replace contents of perform with raise NotImplementedError
tests/graph/test_op.py
Outdated
|
||
|
||
def test_op_name(): | ||
x = pt.vector("x") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use dummy test variables instead of pt.vector (the reason is this tests are in core abstract functionality). tensors ops and variables are specific implementations of these abstract objects
pytensor/graph/op.py
Outdated
@@ -289,7 +289,14 @@ def __call__(self, *inputs: Any, **kwargs) -> Variable | list[Variable]: | |||
|
|||
""" | |||
return_list = kwargs.pop("return_list", False) | |||
name = kwargs.pop("name", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of popping we can make it an explicit optional kwarg in the call signature (so it's actually discoverable). Same for return_list
. No idea why they went for this implicit approach
tests/graph/test_op.py
Outdated
class DummyType(Type): | ||
def filter(self, data): | ||
return data | ||
|
||
def __eq__(self, other): | ||
return isinstance(other, DummyType) | ||
|
||
def __hash__(self): | ||
return hash(DummyType) | ||
|
||
def __repr__(self): | ||
return "DummyType()" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need a new dummy type? Can we reuse one from the existing ones for testing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The existing dummy type takes an argument called: thingy
I'm not sure if that should be reused (I assume it has a specific use)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be fine, it's a test Type, it takes an argument for other purposes but that shouldn't be problematic for us. Just pass None or whatever you want
tests/graph/test_op.py
Outdated
res_single = single_op(x, name=op_name) | ||
assert res_single.name == op_name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This may be more readable. Also the name we are giving is not so much an op_name
but a var_name
.
res_single = single_op(x, name=op_name) | |
assert res_single.name == op_name | |
res_single = single_op(x, name="test_name") | |
assert res_single.name == "test_name" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also test that by default name is None?
tests/graph/test_op.py
Outdated
def make_node(self, *inputs): | ||
outputs = [dummy_variable("a"), dummy_variable("b")] | ||
return Apply(self, list(inputs), outputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def make_node(self, *inputs): | |
outputs = [dummy_variable("a"), dummy_variable("b")] | |
return Apply(self, list(inputs), outputs) | |
def make_node(self, input): | |
inputs = [input] | |
outputs = [input.type(), input.type()] | |
return Apply(self, inputs, outputs) |
tests/graph/test_op.py
Outdated
res = multi_op(x, name=op_name) | ||
for i, r in enumerate(res): | ||
assert r.name == f"{op_name}_{i}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
res = multi_op(x, name=op_name) | |
for i, r in enumerate(res): | |
assert r.name == f"{op_name}_{i}" | |
res = multi_op(x, name="test_name") | |
for i, r in enumerate(res): | |
assert r.name == f"test_name_{i}" |
tests/graph/test_op.py
Outdated
multi_op = MultiOutOp() | ||
multi_op.default_output = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A bit of a nitpick, but more "realistic" if you make the default_output something defined when a class is initialized. Just need to define a __init__
method that assigns it to self after calling super().init.
multi_op = MultiOutOp() | |
multi_op.default_output = 1 | |
multi_op = MultiOutOp(default_output=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can also parametrize the number of outputs, with a multi_output=False|True
. That way you can use a single Test Op class, which reduces a bit the number lines of code in this test. After storing it in init, in made_node
you can do:
if self.multi_output:
outputs = [input.type(), input.type()]
else:
outputs = [input.type()]
tests/graph/test_op.py
Outdated
def make_node(self, *inputs): | ||
outputs = [dummy_variable("a")] | ||
return Apply(self, list(inputs), outputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Traditionally, an Op would not assign a default name in make_node
def make_node(self, *inputs): | |
outputs = [dummy_variable("a")] | |
return Apply(self, list(inputs), outputs) | |
def make_node(self, input): | |
inputs = [input] | |
outputs = [input.type()] | |
return Apply(self, inputs, outputs) |
Functionality looks great, just some test suggestions |
tests/graph/test_op.py
Outdated
@@ -232,3 +232,55 @@ def perform(self, *_): | |||
|
|||
x = pt.TensorType(dtype="float64", shape=(1,))("x") | |||
assert SomeOp()(x).type == pt.dvector | |||
|
|||
|
|||
def test_op_name(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def test_op_name(): | |
def test_call_name(): |
tests/graph/test_op.py
Outdated
multi_op = MultiOutOp() | ||
multi_op.default_output = 1 | ||
res = multi_op(x, name=op_name) | ||
assert res.name == op_name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test that res.owner.outputs
that are not the default_output
still have name is None
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #693 +/- ##
==========================================
- Coverage 80.83% 80.76% -0.08%
==========================================
Files 162 162
Lines 46830 46713 -117
Branches 11447 11426 -21
==========================================
- Hits 37857 37729 -128
- Misses 6710 6735 +25
+ Partials 2263 2249 -14
|
Thanks @HarshvirSandhu ! |
Description
Add
name
keyword inOp.__call__
allowing arbitraryOp
s to be given namesRelated Issue
name
keyword argument toOp.__call__
#685Checklist
Type of change