Skip to content

[rewriter] Transpose rule #2255

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
2 changes: 2 additions & 0 deletions onnxscript/rewriter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
llama_rule_sets,
no_op,
pattern,
transpose_initializer,
)

_ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model)
Expand All @@ -32,6 +33,7 @@
*cast_constant_of_shape.rules.rules,
*collapse_slices.rules.rules,
*llama_rule_sets.llama_p0_rule_set().rules,
transpose_initializer.rule,
)


Expand Down
63 changes: 63 additions & 0 deletions onnxscript/rewriter/transpose_initializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Rules to collapse Transpose nodes into initializers."""

from __future__ import annotations

import logging

import numpy as np

from onnxscript import ir
from onnxscript.rewriter import _ir_utils as ir_utils
from onnxscript.rewriter import pattern as orp

logger = logging.getLogger(__name__)


class TransposeInitializer(orp.RewriteRuleClassBase):
"""Folds Transpose nodes into initializers."""

def __init__(self):
super().__init__("TransposeInitializer", remove_nodes=True)

def pattern(self, op, initializer):
return op.Transpose(initializer, _allow_other_attributes=True)

def rewrite(self, op, initializer: ir.Value) -> ir.Value:
original_transpose = initializer.consumers()[0]
perm_attr = original_transpose.attributes.get("perm")
assert isinstance(perm_attr, ir.Attr)

if perm_attr is not None:
perm = perm_attr.as_ints()
else:
perm = None

array = ir_utils.get_numpy_value(initializer)
if array is None:
# Do nothing
logger.debug("Failed to obtain the initializer value. Do nothing")
# perm=None is filtered out when the attribute is constructed so we are ok
return op.Transpose(initializer, perm=perm_attr)

transposed = np.transpose(array, axes=perm)
new_name = f"{initializer.const_value.name}_transposed"
return op.initializer(ir.tensor(transposed, name=new_name))

def check(self, context, initializer: ir.Value) -> orp.MatchResult:
del context # Unused
check_result = orp.MatchResult()
if not initializer.is_initializer():
return check_result.fail("Value is not an initializer")
if initializer.is_graph_input():
return check_result.fail("Value is a graph input")
if initializer.const_value is None:
return check_result.fail("Value.const_value is None")
if len(initializer.uses()) != 1:
return check_result.fail("Initializer is used by more than one node")
# TODO(justinchuby): Avoid matching when it is a graph input
return check_result


rule = TransposeInitializer.rule()
Loading