Skip to content

Add New Metadata and Pattern Features #2271

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: main
Choose a base branch
from

Conversation

jsmonson
Copy link

@jsmonson jsmonson commented May 5, 2025

This PR will add new metadata and pattern features.

import onnx

from onnxscript import ir
from onnxscript import rewriter

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'rewriter' is not used.
Comment on lines +48 to +49
#else:
# vp_outputs = builder.__getattr__(node.op_type)(*ninputs)

Check notice

Code scanning / CodeQL

Commented-out code Note

This comment appears to contain commented-out code.

return ReplacementPatternGraph(g)

def get_iterator_index(self):

Check notice

Code scanning / CodeQL

Explicit returns mixed with implicit (fall through) returns Note

Mixing implicit and explicit returns may indicate an error as implicit returns always return None.
# Skip the (all) Activation inputs (have been swapped to beginning of the list)
for index in range(activations, len(nodes[0].inputs)):
inputs = []
producers = []

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable producers is not used.
cinput = LoopBody.function.inputs[index]
noutput = vdisconnect(copy.copy(cinput))
noutput._uses = {}
update_node_outputs = False

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable update_node_outputs is not used.
inputs.add(ninput)
elif any(ninput is init for init in node.graph.initializers):
initializers.add(ninput)
elif ninput.producer() == None:

Check notice

Code scanning / CodeQL

Testing equality to None Note

Testing for None should use the 'is' operator.
@@ -0,0 +1,190 @@
import pytest

Check notice

Code scanning / CodeQL

Unused import Note test

Import of 'pytest' is not used.
import ast

import onnx
from onnxscript import script

Check notice

Code scanning / CodeQL

Unused import Note test

Import of 'script' is not used.
Comment on lines +147 to +152
# for rule in tracer.best_matches_map:
# matches = tracer.best_matches_map[rule]
# for match in matches:
# print(f'Reason: {match.match_result.reason}')
# print(f'root_node: {match.root_node}')
# pdb.set_trace()

Check notice

Code scanning / CodeQL

Commented-out code Note test

This comment appears to contain commented-out code.
Comment on lines +166 to +168
# for node in ir.traversal.RecursiveGraphIterator(mypipeline_model.graph):
# if node.domain == '':
# print(node)

Check notice

Code scanning / CodeQL

Commented-out code Note test

This comment appears to contain commented-out code.
Copy link

codecov bot commented May 5, 2025

❌ 1 Tests Failed:

Tests completed Failed Passed Skipped
1755 1 1754 691
View the top 1 failed test(s) by shortest run time
onnxscript.utils.test_PytorchHierarchyNode::test_mistral_pytorch_with_metadata
Stack Traces | 0.001s run time
onnxscript/utils/test_PytorchHierarchyNode.py:165: in test_mistral_pytorch_with_metadata
    model_proto = onnx.load('.../finn_mlo_graphs/demo/mistral.onnx')
..../test_torch_nightly/lib/python3.11....../site-packages/onnx/__init__.py:226: in load_model
    model = _get_serializer(format, f).deserialize_proto(_load_bytes(f), ModelProto())
..../test_torch_nightly/lib/python3.11....../site-packages/onnx/__init__.py:163: in _load_bytes
    with open(f, "rb") as readable:
E   FileNotFoundError: [Errno 2] No such file or directory: '.../finn_mlo_graphs/demo/mistral.onnx'

To view more test analytics, go to the Test Analytics Dashboard
📋 Got 3 mins? Take this short survey to help us improve Test Analytics.

Copy link

@github-advanced-security github-advanced-security bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lintrunner found more than 20 potential problems in the proposed changes. Check the Files changed tab for more details.

@justinchuby justinchuby self-assigned this May 5, 2025


def is_initializer(value):
return input.producer() == None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When a value is a graph input, its producer is also None. I will add a method to make this check more robust

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

usage.add("EXTERNAL")
return usage

def find_subgraph_inputs(nodes):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jsmonson could you share how you use this function?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. It's for constructing the graph view

else:
return self.class_metadata[depth]

class PytorchHierarchyNode:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is useful! Let me see how we can provide the functionality to users.

Comment on lines +204 to +216
def append_output_to_node(node, output):
output._producer = node
output._index = node.outputs[-1]._index + 1
node._outputs = (*node._outputs, output)
node._num_outputs = len(node._outputs)

def prepend_output_to_node(node, output):
output._producer = node
output._index = 0
for outp in node._outputs:
outp._index += 1
node._outputs = (output, *node._outputs)
node._num_outputs = len(node._outputs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add this to node so there is no need to modify internal states

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact we recommend creating new nodes if you want to update inputs and outputs. This way invariance of the graph is maintained. You may consider accumulating the inputs and outputs into two lists before constructing the node.

Comment on lines +320 to +333
Ident = ir.Node(domain='',
op_type='Identity',
inputs = [cinput],
outputs = [noutput],
num_outputs =1)
LoopBody.function.append(Ident)

#Add Output to Function Call Nodes
for i,node in enumerate(nodes):
output_copy = copy.copy(noutput)

#preserve single_assignment
output_copy.name += f'_{i}'
append_output_to_node(node,output_copy)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to collect all outputs before constructing the node?

vmap[input] = ValuePattern(input.name)

for init in graph.initializers:
vmap[init] = ValuePattern(init.name)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like you want to map constants and initializers to unconstrained variables in the pattern? I wonder if it would make sense to map them to "Constants" in the pattern that require a matching contstant-value in the graph for a successful match? That makes reasonable, at least for simple and small constants. If it should be abstracted, wouldn't it be better for the user themselves to do that explicitly by mapping them to graph inputs?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Thanks for this suggestion. I think we should do this. This would provide a good level of control for the user.

ninputs.append(vmap[ninput])

#if len(node.outputs) > 1:
vp_outputs = builder.__getattr__(node.op_type)(*ninputs,_domain=node.domain, _outputs=len(node.outputs))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, the attributes of the node are abstracted away, and not matched?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. We ought to match the attributes. I'll make these changes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Development

Successfully merging this pull request may close these issues.

3 participants