-
Notifications
You must be signed in to change notification settings - Fork 86
Closed
Labels
contribution welcomeWe welcome code contributions for thisWe welcome code contributions for thismodule: rewriter
Milestone
Description
Currently, when applying a rewrite rule, I cannot pass a name with the replacement op.
Minimal example
import onnx
import onnx_ir as ir
from onnxscript.rewriter import pattern, rewrite
def custom_rewrite(model):
def target_pattern(op, x):
return op.Relu(x, _outputs=["relu"])
def replacement_pattern(op, x, relu):
ir_node = relu.producer()
return op.Clip(x, name=ir_node.name)
"""Workaround
out = op.Clip(x)
clip_node = out.producer()
clip_node.name = ir_node.name
"""
rules = [pattern.RewriteRule(target_pattern, replacement_pattern)]
model = rewrite(model, rules)
return model
ir_model = ir.from_onnx_text("""
< ir_version: 10, opset_import: ["" : 20] >
test_model (float[N, 16, 16] X) => (float [N, ?, ?] Y)
{
Y = Relu(X)
}
""")
ir_model.graph[0].name = "relu_1"
ir_model = custom_rewrite(ir_model)
proto = ir.to_proto(ir_model)
onnx.save(proto, "tmp.onnx")
Doing this will result in node with attribute name
Potential fix
Other op parameters (see https://github.com/onnx/ir-py/blob/c7879f7bab341e27ca19ab31b084b449f8a82378/src/onnx_ir/_tape.py#L74) should be extracted just like it is already done for domain, version, and outputs
onnxscript/onnxscript/ir/_tape.py
Line 25 in 87baf8f
def _make_node(self, op_type: str, inputs: Sequence[ir.Value], kwargs: dict[str, Any]): |
I can submit a PR if you find the fix convenient
Metadata
Metadata
Assignees
Labels
contribution welcomeWe welcome code contributions for thisWe welcome code contributions for thismodule: rewriter