diff --git a/onnxscript/rewriter/_rewrite_rule.py b/onnxscript/rewriter/_rewrite_rule.py index 3e910edd52..a797be745f 100644 --- a/onnxscript/rewriter/_rewrite_rule.py +++ b/onnxscript/rewriter/_rewrite_rule.py @@ -531,6 +531,23 @@ def _apply_to_graph_or_function( f = ir.Function(domain, name, overload, graph=graph, attributes=()) model.functions[f.identifier()] = f + # If we are fusing nodes, update the docstring of the new node(s) + attributes = ["namespace", "pkg.torch.onnx.class_hierarchy", "pkg.torch.onnx.fx_node", "pkg.torch.onnx.name_scopes", "pkg.torch.onnx.stack_trace"] + if delta.match.nodes and delta.new_nodes: + # Concatenate docstrings from all original nodes + for attribute in attributes: + fused_attribute = "\n".join( + n.metadata_props[attribute] for n in delta.match.nodes if getattr(n, "metadata_props", None) and attribute in n.metadata_props + ) + if fused_attribute.strip(): + fused_attribute = "Fused from nodes with following attributes: " + fused_attribute + for node in delta.new_nodes: + # Assign to all new nodes + if attribute in node.metadata_props: + node.metadata_props[attribute] += f"\n{fused_attribute}" + else: + node.metadata_props[attribute] = fused_attribute + if verbose: name = f"{rule.name}: " if rule.name else "" print(f"----{name}Matched Nodes----")