-
Notifications
You must be signed in to change notification settings - Fork 136
Add name kwarg to Op.__call__ #693
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
bb12ed0
568334b
3e6379e
182b7ae
64c121a
1755294
31f51eb
1470b1b
183dd3f
1c99e01
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -289,7 +289,14 @@ def __call__(self, *inputs: Any, **kwargs) -> Variable | list[Variable]: | |||||||||||||||||||||||||
|
||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||
return_list = kwargs.pop("return_list", False) | ||||||||||||||||||||||||||
name = kwargs.pop("name", None) | ||||||||||||||||||||||||||
node = self.make_node(*inputs, **kwargs) | ||||||||||||||||||||||||||
if name is not None: | ||||||||||||||||||||||||||
if len(node.outputs) == 1: | ||||||||||||||||||||||||||
node.outputs[0].name = name | ||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||
for i, n in enumerate(node.outputs): | ||||||||||||||||||||||||||
n.name = f"{name}_{i}" | ||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
if config.compute_test_value != "off": | ||||||||||||||||||||||||||
compute_test_value(node) | ||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -232,3 +232,26 @@ def perform(self, *_): | |||||||||||||
|
||||||||||||||
x = pt.TensorType(dtype="float64", shape=(1,))("x") | ||||||||||||||
assert SomeOp()(x).type == pt.dvector | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
def test_op_name(): | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
x = pt.vector("x") | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use dummy test variables instead of pt.vector (the reason is this tests are in core abstract functionality). tensors ops and variables are specific implementations of these abstract objects |
||||||||||||||
y = pt.vector("y") | ||||||||||||||
op_name = "op_name" | ||||||||||||||
|
||||||||||||||
class MultiOutOp(Op): | ||||||||||||||
ricardoV94 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||
def make_node(self, *inputs): | ||||||||||||||
outputs = [pt.dmatrix(), pt.dmatrix()] | ||||||||||||||
return Apply(self, list(inputs), outputs) | ||||||||||||||
|
||||||||||||||
def perform(self, node, inputs, outputs): | ||||||||||||||
outputs[0] = pt.matrix() | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is wrong but not needed. You can replace contents of perform with raise NotImplementedError |
||||||||||||||
outputs[1] = pt.matrix() | ||||||||||||||
|
||||||||||||||
multi_op = MultiOutOp() | ||||||||||||||
res = multi_op(x, name=op_name) | ||||||||||||||
for i, r in enumerate(res): | ||||||||||||||
assert r.name == f"{op_name}_{i}" | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
|
||||||||||||||
z = pt.add(x, y, name=op_name) | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't use pt.add. Use a single output dummy Op |
||||||||||||||
assert z.name == op_name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of popping we can make it an explicit optional kwarg in the call signature (so it's actually discoverable). Same for
return_list
. No idea why they went for this implicit approach