From d2df97c3797921e6248860bdc7e2c25264202f5b Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Thu, 27 Mar 2025 21:37:47 +0000 Subject: [PATCH 01/14] initial loop rolling import --- onnxscript/utils/graph_view_utils.py | 224 ++++++ onnxscript/utils/pattern_builder.py | 894 ++++++++++++++++++++++++ tests/loop_rolling/test_loop_rolling.py | 214 ++++++ 3 files changed, 1332 insertions(+) create mode 100644 onnxscript/utils/graph_view_utils.py create mode 100644 onnxscript/utils/pattern_builder.py create mode 100644 tests/loop_rolling/test_loop_rolling.py diff --git a/onnxscript/utils/graph_view_utils.py b/onnxscript/utils/graph_view_utils.py new file mode 100644 index 0000000000..73423fa99d --- /dev/null +++ b/onnxscript/utils/graph_view_utils.py @@ -0,0 +1,224 @@ +from onnxscript import ir +import onnx + +def is_initializer(value): + return input.producer() == None + +def gather_initializers(inputs): + inits = set() + for input in inputs: + if is_initializer(input): + inits.add(input) + +def is_constant(value): + return value.producer.op_type == 'Constant' + +def gather_constants(inputs): + consts = set() + for input in inputs: + if is_constant(input): + consts.add(input) + + +def has_internal_usage(usage): + return "INTERNAL" in usage + +def has_external_usage(usage): + return "EXTERNAL" in usage + +def classify_usage(value, nodes): + usage = set() + for use in value.uses(): + user_node = use[0] + if user_node in nodes: + usage.add("INTERNAL") + else: + usage.add("EXTERNAL") + return usage + +def find_subgraph_inputs(nodes): + inputs = set() + initializers = set() + for node in nodes: + print(f'{type(node)} {type(nodes)}') + print(f'{node}') + for ninput in node.inputs: + if ninput in node.graph.inputs: + inputs.add(ninput) + elif any(ninput is init for init in node.graph.initializers): + print("adding graph initializater") + initializers.add(ninput) + elif ninput.producer() == None: + print(f"adding none initializer: {ninput}") + inputs.add(ninput) + elif ninput.producer() not in nodes: + print(f"adding not in nodes: {ninput}") + if ninput.producer() in node.graph._nodes: + print(f'\tIn Graph node list') + print(f"\t {ninput.producer()}") + print(f"\t {ninput.producer().metadata_props}") + inputs.add(ninput) + + return inputs, initializers + +def find_subgraph_outputs(nodes): + output = set() + used_output = set() + for node in nodes: + for noutput in node.outputs: + usage = classify_usage(noutput, nodes) + if has_external_usage(usage): + if has_internal_usage(usage): + used_output.add(noutput) + else: + output.add(noutput) + return [output, used_output] + + +def bGraphView(name, nodes): + + + [view_inputs, view_initializers] = find_subgraph_inputs(nodes) + [view_outputs, used_outputs] = find_subgraph_outputs(nodes) + + for used_output in used_outputs: + producer_node = used_output.producer() + nodes.remove(producer_node) + for output in producer_node.outputs: + usage = classify_usage(output,nodes) + if has_internal_usage(usage): + view_inputs.add(output) + if has_external_usage(usage): + if output in view_outputs: + view_outputs.remove(output) + + return ir.GraphView(name=name, + inputs=view_inputs, + outputs=view_outputs, + nodes=nodes, + initializers=view_initializers) + +######################################## +# rebuild_pytorch_dynamo_instance_code # +######################################## + +################################################## +## TODO (JSM): encapsulte this into a function ## +################################################## + +# model = onnx.load('mistral.onnx') + +# model_ir = ir.serde.deserialize_model(model) + +# layer_dict = {} + +# no_name_scopes = set() +# for node in ir.traversal.RecursiveGraphIterator(model_ir.graph): +# if 'pkg.torch.onnx.name_scopes' in node.metadata_props: +# name_scopes = ast.literal_eval(node.metadata_props['pkg.torch.onnx.name_scopes']) +# if name_scopes[1].startswith('layer'): +# if name_scopes[1] not in layer_dict: +# layer_dict[name_scopes[1]] = [] +# layer_dict[name_scopes[1]].append(node) +# else: +# print(node) +# else: +# no_name_scopes.add(node) + +# scoped_nodes = set() +# stop = False +# for node in no_name_scopes: +# #input('pause for enter') +# print(node) +# layer_usage = set() +# for value in node.outputs: +# print(f"\t{value}") +# print(f"\t{node.name}") +# if value.name == 'val_39': +# stop = True +# input('found val_39') +# for use in value.uses(): +# used_node = use[0] +# print(f"\t\t{used_node}") +# if 'pkg.torch.onnx.name_scopes' in used_node.metadata_props: +# print(f"\t\t\t{used_node.metadata_props['pkg.torch.onnx.name_scopes']}") +# name_scopes = ast.literal_eval(used_node.metadata_props['pkg.torch.onnx.name_scopes']) +# layer_usage.add(name_scopes[1]) +# else: +# print(f"\t\t\tno scope") +# layer_usage.add("") +# print(f'\t\tlayer usage {layer_usage}') +# if len(layer_usage) == 1: +# scope = next(iter(layer_usage)) +# print(scope) +# print(layer_dict.keys()) +# print(node) +# if scope in layer_dict: +# layer_dict[scope].insert(0, node) +# scoped_nodes.add(node) +# if stop == True: +# pass + +# for key in layer_dict: +# print(key+"********") +# layer = bGraphView(key, layer_dict[key]) + +# print(key+"STARTIO********") +# print(f"inputs: {layer.inputs}") +# print(f"outputs: {layer.outputs}") +# print(key+"ENDIO********") +# print("\n\n") + +# #exit(1) +# #layer0_gv = bGraphView('layers.0', layer_dict['layers.0']) +# layer1_gv = bGraphView('layers_1',layer_dict['layers.1']) + +# print(f"inputs: {layer1_gv.inputs}") +# print(f"outputs: {layer1_gv.outputs}") +# #print(f"initializers: {.initializers}") + +# #exit(1) + +# layer0 = layer_dict['layers.0'] +# layer1 = layer_dict['layers.1'] + +# d = {} + +# for a in layer0: +# d[a] = {a} +# for b in layer1: +# if ir_node__eq__(a, b): +# d[a].add(b) + +# for node in d: +# if len(d[node]) == 1: +# print(f"single node set: {node}") +# print(f"value: {node.outputs[0]}") +# #print(f"uses: {node.outputs[0].uses()}") +# print(f"len: {len(node.outputs[0].uses())}") +# for use in node.outputs[0].uses(): +# print(str(use)+'\n\n') +# #if len(d[node]) > 2: +# # print(f"set with {len(d[node])} nodes") +# # for n in d[node]: +# # print(f'\t{n}') +# # input('press n to continue') + +# #print(model.graph) + +# layer = bGraphView('layers_0', layer_dict['layers.0']) + +# for node in layer._nodes: +# if node.name == 'node_Constant_2143': +# print(node) +# input('found it!') + + +# model = ir.Model(layer, ir_version=10) +# proto = ir.serde.serialize_model(model) +# print('saving') +# onnx.save(proto, 'layer0.onnx') + + +# print('done') + diff --git a/onnxscript/utils/pattern_builder.py b/onnxscript/utils/pattern_builder.py new file mode 100644 index 0000000000..7e9f429e94 --- /dev/null +++ b/onnxscript/utils/pattern_builder.py @@ -0,0 +1,894 @@ +import ast +import astor +import copy + +import numpy as np +import onnx + +import pdb +from typing import Callable + +from onnxscript import ir, script, INT64, BOOL +from onnxscript.rewriter import pattern, rewrite +from onnxscript.rewriter.pattern import ValuePattern, GraphPattern, OpsetPatternBuilder, pattern_builder, NodeOutputPattern + + + +from collections.abc import Iterable +from onnxscript.utils.graph_view_utils import bGraphView + + +def bAstModule(body=[]): + return ast.Module(body=body) + +def bAstArguments(posonlyargs=[], args=[], vararg=None, + kwonlyargs=None, kw_defaults=None, kwarg=None, + defaults=[]): + return ast.arguments(posonlyargs, args, vararg, + kwonlyargs, kw_defaults, kwarg, + defaults) + +def bAstFunctionDef(name = '', args=bAstArguments(), body=[], + decorator_list=[], returns=None, type_comment='', + type_params=[]): + f_def = ast.FunctionDef(name, args, body, decorator_list, + returns, type_comment, type_params=type_params) + return f_def + +def bAstArg(value): + if isinstance(value, ir.Value): + return ast.arg(arg=value.name) + else: + return ast.arg(arg=value) + +def bAstName(value): + if isinstance(value, ir.Value): + return ast.Name(id=value.name, ctx=ast.Load()) + else: + return ast.Name(id=value, ctx=ast.Load()) + +def bAstList(values, bfunc): + return [bfunc(x) for x in values] + +def bAstArgList(values): + return bAstList(values,bAstArg) + +def bAstNameList(values): + return bAstList(values,bAstName) + + +def bAstOpAttr(node, opAttr='op'): + return ast.Attribute(value = ast.Name(id=opAttr, ctx=ast.Load()), + attr = node.op_type, + ctx = ast.Load()) + +def bAstCall(func, args, keywords=[]): + return ast.Call(func=func, args=args, keywords=keywords) + +def bAstAssign(node, opAttr='op'): + keywords = [] + for attribute in node.attributes: + print(attribute) + keywords.append(ast.keyword(attribute, ast.Constant(0))) + return ast.Assign(targets = bAstNameList(node.outputs), + value = bAstCall(bAstOpAttr(node, opAttr), bAstNameList(node.inputs), keywords=keywords)) + + +# def bAstTensorProto(tensor): +# attr0 = ast.Attribute(value=ast.Name(id='onnx',ctx=ast.Load()), +# attr = 'helper', +# ctx = ast.Load()) +# attr1 = ast.Attribute(value=attr1, +# attr = 'make_tensor', +# ctx = ast.Load()) +# call = bAstCall(func = attr1, args = , key + + +#const_value = onnx.helper.make_tensor("tensor", onnx.TensorProto.INT64, (3,), [3, 384, 384]) + +def bAstReturn(value): + if isinstance(value, ir.Value): + value = bAstName(value) + + return ast.Return(value=value) + +def bAstModule(): + return ast.Module(body=[]) + +def bIrValueList(names): + return [ir.Value(name=name) for name in names] + +def build_pattern(tree, name): + code = 'from onnxscript import script, opset18, FLOAT, BOOL\n' + code += 'import onnx\n' + code += 'from typing import Tuple\n' + code += astor.to_source(tree) + code += '\nfunc = ' + name + #print(code) + scope = {} + exec(code, scope) + return scope['func'] + + +import pdb +def direct_convert_ir_graph_to_pattern(graph): + + + # Transform IR values to ValuePatterns + vmap = {} + for input in graph.inputs: + vmap[input] = ValuePattern(input.name) + + for init in graph.initializers: + print(init) + vmap[init] = ValuePattern(init.name) + + + for node in graph._nodes: + if node.op_type == 'Constant': + vmap[node.outputs[0]] = ValuePattern(node.outputs[0].name) + + builder = OpsetPatternBuilder("", record=True) + + with pattern_builder(builder): + for node in graph._nodes: + ninputs = [] + for ninput in node.inputs: + ninputs.append(vmap[ninput]) + + #if len(node.outputs) > 1: + vp_outputs = builder.__getattr__(node.op_type)(*ninputs,_domain=node.domain, _outputs=len(node.outputs)) + #else: + # vp_outputs = builder.__getattr__(node.op_type)(*ninputs) + + + if isinstance(vp_outputs,NodeOutputPattern): + vp_outputs = [vp_outputs] + + for vp_output in iter(vp_outputs): + vmap[node.outputs[vp_output.output_index]] = vp_output + + + pinputs = [] + for input in graph.inputs: + pinputs.append(vmap[input]) + + # build graph outputs + poutputs = [] + for output in graph.outputs: + poutputs.append(vmap[output]) + + return GraphPattern(inputs=pinputs, outputs=poutputs, nodes=builder.nodes()) + + + + +# def convert_graph_to_pattern(graph): + +# #Build Function Definition +# f_def = bAstFunctionDef(name=graph.name) +# #print(f_def) +# # Add Arguments +# f_def.args.args.append(ast.arg(arg='op')) +# f_def.args.args.extend(bAstArgList(graph.inputs)) + +# # Build Body +# for node in ir.traversal.RecursiveGraphIterator(graph): +# f_def.body.append(bAstAssign(node)) + +# # Add Return Statement +# if len(graph.outputs) == 1: +# f_def.body.append(bAstReturn(graph.outputs[0])) +# else: +# print("only one output supported for now") +# return None + +# return build_pattern(f_def, graph.name) + + +from enum import Enum + +def remove_input_from_node(node, inp): + node._inputs = [x for x in node._inputs if x is not inp] + inp._remove_usage(node) + + +class LoopBodyInputType(Enum): + UNDEFINED = 0 + ACTIVATION = 1 + CONSTANT = 2 + PARAMETER = 3 + ITERATOR = 4 + CONDITION = 5 + + def __str__(self): + return self.name + +class LoopBodyTemplate: + def __init__(self, filename): + self.load(filename) + self.pattern = direct_convert_ir_graph_to_pattern(self._ir_graph) + self.function = self._build_ir_function() + self.function_replace = self._build_function_replace_pattern() + self.signature = [LoopBodyInputType.UNDEFINED] * len(self._ir_graph.inputs) + + def _build_ir_function(self): + return ir.Function(domain='loop', + name='fn_' + self._ir_graph.name, + graph = self._ir_graph, + attributes=[]) + + def _build_function_replace_pattern(self): + + inputs = [vdisconnect(copy.copy(x)) for x in self._ir_graph.inputs] + outputs = [vdisconnect(copy.copy(x)) for x in self._ir_graph.outputs] + + print(type(self.function.domain)) + node = ir.Node(domain=self.function.domain, + version=0, + op_type=self.function.name, + inputs=inputs, outputs=outputs) + + g = ir.Graph(inputs=inputs, outputs=outputs, nodes=[node]) + + return ReplacementPatternGraph(g) + + def get_iterator_index(self): + for i in range(len(self.signature)): + if self.signature[i] == LoopBodyInputType.ITERATOR: + return i + + def insert_gather_nodes(self, loop_iterations): + for index,LoopInputType in enumerate(self.signature): + if LoopInputType == LoopBodyInputType.PARAMETER: + # The Current Input Value will be the output of the gather node + gather_index = self.function.inputs[self.get_iterator_index()] + squeeze_out = self.function.inputs[index] + + gather_in = ir.Value(name=squeeze_out.name+"_gather_in", + shape=ir.Shape([loop_iterations, *squeeze_out.shape.dims]), + type=squeeze_out.type) + gather_out = ir.Value(name=squeeze_out.name+"_gather_out", + shape=ir.Shape([1, *squeeze_out.shape.dims]), + type=squeeze_out.type) + for usage in squeeze_out.uses(): + if usage.node.op_type == "Identity": + usage.node.replace_input_with(usage.idx, gather_in) + usage.node.outputs[0].shape = copy.copy(gather_in.shape) + + self.function.inputs[index] = gather_in + + self.function.append(ir.Node(domain='', + op_type='Gather', + inputs = [gather_in, gather_index], + outputs = [gather_out], + num_outputs = 1 + ) + ) + squeeze_out.name += "_squeeze_out" + squeeze_axis = build_constant_from_tensor(f'{gather_out.name}_squeeze_axis', ir.Tensor(np.array([0]))) + self.function.append(squeeze_axis) + self.function.append(ir.Node(domain='', + op_type='Squeeze', + inputs=[gather_out, squeeze_axis.outputs[0]], + outputs=[squeeze_out], + num_outputs= 1, + version = 13 + ) + ) + + self.function.sort() + + + def build_function_match_pattern(self, graph): + graph.sort() + print(self.function.name) + nodes = find_nodes_of_optype(graph, self.function.name) + nodes.insert(0,graph.node('iteration_ext')) + nodes.insert(0,graph.node('condition_ext')) + print(f'found {len(nodes)}') + + ir_model = ir.Model(bGraphView('inlined_pipe_pattern', nodes), ir_version=10) + + model = ir.serde.serialize_model(ir_model) + onnx.save(model, 'pipeline_match_pattern.onnx') + + pattern = direct_convert_ir_graph_to_pattern(ir_model.graph) + + return (pattern, nodes) + + + def load(self, filename): + self._model_proto = onnx.load(filename) + self._ir_model = ir.serde.deserialize_model(self._model_proto) + self._ir_graph = self._ir_model.graph + + def update(self): + self._ir_model = ir.Model(self._ir_graph, ir_version = 10) + self._model_proto = ir.serde.serialize_model(self._ir_model) + + def save(self, filename): + self.update() + onnx.save(self._model_proto, filename) + + def set_signature_index(self, index, stype): + self.signature[index] = stype + + @property + def output_signature(self): + # The output signature is the same as the input signature but without the iteration input + return self.signature[1:] + + + +# class LoopBodyBuilderInfo: +# def __init__(self, nodes): +# self._nodes = nodes +# self._input_types = [UNDEFINED] * len(self._nodes.inputs()) + +# def _validate_nodes_ + + +# def size(self): +# return len(self._input_types) + +# def set_type(self, index, input_type): +# self._input_types[index] = input_type + +def same(input_list): + return len(set(input_list)) == 1 + +def append_output_to_node(node, output): + output._producer = node + output._index = node.outputs[-1]._index + 1 + node._outputs = (*node._outputs, output) + node._num_outputs = len(node._outputs) + +def prepend_output_to_node(node, output): + output._producer = node + output._index = 0 + for outp in node._outputs: + outp._index += 1 + node._outputs = (output, *node._outputs) + node._num_outputs = len(node._outputs) + +def prepend_input_to_node(node, input): + input._add_usage(node, 0) + # increment the index for all uses on this node + for i,inp in enumerate(node._inputs): + inp._remove_usage(node, i) + inp._add_usage(node, i+1) + + node._inputs = (input, *node._inputs) + +def normalize_io_for_loop_rolling(graph, LoopBody): + + # This takes a graph that has consecutive identical nodes + # and normalizes the i/o indicies prior to the loop rolling + # optimization. Specifically, this function identifies output- + # to-input pairs and permutes the indices of the input to match + # the previous output. + + # The ONNX loop node requires that interloop dependencies + # have identical input and output indices. + + # Run a topological sort to ensure the layers are in order. + graph.sort() + + # get the consecutive node layers + # TODO: write a check to ensure that there is only one + # set of consecutive nodes. + nodes = find_nodes_of_optype(graph, LoopBody.function.name) + + # Loop through all the nodes (execept the last one) and + # identify the input to output pairs + input_swaps = [] + for i in range(len(nodes)-1): + a_node = nodes[i] + b_node = nodes[i+1] + + for a_out in a_node.outputs: + # Require that outputs of a have a single use of b_node + assert(len(a_out.uses()) == 1) + assert(a_out.uses()[0][0] is b_node) + + a_use_index = a_out.uses()[0][1] + input_swap = (a_out.index(), a_use_index) + if i == 0: + # add swaps from the first node + input_swaps.append(input_swap) + else: + # check that they are the same in the rest + assert(input_swap in input_swaps) + + # apply the input swaps to each nodes + for node in nodes: + for swap in input_swaps: + a = node.inputs[swap[0]] + b = node.inputs[swap[1]] + node.replace_input_with(swap[0], b) + node.replace_input_with(swap[1], a) + + # apply the input swaps to the function graph + # mark swapped nodes as activations + activations = 0 + for swap in input_swaps: + a = LoopBody.function.inputs[swap[0]] + b = LoopBody.function.inputs[swap[1]] + LoopBody.function.inputs[swap[0]] = b + LoopBody.function.inputs[swap[1]] = a + LoopBody.signature[swap[0]] = LoopBodyInputType.ACTIVATION + activations+=1 + #print(f"Added Activation: {swap[0]}") + + # Next Inputs according to how they are produced. + # Indexable inputs will have different constant or none producers + # Constant values broadcast to all nodes will have the same producer + # Skip the (all) Activation inputs (have been swapped to beginning of the list) + for index in range(activations, len(nodes[0].inputs)): + inputs = [] + producers = [] + for node in nodes: + cinput = node.inputs[index] + inputs.append(cinput) + if same(inputs): + # Constant with Respect to Loop + #print(f"Added Constant: {index}") + LoopBody.signature[index] = LoopBodyInputType.CONSTANT + else: + # Must be Indexed + #print(f"Added Parameter: {index}") + LoopBody.signature[index] = LoopBodyInputType.PARAMETER + + + # Match Output Signature to Input Signature + for index,LoopInputType in enumerate(LoopBody.signature): + cinput = LoopBody.function.inputs[index] + noutput = vdisconnect(copy.copy(cinput)) + noutput._uses = {} + update_node_outputs = False + + if LoopInputType == LoopBodyInputType.CONSTANT or \ + LoopInputType == LoopBodyInputType.PARAMETER: + # Update Names and Add Output + cinput.name += "_" + str(LoopInputType) + "_in" + noutput.name += "_" + str(LoopInputType) + "_out" + + LoopBody.function.outputs.append(noutput) + + # Add Identify to Pass Inputs to new Output + Ident = ir.Node(domain='', + op_type='Identity', + inputs = [cinput], + outputs = [noutput], + num_outputs =1) + LoopBody.function.append(Ident) + + #Add Output to Function Call Nodes + for i,node in enumerate(nodes): + output_copy = copy.copy(noutput) + + #preserve single_assignment + output_copy.name += f'_{i}' + append_output_to_node(node,output_copy) + + # Add Iterator and Condition Inputs (Leave Unconnected within function for now) + iteration = ir.Value(name='iteration', shape=ir.Shape([1]), type=ir.TensorType(ir.DataType.INT64)) + condition = ir.Value(name='cond', shape=ir.Shape([1]), type=ir.TensorType(ir.DataType.BOOL)) + + LoopBody.signature.insert(0, LoopBodyInputType.CONDITION) + LoopBody.function.inputs.insert(0,condition) + + LoopBody.signature.insert(0, LoopBodyInputType.ITERATOR) + LoopBody.function.inputs.insert(0,iteration) + + + iteration_ext = build_constant_from_tensor('iteration_ext', ir.Tensor(np.array([len(nodes)]))) + condition_ext = build_constant_from_tensor('condition_ext', ir.Tensor(np.array([True]))) + #iteration_ext = ir.Value(name='iteration', shape=ir.Shape([1]), type=ir.TensorType(ir.DataType.INT64)) + #condition_ext = ir.Value(name='cond_ext', shape=ir.Shape([1]), type=ir.TensorType(ir.DataType.BOOL)) + + # add these nodes to the graph before the rest to maintain topological sorted-ness + nodes[0].prepend([iteration_ext,condition_ext]) + + for node in nodes: + prepend_input_to_node(node, condition_ext.outputs[0]) + prepend_input_to_node(node, iteration_ext.outputs[0]) + + # Add Identity Node for Condition + condition_out = ir.Value(name='cond_out', type=ir.TensorType(ir.DataType.BOOL)) + condition_ident = ir.Node(domain='', + op_type='Identity', + inputs = [condition], + outputs = [condition_out], + num_outputs =1) + LoopBody.function.append(condition_ident) + LoopBody.function.outputs.insert(0,condition_out) + + # Add New Condition Output to Node + for i,node in enumerate(nodes): + noutput = ir.Value(name=f'cond_out_{i}', type=ir.TensorType(ir.DataType.BOOL)) + prepend_output_to_node(node,noutput) + + + # Add External Constants to Drive These + # M_value = ir.Attr(name='value', type=ir.AttributeType.TENSOR, value=ir.Tensor(np.array([number_of_layers]))) + # M = ir.Node('', 'Constant', inputs=[], outputs=[loop_node.inputs[0]], attributes=[M_value], graph=model_graph) + + # cond_value = ir.Attr(name='value', type=ir.AttributeType.TENSOR, value=ir.Tensor(np.array([True]))) + # cond = ir.Node('', 'Constant', inputs=[], outputs=[loop_node.inputs[1]],attributes=[cond_value], graph=model_graph) + + # model_graph.insert_before(model_graph[0], [cond]) + # model_graph.insert_before(model_graph[0], [M]) + + # Add Gather Nodes + # for index,LoopInputType in enumerate(LoopBody.signature): +# if LoopInputType == LoopBodyInputType.PARAMETER: +# # Update the Shape + + +# gvo = copy.copy(indexed_input) +# gvo.name = gvo.name + '_gather_out' +# for node in indexed_input.consumers(): +# for i,inp in enumerate(node.inputs): +# if indexed_input is inp: +# node.replace_input_with(i, gvo) +# graph.append(ir.Node(domain='', +# op_type='Gather', +# inputs = [nv, viter], +# outputs = [gvo], +# num_outputs = 1, +# graph = graph +# )) + graph.sort() + return graph + + + +# def find_parameter_input_indexes(graph, layername): + +# nodes = find_nodes_of_optype(graph, layername) + +# if len(nodes) == 0: +# raise Exception("Did not find node with name {layername}") + +# indexes = [] +# for i in range(len(nodes[0].inputs)): +# value = nodes[0].inputs[i] +# print(f'node input[{i}] {value} {value.producer()}') +# if value.producer() == None: # is an initializer/weight value (unless an input) +# print(f"appending {value} at index {i}") +# indexes.append(i) +# elif value.producer().op_type == 'Constant': +# print("found constant optype") +# same_value = True +# for node in nodes: +# if node.inputs[i] is not value: +# same_value = False +# if not same_value: +# print(f"appending constant {value} at index {i}") +# indexes.append(i) + +# return indexes + + + +def build_loop_body(graph, number_of_layers, parameter_indexes = []): + + # for i,inp in enumerate(graph.inputs): + # print(f"graph input: {i} {inp}") + + # print(f'parameter indexes: {parameter_indexes}') + + + # viter = ir.Value(name='iteration', shape=ir.Shape([1]), type=ir.TensorType(ir.DataType.INT64)) + # vcond_in = ir.Value(name='cond', shape=ir.Shape([1]), type=ir.TensorType(ir.DataType.BOOL)) + + # for index in parameter_indexes: + # indexed_input = graph.inputs[index] + + # nv = copy.copy(indexed_input) + # nv.shape = ir.Shape((number_of_layers, *nv.shape.dims)) + # nv.name = nv.name + '_array_in' + # graph.inputs[index] = nv + + # nvo = copy.copy(indexed_input) + # nvo.shape = ir.Shape((nv.shape.dims)) + # nvo.name = nvo.name + '_array_out' + # graph.outputs.append(nvo) + + # graph.append(ir.Node(domain='', op_type='Identity', + # inputs = [nv], + # outputs = [nvo], + # num_outputs =1)) + + # for index in + # gvo = copy.copy(indexed_input) + # gvo.name = gvo.name + '_gather_out' + # for node in indexed_input.consumers(): + # for i,inp in enumerate(node.inputs): + # if indexed_input is inp: + # node.replace_input_with(i, gvo) + # graph.append(ir.Node(domain='', + # op_type='Gather', + # inputs = [nv, viter], + # outputs = [gvo], + # num_outputs = 1, + # graph = graph + # )) + + # print(f'indexed input: {indexed_input} {gvo.producer}') + # print(f'usages: {gvo.uses()}') + # print(f'const_value: {gvo.const_value}') + # print(f'type: {gvo.type}') + # print(f'meta: {gvo.meta}') + # print(f'initializers: {graph.initializers}') + #exit(1) + #new_values.append(nv) + + + + graph.inputs.insert(0,vcond_in) + graph.inputs.insert(0,viter) + + + vcond_out = ir.Value(name='cond_out', type=ir.TensorType(ir.DataType.BOOL)) + vcond_ident = ir.Node(domain='', op_type='Identity', + inputs = [vcond_in], + outputs = [vcond_out], + num_outputs =1, + graph = graph) + graph.append(vcond_ident) + + #ovalue = graph.outputs[0] + #graph.outputs[0] = (ovalue, vcond_out) + + graph.outputs.insert(0,vcond_out) + #print(graph) + print(f"*** graph.outputs = {graph.outputs}") + # for i in range(len(indexed_values)): + # graph.append(ir.Node(domain='', + # op_type='Gather', + # inputs = [new_values[i], viter], + # outputs = [indexed_values[i]], + # num_outputs = 1, + # graph = graph + # )) + + graph.sort() + #print(indexed_values) + #print(graph.inputs) + #print(graph) + + return graph + +import copy + +def vdisconnect(value): + value._uses = {} + value._producer = None + value._index = None + return value + + +class ReplacementPatternGraph(pattern.ReplacementPatternFunction): + def __init__(self, ir_graph): + self._graph = ir_graph + + + def get_replacement(self, match: pattern.MatchResult) -> pattern.ReplacementSubgraph | None: + print(f"in get replacement!!!!") + print(f"match pattern full: {match.outputs}") + + context = pattern.RewriterContext() + # match.bindings is dictionary of value_name (str) in replacement subgraph pattern (i.e. ir_graph -> IR Value in actual graph) + vvmap = {} # Dictionary mapping values in replacement subgraph pattern -> values in the replacement subgraph + + #print(f'self._graph: {self._graph}') + #print(f'len inputs: {self._graph.inputs}') + #print(f'match bindings: {match.bindings}') + for value in self._graph.inputs: + if value.name in match.bindings: + #print(f'***bound input value name: {value.name}') + vvmap[value] = match.bindings[value.name] + else: + #print(f'***unbound input value name: {value.name}') + #print(f'***unbound value : {value}') + vvmap[value] = value + + for node in self._graph._nodes: + ninputs = [] + print(node) + for ninput in node.inputs: + #print(f'{ninput}') + ninputs.append(vvmap[ninput]) + + + coutput = context.__getattr__(node.op_type)(*ninputs, **node.attributes, _outputs=len(node.outputs), _domain=node.domain, _version=node.version) + print(f"coutput type: {type(coutput)}") + print(f"coutput type: {coutput}") + if not isinstance(coutput,Iterable): + print("Not an Iterable") + coutput = [coutput] + + for i, cout in enumerate(coutput): + print(f"listing outputs: {node.outputs[cout.index()]}") + print(f"cout {cout.index()} {cout} ") + vvmap[node.outputs[cout.index()]] = cout + + for x in vvmap: + print(x) + print(f"context nodes") + for node in context.nodes: + print(node) + new_outputs = [vvmap[x] for x in self._graph.outputs] + print(f'new_outputs={new_outputs}') + print(f'_graph.outputs={self._graph.outputs}') + return pattern.ReplacementSubgraph( + match, new_outputs, context.nodes, context.initializers, context.used_opsets + ) + + +def convert_graph_to_function_call_pattern(graph): + + inputs = [vdisconnect(copy.copy(x)) for x in graph.inputs] + outputs = [vdisconnect(copy.copy(x)) for x in graph.outputs] + + node = ir.Node('', graph.name+'_fcall', inputs, outputs=outputs) + + g = ir.Graph(inputs=inputs, outputs=outputs, nodes=[node]) + + + return ReplacementPatternGraph(g) + + +def find_nodes_of_optype(graph, layername): + nodes = [] + for node in ir.traversal.RecursiveGraphIterator(graph): + #print(f'node op_type: {node.op_type} {layername}') + if node.op_type == layername: + nodes.append(node) + return nodes + + +def build_layer_pipeline_pattern(graph, layername): + + nodes = find_nodes_of_optype(graph, layername) + ir_model = ir.Model(bGraphView('inlined_pipe_pattern', nodes), ir_version=10) + + model = ir.serde.serialize_model(ir_model) + onnx.save(model, 'pipeline_match_pattern.onnx') + + pattern = direct_convert_ir_graph_to_pattern(ir_model.graph) + + return (pattern, nodes) + +def connect_loop_constants(model_graph, layer_nodes): + number_of_layers = len(layer_nodes) + + loop_node = find_nodes_of_optype(model_graph,'Loop')[0] + + + #r.TensorType(ir.DataType.INT64), const_value=ir.Tensor(np.array([number_of_layers]))) + M_value = ir.Attr(name='value', type=ir.AttributeType.TENSOR, value=ir.Tensor(np.array([number_of_layers]))) + M = ir.Node('', 'Constant', inputs=[], outputs=[loop_node.inputs[0]], attributes=[M_value], graph=model_graph) + cond_value = ir.Attr(name='value', type=ir.AttributeType.TENSOR, value=ir.Tensor(np.array([True]))) + cond = ir.Node('', 'Constant', inputs=[], outputs=[loop_node.inputs[1]],attributes=[cond_value], graph=model_graph) + model_graph.insert_before(model_graph[0], [cond]) + model_graph.insert_before(model_graph[0], [M]) + return model_graph + +def build_constant_from_tensor(name, tensor): + value_attribute = ir.Attr(name='value', type=ir.AttributeType.TENSOR, value=tensor) + ir_value_out = ir.Value(name=name+'_out', type=ir.TensorType(tensor.dtype)) + return ir.Node('', 'Constant', name=name, inputs=[], outputs=[ir_value_out], attributes=[value_attribute]) + +def build_concat_node_from_inputs(inputs): + + axis = ir.Attr(name='axis', type=ir.AttributeType.INT, value=0) + ndim = len(inputs) * inputs[0].shape.dims[0] + output_shape = ir.Shape([ndim, *inputs[0].shape.dims[1:]]) + output = ir.Value(name=f'{inputs[0].name}_concat', shape=output_shape, type=inputs[0].type) + return ir.Node('', 'Concat', inputs=inputs, attributes=[axis], outputs=[output]) + +def build_reshape_node(inp, reshape_shape): + reshape_out = ir.Value(name=f'{inp.name}_reshape', type=inp.type) + return ir.Node('', 'Reshape', inputs=[inp, reshape_shape], outputs=[reshape_out]) + + + +def build_loop_replace_pattern(graph, LoopBody): + + nodes = find_nodes_of_optype(graph, LoopBody.function.name) + + graph_nodes = [] + loop_inputs = [] + + # Build max_iteration and condition constants + M = build_constant_from_tensor('M', ir.Tensor(np.array([len(nodes)]))) + cond = build_constant_from_tensor('cond', ir.Tensor(np.array([True]))) + + graph_nodes.append(M) + graph_nodes.append(cond) + + loop_inputs.append(M.outputs[0]) + loop_inputs.append(cond.outputs[0]) + + graph_inputs = [] + print(LoopBody.signature) + for i, LoopInputType in enumerate(LoopBody.signature): + + if LoopInputType == LoopBodyInputType.PARAMETER: + # Build Concat Node + concat_inputs = [] + for node in nodes: + print(f"node.optype: {node}") + print(f"node inputs:{i} {node.inputs}") + nvalue = vdisconnect(copy.copy(node.inputs[i])) + graph_inputs.append(nvalue) + concat_inputs.append(nvalue) + + concat_node = build_concat_node_from_inputs(concat_inputs) + graph_nodes.append(concat_node) + + # Build Reshape Node + reshape_shape_const = build_constant_from_tensor(f'reshape_shape_const_{i}', ir.Tensor(np.array([len(nodes),*concat_inputs[0].shape.dims]))) + graph_nodes.append(reshape_shape_const) + + reshape_node = build_reshape_node(concat_node.outputs[0], reshape_shape_const.outputs[0]) + graph_nodes.append(reshape_node) + loop_inputs.append(reshape_node.outputs[0]) + elif LoopInputType == LoopBodyInputType.CONSTANT: + constant_input = nodes[0].inputs[i] + constant_node = constant_input.producer() + constant_value = constant_node.attributes['value'].value.numpy() + # n_constant_input = ir.Value(name = 'n_'+constant_input.name, + # shape = constant_input.shape, + # type = constant_input.type) + # n_constant_node = ir.Node('', 'Constant', + # inputs =[], + # outputs=[n_constant_input], + # attributes=[constant_value_attr], + # version=13) + n_constant_node = build_constant_from_tensor(constant_input.name+"_const_val", ir.Tensor(constant_value)) + graph_nodes.append(n_constant_node) + loop_inputs.append(n_constant_node.outputs[0]) + elif LoopInputType == LoopBodyInputType.ACTIVATION: + cinp = vdisconnect(copy.copy(LoopBody.function.inputs[i])) + graph_inputs.append(cinp) + loop_inputs.append(cinp) + + + + loop_outputs = [] + graph_outputs = [] + for i, LoopOutputType in enumerate(LoopBody.output_signature): + output = vdisconnect(copy.copy(LoopBody.function.outputs[i])) + if LoopOutputType != LoopBodyInputType.CONDITION: + loop_outputs.append(output) + if LoopOutputType == LoopBodyInputType.ACTIVATION: + graph_outputs.append(output) + + LoopBody.insert_gather_nodes(len(nodes)) + body_attr = ir.Attr(name='body', type=ir.AttributeType.GRAPH, value=LoopBody.function._graph) + graph_nodes.append(ir.Node('', 'Loop', inputs=loop_inputs, attributes = [body_attr], outputs=loop_outputs, graph=None)) + + graph = ir.Graph(name='loop_replace',nodes=graph_nodes, inputs=graph_inputs, outputs=graph_outputs) + + print('sorting graph') + graph.sort() + + model = ir.serde.serialize_model(ir.Model(graph, ir_version=8)) + onnx.save(model, 'replacementgraph.onnx') + + return ReplacementPatternGraph(graph) + + + + + + + + + + diff --git a/tests/loop_rolling/test_loop_rolling.py b/tests/loop_rolling/test_loop_rolling.py new file mode 100644 index 0000000000..710c975bdc --- /dev/null +++ b/tests/loop_rolling/test_loop_rolling.py @@ -0,0 +1,214 @@ +import argparse +import ast + +import numpy as np + +import onnx +import onnxscript + +from onnxscript import ir +from onnxscript.rewriter import pattern, rewrite +import onnxruntime as onnxrt + +import pdb + +from onnxscript.utils.pattern_builder import build_layer_pipeline_pattern +from onnxscript.utils.pattern_builder import direct_convert_ir_graph_to_pattern +from onnxscript.utils.pattern_builder import convert_graph_to_function_call_pattern +from onnxscript.utils.pattern_builder import build_loop_body +from onnxscript.utils.pattern_builder import build_loop_replace_pattern +#from pattern_builder import find_parameter_input_indexes +from onnxscript.utils.pattern_builder import connect_loop_constants +from onnxscript.utils.pattern_builder import normalize_io_for_loop_rolling +from onnxscript.utils.pattern_builder import LoopBodyTemplate + +from onnx import shape_inference + +def open_ir(filename): + print('loading onnx') + f = onnx.load(filename) + print('deserializing') + return ir.serde.deserialize_model(f) + +# Handle Parser Stuff +parser = argparse.ArgumentParser(description="A simple argparse example") +parser.add_argument("filename", type=str, help="The name of the input onnx file") +parser.add_argument("patternfilename", type=str, help="The name of the input onnx file that is a single layer") +args = parser.parse_args() + +def ort_make_rand_io(filename): + + session = onnxrt.InferenceSession(filename) + + input_name = session.get_inputs()[0].name + input_type = session.get_inputs()[0].type + input_shape = session.get_inputs()[0].shape + + if input_type == 'tensor(int64)': + np_type = np.int64 + elif input_type == 'tensor(float)': + np_type = np.float32 + else: + raise Exception("unsupported type {input_type}") + + input_data = np.random.random(input_shape).astype(np_type) + + return ({input_name: input_data}, session.get_outputs()) + + +def ort_run_graph(filename, input_dict, output_name): + sess_options = onnxrt.SessionOptions() + sess_options.log_severity_level = 0 + sess_options.graph_optimization_level = onnxrt.GraphOptimizationLevel.ORT_ENABLE_BASIC + + return onnxrt.InferenceSession(filename, sess_options=sess_options).run([output_name], input_dict) + +input_dict, outputs = ort_make_rand_io(args.filename) +golden_results = ort_run_graph(args.filename, input_dict, outputs[0].name) + + +LoopBody = LoopBodyTemplate(args.patternfilename) + +change_layers_to_function_calls = pattern.RewriteRule( + LoopBody.pattern, + LoopBody.function_replace +) + + +print("Find and Replace layers with single layer op.") + +print(LoopBody.function.identifier()) + +mypipeline = onnx.load(args.filename) + + + +print('applying rewrite rule') + +mypipeline_layers_replaced = rewrite( + mypipeline, + pattern_rewrite_rules = [change_layers_to_function_calls] +) + + +mypipeline_model = ir.serde.deserialize_model(mypipeline_layers_replaced) +mypipeline_model.functions[LoopBody.function.identifier()] = LoopBody.function +mypipeline_model.graph.opset_imports['loop']=0 +mypipeline_layers_replaced = ir.serde.serialize_model(mypipeline_model) + + +replaced_filename = "replaced_"+args.filename +print(f"Writing Updated Graph to {replaced_filename}") +onnx.save(mypipeline_layers_replaced, + replaced_filename, save_as_external_data=True, location=replaced_filename+'.data') + +print("Replace Layer Ops with Loop Body") + +normalized_graph = normalize_io_for_loop_rolling(mypipeline_model.graph, LoopBody) + + + +model = ir.serde.serialize_model(mypipeline_model) +onnx.save(model, 'normalized.onnx', save_as_external_data=True, location='normalized.onnx.data') + + +LoopMatchPattern,nodes = LoopBody.build_function_match_pattern(normalized_graph) + + +loop_replace_pattern = build_loop_replace_pattern(normalized_graph, LoopBody) + + +change_function_calls_to_loop = pattern.RewriteRule( + LoopMatchPattern, + loop_replace_pattern +) + +#print(loop_replace_pattern.get_replacement()) +#exit() +class AllTracer(pattern.MatchingTracer): + def __init__(self): + super().__init__() + + def log( + self, + rule: pattern.RewriteRule, + container: ir.Graph | ir.Function, + node: ir.Node, + match_result: pattern.MatchResult, + status: pattern.MatchStatus, + ) -> None: + this_match = pattern.MatchInfo(match_result, node, container, status) + #this_score = this_match.score() + #if this_score == 0: + # return + best_matches = self._best_matches_map[rule] + #if best_matches: + # if this_score < best_matches[0].score(): + # return + # if this_score > best_matches[0].score(): + # best_matches.clear() + best_matches.append(this_match) + + +tracer = pattern.MatchingTracer() +rewrite_set = pattern.RewriteRuleSet([change_function_calls_to_loop]) +#model_ir = ir.serde.deserialize_model(mypipeline_l) +#pdb.set_trace() +count = rewrite_set.apply_to_model(mypipeline_model, verbose=None) +print(f"Count {count}") +print(tracer.best_matches_map) + +# tracer.report() +# for rule in tracer.best_matches_map: +# matches = tracer.best_matches_map[rule] +# for match in matches: +# print(f'Reason: {match.match_result.reason}') +# print(f'root_node: {match.root_node}') +# pdb.set_trace() + +#loop_added = rewrite ( +# mypipeline_layers_replaced, +# pattern_rewrite_rules = [change_function_calls_to_loop] +#) + +print(mypipeline_model.opset_imports) +#mypipeline_model.opset_imports.pop('') +mypipeline_model.opset_imports.pop('loop') +mypipeline_model._functions = {} + + + +# scanning for empty domains +for node in ir.traversal.RecursiveGraphIterator(mypipeline_model.graph): + if node.domain == '': + print(node) + + +#mypipeline_model.opset_imports['main'] = 13 + +loop_added = ir.serde.serialize_model(mypipeline_model) + + + + + +onnx.save(loop_added, 'loop_added.onnx', save_as_external_data=True, location='loop_added.onnx.data') + +onnx.checker.check_model(loop_added) +loop_added = shape_inference.infer_shapes(loop_added) + + + + +transformed_results = ort_run_graph('loop_added.onnx', input_dict, outputs[0].name) + + +if (np.isclose(transformed_results, golden_results[0], rtol=1e-5, atol=1e-6)).all(): + print("Results Equal!!") +else: + print("errors found") + print(transformed_results[0] == golden_results[0]) + print(transformed_results[0]) + print(golden_results[0]) + +print('done') From 2f85bdc8a5d817fd865cd92074feadab50937199 Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Thu, 27 Mar 2025 21:53:17 +0000 Subject: [PATCH 02/14] removed old ast code and commented code. --- onnxscript/utils/pattern_builder.py | 319 ------------------------ tests/loop_rolling/test_loop_rolling.py | 6 - 2 files changed, 325 deletions(-) diff --git a/onnxscript/utils/pattern_builder.py b/onnxscript/utils/pattern_builder.py index 7e9f429e94..7550b1ca79 100644 --- a/onnxscript/utils/pattern_builder.py +++ b/onnxscript/utils/pattern_builder.py @@ -1,12 +1,9 @@ -import ast -import astor import copy import numpy as np import onnx import pdb -from typing import Callable from onnxscript import ir, script, INT64, BOOL from onnxscript.rewriter import pattern, rewrite @@ -18,99 +15,6 @@ from onnxscript.utils.graph_view_utils import bGraphView -def bAstModule(body=[]): - return ast.Module(body=body) - -def bAstArguments(posonlyargs=[], args=[], vararg=None, - kwonlyargs=None, kw_defaults=None, kwarg=None, - defaults=[]): - return ast.arguments(posonlyargs, args, vararg, - kwonlyargs, kw_defaults, kwarg, - defaults) - -def bAstFunctionDef(name = '', args=bAstArguments(), body=[], - decorator_list=[], returns=None, type_comment='', - type_params=[]): - f_def = ast.FunctionDef(name, args, body, decorator_list, - returns, type_comment, type_params=type_params) - return f_def - -def bAstArg(value): - if isinstance(value, ir.Value): - return ast.arg(arg=value.name) - else: - return ast.arg(arg=value) - -def bAstName(value): - if isinstance(value, ir.Value): - return ast.Name(id=value.name, ctx=ast.Load()) - else: - return ast.Name(id=value, ctx=ast.Load()) - -def bAstList(values, bfunc): - return [bfunc(x) for x in values] - -def bAstArgList(values): - return bAstList(values,bAstArg) - -def bAstNameList(values): - return bAstList(values,bAstName) - - -def bAstOpAttr(node, opAttr='op'): - return ast.Attribute(value = ast.Name(id=opAttr, ctx=ast.Load()), - attr = node.op_type, - ctx = ast.Load()) - -def bAstCall(func, args, keywords=[]): - return ast.Call(func=func, args=args, keywords=keywords) - -def bAstAssign(node, opAttr='op'): - keywords = [] - for attribute in node.attributes: - print(attribute) - keywords.append(ast.keyword(attribute, ast.Constant(0))) - return ast.Assign(targets = bAstNameList(node.outputs), - value = bAstCall(bAstOpAttr(node, opAttr), bAstNameList(node.inputs), keywords=keywords)) - - -# def bAstTensorProto(tensor): -# attr0 = ast.Attribute(value=ast.Name(id='onnx',ctx=ast.Load()), -# attr = 'helper', -# ctx = ast.Load()) -# attr1 = ast.Attribute(value=attr1, -# attr = 'make_tensor', -# ctx = ast.Load()) -# call = bAstCall(func = attr1, args = , key - - -#const_value = onnx.helper.make_tensor("tensor", onnx.TensorProto.INT64, (3,), [3, 384, 384]) - -def bAstReturn(value): - if isinstance(value, ir.Value): - value = bAstName(value) - - return ast.Return(value=value) - -def bAstModule(): - return ast.Module(body=[]) - -def bIrValueList(names): - return [ir.Value(name=name) for name in names] - -def build_pattern(tree, name): - code = 'from onnxscript import script, opset18, FLOAT, BOOL\n' - code += 'import onnx\n' - code += 'from typing import Tuple\n' - code += astor.to_source(tree) - code += '\nfunc = ' + name - #print(code) - scope = {} - exec(code, scope) - return scope['func'] - - -import pdb def direct_convert_ir_graph_to_pattern(graph): @@ -160,32 +64,6 @@ def direct_convert_ir_graph_to_pattern(graph): return GraphPattern(inputs=pinputs, outputs=poutputs, nodes=builder.nodes()) - - - -# def convert_graph_to_pattern(graph): - -# #Build Function Definition -# f_def = bAstFunctionDef(name=graph.name) -# #print(f_def) -# # Add Arguments -# f_def.args.args.append(ast.arg(arg='op')) -# f_def.args.args.extend(bAstArgList(graph.inputs)) - -# # Build Body -# for node in ir.traversal.RecursiveGraphIterator(graph): -# f_def.body.append(bAstAssign(node)) - -# # Add Return Statement -# if len(graph.outputs) == 1: -# f_def.body.append(bAstReturn(graph.outputs[0])) -# else: -# print("only one output supported for now") -# return None - -# return build_pattern(f_def, graph.name) - - from enum import Enum def remove_input_from_node(node, inp): @@ -320,21 +198,6 @@ def output_signature(self): return self.signature[1:] - -# class LoopBodyBuilderInfo: -# def __init__(self, nodes): -# self._nodes = nodes -# self._input_types = [UNDEFINED] * len(self._nodes.inputs()) - -# def _validate_nodes_ - - -# def size(self): -# return len(self._input_types) - -# def set_type(self, index, input_type): -# self._input_types[index] = input_type - def same(input_list): return len(set(input_list)) == 1 @@ -510,158 +373,10 @@ def normalize_io_for_loop_rolling(graph, LoopBody): noutput = ir.Value(name=f'cond_out_{i}', type=ir.TensorType(ir.DataType.BOOL)) prepend_output_to_node(node,noutput) - - # Add External Constants to Drive These - # M_value = ir.Attr(name='value', type=ir.AttributeType.TENSOR, value=ir.Tensor(np.array([number_of_layers]))) - # M = ir.Node('', 'Constant', inputs=[], outputs=[loop_node.inputs[0]], attributes=[M_value], graph=model_graph) - - # cond_value = ir.Attr(name='value', type=ir.AttributeType.TENSOR, value=ir.Tensor(np.array([True]))) - # cond = ir.Node('', 'Constant', inputs=[], outputs=[loop_node.inputs[1]],attributes=[cond_value], graph=model_graph) - - # model_graph.insert_before(model_graph[0], [cond]) - # model_graph.insert_before(model_graph[0], [M]) - - # Add Gather Nodes - # for index,LoopInputType in enumerate(LoopBody.signature): -# if LoopInputType == LoopBodyInputType.PARAMETER: -# # Update the Shape - - -# gvo = copy.copy(indexed_input) -# gvo.name = gvo.name + '_gather_out' -# for node in indexed_input.consumers(): -# for i,inp in enumerate(node.inputs): -# if indexed_input is inp: -# node.replace_input_with(i, gvo) -# graph.append(ir.Node(domain='', -# op_type='Gather', -# inputs = [nv, viter], -# outputs = [gvo], -# num_outputs = 1, -# graph = graph -# )) graph.sort() return graph - -# def find_parameter_input_indexes(graph, layername): - -# nodes = find_nodes_of_optype(graph, layername) - -# if len(nodes) == 0: -# raise Exception("Did not find node with name {layername}") - -# indexes = [] -# for i in range(len(nodes[0].inputs)): -# value = nodes[0].inputs[i] -# print(f'node input[{i}] {value} {value.producer()}') -# if value.producer() == None: # is an initializer/weight value (unless an input) -# print(f"appending {value} at index {i}") -# indexes.append(i) -# elif value.producer().op_type == 'Constant': -# print("found constant optype") -# same_value = True -# for node in nodes: -# if node.inputs[i] is not value: -# same_value = False -# if not same_value: -# print(f"appending constant {value} at index {i}") -# indexes.append(i) - -# return indexes - - - -def build_loop_body(graph, number_of_layers, parameter_indexes = []): - - # for i,inp in enumerate(graph.inputs): - # print(f"graph input: {i} {inp}") - - # print(f'parameter indexes: {parameter_indexes}') - - - # viter = ir.Value(name='iteration', shape=ir.Shape([1]), type=ir.TensorType(ir.DataType.INT64)) - # vcond_in = ir.Value(name='cond', shape=ir.Shape([1]), type=ir.TensorType(ir.DataType.BOOL)) - - # for index in parameter_indexes: - # indexed_input = graph.inputs[index] - - # nv = copy.copy(indexed_input) - # nv.shape = ir.Shape((number_of_layers, *nv.shape.dims)) - # nv.name = nv.name + '_array_in' - # graph.inputs[index] = nv - - # nvo = copy.copy(indexed_input) - # nvo.shape = ir.Shape((nv.shape.dims)) - # nvo.name = nvo.name + '_array_out' - # graph.outputs.append(nvo) - - # graph.append(ir.Node(domain='', op_type='Identity', - # inputs = [nv], - # outputs = [nvo], - # num_outputs =1)) - - # for index in - # gvo = copy.copy(indexed_input) - # gvo.name = gvo.name + '_gather_out' - # for node in indexed_input.consumers(): - # for i,inp in enumerate(node.inputs): - # if indexed_input is inp: - # node.replace_input_with(i, gvo) - # graph.append(ir.Node(domain='', - # op_type='Gather', - # inputs = [nv, viter], - # outputs = [gvo], - # num_outputs = 1, - # graph = graph - # )) - - # print(f'indexed input: {indexed_input} {gvo.producer}') - # print(f'usages: {gvo.uses()}') - # print(f'const_value: {gvo.const_value}') - # print(f'type: {gvo.type}') - # print(f'meta: {gvo.meta}') - # print(f'initializers: {graph.initializers}') - #exit(1) - #new_values.append(nv) - - - - graph.inputs.insert(0,vcond_in) - graph.inputs.insert(0,viter) - - - vcond_out = ir.Value(name='cond_out', type=ir.TensorType(ir.DataType.BOOL)) - vcond_ident = ir.Node(domain='', op_type='Identity', - inputs = [vcond_in], - outputs = [vcond_out], - num_outputs =1, - graph = graph) - graph.append(vcond_ident) - - #ovalue = graph.outputs[0] - #graph.outputs[0] = (ovalue, vcond_out) - - graph.outputs.insert(0,vcond_out) - #print(graph) - print(f"*** graph.outputs = {graph.outputs}") - # for i in range(len(indexed_values)): - # graph.append(ir.Node(domain='', - # op_type='Gather', - # inputs = [new_values[i], viter], - # outputs = [indexed_values[i]], - # num_outputs = 1, - # graph = graph - # )) - - graph.sort() - #print(indexed_values) - #print(graph.inputs) - #print(graph) - - return graph - import copy def vdisconnect(value): @@ -745,7 +460,6 @@ def convert_graph_to_function_call_pattern(graph): def find_nodes_of_optype(graph, layername): nodes = [] for node in ir.traversal.RecursiveGraphIterator(graph): - #print(f'node op_type: {node.op_type} {layername}') if node.op_type == layername: nodes.append(node) return nodes @@ -763,21 +477,6 @@ def build_layer_pipeline_pattern(graph, layername): return (pattern, nodes) -def connect_loop_constants(model_graph, layer_nodes): - number_of_layers = len(layer_nodes) - - loop_node = find_nodes_of_optype(model_graph,'Loop')[0] - - - #r.TensorType(ir.DataType.INT64), const_value=ir.Tensor(np.array([number_of_layers]))) - M_value = ir.Attr(name='value', type=ir.AttributeType.TENSOR, value=ir.Tensor(np.array([number_of_layers]))) - M = ir.Node('', 'Constant', inputs=[], outputs=[loop_node.inputs[0]], attributes=[M_value], graph=model_graph) - cond_value = ir.Attr(name='value', type=ir.AttributeType.TENSOR, value=ir.Tensor(np.array([True]))) - cond = ir.Node('', 'Constant', inputs=[], outputs=[loop_node.inputs[1]],attributes=[cond_value], graph=model_graph) - model_graph.insert_before(model_graph[0], [cond]) - model_graph.insert_before(model_graph[0], [M]) - return model_graph - def build_constant_from_tensor(name, tensor): value_attribute = ir.Attr(name='value', type=ir.AttributeType.TENSOR, value=tensor) ir_value_out = ir.Value(name=name+'_out', type=ir.TensorType(tensor.dtype)) @@ -842,14 +541,6 @@ def build_loop_replace_pattern(graph, LoopBody): constant_input = nodes[0].inputs[i] constant_node = constant_input.producer() constant_value = constant_node.attributes['value'].value.numpy() - # n_constant_input = ir.Value(name = 'n_'+constant_input.name, - # shape = constant_input.shape, - # type = constant_input.type) - # n_constant_node = ir.Node('', 'Constant', - # inputs =[], - # outputs=[n_constant_input], - # attributes=[constant_value_attr], - # version=13) n_constant_node = build_constant_from_tensor(constant_input.name+"_const_val", ir.Tensor(constant_value)) graph_nodes.append(n_constant_node) loop_inputs.append(n_constant_node.outputs[0]) @@ -882,13 +573,3 @@ def build_loop_replace_pattern(graph, LoopBody): onnx.save(model, 'replacementgraph.onnx') return ReplacementPatternGraph(graph) - - - - - - - - - - diff --git a/tests/loop_rolling/test_loop_rolling.py b/tests/loop_rolling/test_loop_rolling.py index 710c975bdc..274e64e78a 100644 --- a/tests/loop_rolling/test_loop_rolling.py +++ b/tests/loop_rolling/test_loop_rolling.py @@ -12,13 +12,7 @@ import pdb -from onnxscript.utils.pattern_builder import build_layer_pipeline_pattern -from onnxscript.utils.pattern_builder import direct_convert_ir_graph_to_pattern -from onnxscript.utils.pattern_builder import convert_graph_to_function_call_pattern -from onnxscript.utils.pattern_builder import build_loop_body from onnxscript.utils.pattern_builder import build_loop_replace_pattern -#from pattern_builder import find_parameter_input_indexes -from onnxscript.utils.pattern_builder import connect_loop_constants from onnxscript.utils.pattern_builder import normalize_io_for_loop_rolling from onnxscript.utils.pattern_builder import LoopBodyTemplate From c7cbf5cdfaf2dd1503e53c9fe39016738c23e576 Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Thu, 27 Mar 2025 22:23:29 +0000 Subject: [PATCH 03/14] add code to remove existing data files before writing to disk. --- tests/loop_rolling/test_loop_rolling.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/tests/loop_rolling/test_loop_rolling.py b/tests/loop_rolling/test_loop_rolling.py index 274e64e78a..fec91729e3 100644 --- a/tests/loop_rolling/test_loop_rolling.py +++ b/tests/loop_rolling/test_loop_rolling.py @@ -1,23 +1,25 @@ import argparse -import ast import numpy as np import onnx -import onnxscript +import os from onnxscript import ir from onnxscript.rewriter import pattern, rewrite import onnxruntime as onnxrt -import pdb - from onnxscript.utils.pattern_builder import build_loop_replace_pattern from onnxscript.utils.pattern_builder import normalize_io_for_loop_rolling from onnxscript.utils.pattern_builder import LoopBodyTemplate from onnx import shape_inference +def remove_existing_data_file(filename): + if os.path.exists(filename): + print(f"Removing existing data file: {filename}") + os.remove(filename) + def open_ir(filename): print('loading onnx') f = onnx.load(filename) @@ -93,8 +95,9 @@ def ort_run_graph(filename, input_dict, output_name): replaced_filename = "replaced_"+args.filename print(f"Writing Updated Graph to {replaced_filename}") +remove_existing_data_file(replaced_filename+'.data') onnx.save(mypipeline_layers_replaced, - replaced_filename, save_as_external_data=True, location=replaced_filename+'.data') + replaced_filename, save_as_external_data=True, location= replaced_filename+'.data') print("Replace Layer Ops with Loop Body") @@ -103,6 +106,7 @@ def ort_run_graph(filename, input_dict, output_name): model = ir.serde.serialize_model(mypipeline_model) +remove_existing_data_file('normalized.onnx.data') onnx.save(model, 'normalized.onnx', save_as_external_data=True, location='normalized.onnx.data') @@ -185,7 +189,7 @@ def log( - +remove_existing_data_file('loop_added.onnx.data') onnx.save(loop_added, 'loop_added.onnx', save_as_external_data=True, location='loop_added.onnx.data') onnx.checker.check_model(loop_added) From 8ace32fecd82b3b576632e5cac8cbe68f23d46e6 Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Thu, 27 Mar 2025 22:35:30 +0000 Subject: [PATCH 04/14] remove debug printing statements --- onnxscript/utils/graph_view_utils.py | 9 ------- onnxscript/utils/pattern_builder.py | 33 ------------------------- tests/loop_rolling/test_loop_rolling.py | 24 +++--------------- 3 files changed, 3 insertions(+), 63 deletions(-) diff --git a/onnxscript/utils/graph_view_utils.py b/onnxscript/utils/graph_view_utils.py index 73423fa99d..fd823588de 100644 --- a/onnxscript/utils/graph_view_utils.py +++ b/onnxscript/utils/graph_view_utils.py @@ -40,23 +40,14 @@ def find_subgraph_inputs(nodes): inputs = set() initializers = set() for node in nodes: - print(f'{type(node)} {type(nodes)}') - print(f'{node}') for ninput in node.inputs: if ninput in node.graph.inputs: inputs.add(ninput) elif any(ninput is init for init in node.graph.initializers): - print("adding graph initializater") initializers.add(ninput) elif ninput.producer() == None: - print(f"adding none initializer: {ninput}") inputs.add(ninput) elif ninput.producer() not in nodes: - print(f"adding not in nodes: {ninput}") - if ninput.producer() in node.graph._nodes: - print(f'\tIn Graph node list') - print(f"\t {ninput.producer()}") - print(f"\t {ninput.producer().metadata_props}") inputs.add(ninput) return inputs, initializers diff --git a/onnxscript/utils/pattern_builder.py b/onnxscript/utils/pattern_builder.py index 7550b1ca79..ee6c446a0b 100644 --- a/onnxscript/utils/pattern_builder.py +++ b/onnxscript/utils/pattern_builder.py @@ -24,7 +24,6 @@ def direct_convert_ir_graph_to_pattern(graph): vmap[input] = ValuePattern(input.name) for init in graph.initializers: - print(init) vmap[init] = ValuePattern(init.name) @@ -101,7 +100,6 @@ def _build_function_replace_pattern(self): inputs = [vdisconnect(copy.copy(x)) for x in self._ir_graph.inputs] outputs = [vdisconnect(copy.copy(x)) for x in self._ir_graph.outputs] - print(type(self.function.domain)) node = ir.Node(domain=self.function.domain, version=0, op_type=self.function.name, @@ -160,11 +158,9 @@ def insert_gather_nodes(self, loop_iterations): def build_function_match_pattern(self, graph): graph.sort() - print(self.function.name) nodes = find_nodes_of_optype(graph, self.function.name) nodes.insert(0,graph.node('iteration_ext')) nodes.insert(0,graph.node('condition_ext')) - print(f'found {len(nodes)}') ir_model = ir.Model(bGraphView('inlined_pipe_pattern', nodes), ir_version=10) @@ -282,7 +278,6 @@ def normalize_io_for_loop_rolling(graph, LoopBody): LoopBody.function.inputs[swap[1]] = a LoopBody.signature[swap[0]] = LoopBodyInputType.ACTIVATION activations+=1 - #print(f"Added Activation: {swap[0]}") # Next Inputs according to how they are produced. # Indexable inputs will have different constant or none producers @@ -296,11 +291,9 @@ def normalize_io_for_loop_rolling(graph, LoopBody): inputs.append(cinput) if same(inputs): # Constant with Respect to Loop - #print(f"Added Constant: {index}") LoopBody.signature[index] = LoopBodyInputType.CONSTANT else: # Must be Indexed - #print(f"Added Parameter: {index}") LoopBody.signature[index] = LoopBodyInputType.PARAMETER @@ -392,53 +385,31 @@ def __init__(self, ir_graph): def get_replacement(self, match: pattern.MatchResult) -> pattern.ReplacementSubgraph | None: - print(f"in get replacement!!!!") - print(f"match pattern full: {match.outputs}") context = pattern.RewriterContext() # match.bindings is dictionary of value_name (str) in replacement subgraph pattern (i.e. ir_graph -> IR Value in actual graph) vvmap = {} # Dictionary mapping values in replacement subgraph pattern -> values in the replacement subgraph - #print(f'self._graph: {self._graph}') - #print(f'len inputs: {self._graph.inputs}') - #print(f'match bindings: {match.bindings}') for value in self._graph.inputs: if value.name in match.bindings: - #print(f'***bound input value name: {value.name}') vvmap[value] = match.bindings[value.name] else: - #print(f'***unbound input value name: {value.name}') - #print(f'***unbound value : {value}') vvmap[value] = value for node in self._graph._nodes: ninputs = [] - print(node) for ninput in node.inputs: - #print(f'{ninput}') ninputs.append(vvmap[ninput]) coutput = context.__getattr__(node.op_type)(*ninputs, **node.attributes, _outputs=len(node.outputs), _domain=node.domain, _version=node.version) - print(f"coutput type: {type(coutput)}") - print(f"coutput type: {coutput}") if not isinstance(coutput,Iterable): - print("Not an Iterable") coutput = [coutput] for i, cout in enumerate(coutput): - print(f"listing outputs: {node.outputs[cout.index()]}") - print(f"cout {cout.index()} {cout} ") vvmap[node.outputs[cout.index()]] = cout - for x in vvmap: - print(x) - print(f"context nodes") - for node in context.nodes: - print(node) new_outputs = [vvmap[x] for x in self._graph.outputs] - print(f'new_outputs={new_outputs}') - print(f'_graph.outputs={self._graph.outputs}') return pattern.ReplacementSubgraph( match, new_outputs, context.nodes, context.initializers, context.used_opsets ) @@ -514,15 +485,12 @@ def build_loop_replace_pattern(graph, LoopBody): loop_inputs.append(cond.outputs[0]) graph_inputs = [] - print(LoopBody.signature) for i, LoopInputType in enumerate(LoopBody.signature): if LoopInputType == LoopBodyInputType.PARAMETER: # Build Concat Node concat_inputs = [] for node in nodes: - print(f"node.optype: {node}") - print(f"node inputs:{i} {node.inputs}") nvalue = vdisconnect(copy.copy(node.inputs[i])) graph_inputs.append(nvalue) concat_inputs.append(nvalue) @@ -566,7 +534,6 @@ def build_loop_replace_pattern(graph, LoopBody): graph = ir.Graph(name='loop_replace',nodes=graph_nodes, inputs=graph_inputs, outputs=graph_outputs) - print('sorting graph') graph.sort() model = ir.serde.serialize_model(ir.Model(graph, ir_version=8)) diff --git a/tests/loop_rolling/test_loop_rolling.py b/tests/loop_rolling/test_loop_rolling.py index fec91729e3..5bb980ba85 100644 --- a/tests/loop_rolling/test_loop_rolling.py +++ b/tests/loop_rolling/test_loop_rolling.py @@ -73,12 +73,8 @@ def ort_run_graph(filename, input_dict, output_name): print("Find and Replace layers with single layer op.") -print(LoopBody.function.identifier()) - mypipeline = onnx.load(args.filename) - - print('applying rewrite rule') mypipeline_layers_replaced = rewrite( @@ -121,8 +117,6 @@ def ort_run_graph(filename, input_dict, output_name): loop_replace_pattern ) -#print(loop_replace_pattern.get_replacement()) -#exit() class AllTracer(pattern.MatchingTracer): def __init__(self): super().__init__() @@ -136,25 +130,14 @@ def log( status: pattern.MatchStatus, ) -> None: this_match = pattern.MatchInfo(match_result, node, container, status) - #this_score = this_match.score() - #if this_score == 0: - # return best_matches = self._best_matches_map[rule] - #if best_matches: - # if this_score < best_matches[0].score(): - # return - # if this_score > best_matches[0].score(): - # best_matches.clear() best_matches.append(this_match) tracer = pattern.MatchingTracer() rewrite_set = pattern.RewriteRuleSet([change_function_calls_to_loop]) -#model_ir = ir.serde.deserialize_model(mypipeline_l) -#pdb.set_trace() count = rewrite_set.apply_to_model(mypipeline_model, verbose=None) print(f"Count {count}") -print(tracer.best_matches_map) # tracer.report() # for rule in tracer.best_matches_map: @@ -169,7 +152,6 @@ def log( # pattern_rewrite_rules = [change_function_calls_to_loop] #) -print(mypipeline_model.opset_imports) #mypipeline_model.opset_imports.pop('') mypipeline_model.opset_imports.pop('loop') mypipeline_model._functions = {} @@ -177,9 +159,9 @@ def log( # scanning for empty domains -for node in ir.traversal.RecursiveGraphIterator(mypipeline_model.graph): - if node.domain == '': - print(node) +# for node in ir.traversal.RecursiveGraphIterator(mypipeline_model.graph): +# if node.domain == '': +# print(node) #mypipeline_model.opset_imports['main'] = 13 From f71060ab465d52b6d459f425ed4871ec79fea5fa Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Fri, 28 Mar 2025 20:55:50 +0000 Subject: [PATCH 05/14] fix imports and stuff --- ...tern_builder.py => pattern_builder_jsm.py} | 21 +++++--- tests/loop_rolling/test_loop_rolling.py | 52 ++++++++++--------- 2 files changed, 41 insertions(+), 32 deletions(-) rename onnxscript/utils/{pattern_builder.py => pattern_builder_jsm.py} (96%) diff --git a/onnxscript/utils/pattern_builder.py b/onnxscript/utils/pattern_builder_jsm.py similarity index 96% rename from onnxscript/utils/pattern_builder.py rename to onnxscript/utils/pattern_builder_jsm.py index ee6c446a0b..b70414cc35 100644 --- a/onnxscript/utils/pattern_builder.py +++ b/onnxscript/utils/pattern_builder_jsm.py @@ -5,15 +5,22 @@ import pdb -from onnxscript import ir, script, INT64, BOOL -from onnxscript.rewriter import pattern, rewrite -from onnxscript.rewriter.pattern import ValuePattern, GraphPattern, OpsetPatternBuilder, pattern_builder, NodeOutputPattern +from onnxscript import ir +from onnxscript import rewriter +from onnxscript.rewriter.pattern import ( + MatchResult, ValuePattern, GraphPattern, OpsetPatternBuilder, pattern_builder, NodeOutputPattern, ReplacementSubgraph + +) from collections.abc import Iterable from onnxscript.utils.graph_view_utils import bGraphView +#print("**************************************") +#print("********* Pattern Builder ************") +#print("**************************************") + def direct_convert_ir_graph_to_pattern(graph): @@ -379,14 +386,14 @@ def vdisconnect(value): return value -class ReplacementPatternGraph(pattern.ReplacementPatternFunction): +class ReplacementPatternGraph(rewriter.pattern.ReplacementPatternFunction): def __init__(self, ir_graph): self._graph = ir_graph - def get_replacement(self, match: pattern.MatchResult) -> pattern.ReplacementSubgraph | None: + def get_replacement(self, match: MatchResult) -> ReplacementSubgraph | None: - context = pattern.RewriterContext() + context = rewriter.RewriterContext() # match.bindings is dictionary of value_name (str) in replacement subgraph pattern (i.e. ir_graph -> IR Value in actual graph) vvmap = {} # Dictionary mapping values in replacement subgraph pattern -> values in the replacement subgraph @@ -410,7 +417,7 @@ def get_replacement(self, match: pattern.MatchResult) -> pattern.ReplacementSubg vvmap[node.outputs[cout.index()]] = cout new_outputs = [vvmap[x] for x in self._graph.outputs] - return pattern.ReplacementSubgraph( + return rewriter.ReplacementSubgraph( match, new_outputs, context.nodes, context.initializers, context.used_opsets ) diff --git a/tests/loop_rolling/test_loop_rolling.py b/tests/loop_rolling/test_loop_rolling.py index 5bb980ba85..0f5101a2e2 100644 --- a/tests/loop_rolling/test_loop_rolling.py +++ b/tests/loop_rolling/test_loop_rolling.py @@ -6,12 +6,14 @@ import os from onnxscript import ir -from onnxscript.rewriter import pattern, rewrite +from onnxscript import rewriter + +#.rewriter import PatternRewriteRule, RewriteRuleSet, rewrite import onnxruntime as onnxrt -from onnxscript.utils.pattern_builder import build_loop_replace_pattern -from onnxscript.utils.pattern_builder import normalize_io_for_loop_rolling -from onnxscript.utils.pattern_builder import LoopBodyTemplate +from onnxscript.utils.pattern_builder_jsm import build_loop_replace_pattern +from onnxscript.utils.pattern_builder_jsm import normalize_io_for_loop_rolling +from onnxscript.utils.pattern_builder_jsm import LoopBodyTemplate from onnx import shape_inference @@ -65,7 +67,7 @@ def ort_run_graph(filename, input_dict, output_name): LoopBody = LoopBodyTemplate(args.patternfilename) -change_layers_to_function_calls = pattern.RewriteRule( +change_layers_to_function_calls = rewriter.PatternRewriteRule( LoopBody.pattern, LoopBody.function_replace ) @@ -112,30 +114,30 @@ def ort_run_graph(filename, input_dict, output_name): loop_replace_pattern = build_loop_replace_pattern(normalized_graph, LoopBody) -change_function_calls_to_loop = pattern.RewriteRule( +change_function_calls_to_loop = rewriter.PatternRewriteRule( LoopMatchPattern, loop_replace_pattern ) -class AllTracer(pattern.MatchingTracer): - def __init__(self): - super().__init__() - - def log( - self, - rule: pattern.RewriteRule, - container: ir.Graph | ir.Function, - node: ir.Node, - match_result: pattern.MatchResult, - status: pattern.MatchStatus, - ) -> None: - this_match = pattern.MatchInfo(match_result, node, container, status) - best_matches = self._best_matches_map[rule] - best_matches.append(this_match) - - -tracer = pattern.MatchingTracer() -rewrite_set = pattern.RewriteRuleSet([change_function_calls_to_loop]) +# class AllTracer(pattern.MatchingTracer): +# def __init__(self): +# super().__init__() + +# def log( +# self, +# rule: PatternRewriteRule, +# container: ir.Graph | ir.Function, +# node: ir.Node, +# match_result: pattern.MatchResult, +# status: pattern.MatchStatus, +# ) -> None: +# this_match = pattern.MatchInfo(match_result, node, container, status) +# best_matches = self._best_matches_map[rule] +# best_matches.append(this_match) + + +# tracer = pattern.MatchingTracer() +rewrite_set = RewriteRuleSet([change_function_calls_to_loop]) count = rewrite_set.apply_to_model(mypipeline_model, verbose=None) print(f"Count {count}") From a1c138c3dbb7acd1a7adb9fa581d8c7d40cb1c2b Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Fri, 28 Mar 2025 23:03:25 +0000 Subject: [PATCH 06/14] updates to make tests orks --- .../{utils => rewriter}/pattern_builder_jsm.py | 13 ++++++------- tests/loop_rolling/test_loop_rolling.py | 8 +++++--- 2 files changed, 11 insertions(+), 10 deletions(-) rename onnxscript/{utils => rewriter}/pattern_builder_jsm.py (98%) diff --git a/onnxscript/utils/pattern_builder_jsm.py b/onnxscript/rewriter/pattern_builder_jsm.py similarity index 98% rename from onnxscript/utils/pattern_builder_jsm.py rename to onnxscript/rewriter/pattern_builder_jsm.py index b70414cc35..2b398a1067 100644 --- a/onnxscript/utils/pattern_builder_jsm.py +++ b/onnxscript/rewriter/pattern_builder_jsm.py @@ -1,15 +1,14 @@ +print("***** IMPORTING JSM PATTERN BUILDER *****") + import copy import numpy as np import onnx -import pdb - from onnxscript import ir from onnxscript import rewriter from onnxscript.rewriter.pattern import ( - MatchResult, ValuePattern, GraphPattern, OpsetPatternBuilder, pattern_builder, NodeOutputPattern, ReplacementSubgraph - + RewriterContext, MatchResult, ValuePattern, GraphPattern, OpsetPatternBuilder, pattern_builder, NodeOutputPattern, ReplacementSubgraph, ReplacementPatternFunction ) @@ -386,14 +385,14 @@ def vdisconnect(value): return value -class ReplacementPatternGraph(rewriter.pattern.ReplacementPatternFunction): +class ReplacementPatternGraph(ReplacementPatternFunction): def __init__(self, ir_graph): self._graph = ir_graph def get_replacement(self, match: MatchResult) -> ReplacementSubgraph | None: - context = rewriter.RewriterContext() + context = RewriterContext() # match.bindings is dictionary of value_name (str) in replacement subgraph pattern (i.e. ir_graph -> IR Value in actual graph) vvmap = {} # Dictionary mapping values in replacement subgraph pattern -> values in the replacement subgraph @@ -417,7 +416,7 @@ def get_replacement(self, match: MatchResult) -> ReplacementSubgraph | None: vvmap[node.outputs[cout.index()]] = cout new_outputs = [vvmap[x] for x in self._graph.outputs] - return rewriter.ReplacementSubgraph( + return ReplacementSubgraph( match, new_outputs, context.nodes, context.initializers, context.used_opsets ) diff --git a/tests/loop_rolling/test_loop_rolling.py b/tests/loop_rolling/test_loop_rolling.py index 0f5101a2e2..7ff7b476d8 100644 --- a/tests/loop_rolling/test_loop_rolling.py +++ b/tests/loop_rolling/test_loop_rolling.py @@ -11,9 +11,11 @@ #.rewriter import PatternRewriteRule, RewriteRuleSet, rewrite import onnxruntime as onnxrt -from onnxscript.utils.pattern_builder_jsm import build_loop_replace_pattern -from onnxscript.utils.pattern_builder_jsm import normalize_io_for_loop_rolling -from onnxscript.utils.pattern_builder_jsm import LoopBodyTemplate +from onnxscript.rewriter import rewrite, RewriteRuleSet + +from onnxscript.rewriter.pattern_builder_jsm import build_loop_replace_pattern +from onnxscript.rewriter.pattern_builder_jsm import normalize_io_for_loop_rolling +from onnxscript.rewriter.pattern_builder_jsm import LoopBodyTemplate from onnx import shape_inference From 3274be309eb6b88a1e92a335d4e3910fc554e7d9 Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Fri, 28 Mar 2025 23:09:59 +0000 Subject: [PATCH 07/14] remove print statements --- onnxscript/rewriter/pattern_builder_jsm.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/onnxscript/rewriter/pattern_builder_jsm.py b/onnxscript/rewriter/pattern_builder_jsm.py index 2b398a1067..8f3ec1f324 100644 --- a/onnxscript/rewriter/pattern_builder_jsm.py +++ b/onnxscript/rewriter/pattern_builder_jsm.py @@ -1,5 +1,3 @@ -print("***** IMPORTING JSM PATTERN BUILDER *****") - import copy import numpy as np From 65985d86e47ebdf5c35a8f80eaece4615af87937 Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Mon, 14 Apr 2025 17:28:25 +0000 Subject: [PATCH 08/14] print_hierarchy works, initial tests pass --- onnxscript/utils/graph_view_utils.py | 139 +++++++++++++++++- onnxscript/utils/test_PytorchHierarchyNode.py | 93 ++++++++++++ 2 files changed, 231 insertions(+), 1 deletion(-) create mode 100644 onnxscript/utils/test_PytorchHierarchyNode.py diff --git a/onnxscript/utils/graph_view_utils.py b/onnxscript/utils/graph_view_utils.py index fd823588de..5d1964e993 100644 --- a/onnxscript/utils/graph_view_utils.py +++ b/onnxscript/utils/graph_view_utils.py @@ -1,5 +1,7 @@ from onnxscript import ir -import onnx + +import ast + def is_initializer(value): return input.producer() == None @@ -93,6 +95,141 @@ def bGraphView(name, nodes): # rebuild_pytorch_dynamo_instance_code # ######################################## +from typing import List + + +class PytorchMetadataNode: + def __init__(self, node): + self._node = node + + if self.check_node_metadata_exists(): + self.instance_metadata = ast.literal_eval(self._node.metadata_props['pkg.torch.onnx.name_scopes']) + self.class_metadata = ast.literal_eval(self._node.metadata_props['pkg.torch.onnx.class_hierarchy']) + print(f'self.node.metadata_props: {self._node.metadata_props}') + else: + raise ValueError(f"Node {self._node.name} does not have the required metadata properties") + + def check_node_metadata_exists(self): + if 'pkg.torch.onnx.name_scopes' in self._node.metadata_props and \ + 'pkg.torch.onnx.class_hierarchy' in self._node.metadata_props: + return True + else: + return False + + def is_last_level(self, level): + if len(self.instance_metadata) - 1 == level: + return True + else: + return False + + def get_instance_name(self, depth=0): + if depth >= len(self.instance_metadata): + return None + else: + return self.instance_metadata[depth] + + def get_class_name(self, depth=0): + if depth >= len(self.instance_metadata): + return None + else: + return self.class_metadata[depth] + +class PytorchHierarchyNode: + def __init__(self): + self.instance_name = None + self.module_type = None + self.children = [] + self.nodes = [] + + def print_hierarchy(self, instance_hierarchy: List[str] = None): + if instance_hierarchy is None: + instance_hierarchy = [] + instance_hierarchy.append(self.instance_name) + + for child in self.children: + child.print_hierarchy(list(instance_hierarchy)) + + for node in self.nodes: + print(f"Node: {node._node.name}, Instance: {'/'.join(instance_hierarchy)}, Module: {self.module_type}") + + + def get_unwrapped_nodes(self): + # Return _node from self._nodes + return [node._node for node in self.nodes] + + # Checks if the search hierarchy matches the instance hierarchy + def hierarchy_matches(self, search_hierarchy: List[str], instance_hierarchy: List[str] = []): + search_length = min(len(search_hierarchy), len(instance_hierarchy)) + for i in range(search_length): + if search_hierarchy[i] != instance_hierarchy[i]: + return False + return True + + # Return all nodes from the given name hierarchy on down + def get_nodes(self, search_hierarchy: List[str], instance_hierarchy: List[str] = None): + if instance_hierarchy is None: + instance_hierarchy = [] + + nodes_to_return = [] + # base case for recursion + # 1 - search_hierarchy does not match instance_hierarchy + instance_hierarchy.append(self.instance_name) + print(f"search_hierarchy: {search_hierarchy}") + print(f"instance_hierarchy: {instance_hierarchy}") + + if not self.hierarchy_matches(search_hierarchy, instance_hierarchy): + return [] + + for child in self.children: + child_nodes = child.get_nodes(search_hierarchy, list(instance_hierarchy)) + nodes_to_return.extend(child_nodes) + + if len(instance_hierarchy) >= len(search_hierarchy): + nodes_to_return.extend(self.get_unwrapped_nodes()) + + return nodes_to_return + + def add_nodes(self, nodes): + for node in nodes: + self.add_node(node) + + def add_node(self, node, level=0): + + print("calling add_node") + if not isinstance(node, PytorchMetadataNode): + node = PytorchMetadataNode(node) + + if self.instance_name is None: + print(f"setting instance name to {node.get_instance_name(level)}") + self.instance_name = node.get_instance_name(level) + if self.module_type is None: + self.module_type = node.get_class_name(level) + + # check that instance name and module type match + if self.instance_name != node.get_instance_name(level): + raise ValueError(f"Instance name mismatch: {self.instance_name} != {node.get_instance_name(level)}") + if self.module_type != node.get_class_name(level): + raise ValueError(f"Module type mismatch: {self.module_type} != {node.get_class_name(level)}") + + # if this is the last level of the hierarchy, add the node to this node + # otherwise find the child node that matches the next level of the hierarchy + # and add the node to that child + if node.is_last_level(level): + print(f"Adding node {node} to {self.instance_name}") + self.nodes.append(node) + else: + for child in self.children: + if child.instance_name == node.get_instance_name(level + 1): + child.add_node(node, level + 1) + return + + # if no child matches the next level of the hierarchy, create a new child node + new_child = PytorchHierarchyNode() + new_child.instance_name = node.get_instance_name(level + 1) + new_child.module_type = node.get_class_name(level + 1) + self.children.append(new_child) + new_child.add_node(node, level + 1) + ################################################## ## TODO (JSM): encapsulte this into a function ## ################################################## diff --git a/onnxscript/utils/test_PytorchHierarchyNode.py b/onnxscript/utils/test_PytorchHierarchyNode.py new file mode 100644 index 0000000000..308831f398 --- /dev/null +++ b/onnxscript/utils/test_PytorchHierarchyNode.py @@ -0,0 +1,93 @@ +import pytest + +from onnxscript import script +from typing import List, Tuple + + +from onnxscript.utils import graph_view_utils as gvu +from onnxscript import ir + + + +tape = ir._tape.Tape() + + +def build_metadata(instance_hierarchy, class_hierarchy): + return { + "pkg.torch.onnx.class_hierarchy": str(class_hierarchy), + "pkg.torch.onnx.name_scopes": str(instance_hierarchy) + } + +class HierarchyBuilder(ir._tape.Tape): + def __init__(self, graph_like: ir.Graph | ir.Function | None = None) -> None: + super().__init__(graph_like) + + def add_hierarchical_node(self, hierarchy: List[Tuple[str, str]]): + # Hierarchy is a list of tuples, where each tuple contains (instance name, module type) + + instance_hierarchy = [] + class_hierarchy = [] + for instance_name, module_type in hierarchy: + instance_hierarchy.append(instance_name) + class_hierarchy.append(module_type) + + self.op( + op_type="HierarchyNode", + inputs=[], + metadata_props=build_metadata(instance_hierarchy, class_hierarchy), + ) + +# add a basic test to make sure the test file is working +def test_onenode(): + + B = HierarchyBuilder() + B.add_hierarchical_node([("", "class0")]) + + P = gvu.PytorchHierarchyNode() + print("\nadding bnode") + P.add_node( + B.nodes[0] + ) + print("\nadded bnode") + nodes = P.get_nodes([""]) + assert len(nodes) == 1 + assert nodes[0] is B.nodes[0] + +def test_twonodes(): + B = HierarchyBuilder() + B.add_hierarchical_node([("", "class0")]) + B.add_hierarchical_node([("", "class0")]) + + P = gvu.PytorchHierarchyNode() + P.add_nodes(B.nodes) + + nodes = P.get_nodes([""]) + assert len(nodes) == 2 + assert nodes[0] is B.nodes[0] + assert nodes[1] is B.nodes[1] + +def test_twonodes_one_with_hierarchy(): + B = HierarchyBuilder() + B.add_hierarchical_node([("", "class0")]) + B.add_hierarchical_node([("", "class0"), ("a", "class1")]) + + + P = gvu.PytorchHierarchyNode() + P.add_nodes(B.nodes) + + print("Printing hierarchy") + import pdb + pdb.set_trace() + P.print_hierarchy() + + print("test 1") + nodes = P.get_nodes([""]) + assert len(nodes) == 2 + assert nodes[0] in B.nodes + assert nodes[1] in B.nodes + assert nodes[0] is not nodes[1] + + print("test 2") + nodes = P.get_nodes(["", "a"]) + assert len(nodes) == 1 + assert nodes[0] is B.nodes[1] From 90f85d173e3f2d4b1069f3ae5e97f1b64079ca99 Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Mon, 14 Apr 2025 21:50:41 +0000 Subject: [PATCH 09/14] now properly rejects non-hierarchical nodes --- onnxscript/utils/graph_view_utils.py | 27 ++++---- onnxscript/utils/test_PytorchHierarchyNode.py | 63 ++++++++++++++++--- 2 files changed, 68 insertions(+), 22 deletions(-) diff --git a/onnxscript/utils/graph_view_utils.py b/onnxscript/utils/graph_view_utils.py index 5d1964e993..a1f8f32f90 100644 --- a/onnxscript/utils/graph_view_utils.py +++ b/onnxscript/utils/graph_view_utils.py @@ -106,8 +106,6 @@ def __init__(self, node): self.instance_metadata = ast.literal_eval(self._node.metadata_props['pkg.torch.onnx.name_scopes']) self.class_metadata = ast.literal_eval(self._node.metadata_props['pkg.torch.onnx.class_hierarchy']) print(f'self.node.metadata_props: {self._node.metadata_props}') - else: - raise ValueError(f"Node {self._node.name} does not have the required metadata properties") def check_node_metadata_exists(self): if 'pkg.torch.onnx.name_scopes' in self._node.metadata_props and \ @@ -174,8 +172,8 @@ def get_nodes(self, search_hierarchy: List[str], instance_hierarchy: List[str] = # base case for recursion # 1 - search_hierarchy does not match instance_hierarchy instance_hierarchy.append(self.instance_name) - print(f"search_hierarchy: {search_hierarchy}") - print(f"instance_hierarchy: {instance_hierarchy}") + #print(f"search_hierarchy: {search_hierarchy}") + #print(f"instance_hierarchy: {instance_hierarchy}") if not self.hierarchy_matches(search_hierarchy, instance_hierarchy): return [] @@ -189,15 +187,17 @@ def get_nodes(self, search_hierarchy: List[str], instance_hierarchy: List[str] = return nodes_to_return - def add_nodes(self, nodes): - for node in nodes: - self.add_node(node) + # def add_nodes(self, nodes): + # for node in nodes: + # self.add_node(node) def add_node(self, node, level=0): print("calling add_node") if not isinstance(node, PytorchMetadataNode): node = PytorchMetadataNode(node) + if node.check_node_metadata_exists() is False: + return False if self.instance_name is None: print(f"setting instance name to {node.get_instance_name(level)}") @@ -207,28 +207,29 @@ def add_node(self, node, level=0): # check that instance name and module type match if self.instance_name != node.get_instance_name(level): - raise ValueError(f"Instance name mismatch: {self.instance_name} != {node.get_instance_name(level)}") + #raise ValueError(f"Instance name mismatch: {self.instance_name} != {node.get_instance_name(level)}") + return False if self.module_type != node.get_class_name(level): - raise ValueError(f"Module type mismatch: {self.module_type} != {node.get_class_name(level)}") - + #raise ValueError(f"Module type mismatch: {self.module_type} != {node.get_class_name(level)}") + return False # if this is the last level of the hierarchy, add the node to this node # otherwise find the child node that matches the next level of the hierarchy # and add the node to that child if node.is_last_level(level): print(f"Adding node {node} to {self.instance_name}") self.nodes.append(node) + return True else: for child in self.children: if child.instance_name == node.get_instance_name(level + 1): - child.add_node(node, level + 1) - return + return child.add_node(node, level + 1) # if no child matches the next level of the hierarchy, create a new child node new_child = PytorchHierarchyNode() new_child.instance_name = node.get_instance_name(level + 1) new_child.module_type = node.get_class_name(level + 1) self.children.append(new_child) - new_child.add_node(node, level + 1) + return new_child.add_node(node, level + 1) ################################################## ## TODO (JSM): encapsulte this into a function ## diff --git a/onnxscript/utils/test_PytorchHierarchyNode.py b/onnxscript/utils/test_PytorchHierarchyNode.py index 308831f398..f1d522edbd 100644 --- a/onnxscript/utils/test_PytorchHierarchyNode.py +++ b/onnxscript/utils/test_PytorchHierarchyNode.py @@ -22,6 +22,13 @@ class HierarchyBuilder(ir._tape.Tape): def __init__(self, graph_like: ir.Graph | ir.Function | None = None) -> None: super().__init__(graph_like) + def add_non_hierarchical_node(self): + # Non-hierarchical node is a single instance of a module + self.op( + op_type="NonHierarchicalNode", + inputs=[] + ) + def add_hierarchical_node(self, hierarchy: List[Tuple[str, str]]): # Hierarchy is a list of tuples, where each tuple contains (instance name, module type) @@ -37,6 +44,13 @@ def add_hierarchical_node(self, hierarchy: List[Tuple[str, str]]): metadata_props=build_metadata(instance_hierarchy, class_hierarchy), ) + +def add_node_expect_success(P, node): + assert P.add_node(node) is True + +def add_node_expect_failure(P, node): + assert P.add_node(node) is False + # add a basic test to make sure the test file is working def test_onenode(): @@ -44,11 +58,13 @@ def test_onenode(): B.add_hierarchical_node([("", "class0")]) P = gvu.PytorchHierarchyNode() + print("\nadding bnode") - P.add_node( - B.nodes[0] - ) + + add_node_expect_success(P, B.nodes[0]) + print("\nadded bnode") + nodes = P.get_nodes([""]) assert len(nodes) == 1 assert nodes[0] is B.nodes[0] @@ -59,7 +75,8 @@ def test_twonodes(): B.add_hierarchical_node([("", "class0")]) P = gvu.PytorchHierarchyNode() - P.add_nodes(B.nodes) + add_node_expect_success(P, B.nodes[0]) + add_node_expect_success(P, B.nodes[1]) nodes = P.get_nodes([""]) assert len(nodes) == 2 @@ -73,21 +90,49 @@ def test_twonodes_one_with_hierarchy(): P = gvu.PytorchHierarchyNode() - P.add_nodes(B.nodes) + add_node_expect_success(P, B.nodes[0]) + add_node_expect_success(P, B.nodes[1]) print("Printing hierarchy") - import pdb - pdb.set_trace() P.print_hierarchy() - print("test 1") nodes = P.get_nodes([""]) assert len(nodes) == 2 assert nodes[0] in B.nodes assert nodes[1] in B.nodes assert nodes[0] is not nodes[1] - print("test 2") nodes = P.get_nodes(["", "a"]) assert len(nodes) == 1 assert nodes[0] is B.nodes[1] + +def test_three_levels_of_hierarchy(): + B = HierarchyBuilder() + B.add_hierarchical_node([("", "class0"), ("a", "class1"), ("b", "class2")]) + B.add_hierarchical_node([("", "class0"), ("a", "class1"), ("b", "class2")]) + B.add_hierarchical_node([("", "class0"), ("a", "class1"), ("b", "class2")]) + B.add_hierarchical_node([("", "class0"), ("a", "class1"), ("b", "class2")]) + + P = gvu.PytorchHierarchyNode() + add_node_expect_success(P, B.nodes[0]) + add_node_expect_success(P, B.nodes[1]) + add_node_expect_success(P, B.nodes[2]) + add_node_expect_success(P, B.nodes[3]) + + nodes = P.get_nodes(["", "a", "b"]) + assert len(nodes) == 4 + assert nodes[0] in B.nodes + assert nodes[1] in B.nodes + assert nodes[2] in B.nodes + assert nodes[3] in B.nodes + + +def test_non_hierarchical_node(): + B = HierarchyBuilder() + B.add_non_hierarchical_node() + + P = gvu.PytorchHierarchyNode() + add_node_expect_failure(P, B.nodes[0]) + + assert len(P.get_nodes([""])) == 0 + assert len(P.children) == 0 From e47161a3736a0e8b3102cff2cbacffcf86f3aa79 Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Mon, 14 Apr 2025 22:44:12 +0000 Subject: [PATCH 10/14] remove old code --- onnxscript/utils/graph_view_utils.py | 121 --------------------------- 1 file changed, 121 deletions(-) diff --git a/onnxscript/utils/graph_view_utils.py b/onnxscript/utils/graph_view_utils.py index a1f8f32f90..9c6868cea0 100644 --- a/onnxscript/utils/graph_view_utils.py +++ b/onnxscript/utils/graph_view_utils.py @@ -230,124 +230,3 @@ def add_node(self, node, level=0): new_child.module_type = node.get_class_name(level + 1) self.children.append(new_child) return new_child.add_node(node, level + 1) - -################################################## -## TODO (JSM): encapsulte this into a function ## -################################################## - -# model = onnx.load('mistral.onnx') - -# model_ir = ir.serde.deserialize_model(model) - -# layer_dict = {} - -# no_name_scopes = set() -# for node in ir.traversal.RecursiveGraphIterator(model_ir.graph): -# if 'pkg.torch.onnx.name_scopes' in node.metadata_props: -# name_scopes = ast.literal_eval(node.metadata_props['pkg.torch.onnx.name_scopes']) -# if name_scopes[1].startswith('layer'): -# if name_scopes[1] not in layer_dict: -# layer_dict[name_scopes[1]] = [] -# layer_dict[name_scopes[1]].append(node) -# else: -# print(node) -# else: -# no_name_scopes.add(node) - -# scoped_nodes = set() -# stop = False -# for node in no_name_scopes: -# #input('pause for enter') -# print(node) -# layer_usage = set() -# for value in node.outputs: -# print(f"\t{value}") -# print(f"\t{node.name}") -# if value.name == 'val_39': -# stop = True -# input('found val_39') -# for use in value.uses(): -# used_node = use[0] -# print(f"\t\t{used_node}") -# if 'pkg.torch.onnx.name_scopes' in used_node.metadata_props: -# print(f"\t\t\t{used_node.metadata_props['pkg.torch.onnx.name_scopes']}") -# name_scopes = ast.literal_eval(used_node.metadata_props['pkg.torch.onnx.name_scopes']) -# layer_usage.add(name_scopes[1]) -# else: -# print(f"\t\t\tno scope") -# layer_usage.add("") -# print(f'\t\tlayer usage {layer_usage}') -# if len(layer_usage) == 1: -# scope = next(iter(layer_usage)) -# print(scope) -# print(layer_dict.keys()) -# print(node) -# if scope in layer_dict: -# layer_dict[scope].insert(0, node) -# scoped_nodes.add(node) -# if stop == True: -# pass - -# for key in layer_dict: -# print(key+"********") -# layer = bGraphView(key, layer_dict[key]) - -# print(key+"STARTIO********") -# print(f"inputs: {layer.inputs}") -# print(f"outputs: {layer.outputs}") -# print(key+"ENDIO********") -# print("\n\n") - -# #exit(1) -# #layer0_gv = bGraphView('layers.0', layer_dict['layers.0']) -# layer1_gv = bGraphView('layers_1',layer_dict['layers.1']) - -# print(f"inputs: {layer1_gv.inputs}") -# print(f"outputs: {layer1_gv.outputs}") -# #print(f"initializers: {.initializers}") - -# #exit(1) - -# layer0 = layer_dict['layers.0'] -# layer1 = layer_dict['layers.1'] - -# d = {} - -# for a in layer0: -# d[a] = {a} -# for b in layer1: -# if ir_node__eq__(a, b): -# d[a].add(b) - -# for node in d: -# if len(d[node]) == 1: -# print(f"single node set: {node}") -# print(f"value: {node.outputs[0]}") -# #print(f"uses: {node.outputs[0].uses()}") -# print(f"len: {len(node.outputs[0].uses())}") -# for use in node.outputs[0].uses(): -# print(str(use)+'\n\n') -# #if len(d[node]) > 2: -# # print(f"set with {len(d[node])} nodes") -# # for n in d[node]: -# # print(f'\t{n}') -# # input('press n to continue') - -# #print(model.graph) - -# layer = bGraphView('layers_0', layer_dict['layers.0']) - -# for node in layer._nodes: -# if node.name == 'node_Constant_2143': -# print(node) -# input('found it!') - - -# model = ir.Model(layer, ir_version=10) -# proto = ir.serde.serialize_model(model) -# print('saving') -# onnx.save(proto, 'layer0.onnx') - - -# print('done') - From cbbfe784530d3720d074df972e51a5bba199a6e7 Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Tue, 15 Apr 2025 21:51:08 +0000 Subject: [PATCH 11/14] add comprehensive mistral test checking --- onnxscript/utils/test_PytorchHierarchyNode.py | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/onnxscript/utils/test_PytorchHierarchyNode.py b/onnxscript/utils/test_PytorchHierarchyNode.py index f1d522edbd..e2a85770ff 100644 --- a/onnxscript/utils/test_PytorchHierarchyNode.py +++ b/onnxscript/utils/test_PytorchHierarchyNode.py @@ -1,5 +1,7 @@ import pytest +import ast +import onnx from onnxscript import script from typing import List, Tuple @@ -136,3 +138,47 @@ def test_non_hierarchical_node(): assert len(P.get_nodes([""])) == 0 assert len(P.children) == 0 + +def build_golden_results(nodes): + golden_get_node_results = {} + for node in nodes: + metadata = node.metadata_props + if metadata: + new_key = "/".join(ast.literal_eval(metadata.get("pkg.torch.onnx.name_scopes"))) + "/" + + # search the hierarchy_dict for entries that are a prefix of the key + # add the current node to the list of nodes for that key + for key in golden_get_node_results: + if new_key != key and new_key.startswith(key): + golden_get_node_results[key].append(node) + + # check if the new_key is already in the hierarchy_dict + if new_key not in golden_get_node_results: + golden_get_node_results[new_key] = [] + + # add the current node to the list of nodes for the new_key + golden_get_node_results[new_key].append(node) + + return golden_get_node_results + +def test_mistral_pytorch_with_metadata(): + model_proto = onnx.load('/home/joshmonson/Projects/experiments/finn_mlo_graphs/demo/mistral.onnx') + model = ir.serde.deserialize_model(model_proto) + graph = model.graph + + + P = gvu.PytorchHierarchyNode() + count_not_added = 0 + for node in graph._nodes: + added = P.add_node(node) + if not added: + count_not_added += 1 + + golden = build_golden_results(graph._nodes) + for key, gnodes in golden.items(): + key = key.rstrip("/") + nodes = P.get_nodes(key.split("/")) + # check if the nodes in the result are in the list of nodes for that key + print(f"got nodes for key {key}: {nodes}") + for gnode in gnodes: + assert gnode in nodes, f"Node {gnode.metadata_props.get('pkg.torch.onnx.name_scopes')} not found in nodes for key {key}" From 26f906c1df8f2899f133b7e9ebc24f88057ce79b Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Tue, 15 Apr 2025 23:43:22 +0000 Subject: [PATCH 12/14] add code to place unannotated constant nodes --- onnxscript/utils/graph_view_utils.py | 48 ++++++++++++++++++- onnxscript/utils/test_PytorchHierarchyNode.py | 12 +++-- 2 files changed, 56 insertions(+), 4 deletions(-) diff --git a/onnxscript/utils/graph_view_utils.py b/onnxscript/utils/graph_view_utils.py index 9c6868cea0..2b53ec01a1 100644 --- a/onnxscript/utils/graph_view_utils.py +++ b/onnxscript/utils/graph_view_utils.py @@ -190,9 +190,12 @@ def get_nodes(self, search_hierarchy: List[str], instance_hierarchy: List[str] = # def add_nodes(self, nodes): # for node in nodes: # self.add_node(node) - def add_node(self, node, level=0): + # if node.name == 'node_Constant_2153': + # import pdb + # pdb.set_trace() + print("calling add_node") if not isinstance(node, PytorchMetadataNode): node = PytorchMetadataNode(node) @@ -230,3 +233,46 @@ def add_node(self, node, level=0): new_child.module_type = node.get_class_name(level + 1) self.children.append(new_child) return new_child.add_node(node, level + 1) + +def add_metadata_to_unannotated_constant_nodes(graph): + for node in graph._nodes: + if node.op_type == 'Constant' and not node.metadata_props: + # search all of the uses to determine which hierarhcy to add + # to the constant node + # if all users have the same hierarchy, add that hierarchy to the constant node + # if the users have different hierarchies, use the one level above the highest + # level in the hierarchy + metadata = set() + for output in node.outputs: + for user in output.uses(): + user_node = user[0] + if user_node.metadata_props: + metadata.add((user_node.metadata_props['pkg.torch.onnx.name_scopes'], + user_node.metadata_props['pkg.torch.onnx.class_hierarchy'])) + + if len(metadata) == 1: + name, class_hier = metadata.pop() + node.metadata_props['pkg.torch.onnx.name_scopes'] = name + node.metadata_props['pkg.torch.onnx.class_hierarchy'] = class_hier + else: + # convert the metadata_namescope set to a list of lists + metadata_list = [(ast.literal_eval(x[0]),ast.literal_eval(x[1])) for x in list(metadata)] + + # find the index of namescope_list with the shortest length + min_index = min(range(len(metadata_list)), key=lambda i: len(metadata_list[i][0])) + + # get the shortest namescope + shortest_hierarchy = metadata_list[min_index] + + # remove the last level of the hierarchy + target_name = shortest_hierarchy[0][:len(shortest_hierarchy[0]) - 1] + target_class = shortest_hierarchy[1][:len(shortest_hierarchy[1]) - 1] + + # convert the target_hierarchy to a string + target_name_str = str(target_name) + target_class_str = str(target_class) + + # add the target_hierarchy to the node + node.metadata_props['pkg.torch.onnx.name_scopes'] = target_name_str + node.metadata_props['pkg.torch.onnx.class_hierarchy'] = target_class_str + return graph diff --git a/onnxscript/utils/test_PytorchHierarchyNode.py b/onnxscript/utils/test_PytorchHierarchyNode.py index e2a85770ff..18f7f4189e 100644 --- a/onnxscript/utils/test_PytorchHierarchyNode.py +++ b/onnxscript/utils/test_PytorchHierarchyNode.py @@ -166,19 +166,25 @@ def test_mistral_pytorch_with_metadata(): model = ir.serde.deserialize_model(model_proto) graph = model.graph + graph = gvu.add_metadata_to_unannotated_constant_nodes(graph) P = gvu.PytorchHierarchyNode() - count_not_added = 0 + unadded_nodes = [] for node in graph._nodes: added = P.add_node(node) if not added: - count_not_added += 1 + unadded_nodes.append(node) + print(f"Number of nodes not added: {len(unadded_nodes)}") + print(f"unadded nodes: {unadded_nodes}") + for node in unadded_nodes: + print(f"unadded node: {node}") + print(f"unadded node metadata: {node.metadata_props}") golden = build_golden_results(graph._nodes) for key, gnodes in golden.items(): key = key.rstrip("/") nodes = P.get_nodes(key.split("/")) # check if the nodes in the result are in the list of nodes for that key - print(f"got nodes for key {key}: {nodes}") + #print(f"got nodes for key {key}: {nodes}") for gnode in gnodes: assert gnode in nodes, f"Node {gnode.metadata_props.get('pkg.torch.onnx.name_scopes')} not found in nodes for key {key}" From abf892d916956be461817af04470c5eeeca23497 Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Wed, 16 Apr 2025 14:27:16 +0000 Subject: [PATCH 13/14] remove print and old comments --- onnxscript/utils/graph_view_utils.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/onnxscript/utils/graph_view_utils.py b/onnxscript/utils/graph_view_utils.py index 2b53ec01a1..a9cd684a7f 100644 --- a/onnxscript/utils/graph_view_utils.py +++ b/onnxscript/utils/graph_view_utils.py @@ -187,39 +187,27 @@ def get_nodes(self, search_hierarchy: List[str], instance_hierarchy: List[str] = return nodes_to_return - # def add_nodes(self, nodes): - # for node in nodes: - # self.add_node(node) def add_node(self, node, level=0): - # if node.name == 'node_Constant_2153': - # import pdb - # pdb.set_trace() - - print("calling add_node") if not isinstance(node, PytorchMetadataNode): node = PytorchMetadataNode(node) if node.check_node_metadata_exists() is False: return False if self.instance_name is None: - print(f"setting instance name to {node.get_instance_name(level)}") self.instance_name = node.get_instance_name(level) if self.module_type is None: self.module_type = node.get_class_name(level) # check that instance name and module type match if self.instance_name != node.get_instance_name(level): - #raise ValueError(f"Instance name mismatch: {self.instance_name} != {node.get_instance_name(level)}") return False if self.module_type != node.get_class_name(level): - #raise ValueError(f"Module type mismatch: {self.module_type} != {node.get_class_name(level)}") return False # if this is the last level of the hierarchy, add the node to this node # otherwise find the child node that matches the next level of the hierarchy # and add the node to that child if node.is_last_level(level): - print(f"Adding node {node} to {self.instance_name}") self.nodes.append(node) return True else: From b8c606825100f7e6b2551a23dd9161b0b5470c56 Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Wed, 16 Apr 2025 14:28:48 +0000 Subject: [PATCH 14/14] remove additional comments --- onnxscript/utils/graph_view_utils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/onnxscript/utils/graph_view_utils.py b/onnxscript/utils/graph_view_utils.py index a9cd684a7f..4e5884bd8b 100644 --- a/onnxscript/utils/graph_view_utils.py +++ b/onnxscript/utils/graph_view_utils.py @@ -105,7 +105,6 @@ def __init__(self, node): if self.check_node_metadata_exists(): self.instance_metadata = ast.literal_eval(self._node.metadata_props['pkg.torch.onnx.name_scopes']) self.class_metadata = ast.literal_eval(self._node.metadata_props['pkg.torch.onnx.class_hierarchy']) - print(f'self.node.metadata_props: {self._node.metadata_props}') def check_node_metadata_exists(self): if 'pkg.torch.onnx.name_scopes' in self._node.metadata_props and \ @@ -172,8 +171,6 @@ def get_nodes(self, search_hierarchy: List[str], instance_hierarchy: List[str] = # base case for recursion # 1 - search_hierarchy does not match instance_hierarchy instance_hierarchy.append(self.instance_name) - #print(f"search_hierarchy: {search_hierarchy}") - #print(f"instance_hierarchy: {instance_hierarchy}") if not self.hierarchy_matches(search_hierarchy, instance_hierarchy): return []