From 0d8bece127b90610ae7ebba3f0dc068a52bdce1d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 30 Apr 2025 11:13:29 -0700 Subject: [PATCH 01/14] [rewriter] Transpose rule --- onnxscript/rewriter/transpose_initializer.py | 38 ++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 onnxscript/rewriter/transpose_initializer.py diff --git a/onnxscript/rewriter/transpose_initializer.py b/onnxscript/rewriter/transpose_initializer.py new file mode 100644 index 000000000..bec0edba9 --- /dev/null +++ b/onnxscript/rewriter/transpose_initializer.py @@ -0,0 +1,38 @@ +"""Rules to collapse Transpose nodes into initializers.""" +from __future__ import annotations +from onnxscript import ir +from onnxscript.rewriter import _ir_utils as ir_utils +from onnxscript.rewriter import pattern as orp + +import logging + +logger = logging.getLogger(__name__) + +class TransposeInitializer(orp.RewriteRuleClassBase): + """Folds Transpose nodes into initializers.""" + + def __init__(self): + super().__init__("TransposeInitializer", remove_nodes=True) + + def pattern(self, op, initializer): + return op.Transpose(initializer, _allow_other_attributes=True) + + def rewrite(self, op, initializer: ir.Value) -> ir.Value: + array = ir_utils.get_const_value(initializer) + if array is None: + # Do nothing + logger.debug("Failed to obtain the initializer value. Do nothing") + return op.Transpose(initializer, dims) + return op.initializer(ir.tensor() + + def check(self, context, initializer: ir.Value) -> orp.MatchResult: + del context # Unused + check_result = orp.MatchResult() + if initializer.const_value is None: + return check_result.fail("Value is not an initializer, const_value is None") + if initializer.producer() is not None: + return check_result.fail("Value is not an initializer, producer is not None") + return check_result + + +rule = TransposeInitializer.rule() From 43b9f26b46744345368ee48e26ccdcea50cbd6a7 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 30 Apr 2025 11:15:12 -0700 Subject: [PATCH 02/14] WIP --- onnxscript/rewriter/transpose_initializer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/onnxscript/rewriter/transpose_initializer.py b/onnxscript/rewriter/transpose_initializer.py index bec0edba9..07b54db7c 100644 --- a/onnxscript/rewriter/transpose_initializer.py +++ b/onnxscript/rewriter/transpose_initializer.py @@ -22,8 +22,10 @@ def rewrite(self, op, initializer: ir.Value) -> ir.Value: if array is None: # Do nothing logger.debug("Failed to obtain the initializer value. Do nothing") - return op.Transpose(initializer, dims) - return op.initializer(ir.tensor() + # TODO: Handle both when perms is None and when perms is not None + return op.Transpose(initializer, perms) + # TODO Obtain perms from the matched node + return op.initializer(ir.tensor()) def check(self, context, initializer: ir.Value) -> orp.MatchResult: del context # Unused From c9e93e1f2e2ba66ea2395de1a5863f5c92efe0e0 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 30 Apr 2025 11:54:09 -0700 Subject: [PATCH 03/14] Implement the rule --- onnxscript/rewriter/transpose_initializer.py | 26 +++++++++++++++----- 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/onnxscript/rewriter/transpose_initializer.py b/onnxscript/rewriter/transpose_initializer.py index 07b54db7c..67cc71b9f 100644 --- a/onnxscript/rewriter/transpose_initializer.py +++ b/onnxscript/rewriter/transpose_initializer.py @@ -1,13 +1,18 @@ """Rules to collapse Transpose nodes into initializers.""" + from __future__ import annotations + +import logging + +import numpy as np + from onnxscript import ir from onnxscript.rewriter import _ir_utils as ir_utils from onnxscript.rewriter import pattern as orp -import logging - logger = logging.getLogger(__name__) + class TransposeInitializer(orp.RewriteRuleClassBase): """Folds Transpose nodes into initializers.""" @@ -18,14 +23,21 @@ def pattern(self, op, initializer): return op.Transpose(initializer, _allow_other_attributes=True) def rewrite(self, op, initializer: ir.Value) -> ir.Value: + original_transpose = initializer.consumers()[0] + perm_attr = original_transpose.attributes.get("perm") + if perm_attr is not None: + perm = perm_attr.as_ints() + else: + perm = None + array = ir_utils.get_const_value(initializer) if array is None: # Do nothing logger.debug("Failed to obtain the initializer value. Do nothing") - # TODO: Handle both when perms is None and when perms is not None - return op.Transpose(initializer, perms) - # TODO Obtain perms from the matched node - return op.initializer(ir.tensor()) + return op.Transpose(initializer, perm=perm) + + transposed = np.transpose(array, axes=perm) + return op.initializer(ir.tensor(transposed)) def check(self, context, initializer: ir.Value) -> orp.MatchResult: del context # Unused @@ -34,6 +46,8 @@ def check(self, context, initializer: ir.Value) -> orp.MatchResult: return check_result.fail("Value is not an initializer, const_value is None") if initializer.producer() is not None: return check_result.fail("Value is not an initializer, producer is not None") + if initializer.uses() != 1: + return check_result.fail("Initializer is used by more than one node") return check_result From 37232d3eb342dbd4a08a02c3ad287c5df7697c3d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 30 Apr 2025 11:54:47 -0700 Subject: [PATCH 04/14] init --- onnxscript/rewriter/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index 5efaf784b..656b9c0ab 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -22,6 +22,7 @@ llama_rule_sets, no_op, pattern, + transpose_initializer, ) _ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model) @@ -32,6 +33,7 @@ *cast_constant_of_shape.rules.rules, *collapse_slices.rules.rules, *llama_rule_sets.llama_p0_rule_set().rules, + transpose_initializer.rule, ) From 0c4722608d94998b5cec9120be2b9dea08db9fdd Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 30 Apr 2025 11:56:55 -0700 Subject: [PATCH 05/14] attr --- onnxscript/rewriter/transpose_initializer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/onnxscript/rewriter/transpose_initializer.py b/onnxscript/rewriter/transpose_initializer.py index 67cc71b9f..f8109707e 100644 --- a/onnxscript/rewriter/transpose_initializer.py +++ b/onnxscript/rewriter/transpose_initializer.py @@ -25,17 +25,17 @@ def pattern(self, op, initializer): def rewrite(self, op, initializer: ir.Value) -> ir.Value: original_transpose = initializer.consumers()[0] perm_attr = original_transpose.attributes.get("perm") - if perm_attr is not None: - perm = perm_attr.as_ints() - else: - perm = None - array = ir_utils.get_const_value(initializer) if array is None: # Do nothing logger.debug("Failed to obtain the initializer value. Do nothing") - return op.Transpose(initializer, perm=perm) + # perm=None is filtered out when the attribute is constructed so we are ok + return op.Transpose(initializer, perm=perm_attr) + if perm_attr is not None: + perm = perm_attr.as_ints() + else: + perm = None transposed = np.transpose(array, axes=perm) return op.initializer(ir.tensor(transposed)) From f2d2ccf45a115ece7c64970dfa84f52cd9161882 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 30 Apr 2025 11:59:37 -0700 Subject: [PATCH 06/14] new_name --- onnxscript/rewriter/transpose_initializer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/onnxscript/rewriter/transpose_initializer.py b/onnxscript/rewriter/transpose_initializer.py index f8109707e..3d4b342c2 100644 --- a/onnxscript/rewriter/transpose_initializer.py +++ b/onnxscript/rewriter/transpose_initializer.py @@ -25,7 +25,7 @@ def pattern(self, op, initializer): def rewrite(self, op, initializer: ir.Value) -> ir.Value: original_transpose = initializer.consumers()[0] perm_attr = original_transpose.attributes.get("perm") - array = ir_utils.get_const_value(initializer) + array = ir_utils.get_numpy_value(initializer) if array is None: # Do nothing logger.debug("Failed to obtain the initializer value. Do nothing") @@ -37,7 +37,8 @@ def rewrite(self, op, initializer: ir.Value) -> ir.Value: else: perm = None transposed = np.transpose(array, axes=perm) - return op.initializer(ir.tensor(transposed)) + new_name = f"{initializer.const_value.name}_transposed" + return op.initializer(ir.tensor(transposed, name=new_name)) def check(self, context, initializer: ir.Value) -> orp.MatchResult: del context # Unused From 278e97756e3ba5c5a53749d1ab3586b9d0a28274 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 30 Apr 2025 12:00:26 -0700 Subject: [PATCH 07/14] Copyright --- onnxscript/rewriter/transpose_initializer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnxscript/rewriter/transpose_initializer.py b/onnxscript/rewriter/transpose_initializer.py index 3d4b342c2..b7e000047 100644 --- a/onnxscript/rewriter/transpose_initializer.py +++ b/onnxscript/rewriter/transpose_initializer.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. """Rules to collapse Transpose nodes into initializers.""" from __future__ import annotations From d028482370c07564170da487aa600418e3ae941d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 30 Apr 2025 12:11:25 -0700 Subject: [PATCH 08/14] typing --- onnxscript/rewriter/transpose_initializer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxscript/rewriter/transpose_initializer.py b/onnxscript/rewriter/transpose_initializer.py index b7e000047..1d0503cde 100644 --- a/onnxscript/rewriter/transpose_initializer.py +++ b/onnxscript/rewriter/transpose_initializer.py @@ -27,6 +27,7 @@ def pattern(self, op, initializer): def rewrite(self, op, initializer: ir.Value) -> ir.Value: original_transpose = initializer.consumers()[0] perm_attr = original_transpose.attributes.get("perm") + assert isinstance(perm_attr, ir.Attr) array = ir_utils.get_numpy_value(initializer) if array is None: # Do nothing From 3c2369cac4e391f58e3968ae575360fa5ca42a2b Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 30 Apr 2025 12:13:03 -0700 Subject: [PATCH 09/14] todo --- onnxscript/rewriter/transpose_initializer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxscript/rewriter/transpose_initializer.py b/onnxscript/rewriter/transpose_initializer.py index 1d0503cde..82eb678ed 100644 --- a/onnxscript/rewriter/transpose_initializer.py +++ b/onnxscript/rewriter/transpose_initializer.py @@ -52,6 +52,7 @@ def check(self, context, initializer: ir.Value) -> orp.MatchResult: return check_result.fail("Value is not an initializer, producer is not None") if initializer.uses() != 1: return check_result.fail("Initializer is used by more than one node") + # TODO(justinchuby): Avoid matching when it is a graph input return check_result From f271a7caf6f15f30075cb3b8220faa638f40e6f1 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 30 Apr 2025 16:16:32 -0700 Subject: [PATCH 10/14] Update onnxscript/rewriter/transpose_initializer.py Co-authored-by: Shubham Bhokare <32080845+shubhambhokare1@users.noreply.github.com> --- onnxscript/rewriter/transpose_initializer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/rewriter/transpose_initializer.py b/onnxscript/rewriter/transpose_initializer.py index 82eb678ed..e637ed37e 100644 --- a/onnxscript/rewriter/transpose_initializer.py +++ b/onnxscript/rewriter/transpose_initializer.py @@ -50,7 +50,7 @@ def check(self, context, initializer: ir.Value) -> orp.MatchResult: return check_result.fail("Value is not an initializer, const_value is None") if initializer.producer() is not None: return check_result.fail("Value is not an initializer, producer is not None") - if initializer.uses() != 1: + if len(initializer.uses()) != 1: return check_result.fail("Initializer is used by more than one node") # TODO(justinchuby): Avoid matching when it is a graph input return check_result From 19aaedbef9a6971335c5d007bc9407fc6b02c0b9 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 9 May 2025 08:52:02 -0700 Subject: [PATCH 11/14] update --- onnxscript/rewriter/transpose_initializer.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/onnxscript/rewriter/transpose_initializer.py b/onnxscript/rewriter/transpose_initializer.py index e637ed37e..7559d71dc 100644 --- a/onnxscript/rewriter/transpose_initializer.py +++ b/onnxscript/rewriter/transpose_initializer.py @@ -28,6 +28,12 @@ def rewrite(self, op, initializer: ir.Value) -> ir.Value: original_transpose = initializer.consumers()[0] perm_attr = original_transpose.attributes.get("perm") assert isinstance(perm_attr, ir.Attr) + + if perm_attr is not None: + perm = perm_attr.as_ints() + else: + perm = None + array = ir_utils.get_numpy_value(initializer) if array is None: # Do nothing @@ -35,10 +41,6 @@ def rewrite(self, op, initializer: ir.Value) -> ir.Value: # perm=None is filtered out when the attribute is constructed so we are ok return op.Transpose(initializer, perm=perm_attr) - if perm_attr is not None: - perm = perm_attr.as_ints() - else: - perm = None transposed = np.transpose(array, axes=perm) new_name = f"{initializer.const_value.name}_transposed" return op.initializer(ir.tensor(transposed, name=new_name)) @@ -46,10 +48,12 @@ def rewrite(self, op, initializer: ir.Value) -> ir.Value: def check(self, context, initializer: ir.Value) -> orp.MatchResult: del context # Unused check_result = orp.MatchResult() + if not initializer.is_initializer(): + return check_result.fail("Value is not an initializer") + if initializer.is_graph_input(): + return check_result.fail("Value is a graph input") if initializer.const_value is None: - return check_result.fail("Value is not an initializer, const_value is None") - if initializer.producer() is not None: - return check_result.fail("Value is not an initializer, producer is not None") + return check_result.fail("Value.const_value is None") if len(initializer.uses()) != 1: return check_result.fail("Initializer is used by more than one node") # TODO(justinchuby): Avoid matching when it is a graph input From f71773e74e3687100140b9dc3cc26d69b17bc51d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 13 May 2025 10:07:09 -0700 Subject: [PATCH 12/14] unregister_initializer --- onnxscript/ir/_core.py | 7 +++++++ onnxscript/rewriter/transpose_initializer.py | 1 + 2 files changed, 8 insertions(+) diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index f699916f0..32ab2b2b7 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -2088,6 +2088,13 @@ def is_initializer(self) -> bool: """Whether the value is an initializer of a graph.""" return self._is_initializer + def unregister_initializer(self) -> None: + """Unregister the value as an initializer of a graph.""" + if not self.is_initializer(): + raise ValueError("The value is not an initializer.") + assert self.graph.initializers[self.name] is self + self.graph.initializers.pop(self.name) + def Input( name: str | None = None, diff --git a/onnxscript/rewriter/transpose_initializer.py b/onnxscript/rewriter/transpose_initializer.py index 7559d71dc..162e42e07 100644 --- a/onnxscript/rewriter/transpose_initializer.py +++ b/onnxscript/rewriter/transpose_initializer.py @@ -42,6 +42,7 @@ def rewrite(self, op, initializer: ir.Value) -> ir.Value: return op.Transpose(initializer, perm=perm_attr) transposed = np.transpose(array, axes=perm) + initializer.unregister_initializer() new_name = f"{initializer.const_value.name}_transposed" return op.initializer(ir.tensor(transposed, name=new_name)) From 99fcd4e48e83dd595abbceb02420542580e6ff54 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 13 May 2025 10:19:14 -0700 Subject: [PATCH 13/14] name --- onnxscript/rewriter/transpose_initializer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/rewriter/transpose_initializer.py b/onnxscript/rewriter/transpose_initializer.py index 162e42e07..32a279da9 100644 --- a/onnxscript/rewriter/transpose_initializer.py +++ b/onnxscript/rewriter/transpose_initializer.py @@ -43,7 +43,7 @@ def rewrite(self, op, initializer: ir.Value) -> ir.Value: transposed = np.transpose(array, axes=perm) initializer.unregister_initializer() - new_name = f"{initializer.const_value.name}_transposed" + new_name = f"{initializer.name}_transposed" return op.initializer(ir.tensor(transposed, name=new_name)) def check(self, context, initializer: ir.Value) -> orp.MatchResult: From 41cb1193a5e1490dec75c05482701dc3a94b1db9 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 13 May 2025 18:06:58 -0700 Subject: [PATCH 14/14] Fix value naming --- onnxscript/ir/_convenience/__init__.py | 12 ++++++++---- onnxscript/ir/_core.py | 7 ------- onnxscript/rewriter/transpose_initializer.py | 1 - 3 files changed, 8 insertions(+), 12 deletions(-) diff --git a/onnxscript/ir/_convenience/__init__.py b/onnxscript/ir/_convenience/__init__.py index 839c5d330..1287fa31c 100644 --- a/onnxscript/ir/_convenience/__init__.py +++ b/onnxscript/ir/_convenience/__init__.py @@ -360,10 +360,14 @@ def replace_nodes_and_values( # Propagate relevant info from old value to new value # TODO(Rama): Perhaps this should be a separate utility function. Also, consider # merging old and new type/shape info. - new_value.type = old_value.type - new_value.shape = old_value.shape - new_value.const_value = old_value.const_value - new_value.name = old_value.name + if new_value.type is None: + new_value.type = old_value.type + if new_value.shape is None: + new_value.shape = old_value.shape + if new_value.const_value is None: + new_value.const_value = old_value.const_value + if new_value.name is None: + new_value.name = old_value.name # Reconnect the users of the deleted values to use the new values replace_all_uses_with(old_values, new_values) diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index 32ab2b2b7..f699916f0 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -2088,13 +2088,6 @@ def is_initializer(self) -> bool: """Whether the value is an initializer of a graph.""" return self._is_initializer - def unregister_initializer(self) -> None: - """Unregister the value as an initializer of a graph.""" - if not self.is_initializer(): - raise ValueError("The value is not an initializer.") - assert self.graph.initializers[self.name] is self - self.graph.initializers.pop(self.name) - def Input( name: str | None = None, diff --git a/onnxscript/rewriter/transpose_initializer.py b/onnxscript/rewriter/transpose_initializer.py index 32a279da9..f60bc0693 100644 --- a/onnxscript/rewriter/transpose_initializer.py +++ b/onnxscript/rewriter/transpose_initializer.py @@ -42,7 +42,6 @@ def rewrite(self, op, initializer: ir.Value) -> ir.Value: return op.Transpose(initializer, perm=perm_attr) transposed = np.transpose(array, axes=perm) - initializer.unregister_initializer() new_name = f"{initializer.name}_transposed" return op.initializer(ir.tensor(transposed, name=new_name))