Skip to content

Unify rule implementations with classes #2288

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/api/rewriter_pattern.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
rewriter.pattern.PatternMatcher
rewriter.pattern.SimplePatternMatcher
rewriter.pattern.RewriteRule
rewriter.pattern.RewriteRuleAsClass
rewriter.pattern.RewriteRuleSet
rewriter.pattern.RewriteRuleClassBase
rewriter.pattern.MatchStatus
rewriter.pattern.MatchInfo
rewriter.pattern.MatchingTracer
Expand Down
131 changes: 53 additions & 78 deletions onnxscript/rewriter/llama_rule_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT License.
from __future__ import annotations

from typing import ClassVar
from typing import ClassVar, Sequence

from onnxscript import ir
from onnxscript.rewriter import _ir_utils as ir_utils
Expand Down Expand Up @@ -32,26 +32,23 @@ def check(self, context, x) -> orp.MatchResult:
return check_result


class CastIdentity(orp.RewriteRuleAsClass):
class CastIdentity(orp.RewriteRuleClassBase):
"""Replaces ``Cast(., to=to)`` by ``Identity`` if possible."""

@classmethod
def pattern(cls, op, x, to):
def pattern(self, op, x, to):
return op.Cast(x, to=to)

@classmethod
def rewrite(cls, op, x: ir.Value, to: ir.Attr):
def rewrite(self, op, x: ir.Value, to: ir.Attr):
return op.Identity(x)

@classmethod
def check(cls, context, x, to) -> orp.MatchResult:
def check(self, context, x, to) -> orp.MatchResult:
check_result = orp.MatchResult()
if x.dtype != to.value:
if x.dtype != to.as_int():
return check_result.fail("Input and output types are not the same")
return check_result


class CastCast(orp.RewriteRuleAsClass):
class CastCast(orp.RewriteRuleClassBase):
"""Replaces ``Cast(Cast(X, ...), to=to)`` by ``Cast(X, to=to)``."""

_allowed_tensor_types: ClassVar = {
Expand All @@ -61,37 +58,31 @@ class CastCast(orp.RewriteRuleAsClass):
ir.DataType.DOUBLE,
}

@classmethod
def pattern(cls, op, x, to, to_ignored):
def pattern(self, op, x, to, to_ignored):
return op.Cast(op.Cast(x, to=to_ignored), to=to)

@classmethod
def check(cls, context, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr) -> orp.MatchResult:
def check(self, context, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr) -> orp.MatchResult:
check_result = orp.MatchResult()
if to.value not in cls._allowed_tensor_types:
return check_result.fail(f"Output type {to.value} is not allowed")
if to_ignored.as_int() not in cls._allowed_tensor_types:
return check_result.fail(f"Ignored type {to_ignored.value} is not allowed")
if to.as_int() not in self._allowed_tensor_types:
return check_result.fail(f"Output type {to.as_int()} is not allowed")
if to_ignored.as_int() not in self._allowed_tensor_types:
return check_result.fail(f"Ignored type {to_ignored.as_int()} is not allowed")
return check_result

@classmethod
def rewrite(cls, op, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr):
def rewrite(self, op, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr):
return op.Cast(x, to=to)


class ExpandIdentity(orp.RewriteRuleAsClass):
class ExpandIdentity(orp.RewriteRuleClassBase):
"""Replaces ``Expand(..., shape)`` by ``Identity`` if possible."""

@classmethod
def pattern(cls, op, x, shape):
def pattern(self, op, x, shape):
return op.Expand(x, shape)

@classmethod
def rewrite(cls, op, x: ir.Value, shape: ir.Value):
def rewrite(self, op, x: ir.Value, shape: ir.Value):
return op.Identity(x)

@classmethod
def check(cls, context, x, shape) -> orp.MatchResult:
def check(self, context, x, shape) -> orp.MatchResult:
check_result = orp.MatchResult()
if shape.const_value is None:
# Shape is not a constant and cannot be guessed.
Expand All @@ -106,22 +97,19 @@ def check(cls, context, x, shape) -> orp.MatchResult:
return check_result


class ReshapeReshape(orp.RewriteRuleAsClass):
class ReshapeReshape(orp.RewriteRuleClassBase):
"""Replaces ``Reshape(Reshape(X, ...), shape)`` by ``Reshape(X, shape)``.
The pattern matches only if second reshape reshapes into a shape
with positive values.
"""

@classmethod
def pattern(cls, op, x, shape_ignored, shape):
def pattern(self, op, x, shape_ignored, shape):
return op.Reshape(op.Reshape(x, shape_ignored), shape)

@classmethod
def rewrite(cls, op, x: ir.Value, shape_ignored: ir.Value, shape: ir.Value):
def rewrite(self, op, x: ir.Value, shape_ignored: ir.Value, shape: ir.Value):
return op.Reshape(x, shape)

@classmethod
def check(cls, context, x, shape_ignored, shape) -> orp.MatchResult:
def check(self, context, x, shape_ignored, shape) -> orp.MatchResult:
check_result = orp.MatchResult()
if shape_ignored.const_value is None:
return check_result.fail("Shape ignored is not a constant.")
Expand All @@ -132,17 +120,15 @@ def check(cls, context, x, shape_ignored, shape) -> orp.MatchResult:
return check_result


class SlicesSplit(orp.RewriteRuleAsClass):
class SlicesSplit(orp.RewriteRuleClassBase):
"""Replaces ``Slice(x, ...), Slice(x, ...)``
by ``Split(x, ...)`` if possible.
"""

@classmethod
def pattern(cls, op, x, begin0, end0, axes0, begin1, end1, axes1):
def pattern(self, op, x, begin0, end0, axes0, begin1, end1, axes1):
return op.Slice(x, begin0, end0, axes0), op.Slice(x, begin1, end1, axes1)

@classmethod
def check(cls, context, x, begin0, end0, axes0, begin1, end1, axes1) -> orp.MatchResult:
def check(self, context, x, begin0, end0, axes0, begin1, end1, axes1) -> orp.MatchResult:
check_result = orp.MatchResult()
if (
axes0.const_value is None
Expand Down Expand Up @@ -187,94 +173,83 @@ def check(cls, context, x, begin0, end0, axes0, begin1, end1, axes1) -> orp.Matc
return check_result.fail("Last dimension is not equal to Begin1.")
return check_result

@classmethod
def rewrite(cls, op, x, begin0, end0, axes0, begin1, end1, axes1):
def rewrite(self, op, x, begin0, end0, axes0, begin1, end1, axes1):
return op.Split(x, num_outputs=2, axis=-1, _outputs=2)


class TransposeIdentity(orp.RewriteRuleAsClass):
class TransposeIdentity(orp.RewriteRuleClassBase):
"""Replaces ``Transpose(. perm=perm)``
when the permutation is identity.
"""

@classmethod
def pattern(cls, op, x, perm):
def pattern(self, op, x, perm):
return op.Transpose(x, perm=perm)

@classmethod
def check(cls, context, x: ir.Value, perm: ir.Attr) -> orp.MatchResult:
def check(self, context, x: ir.Value, perm: ir.Attr) -> orp.MatchResult:
check_result = orp.MatchResult()
if isinstance(perm, ir.RefAttr):
return check_result.fail("Permutation is a reference attribute.")
if perm.type == ir.AttributeType.INTS:
if perm.value == list(range(len(perm.value))):
perm_ints = perm.as_ints()
if perm_ints == list(range(len(perm_ints))):
return check_result
return check_result.fail("Permutation is not identity.")

@classmethod
def rewrite(cls, op, x: ir.Value, perm: ir.Attr):
def rewrite(self, op, x: ir.Value, perm: ir.Attr):
return op.Identity(x)


class TransposeTranspose(orp.RewriteRuleAsClass):
class TransposeTranspose(orp.RewriteRuleClassBase):
"""Replaces ``Transpose(Transpose(., perm=perm1), perm=perm2)``
when both permutations are inverse.
"""

@classmethod
def pattern(cls, op, x, perm1, perm2):
def pattern(self, op, x, perm1, perm2):
return op.Transpose(op.Transpose(x, perm=perm1), perm=perm2)

@classmethod
def check(cls, context, x: ir.Value, perm1: ir.Attr, perm2: ir.Attr) -> orp.MatchResult:
def check(self, context, x: ir.Value, perm1: ir.Attr, perm2: ir.Attr) -> orp.MatchResult:
check_result = orp.MatchResult()
if isinstance(perm1, ir.RefAttr) or isinstance(perm2, ir.RefAttr):
return check_result.fail("Permutation is a reference attribute.")
return check_result

@classmethod
def _apply_transpose(cls, perm: tuple[int, ...], on: list[int]) -> list[int]:
def _apply_transpose(self, perm: Sequence[int], on: list[int]) -> list[int]:
assert len(perm) == len(on), "length mismatch"
res = [-1 for i in on]
for i, p in enumerate(perm):
res[i] = on[p]
return res

@classmethod
def _apply_transposes(
cls, perms: list[tuple[int, ...]], on: list[int] | None = None
self, perms: list[Sequence[int]], on: list[int] | None = None
) -> list[int]:
if on is None:
on = list(range(len(perms[0])))
for p in perms:
on = cls._apply_transpose(p, on)
on = self._apply_transpose(p, on)
return on

@classmethod
def rewrite(cls, op, x: ir.Value, perm1: ir.Attr, perm2: ir.Attr):
first = list(range(len(perm1.value)))
last = cls._apply_transposes([perm1.value, perm2.value])
def rewrite(self, op, x: ir.Value, perm1: ir.Attr, perm2: ir.Attr):
first = list(range(len(perm1.as_ints())))
last = self._apply_transposes([perm1.as_ints(), perm2.as_ints()])
if first == last:
return op.Identity(x)
return op.Transpose(x, perm=last)


class UnsqueezeUnsqueeze(orp.RewriteRuleAsClass):
class UnsqueezeUnsqueeze(orp.RewriteRuleClassBase):
"""Replaces ``Unsqueeze(Unsqueeze(., axes1), axes2)`` with one Unsqueeze."""

@classmethod
def pattern(cls, op, x, axes1, axes2):
def pattern(self, op, x, axes1, axes2):
return op.Unsqueeze(op.Unsqueeze(x, axes1), axes2)

@classmethod
def rewrite(cls, op, x: ir.Value, axes1: ir.Value, axes2: ir.Value):
def rewrite(self, op, x: ir.Value, axes1: ir.Value, axes2: ir.Value):
v1 = ir_utils.get_singleton_value(axes1)
v2 = ir_utils.get_singleton_value(axes2)
axes = [v1, v2] if v1 < v2 else [v2, v1 + 1]
return op.Unsqueeze(x, op.Constant(value=ir.tensor(axes, dtype=ir.DataType.INT64)))

@classmethod
def check(cls, context, x, axes1, axes2) -> orp.MatchResult:
def check(self, context, x, axes1, axes2) -> orp.MatchResult:
check_result = orp.MatchResult()
del context # Unused
del x # Unused
Expand All @@ -288,14 +263,14 @@ def check(cls, context, x, axes1, axes2) -> orp.MatchResult:
return check_result


cast_cast_rule = orp.make_rewrite_rule_from_class(CastCast)
cast_identity_rule = orp.make_rewrite_rule_from_class(CastIdentity)
expand_identity_rule = orp.make_rewrite_rule_from_class(ExpandIdentity)
reshape_reshape_rule = orp.make_rewrite_rule_from_class(ReshapeReshape)
slice_split_rule = orp.make_rewrite_rule_from_class(SlicesSplit, True)
transpose_identity_rule = orp.make_rewrite_rule_from_class(TransposeIdentity)
transpose_transpose_rule = orp.make_rewrite_rule_from_class(TransposeTranspose)
unsqueeze_unsqueeze_rule = orp.make_rewrite_rule_from_class(UnsqueezeUnsqueeze)
cast_cast_rule = CastCast.rule()
cast_identity_rule = CastIdentity.rule()
expand_identity_rule = ExpandIdentity.rule()
reshape_reshape_rule = ReshapeReshape.rule()
slice_split_rule = SlicesSplit.rule()
transpose_identity_rule = TransposeIdentity.rule()
transpose_transpose_rule = TransposeTranspose.rule()
unsqueeze_unsqueeze_rule = UnsqueezeUnsqueeze.rule()
squeeze_reshape_1d_rule = SqueezeReshape.rule()


Expand Down
Loading
Loading