Skip to content

Commit db3dc8c

Browse files
justinchubyCopilot
andauthored
Unify rule implementations with classes (#2288)
- Replace RewriteRuleAsClass with RewriteRuleClassBase to unify rule implementations. - Also: Use as ints() on attributes in rewrite rules --------- Co-authored-by: Copilot <[email protected]>
1 parent 25a8a7e commit db3dc8c

File tree

4 files changed

+107
-190
lines changed

4 files changed

+107
-190
lines changed

docs/api/rewriter_pattern.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@
3232
rewriter.pattern.PatternMatcher
3333
rewriter.pattern.SimplePatternMatcher
3434
rewriter.pattern.RewriteRule
35-
rewriter.pattern.RewriteRuleAsClass
3635
rewriter.pattern.RewriteRuleSet
36+
rewriter.pattern.RewriteRuleClassBase
3737
rewriter.pattern.MatchStatus
3838
rewriter.pattern.MatchInfo
3939
rewriter.pattern.MatchingTracer

onnxscript/rewriter/llama_rule_sets.py

+53-78
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Licensed under the MIT License.
33
from __future__ import annotations
44

5-
from typing import ClassVar
5+
from typing import ClassVar, Sequence
66

77
from onnxscript import ir
88
from onnxscript.rewriter import _ir_utils as ir_utils
@@ -32,26 +32,23 @@ def check(self, context, x) -> orp.MatchResult:
3232
return check_result
3333

3434

35-
class CastIdentity(orp.RewriteRuleAsClass):
35+
class CastIdentity(orp.RewriteRuleClassBase):
3636
"""Replaces ``Cast(., to=to)`` by ``Identity`` if possible."""
3737

38-
@classmethod
39-
def pattern(cls, op, x, to):
38+
def pattern(self, op, x, to):
4039
return op.Cast(x, to=to)
4140

42-
@classmethod
43-
def rewrite(cls, op, x: ir.Value, to: ir.Attr):
41+
def rewrite(self, op, x: ir.Value, to: ir.Attr):
4442
return op.Identity(x)
4543

46-
@classmethod
47-
def check(cls, context, x, to) -> orp.MatchResult:
44+
def check(self, context, x, to) -> orp.MatchResult:
4845
check_result = orp.MatchResult()
49-
if x.dtype != to.value:
46+
if x.dtype != to.as_int():
5047
return check_result.fail("Input and output types are not the same")
5148
return check_result
5249

5350

54-
class CastCast(orp.RewriteRuleAsClass):
51+
class CastCast(orp.RewriteRuleClassBase):
5552
"""Replaces ``Cast(Cast(X, ...), to=to)`` by ``Cast(X, to=to)``."""
5653

5754
_allowed_tensor_types: ClassVar = {
@@ -61,37 +58,31 @@ class CastCast(orp.RewriteRuleAsClass):
6158
ir.DataType.DOUBLE,
6259
}
6360

64-
@classmethod
65-
def pattern(cls, op, x, to, to_ignored):
61+
def pattern(self, op, x, to, to_ignored):
6662
return op.Cast(op.Cast(x, to=to_ignored), to=to)
6763

68-
@classmethod
69-
def check(cls, context, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr) -> orp.MatchResult:
64+
def check(self, context, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr) -> orp.MatchResult:
7065
check_result = orp.MatchResult()
71-
if to.value not in cls._allowed_tensor_types:
72-
return check_result.fail(f"Output type {to.value} is not allowed")
73-
if to_ignored.as_int() not in cls._allowed_tensor_types:
74-
return check_result.fail(f"Ignored type {to_ignored.value} is not allowed")
66+
if to.as_int() not in self._allowed_tensor_types:
67+
return check_result.fail(f"Output type {to.as_int()} is not allowed")
68+
if to_ignored.as_int() not in self._allowed_tensor_types:
69+
return check_result.fail(f"Ignored type {to_ignored.as_int()} is not allowed")
7570
return check_result
7671

77-
@classmethod
78-
def rewrite(cls, op, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr):
72+
def rewrite(self, op, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr):
7973
return op.Cast(x, to=to)
8074

8175

82-
class ExpandIdentity(orp.RewriteRuleAsClass):
76+
class ExpandIdentity(orp.RewriteRuleClassBase):
8377
"""Replaces ``Expand(..., shape)`` by ``Identity`` if possible."""
8478

85-
@classmethod
86-
def pattern(cls, op, x, shape):
79+
def pattern(self, op, x, shape):
8780
return op.Expand(x, shape)
8881

89-
@classmethod
90-
def rewrite(cls, op, x: ir.Value, shape: ir.Value):
82+
def rewrite(self, op, x: ir.Value, shape: ir.Value):
9183
return op.Identity(x)
9284

93-
@classmethod
94-
def check(cls, context, x, shape) -> orp.MatchResult:
85+
def check(self, context, x, shape) -> orp.MatchResult:
9586
check_result = orp.MatchResult()
9687
if shape.const_value is None:
9788
# Shape is not a constant and cannot be guessed.
@@ -106,22 +97,19 @@ def check(cls, context, x, shape) -> orp.MatchResult:
10697
return check_result
10798

10899

109-
class ReshapeReshape(orp.RewriteRuleAsClass):
100+
class ReshapeReshape(orp.RewriteRuleClassBase):
110101
"""Replaces ``Reshape(Reshape(X, ...), shape)`` by ``Reshape(X, shape)``.
111102
The pattern matches only if second reshape reshapes into a shape
112103
with positive values.
113104
"""
114105

115-
@classmethod
116-
def pattern(cls, op, x, shape_ignored, shape):
106+
def pattern(self, op, x, shape_ignored, shape):
117107
return op.Reshape(op.Reshape(x, shape_ignored), shape)
118108

119-
@classmethod
120-
def rewrite(cls, op, x: ir.Value, shape_ignored: ir.Value, shape: ir.Value):
109+
def rewrite(self, op, x: ir.Value, shape_ignored: ir.Value, shape: ir.Value):
121110
return op.Reshape(x, shape)
122111

123-
@classmethod
124-
def check(cls, context, x, shape_ignored, shape) -> orp.MatchResult:
112+
def check(self, context, x, shape_ignored, shape) -> orp.MatchResult:
125113
check_result = orp.MatchResult()
126114
if shape_ignored.const_value is None:
127115
return check_result.fail("Shape ignored is not a constant.")
@@ -132,17 +120,15 @@ def check(cls, context, x, shape_ignored, shape) -> orp.MatchResult:
132120
return check_result
133121

134122

135-
class SlicesSplit(orp.RewriteRuleAsClass):
123+
class SlicesSplit(orp.RewriteRuleClassBase):
136124
"""Replaces ``Slice(x, ...), Slice(x, ...)``
137125
by ``Split(x, ...)`` if possible.
138126
"""
139127

140-
@classmethod
141-
def pattern(cls, op, x, begin0, end0, axes0, begin1, end1, axes1):
128+
def pattern(self, op, x, begin0, end0, axes0, begin1, end1, axes1):
142129
return op.Slice(x, begin0, end0, axes0), op.Slice(x, begin1, end1, axes1)
143130

144-
@classmethod
145-
def check(cls, context, x, begin0, end0, axes0, begin1, end1, axes1) -> orp.MatchResult:
131+
def check(self, context, x, begin0, end0, axes0, begin1, end1, axes1) -> orp.MatchResult:
146132
check_result = orp.MatchResult()
147133
if (
148134
axes0.const_value is None
@@ -187,94 +173,83 @@ def check(cls, context, x, begin0, end0, axes0, begin1, end1, axes1) -> orp.Matc
187173
return check_result.fail("Last dimension is not equal to Begin1.")
188174
return check_result
189175

190-
@classmethod
191-
def rewrite(cls, op, x, begin0, end0, axes0, begin1, end1, axes1):
176+
def rewrite(self, op, x, begin0, end0, axes0, begin1, end1, axes1):
192177
return op.Split(x, num_outputs=2, axis=-1, _outputs=2)
193178

194179

195-
class TransposeIdentity(orp.RewriteRuleAsClass):
180+
class TransposeIdentity(orp.RewriteRuleClassBase):
196181
"""Replaces ``Transpose(. perm=perm)``
197182
when the permutation is identity.
198183
"""
199184

200-
@classmethod
201-
def pattern(cls, op, x, perm):
185+
def pattern(self, op, x, perm):
202186
return op.Transpose(x, perm=perm)
203187

204-
@classmethod
205-
def check(cls, context, x: ir.Value, perm: ir.Attr) -> orp.MatchResult:
188+
def check(self, context, x: ir.Value, perm: ir.Attr) -> orp.MatchResult:
206189
check_result = orp.MatchResult()
207190
if isinstance(perm, ir.RefAttr):
208191
return check_result.fail("Permutation is a reference attribute.")
209192
if perm.type == ir.AttributeType.INTS:
210-
if perm.value == list(range(len(perm.value))):
193+
perm_ints = perm.as_ints()
194+
if perm_ints == list(range(len(perm_ints))):
211195
return check_result
212196
return check_result.fail("Permutation is not identity.")
213197

214-
@classmethod
215-
def rewrite(cls, op, x: ir.Value, perm: ir.Attr):
198+
def rewrite(self, op, x: ir.Value, perm: ir.Attr):
216199
return op.Identity(x)
217200

218201

219-
class TransposeTranspose(orp.RewriteRuleAsClass):
202+
class TransposeTranspose(orp.RewriteRuleClassBase):
220203
"""Replaces ``Transpose(Transpose(., perm=perm1), perm=perm2)``
221204
when both permutations are inverse.
222205
"""
223206

224-
@classmethod
225-
def pattern(cls, op, x, perm1, perm2):
207+
def pattern(self, op, x, perm1, perm2):
226208
return op.Transpose(op.Transpose(x, perm=perm1), perm=perm2)
227209

228-
@classmethod
229-
def check(cls, context, x: ir.Value, perm1: ir.Attr, perm2: ir.Attr) -> orp.MatchResult:
210+
def check(self, context, x: ir.Value, perm1: ir.Attr, perm2: ir.Attr) -> orp.MatchResult:
230211
check_result = orp.MatchResult()
231212
if isinstance(perm1, ir.RefAttr) or isinstance(perm2, ir.RefAttr):
232213
return check_result.fail("Permutation is a reference attribute.")
233214
return check_result
234215

235-
@classmethod
236-
def _apply_transpose(cls, perm: tuple[int, ...], on: list[int]) -> list[int]:
216+
def _apply_transpose(self, perm: Sequence[int], on: list[int]) -> list[int]:
237217
assert len(perm) == len(on), "length mismatch"
238218
res = [-1 for i in on]
239219
for i, p in enumerate(perm):
240220
res[i] = on[p]
241221
return res
242222

243-
@classmethod
244223
def _apply_transposes(
245-
cls, perms: list[tuple[int, ...]], on: list[int] | None = None
224+
self, perms: list[Sequence[int]], on: list[int] | None = None
246225
) -> list[int]:
247226
if on is None:
248227
on = list(range(len(perms[0])))
249228
for p in perms:
250-
on = cls._apply_transpose(p, on)
229+
on = self._apply_transpose(p, on)
251230
return on
252231

253-
@classmethod
254-
def rewrite(cls, op, x: ir.Value, perm1: ir.Attr, perm2: ir.Attr):
255-
first = list(range(len(perm1.value)))
256-
last = cls._apply_transposes([perm1.value, perm2.value])
232+
def rewrite(self, op, x: ir.Value, perm1: ir.Attr, perm2: ir.Attr):
233+
first = list(range(len(perm1.as_ints())))
234+
last = self._apply_transposes([perm1.as_ints(), perm2.as_ints()])
257235
if first == last:
258236
return op.Identity(x)
259237
return op.Transpose(x, perm=last)
260238

261239

262-
class UnsqueezeUnsqueeze(orp.RewriteRuleAsClass):
240+
class UnsqueezeUnsqueeze(orp.RewriteRuleClassBase):
263241
"""Replaces ``Unsqueeze(Unsqueeze(., axes1), axes2)`` with one Unsqueeze."""
264242

265-
@classmethod
266-
def pattern(cls, op, x, axes1, axes2):
243+
def pattern(self, op, x, axes1, axes2):
267244
return op.Unsqueeze(op.Unsqueeze(x, axes1), axes2)
268245

269-
@classmethod
270-
def rewrite(cls, op, x: ir.Value, axes1: ir.Value, axes2: ir.Value):
246+
def rewrite(self, op, x: ir.Value, axes1: ir.Value, axes2: ir.Value):
271247
v1 = ir_utils.get_singleton_value(axes1)
272248
v2 = ir_utils.get_singleton_value(axes2)
273249
axes = [v1, v2] if v1 < v2 else [v2, v1 + 1]
274250
return op.Unsqueeze(x, op.Constant(value=ir.tensor(axes, dtype=ir.DataType.INT64)))
275251

276-
@classmethod
277-
def check(cls, context, x, axes1, axes2) -> orp.MatchResult:
252+
def check(self, context, x, axes1, axes2) -> orp.MatchResult:
278253
check_result = orp.MatchResult()
279254
del context # Unused
280255
del x # Unused
@@ -288,14 +263,14 @@ def check(cls, context, x, axes1, axes2) -> orp.MatchResult:
288263
return check_result
289264

290265

291-
cast_cast_rule = orp.make_rewrite_rule_from_class(CastCast)
292-
cast_identity_rule = orp.make_rewrite_rule_from_class(CastIdentity)
293-
expand_identity_rule = orp.make_rewrite_rule_from_class(ExpandIdentity)
294-
reshape_reshape_rule = orp.make_rewrite_rule_from_class(ReshapeReshape)
295-
slice_split_rule = orp.make_rewrite_rule_from_class(SlicesSplit, True)
296-
transpose_identity_rule = orp.make_rewrite_rule_from_class(TransposeIdentity)
297-
transpose_transpose_rule = orp.make_rewrite_rule_from_class(TransposeTranspose)
298-
unsqueeze_unsqueeze_rule = orp.make_rewrite_rule_from_class(UnsqueezeUnsqueeze)
266+
cast_cast_rule = CastCast.rule()
267+
cast_identity_rule = CastIdentity.rule()
268+
expand_identity_rule = ExpandIdentity.rule()
269+
reshape_reshape_rule = ReshapeReshape.rule()
270+
slice_split_rule = SlicesSplit.rule()
271+
transpose_identity_rule = TransposeIdentity.rule()
272+
transpose_transpose_rule = TransposeTranspose.rule()
273+
unsqueeze_unsqueeze_rule = UnsqueezeUnsqueeze.rule()
299274
squeeze_reshape_1d_rule = SqueezeReshape.rule()
300275

301276

0 commit comments

Comments
 (0)