diff --git a/onnxscript/rewriter/pattern_builder_jsm.py b/onnxscript/rewriter/pattern_builder_jsm.py new file mode 100644 index 0000000000..8f3ec1f324 --- /dev/null +++ b/onnxscript/rewriter/pattern_builder_jsm.py @@ -0,0 +1,546 @@ +import copy + +import numpy as np +import onnx + +from onnxscript import ir +from onnxscript import rewriter +from onnxscript.rewriter.pattern import ( + RewriterContext, MatchResult, ValuePattern, GraphPattern, OpsetPatternBuilder, pattern_builder, NodeOutputPattern, ReplacementSubgraph, ReplacementPatternFunction +) + + + +from collections.abc import Iterable +from onnxscript.utils.graph_view_utils import bGraphView + +#print("**************************************") +#print("********* Pattern Builder ************") +#print("**************************************") + + +def direct_convert_ir_graph_to_pattern(graph): + + + # Transform IR values to ValuePatterns + vmap = {} + for input in graph.inputs: + vmap[input] = ValuePattern(input.name) + + for init in graph.initializers: + vmap[init] = ValuePattern(init.name) + + + for node in graph._nodes: + if node.op_type == 'Constant': + vmap[node.outputs[0]] = ValuePattern(node.outputs[0].name) + + builder = OpsetPatternBuilder("", record=True) + + with pattern_builder(builder): + for node in graph._nodes: + ninputs = [] + for ninput in node.inputs: + ninputs.append(vmap[ninput]) + + #if len(node.outputs) > 1: + vp_outputs = builder.__getattr__(node.op_type)(*ninputs,_domain=node.domain, _outputs=len(node.outputs)) + #else: + # vp_outputs = builder.__getattr__(node.op_type)(*ninputs) + + + if isinstance(vp_outputs,NodeOutputPattern): + vp_outputs = [vp_outputs] + + for vp_output in iter(vp_outputs): + vmap[node.outputs[vp_output.output_index]] = vp_output + + + pinputs = [] + for input in graph.inputs: + pinputs.append(vmap[input]) + + # build graph outputs + poutputs = [] + for output in graph.outputs: + poutputs.append(vmap[output]) + + return GraphPattern(inputs=pinputs, outputs=poutputs, nodes=builder.nodes()) + +from enum import Enum + +def remove_input_from_node(node, inp): + node._inputs = [x for x in node._inputs if x is not inp] + inp._remove_usage(node) + + +class LoopBodyInputType(Enum): + UNDEFINED = 0 + ACTIVATION = 1 + CONSTANT = 2 + PARAMETER = 3 + ITERATOR = 4 + CONDITION = 5 + + def __str__(self): + return self.name + +class LoopBodyTemplate: + def __init__(self, filename): + self.load(filename) + self.pattern = direct_convert_ir_graph_to_pattern(self._ir_graph) + self.function = self._build_ir_function() + self.function_replace = self._build_function_replace_pattern() + self.signature = [LoopBodyInputType.UNDEFINED] * len(self._ir_graph.inputs) + + def _build_ir_function(self): + return ir.Function(domain='loop', + name='fn_' + self._ir_graph.name, + graph = self._ir_graph, + attributes=[]) + + def _build_function_replace_pattern(self): + + inputs = [vdisconnect(copy.copy(x)) for x in self._ir_graph.inputs] + outputs = [vdisconnect(copy.copy(x)) for x in self._ir_graph.outputs] + + node = ir.Node(domain=self.function.domain, + version=0, + op_type=self.function.name, + inputs=inputs, outputs=outputs) + + g = ir.Graph(inputs=inputs, outputs=outputs, nodes=[node]) + + return ReplacementPatternGraph(g) + + def get_iterator_index(self): + for i in range(len(self.signature)): + if self.signature[i] == LoopBodyInputType.ITERATOR: + return i + + def insert_gather_nodes(self, loop_iterations): + for index,LoopInputType in enumerate(self.signature): + if LoopInputType == LoopBodyInputType.PARAMETER: + # The Current Input Value will be the output of the gather node + gather_index = self.function.inputs[self.get_iterator_index()] + squeeze_out = self.function.inputs[index] + + gather_in = ir.Value(name=squeeze_out.name+"_gather_in", + shape=ir.Shape([loop_iterations, *squeeze_out.shape.dims]), + type=squeeze_out.type) + gather_out = ir.Value(name=squeeze_out.name+"_gather_out", + shape=ir.Shape([1, *squeeze_out.shape.dims]), + type=squeeze_out.type) + for usage in squeeze_out.uses(): + if usage.node.op_type == "Identity": + usage.node.replace_input_with(usage.idx, gather_in) + usage.node.outputs[0].shape = copy.copy(gather_in.shape) + + self.function.inputs[index] = gather_in + + self.function.append(ir.Node(domain='', + op_type='Gather', + inputs = [gather_in, gather_index], + outputs = [gather_out], + num_outputs = 1 + ) + ) + squeeze_out.name += "_squeeze_out" + squeeze_axis = build_constant_from_tensor(f'{gather_out.name}_squeeze_axis', ir.Tensor(np.array([0]))) + self.function.append(squeeze_axis) + self.function.append(ir.Node(domain='', + op_type='Squeeze', + inputs=[gather_out, squeeze_axis.outputs[0]], + outputs=[squeeze_out], + num_outputs= 1, + version = 13 + ) + ) + + self.function.sort() + + + def build_function_match_pattern(self, graph): + graph.sort() + nodes = find_nodes_of_optype(graph, self.function.name) + nodes.insert(0,graph.node('iteration_ext')) + nodes.insert(0,graph.node('condition_ext')) + + ir_model = ir.Model(bGraphView('inlined_pipe_pattern', nodes), ir_version=10) + + model = ir.serde.serialize_model(ir_model) + onnx.save(model, 'pipeline_match_pattern.onnx') + + pattern = direct_convert_ir_graph_to_pattern(ir_model.graph) + + return (pattern, nodes) + + + def load(self, filename): + self._model_proto = onnx.load(filename) + self._ir_model = ir.serde.deserialize_model(self._model_proto) + self._ir_graph = self._ir_model.graph + + def update(self): + self._ir_model = ir.Model(self._ir_graph, ir_version = 10) + self._model_proto = ir.serde.serialize_model(self._ir_model) + + def save(self, filename): + self.update() + onnx.save(self._model_proto, filename) + + def set_signature_index(self, index, stype): + self.signature[index] = stype + + @property + def output_signature(self): + # The output signature is the same as the input signature but without the iteration input + return self.signature[1:] + + +def same(input_list): + return len(set(input_list)) == 1 + +def append_output_to_node(node, output): + output._producer = node + output._index = node.outputs[-1]._index + 1 + node._outputs = (*node._outputs, output) + node._num_outputs = len(node._outputs) + +def prepend_output_to_node(node, output): + output._producer = node + output._index = 0 + for outp in node._outputs: + outp._index += 1 + node._outputs = (output, *node._outputs) + node._num_outputs = len(node._outputs) + +def prepend_input_to_node(node, input): + input._add_usage(node, 0) + # increment the index for all uses on this node + for i,inp in enumerate(node._inputs): + inp._remove_usage(node, i) + inp._add_usage(node, i+1) + + node._inputs = (input, *node._inputs) + +def normalize_io_for_loop_rolling(graph, LoopBody): + + # This takes a graph that has consecutive identical nodes + # and normalizes the i/o indicies prior to the loop rolling + # optimization. Specifically, this function identifies output- + # to-input pairs and permutes the indices of the input to match + # the previous output. + + # The ONNX loop node requires that interloop dependencies + # have identical input and output indices. + + # Run a topological sort to ensure the layers are in order. + graph.sort() + + # get the consecutive node layers + # TODO: write a check to ensure that there is only one + # set of consecutive nodes. + nodes = find_nodes_of_optype(graph, LoopBody.function.name) + + # Loop through all the nodes (execept the last one) and + # identify the input to output pairs + input_swaps = [] + for i in range(len(nodes)-1): + a_node = nodes[i] + b_node = nodes[i+1] + + for a_out in a_node.outputs: + # Require that outputs of a have a single use of b_node + assert(len(a_out.uses()) == 1) + assert(a_out.uses()[0][0] is b_node) + + a_use_index = a_out.uses()[0][1] + input_swap = (a_out.index(), a_use_index) + if i == 0: + # add swaps from the first node + input_swaps.append(input_swap) + else: + # check that they are the same in the rest + assert(input_swap in input_swaps) + + # apply the input swaps to each nodes + for node in nodes: + for swap in input_swaps: + a = node.inputs[swap[0]] + b = node.inputs[swap[1]] + node.replace_input_with(swap[0], b) + node.replace_input_with(swap[1], a) + + # apply the input swaps to the function graph + # mark swapped nodes as activations + activations = 0 + for swap in input_swaps: + a = LoopBody.function.inputs[swap[0]] + b = LoopBody.function.inputs[swap[1]] + LoopBody.function.inputs[swap[0]] = b + LoopBody.function.inputs[swap[1]] = a + LoopBody.signature[swap[0]] = LoopBodyInputType.ACTIVATION + activations+=1 + + # Next Inputs according to how they are produced. + # Indexable inputs will have different constant or none producers + # Constant values broadcast to all nodes will have the same producer + # Skip the (all) Activation inputs (have been swapped to beginning of the list) + for index in range(activations, len(nodes[0].inputs)): + inputs = [] + producers = [] + for node in nodes: + cinput = node.inputs[index] + inputs.append(cinput) + if same(inputs): + # Constant with Respect to Loop + LoopBody.signature[index] = LoopBodyInputType.CONSTANT + else: + # Must be Indexed + LoopBody.signature[index] = LoopBodyInputType.PARAMETER + + + # Match Output Signature to Input Signature + for index,LoopInputType in enumerate(LoopBody.signature): + cinput = LoopBody.function.inputs[index] + noutput = vdisconnect(copy.copy(cinput)) + noutput._uses = {} + update_node_outputs = False + + if LoopInputType == LoopBodyInputType.CONSTANT or \ + LoopInputType == LoopBodyInputType.PARAMETER: + # Update Names and Add Output + cinput.name += "_" + str(LoopInputType) + "_in" + noutput.name += "_" + str(LoopInputType) + "_out" + + LoopBody.function.outputs.append(noutput) + + # Add Identify to Pass Inputs to new Output + Ident = ir.Node(domain='', + op_type='Identity', + inputs = [cinput], + outputs = [noutput], + num_outputs =1) + LoopBody.function.append(Ident) + + #Add Output to Function Call Nodes + for i,node in enumerate(nodes): + output_copy = copy.copy(noutput) + + #preserve single_assignment + output_copy.name += f'_{i}' + append_output_to_node(node,output_copy) + + # Add Iterator and Condition Inputs (Leave Unconnected within function for now) + iteration = ir.Value(name='iteration', shape=ir.Shape([1]), type=ir.TensorType(ir.DataType.INT64)) + condition = ir.Value(name='cond', shape=ir.Shape([1]), type=ir.TensorType(ir.DataType.BOOL)) + + LoopBody.signature.insert(0, LoopBodyInputType.CONDITION) + LoopBody.function.inputs.insert(0,condition) + + LoopBody.signature.insert(0, LoopBodyInputType.ITERATOR) + LoopBody.function.inputs.insert(0,iteration) + + + iteration_ext = build_constant_from_tensor('iteration_ext', ir.Tensor(np.array([len(nodes)]))) + condition_ext = build_constant_from_tensor('condition_ext', ir.Tensor(np.array([True]))) + #iteration_ext = ir.Value(name='iteration', shape=ir.Shape([1]), type=ir.TensorType(ir.DataType.INT64)) + #condition_ext = ir.Value(name='cond_ext', shape=ir.Shape([1]), type=ir.TensorType(ir.DataType.BOOL)) + + # add these nodes to the graph before the rest to maintain topological sorted-ness + nodes[0].prepend([iteration_ext,condition_ext]) + + for node in nodes: + prepend_input_to_node(node, condition_ext.outputs[0]) + prepend_input_to_node(node, iteration_ext.outputs[0]) + + # Add Identity Node for Condition + condition_out = ir.Value(name='cond_out', type=ir.TensorType(ir.DataType.BOOL)) + condition_ident = ir.Node(domain='', + op_type='Identity', + inputs = [condition], + outputs = [condition_out], + num_outputs =1) + LoopBody.function.append(condition_ident) + LoopBody.function.outputs.insert(0,condition_out) + + # Add New Condition Output to Node + for i,node in enumerate(nodes): + noutput = ir.Value(name=f'cond_out_{i}', type=ir.TensorType(ir.DataType.BOOL)) + prepend_output_to_node(node,noutput) + + graph.sort() + return graph + + +import copy + +def vdisconnect(value): + value._uses = {} + value._producer = None + value._index = None + return value + + +class ReplacementPatternGraph(ReplacementPatternFunction): + def __init__(self, ir_graph): + self._graph = ir_graph + + + def get_replacement(self, match: MatchResult) -> ReplacementSubgraph | None: + + context = RewriterContext() + # match.bindings is dictionary of value_name (str) in replacement subgraph pattern (i.e. ir_graph -> IR Value in actual graph) + vvmap = {} # Dictionary mapping values in replacement subgraph pattern -> values in the replacement subgraph + + for value in self._graph.inputs: + if value.name in match.bindings: + vvmap[value] = match.bindings[value.name] + else: + vvmap[value] = value + + for node in self._graph._nodes: + ninputs = [] + for ninput in node.inputs: + ninputs.append(vvmap[ninput]) + + + coutput = context.__getattr__(node.op_type)(*ninputs, **node.attributes, _outputs=len(node.outputs), _domain=node.domain, _version=node.version) + if not isinstance(coutput,Iterable): + coutput = [coutput] + + for i, cout in enumerate(coutput): + vvmap[node.outputs[cout.index()]] = cout + + new_outputs = [vvmap[x] for x in self._graph.outputs] + return ReplacementSubgraph( + match, new_outputs, context.nodes, context.initializers, context.used_opsets + ) + + +def convert_graph_to_function_call_pattern(graph): + + inputs = [vdisconnect(copy.copy(x)) for x in graph.inputs] + outputs = [vdisconnect(copy.copy(x)) for x in graph.outputs] + + node = ir.Node('', graph.name+'_fcall', inputs, outputs=outputs) + + g = ir.Graph(inputs=inputs, outputs=outputs, nodes=[node]) + + + return ReplacementPatternGraph(g) + + +def find_nodes_of_optype(graph, layername): + nodes = [] + for node in ir.traversal.RecursiveGraphIterator(graph): + if node.op_type == layername: + nodes.append(node) + return nodes + + +def build_layer_pipeline_pattern(graph, layername): + + nodes = find_nodes_of_optype(graph, layername) + ir_model = ir.Model(bGraphView('inlined_pipe_pattern', nodes), ir_version=10) + + model = ir.serde.serialize_model(ir_model) + onnx.save(model, 'pipeline_match_pattern.onnx') + + pattern = direct_convert_ir_graph_to_pattern(ir_model.graph) + + return (pattern, nodes) + +def build_constant_from_tensor(name, tensor): + value_attribute = ir.Attr(name='value', type=ir.AttributeType.TENSOR, value=tensor) + ir_value_out = ir.Value(name=name+'_out', type=ir.TensorType(tensor.dtype)) + return ir.Node('', 'Constant', name=name, inputs=[], outputs=[ir_value_out], attributes=[value_attribute]) + +def build_concat_node_from_inputs(inputs): + + axis = ir.Attr(name='axis', type=ir.AttributeType.INT, value=0) + ndim = len(inputs) * inputs[0].shape.dims[0] + output_shape = ir.Shape([ndim, *inputs[0].shape.dims[1:]]) + output = ir.Value(name=f'{inputs[0].name}_concat', shape=output_shape, type=inputs[0].type) + return ir.Node('', 'Concat', inputs=inputs, attributes=[axis], outputs=[output]) + +def build_reshape_node(inp, reshape_shape): + reshape_out = ir.Value(name=f'{inp.name}_reshape', type=inp.type) + return ir.Node('', 'Reshape', inputs=[inp, reshape_shape], outputs=[reshape_out]) + + + +def build_loop_replace_pattern(graph, LoopBody): + + nodes = find_nodes_of_optype(graph, LoopBody.function.name) + + graph_nodes = [] + loop_inputs = [] + + # Build max_iteration and condition constants + M = build_constant_from_tensor('M', ir.Tensor(np.array([len(nodes)]))) + cond = build_constant_from_tensor('cond', ir.Tensor(np.array([True]))) + + graph_nodes.append(M) + graph_nodes.append(cond) + + loop_inputs.append(M.outputs[0]) + loop_inputs.append(cond.outputs[0]) + + graph_inputs = [] + for i, LoopInputType in enumerate(LoopBody.signature): + + if LoopInputType == LoopBodyInputType.PARAMETER: + # Build Concat Node + concat_inputs = [] + for node in nodes: + nvalue = vdisconnect(copy.copy(node.inputs[i])) + graph_inputs.append(nvalue) + concat_inputs.append(nvalue) + + concat_node = build_concat_node_from_inputs(concat_inputs) + graph_nodes.append(concat_node) + + # Build Reshape Node + reshape_shape_const = build_constant_from_tensor(f'reshape_shape_const_{i}', ir.Tensor(np.array([len(nodes),*concat_inputs[0].shape.dims]))) + graph_nodes.append(reshape_shape_const) + + reshape_node = build_reshape_node(concat_node.outputs[0], reshape_shape_const.outputs[0]) + graph_nodes.append(reshape_node) + loop_inputs.append(reshape_node.outputs[0]) + elif LoopInputType == LoopBodyInputType.CONSTANT: + constant_input = nodes[0].inputs[i] + constant_node = constant_input.producer() + constant_value = constant_node.attributes['value'].value.numpy() + n_constant_node = build_constant_from_tensor(constant_input.name+"_const_val", ir.Tensor(constant_value)) + graph_nodes.append(n_constant_node) + loop_inputs.append(n_constant_node.outputs[0]) + elif LoopInputType == LoopBodyInputType.ACTIVATION: + cinp = vdisconnect(copy.copy(LoopBody.function.inputs[i])) + graph_inputs.append(cinp) + loop_inputs.append(cinp) + + + + loop_outputs = [] + graph_outputs = [] + for i, LoopOutputType in enumerate(LoopBody.output_signature): + output = vdisconnect(copy.copy(LoopBody.function.outputs[i])) + if LoopOutputType != LoopBodyInputType.CONDITION: + loop_outputs.append(output) + if LoopOutputType == LoopBodyInputType.ACTIVATION: + graph_outputs.append(output) + + LoopBody.insert_gather_nodes(len(nodes)) + body_attr = ir.Attr(name='body', type=ir.AttributeType.GRAPH, value=LoopBody.function._graph) + graph_nodes.append(ir.Node('', 'Loop', inputs=loop_inputs, attributes = [body_attr], outputs=loop_outputs, graph=None)) + + graph = ir.Graph(name='loop_replace',nodes=graph_nodes, inputs=graph_inputs, outputs=graph_outputs) + + graph.sort() + + model = ir.serde.serialize_model(ir.Model(graph, ir_version=8)) + onnx.save(model, 'replacementgraph.onnx') + + return ReplacementPatternGraph(graph) diff --git a/onnxscript/utils/graph_view_utils.py b/onnxscript/utils/graph_view_utils.py new file mode 100644 index 0000000000..4e5884bd8b --- /dev/null +++ b/onnxscript/utils/graph_view_utils.py @@ -0,0 +1,263 @@ +from onnxscript import ir + +import ast + + +def is_initializer(value): + return input.producer() == None + +def gather_initializers(inputs): + inits = set() + for input in inputs: + if is_initializer(input): + inits.add(input) + +def is_constant(value): + return value.producer.op_type == 'Constant' + +def gather_constants(inputs): + consts = set() + for input in inputs: + if is_constant(input): + consts.add(input) + + +def has_internal_usage(usage): + return "INTERNAL" in usage + +def has_external_usage(usage): + return "EXTERNAL" in usage + +def classify_usage(value, nodes): + usage = set() + for use in value.uses(): + user_node = use[0] + if user_node in nodes: + usage.add("INTERNAL") + else: + usage.add("EXTERNAL") + return usage + +def find_subgraph_inputs(nodes): + inputs = set() + initializers = set() + for node in nodes: + for ninput in node.inputs: + if ninput in node.graph.inputs: + inputs.add(ninput) + elif any(ninput is init for init in node.graph.initializers): + initializers.add(ninput) + elif ninput.producer() == None: + inputs.add(ninput) + elif ninput.producer() not in nodes: + inputs.add(ninput) + + return inputs, initializers + +def find_subgraph_outputs(nodes): + output = set() + used_output = set() + for node in nodes: + for noutput in node.outputs: + usage = classify_usage(noutput, nodes) + if has_external_usage(usage): + if has_internal_usage(usage): + used_output.add(noutput) + else: + output.add(noutput) + return [output, used_output] + + +def bGraphView(name, nodes): + + + [view_inputs, view_initializers] = find_subgraph_inputs(nodes) + [view_outputs, used_outputs] = find_subgraph_outputs(nodes) + + for used_output in used_outputs: + producer_node = used_output.producer() + nodes.remove(producer_node) + for output in producer_node.outputs: + usage = classify_usage(output,nodes) + if has_internal_usage(usage): + view_inputs.add(output) + if has_external_usage(usage): + if output in view_outputs: + view_outputs.remove(output) + + return ir.GraphView(name=name, + inputs=view_inputs, + outputs=view_outputs, + nodes=nodes, + initializers=view_initializers) + +######################################## +# rebuild_pytorch_dynamo_instance_code # +######################################## + +from typing import List + + +class PytorchMetadataNode: + def __init__(self, node): + self._node = node + + if self.check_node_metadata_exists(): + self.instance_metadata = ast.literal_eval(self._node.metadata_props['pkg.torch.onnx.name_scopes']) + self.class_metadata = ast.literal_eval(self._node.metadata_props['pkg.torch.onnx.class_hierarchy']) + + def check_node_metadata_exists(self): + if 'pkg.torch.onnx.name_scopes' in self._node.metadata_props and \ + 'pkg.torch.onnx.class_hierarchy' in self._node.metadata_props: + return True + else: + return False + + def is_last_level(self, level): + if len(self.instance_metadata) - 1 == level: + return True + else: + return False + + def get_instance_name(self, depth=0): + if depth >= len(self.instance_metadata): + return None + else: + return self.instance_metadata[depth] + + def get_class_name(self, depth=0): + if depth >= len(self.instance_metadata): + return None + else: + return self.class_metadata[depth] + +class PytorchHierarchyNode: + def __init__(self): + self.instance_name = None + self.module_type = None + self.children = [] + self.nodes = [] + + def print_hierarchy(self, instance_hierarchy: List[str] = None): + if instance_hierarchy is None: + instance_hierarchy = [] + instance_hierarchy.append(self.instance_name) + + for child in self.children: + child.print_hierarchy(list(instance_hierarchy)) + + for node in self.nodes: + print(f"Node: {node._node.name}, Instance: {'/'.join(instance_hierarchy)}, Module: {self.module_type}") + + + def get_unwrapped_nodes(self): + # Return _node from self._nodes + return [node._node for node in self.nodes] + + # Checks if the search hierarchy matches the instance hierarchy + def hierarchy_matches(self, search_hierarchy: List[str], instance_hierarchy: List[str] = []): + search_length = min(len(search_hierarchy), len(instance_hierarchy)) + for i in range(search_length): + if search_hierarchy[i] != instance_hierarchy[i]: + return False + return True + + # Return all nodes from the given name hierarchy on down + def get_nodes(self, search_hierarchy: List[str], instance_hierarchy: List[str] = None): + if instance_hierarchy is None: + instance_hierarchy = [] + + nodes_to_return = [] + # base case for recursion + # 1 - search_hierarchy does not match instance_hierarchy + instance_hierarchy.append(self.instance_name) + + if not self.hierarchy_matches(search_hierarchy, instance_hierarchy): + return [] + + for child in self.children: + child_nodes = child.get_nodes(search_hierarchy, list(instance_hierarchy)) + nodes_to_return.extend(child_nodes) + + if len(instance_hierarchy) >= len(search_hierarchy): + nodes_to_return.extend(self.get_unwrapped_nodes()) + + return nodes_to_return + + def add_node(self, node, level=0): + + if not isinstance(node, PytorchMetadataNode): + node = PytorchMetadataNode(node) + if node.check_node_metadata_exists() is False: + return False + + if self.instance_name is None: + self.instance_name = node.get_instance_name(level) + if self.module_type is None: + self.module_type = node.get_class_name(level) + + # check that instance name and module type match + if self.instance_name != node.get_instance_name(level): + return False + if self.module_type != node.get_class_name(level): + return False + # if this is the last level of the hierarchy, add the node to this node + # otherwise find the child node that matches the next level of the hierarchy + # and add the node to that child + if node.is_last_level(level): + self.nodes.append(node) + return True + else: + for child in self.children: + if child.instance_name == node.get_instance_name(level + 1): + return child.add_node(node, level + 1) + + # if no child matches the next level of the hierarchy, create a new child node + new_child = PytorchHierarchyNode() + new_child.instance_name = node.get_instance_name(level + 1) + new_child.module_type = node.get_class_name(level + 1) + self.children.append(new_child) + return new_child.add_node(node, level + 1) + +def add_metadata_to_unannotated_constant_nodes(graph): + for node in graph._nodes: + if node.op_type == 'Constant' and not node.metadata_props: + # search all of the uses to determine which hierarhcy to add + # to the constant node + # if all users have the same hierarchy, add that hierarchy to the constant node + # if the users have different hierarchies, use the one level above the highest + # level in the hierarchy + metadata = set() + for output in node.outputs: + for user in output.uses(): + user_node = user[0] + if user_node.metadata_props: + metadata.add((user_node.metadata_props['pkg.torch.onnx.name_scopes'], + user_node.metadata_props['pkg.torch.onnx.class_hierarchy'])) + + if len(metadata) == 1: + name, class_hier = metadata.pop() + node.metadata_props['pkg.torch.onnx.name_scopes'] = name + node.metadata_props['pkg.torch.onnx.class_hierarchy'] = class_hier + else: + # convert the metadata_namescope set to a list of lists + metadata_list = [(ast.literal_eval(x[0]),ast.literal_eval(x[1])) for x in list(metadata)] + + # find the index of namescope_list with the shortest length + min_index = min(range(len(metadata_list)), key=lambda i: len(metadata_list[i][0])) + + # get the shortest namescope + shortest_hierarchy = metadata_list[min_index] + + # remove the last level of the hierarchy + target_name = shortest_hierarchy[0][:len(shortest_hierarchy[0]) - 1] + target_class = shortest_hierarchy[1][:len(shortest_hierarchy[1]) - 1] + + # convert the target_hierarchy to a string + target_name_str = str(target_name) + target_class_str = str(target_class) + + # add the target_hierarchy to the node + node.metadata_props['pkg.torch.onnx.name_scopes'] = target_name_str + node.metadata_props['pkg.torch.onnx.class_hierarchy'] = target_class_str + return graph diff --git a/onnxscript/utils/test_PytorchHierarchyNode.py b/onnxscript/utils/test_PytorchHierarchyNode.py new file mode 100644 index 0000000000..18f7f4189e --- /dev/null +++ b/onnxscript/utils/test_PytorchHierarchyNode.py @@ -0,0 +1,190 @@ +import pytest +import ast + +import onnx +from onnxscript import script +from typing import List, Tuple + + +from onnxscript.utils import graph_view_utils as gvu +from onnxscript import ir + + + +tape = ir._tape.Tape() + + +def build_metadata(instance_hierarchy, class_hierarchy): + return { + "pkg.torch.onnx.class_hierarchy": str(class_hierarchy), + "pkg.torch.onnx.name_scopes": str(instance_hierarchy) + } + +class HierarchyBuilder(ir._tape.Tape): + def __init__(self, graph_like: ir.Graph | ir.Function | None = None) -> None: + super().__init__(graph_like) + + def add_non_hierarchical_node(self): + # Non-hierarchical node is a single instance of a module + self.op( + op_type="NonHierarchicalNode", + inputs=[] + ) + + def add_hierarchical_node(self, hierarchy: List[Tuple[str, str]]): + # Hierarchy is a list of tuples, where each tuple contains (instance name, module type) + + instance_hierarchy = [] + class_hierarchy = [] + for instance_name, module_type in hierarchy: + instance_hierarchy.append(instance_name) + class_hierarchy.append(module_type) + + self.op( + op_type="HierarchyNode", + inputs=[], + metadata_props=build_metadata(instance_hierarchy, class_hierarchy), + ) + + +def add_node_expect_success(P, node): + assert P.add_node(node) is True + +def add_node_expect_failure(P, node): + assert P.add_node(node) is False + +# add a basic test to make sure the test file is working +def test_onenode(): + + B = HierarchyBuilder() + B.add_hierarchical_node([("", "class0")]) + + P = gvu.PytorchHierarchyNode() + + print("\nadding bnode") + + add_node_expect_success(P, B.nodes[0]) + + print("\nadded bnode") + + nodes = P.get_nodes([""]) + assert len(nodes) == 1 + assert nodes[0] is B.nodes[0] + +def test_twonodes(): + B = HierarchyBuilder() + B.add_hierarchical_node([("", "class0")]) + B.add_hierarchical_node([("", "class0")]) + + P = gvu.PytorchHierarchyNode() + add_node_expect_success(P, B.nodes[0]) + add_node_expect_success(P, B.nodes[1]) + + nodes = P.get_nodes([""]) + assert len(nodes) == 2 + assert nodes[0] is B.nodes[0] + assert nodes[1] is B.nodes[1] + +def test_twonodes_one_with_hierarchy(): + B = HierarchyBuilder() + B.add_hierarchical_node([("", "class0")]) + B.add_hierarchical_node([("", "class0"), ("a", "class1")]) + + + P = gvu.PytorchHierarchyNode() + add_node_expect_success(P, B.nodes[0]) + add_node_expect_success(P, B.nodes[1]) + + print("Printing hierarchy") + P.print_hierarchy() + + nodes = P.get_nodes([""]) + assert len(nodes) == 2 + assert nodes[0] in B.nodes + assert nodes[1] in B.nodes + assert nodes[0] is not nodes[1] + + nodes = P.get_nodes(["", "a"]) + assert len(nodes) == 1 + assert nodes[0] is B.nodes[1] + +def test_three_levels_of_hierarchy(): + B = HierarchyBuilder() + B.add_hierarchical_node([("", "class0"), ("a", "class1"), ("b", "class2")]) + B.add_hierarchical_node([("", "class0"), ("a", "class1"), ("b", "class2")]) + B.add_hierarchical_node([("", "class0"), ("a", "class1"), ("b", "class2")]) + B.add_hierarchical_node([("", "class0"), ("a", "class1"), ("b", "class2")]) + + P = gvu.PytorchHierarchyNode() + add_node_expect_success(P, B.nodes[0]) + add_node_expect_success(P, B.nodes[1]) + add_node_expect_success(P, B.nodes[2]) + add_node_expect_success(P, B.nodes[3]) + + nodes = P.get_nodes(["", "a", "b"]) + assert len(nodes) == 4 + assert nodes[0] in B.nodes + assert nodes[1] in B.nodes + assert nodes[2] in B.nodes + assert nodes[3] in B.nodes + + +def test_non_hierarchical_node(): + B = HierarchyBuilder() + B.add_non_hierarchical_node() + + P = gvu.PytorchHierarchyNode() + add_node_expect_failure(P, B.nodes[0]) + + assert len(P.get_nodes([""])) == 0 + assert len(P.children) == 0 + +def build_golden_results(nodes): + golden_get_node_results = {} + for node in nodes: + metadata = node.metadata_props + if metadata: + new_key = "/".join(ast.literal_eval(metadata.get("pkg.torch.onnx.name_scopes"))) + "/" + + # search the hierarchy_dict for entries that are a prefix of the key + # add the current node to the list of nodes for that key + for key in golden_get_node_results: + if new_key != key and new_key.startswith(key): + golden_get_node_results[key].append(node) + + # check if the new_key is already in the hierarchy_dict + if new_key not in golden_get_node_results: + golden_get_node_results[new_key] = [] + + # add the current node to the list of nodes for the new_key + golden_get_node_results[new_key].append(node) + + return golden_get_node_results + +def test_mistral_pytorch_with_metadata(): + model_proto = onnx.load('/home/joshmonson/Projects/experiments/finn_mlo_graphs/demo/mistral.onnx') + model = ir.serde.deserialize_model(model_proto) + graph = model.graph + + graph = gvu.add_metadata_to_unannotated_constant_nodes(graph) + + P = gvu.PytorchHierarchyNode() + unadded_nodes = [] + for node in graph._nodes: + added = P.add_node(node) + if not added: + unadded_nodes.append(node) + + print(f"Number of nodes not added: {len(unadded_nodes)}") + print(f"unadded nodes: {unadded_nodes}") + for node in unadded_nodes: + print(f"unadded node: {node}") + print(f"unadded node metadata: {node.metadata_props}") + golden = build_golden_results(graph._nodes) + for key, gnodes in golden.items(): + key = key.rstrip("/") + nodes = P.get_nodes(key.split("/")) + # check if the nodes in the result are in the list of nodes for that key + #print(f"got nodes for key {key}: {nodes}") + for gnode in gnodes: + assert gnode in nodes, f"Node {gnode.metadata_props.get('pkg.torch.onnx.name_scopes')} not found in nodes for key {key}" diff --git a/tests/loop_rolling/test_loop_rolling.py b/tests/loop_rolling/test_loop_rolling.py new file mode 100644 index 0000000000..7ff7b476d8 --- /dev/null +++ b/tests/loop_rolling/test_loop_rolling.py @@ -0,0 +1,198 @@ +import argparse + +import numpy as np + +import onnx +import os + +from onnxscript import ir +from onnxscript import rewriter + +#.rewriter import PatternRewriteRule, RewriteRuleSet, rewrite +import onnxruntime as onnxrt + +from onnxscript.rewriter import rewrite, RewriteRuleSet + +from onnxscript.rewriter.pattern_builder_jsm import build_loop_replace_pattern +from onnxscript.rewriter.pattern_builder_jsm import normalize_io_for_loop_rolling +from onnxscript.rewriter.pattern_builder_jsm import LoopBodyTemplate + +from onnx import shape_inference + +def remove_existing_data_file(filename): + if os.path.exists(filename): + print(f"Removing existing data file: {filename}") + os.remove(filename) + +def open_ir(filename): + print('loading onnx') + f = onnx.load(filename) + print('deserializing') + return ir.serde.deserialize_model(f) + +# Handle Parser Stuff +parser = argparse.ArgumentParser(description="A simple argparse example") +parser.add_argument("filename", type=str, help="The name of the input onnx file") +parser.add_argument("patternfilename", type=str, help="The name of the input onnx file that is a single layer") +args = parser.parse_args() + +def ort_make_rand_io(filename): + + session = onnxrt.InferenceSession(filename) + + input_name = session.get_inputs()[0].name + input_type = session.get_inputs()[0].type + input_shape = session.get_inputs()[0].shape + + if input_type == 'tensor(int64)': + np_type = np.int64 + elif input_type == 'tensor(float)': + np_type = np.float32 + else: + raise Exception("unsupported type {input_type}") + + input_data = np.random.random(input_shape).astype(np_type) + + return ({input_name: input_data}, session.get_outputs()) + + +def ort_run_graph(filename, input_dict, output_name): + sess_options = onnxrt.SessionOptions() + sess_options.log_severity_level = 0 + sess_options.graph_optimization_level = onnxrt.GraphOptimizationLevel.ORT_ENABLE_BASIC + + return onnxrt.InferenceSession(filename, sess_options=sess_options).run([output_name], input_dict) + +input_dict, outputs = ort_make_rand_io(args.filename) +golden_results = ort_run_graph(args.filename, input_dict, outputs[0].name) + + +LoopBody = LoopBodyTemplate(args.patternfilename) + +change_layers_to_function_calls = rewriter.PatternRewriteRule( + LoopBody.pattern, + LoopBody.function_replace +) + + +print("Find and Replace layers with single layer op.") + +mypipeline = onnx.load(args.filename) + +print('applying rewrite rule') + +mypipeline_layers_replaced = rewrite( + mypipeline, + pattern_rewrite_rules = [change_layers_to_function_calls] +) + + +mypipeline_model = ir.serde.deserialize_model(mypipeline_layers_replaced) +mypipeline_model.functions[LoopBody.function.identifier()] = LoopBody.function +mypipeline_model.graph.opset_imports['loop']=0 +mypipeline_layers_replaced = ir.serde.serialize_model(mypipeline_model) + + +replaced_filename = "replaced_"+args.filename +print(f"Writing Updated Graph to {replaced_filename}") +remove_existing_data_file(replaced_filename+'.data') +onnx.save(mypipeline_layers_replaced, + replaced_filename, save_as_external_data=True, location= replaced_filename+'.data') + +print("Replace Layer Ops with Loop Body") + +normalized_graph = normalize_io_for_loop_rolling(mypipeline_model.graph, LoopBody) + + + +model = ir.serde.serialize_model(mypipeline_model) +remove_existing_data_file('normalized.onnx.data') +onnx.save(model, 'normalized.onnx', save_as_external_data=True, location='normalized.onnx.data') + + +LoopMatchPattern,nodes = LoopBody.build_function_match_pattern(normalized_graph) + + +loop_replace_pattern = build_loop_replace_pattern(normalized_graph, LoopBody) + + +change_function_calls_to_loop = rewriter.PatternRewriteRule( + LoopMatchPattern, + loop_replace_pattern +) + +# class AllTracer(pattern.MatchingTracer): +# def __init__(self): +# super().__init__() + +# def log( +# self, +# rule: PatternRewriteRule, +# container: ir.Graph | ir.Function, +# node: ir.Node, +# match_result: pattern.MatchResult, +# status: pattern.MatchStatus, +# ) -> None: +# this_match = pattern.MatchInfo(match_result, node, container, status) +# best_matches = self._best_matches_map[rule] +# best_matches.append(this_match) + + +# tracer = pattern.MatchingTracer() +rewrite_set = RewriteRuleSet([change_function_calls_to_loop]) +count = rewrite_set.apply_to_model(mypipeline_model, verbose=None) +print(f"Count {count}") + +# tracer.report() +# for rule in tracer.best_matches_map: +# matches = tracer.best_matches_map[rule] +# for match in matches: +# print(f'Reason: {match.match_result.reason}') +# print(f'root_node: {match.root_node}') +# pdb.set_trace() + +#loop_added = rewrite ( +# mypipeline_layers_replaced, +# pattern_rewrite_rules = [change_function_calls_to_loop] +#) + +#mypipeline_model.opset_imports.pop('') +mypipeline_model.opset_imports.pop('loop') +mypipeline_model._functions = {} + + + +# scanning for empty domains +# for node in ir.traversal.RecursiveGraphIterator(mypipeline_model.graph): +# if node.domain == '': +# print(node) + + +#mypipeline_model.opset_imports['main'] = 13 + +loop_added = ir.serde.serialize_model(mypipeline_model) + + + + +remove_existing_data_file('loop_added.onnx.data') +onnx.save(loop_added, 'loop_added.onnx', save_as_external_data=True, location='loop_added.onnx.data') + +onnx.checker.check_model(loop_added) +loop_added = shape_inference.infer_shapes(loop_added) + + + + +transformed_results = ort_run_graph('loop_added.onnx', input_dict, outputs[0].name) + + +if (np.isclose(transformed_results, golden_results[0], rtol=1e-5, atol=1e-6)).all(): + print("Results Equal!!") +else: + print("errors found") + print(transformed_results[0] == golden_results[0]) + print(transformed_results[0]) + print(golden_results[0]) + +print('done')