2
2
# Licensed under the MIT License.
3
3
from __future__ import annotations
4
4
5
- from typing import ClassVar
5
+ from typing import ClassVar , Sequence
6
6
7
7
from onnxscript import ir
8
8
from onnxscript .rewriter import _ir_utils as ir_utils
@@ -32,26 +32,23 @@ def check(self, context, x) -> orp.MatchResult:
32
32
return check_result
33
33
34
34
35
- class CastIdentity (orp .RewriteRuleAsClass ):
35
+ class CastIdentity (orp .RewriteRuleClassBase ):
36
36
"""Replaces ``Cast(., to=to)`` by ``Identity`` if possible."""
37
37
38
- @classmethod
39
- def pattern (cls , op , x , to ):
38
+ def pattern (self , op , x , to ):
40
39
return op .Cast (x , to = to )
41
40
42
- @classmethod
43
- def rewrite (cls , op , x : ir .Value , to : ir .Attr ):
41
+ def rewrite (self , op , x : ir .Value , to : ir .Attr ):
44
42
return op .Identity (x )
45
43
46
- @classmethod
47
- def check (cls , context , x , to ) -> orp .MatchResult :
44
+ def check (self , context , x , to ) -> orp .MatchResult :
48
45
check_result = orp .MatchResult ()
49
- if x .dtype != to .value :
46
+ if x .dtype != to .as_int () :
50
47
return check_result .fail ("Input and output types are not the same" )
51
48
return check_result
52
49
53
50
54
- class CastCast (orp .RewriteRuleAsClass ):
51
+ class CastCast (orp .RewriteRuleClassBase ):
55
52
"""Replaces ``Cast(Cast(X, ...), to=to)`` by ``Cast(X, to=to)``."""
56
53
57
54
_allowed_tensor_types : ClassVar = {
@@ -61,37 +58,31 @@ class CastCast(orp.RewriteRuleAsClass):
61
58
ir .DataType .DOUBLE ,
62
59
}
63
60
64
- @classmethod
65
- def pattern (cls , op , x , to , to_ignored ):
61
+ def pattern (self , op , x , to , to_ignored ):
66
62
return op .Cast (op .Cast (x , to = to_ignored ), to = to )
67
63
68
- @classmethod
69
- def check (cls , context , x : ir .Value , to : ir .Attr , to_ignored : ir .Attr ) -> orp .MatchResult :
64
+ def check (self , context , x : ir .Value , to : ir .Attr , to_ignored : ir .Attr ) -> orp .MatchResult :
70
65
check_result = orp .MatchResult ()
71
- if to .value not in cls ._allowed_tensor_types :
72
- return check_result .fail (f"Output type { to .value } is not allowed" )
73
- if to_ignored .as_int () not in cls ._allowed_tensor_types :
74
- return check_result .fail (f"Ignored type { to_ignored .value } is not allowed" )
66
+ if to .as_int () not in self ._allowed_tensor_types :
67
+ return check_result .fail (f"Output type { to .as_int () } is not allowed" )
68
+ if to_ignored .as_int () not in self ._allowed_tensor_types :
69
+ return check_result .fail (f"Ignored type { to_ignored .as_int () } is not allowed" )
75
70
return check_result
76
71
77
- @classmethod
78
- def rewrite (cls , op , x : ir .Value , to : ir .Attr , to_ignored : ir .Attr ):
72
+ def rewrite (self , op , x : ir .Value , to : ir .Attr , to_ignored : ir .Attr ):
79
73
return op .Cast (x , to = to )
80
74
81
75
82
- class ExpandIdentity (orp .RewriteRuleAsClass ):
76
+ class ExpandIdentity (orp .RewriteRuleClassBase ):
83
77
"""Replaces ``Expand(..., shape)`` by ``Identity`` if possible."""
84
78
85
- @classmethod
86
- def pattern (cls , op , x , shape ):
79
+ def pattern (self , op , x , shape ):
87
80
return op .Expand (x , shape )
88
81
89
- @classmethod
90
- def rewrite (cls , op , x : ir .Value , shape : ir .Value ):
82
+ def rewrite (self , op , x : ir .Value , shape : ir .Value ):
91
83
return op .Identity (x )
92
84
93
- @classmethod
94
- def check (cls , context , x , shape ) -> orp .MatchResult :
85
+ def check (self , context , x , shape ) -> orp .MatchResult :
95
86
check_result = orp .MatchResult ()
96
87
if shape .const_value is None :
97
88
# Shape is not a constant and cannot be guessed.
@@ -106,22 +97,19 @@ def check(cls, context, x, shape) -> orp.MatchResult:
106
97
return check_result
107
98
108
99
109
- class ReshapeReshape (orp .RewriteRuleAsClass ):
100
+ class ReshapeReshape (orp .RewriteRuleClassBase ):
110
101
"""Replaces ``Reshape(Reshape(X, ...), shape)`` by ``Reshape(X, shape)``.
111
102
The pattern matches only if second reshape reshapes into a shape
112
103
with positive values.
113
104
"""
114
105
115
- @classmethod
116
- def pattern (cls , op , x , shape_ignored , shape ):
106
+ def pattern (self , op , x , shape_ignored , shape ):
117
107
return op .Reshape (op .Reshape (x , shape_ignored ), shape )
118
108
119
- @classmethod
120
- def rewrite (cls , op , x : ir .Value , shape_ignored : ir .Value , shape : ir .Value ):
109
+ def rewrite (self , op , x : ir .Value , shape_ignored : ir .Value , shape : ir .Value ):
121
110
return op .Reshape (x , shape )
122
111
123
- @classmethod
124
- def check (cls , context , x , shape_ignored , shape ) -> orp .MatchResult :
112
+ def check (self , context , x , shape_ignored , shape ) -> orp .MatchResult :
125
113
check_result = orp .MatchResult ()
126
114
if shape_ignored .const_value is None :
127
115
return check_result .fail ("Shape ignored is not a constant." )
@@ -132,17 +120,15 @@ def check(cls, context, x, shape_ignored, shape) -> orp.MatchResult:
132
120
return check_result
133
121
134
122
135
- class SlicesSplit (orp .RewriteRuleAsClass ):
123
+ class SlicesSplit (orp .RewriteRuleClassBase ):
136
124
"""Replaces ``Slice(x, ...), Slice(x, ...)``
137
125
by ``Split(x, ...)`` if possible.
138
126
"""
139
127
140
- @classmethod
141
- def pattern (cls , op , x , begin0 , end0 , axes0 , begin1 , end1 , axes1 ):
128
+ def pattern (self , op , x , begin0 , end0 , axes0 , begin1 , end1 , axes1 ):
142
129
return op .Slice (x , begin0 , end0 , axes0 ), op .Slice (x , begin1 , end1 , axes1 )
143
130
144
- @classmethod
145
- def check (cls , context , x , begin0 , end0 , axes0 , begin1 , end1 , axes1 ) -> orp .MatchResult :
131
+ def check (self , context , x , begin0 , end0 , axes0 , begin1 , end1 , axes1 ) -> orp .MatchResult :
146
132
check_result = orp .MatchResult ()
147
133
if (
148
134
axes0 .const_value is None
@@ -187,94 +173,83 @@ def check(cls, context, x, begin0, end0, axes0, begin1, end1, axes1) -> orp.Matc
187
173
return check_result .fail ("Last dimension is not equal to Begin1." )
188
174
return check_result
189
175
190
- @classmethod
191
- def rewrite (cls , op , x , begin0 , end0 , axes0 , begin1 , end1 , axes1 ):
176
+ def rewrite (self , op , x , begin0 , end0 , axes0 , begin1 , end1 , axes1 ):
192
177
return op .Split (x , num_outputs = 2 , axis = - 1 , _outputs = 2 )
193
178
194
179
195
- class TransposeIdentity (orp .RewriteRuleAsClass ):
180
+ class TransposeIdentity (orp .RewriteRuleClassBase ):
196
181
"""Replaces ``Transpose(. perm=perm)``
197
182
when the permutation is identity.
198
183
"""
199
184
200
- @classmethod
201
- def pattern (cls , op , x , perm ):
185
+ def pattern (self , op , x , perm ):
202
186
return op .Transpose (x , perm = perm )
203
187
204
- @classmethod
205
- def check (cls , context , x : ir .Value , perm : ir .Attr ) -> orp .MatchResult :
188
+ def check (self , context , x : ir .Value , perm : ir .Attr ) -> orp .MatchResult :
206
189
check_result = orp .MatchResult ()
207
190
if isinstance (perm , ir .RefAttr ):
208
191
return check_result .fail ("Permutation is a reference attribute." )
209
192
if perm .type == ir .AttributeType .INTS :
210
- if perm .value == list (range (len (perm .value ))):
193
+ perm_ints = perm .as_ints ()
194
+ if perm_ints == list (range (len (perm_ints ))):
211
195
return check_result
212
196
return check_result .fail ("Permutation is not identity." )
213
197
214
- @classmethod
215
- def rewrite (cls , op , x : ir .Value , perm : ir .Attr ):
198
+ def rewrite (self , op , x : ir .Value , perm : ir .Attr ):
216
199
return op .Identity (x )
217
200
218
201
219
- class TransposeTranspose (orp .RewriteRuleAsClass ):
202
+ class TransposeTranspose (orp .RewriteRuleClassBase ):
220
203
"""Replaces ``Transpose(Transpose(., perm=perm1), perm=perm2)``
221
204
when both permutations are inverse.
222
205
"""
223
206
224
- @classmethod
225
- def pattern (cls , op , x , perm1 , perm2 ):
207
+ def pattern (self , op , x , perm1 , perm2 ):
226
208
return op .Transpose (op .Transpose (x , perm = perm1 ), perm = perm2 )
227
209
228
- @classmethod
229
- def check (cls , context , x : ir .Value , perm1 : ir .Attr , perm2 : ir .Attr ) -> orp .MatchResult :
210
+ def check (self , context , x : ir .Value , perm1 : ir .Attr , perm2 : ir .Attr ) -> orp .MatchResult :
230
211
check_result = orp .MatchResult ()
231
212
if isinstance (perm1 , ir .RefAttr ) or isinstance (perm2 , ir .RefAttr ):
232
213
return check_result .fail ("Permutation is a reference attribute." )
233
214
return check_result
234
215
235
- @classmethod
236
- def _apply_transpose (cls , perm : tuple [int , ...], on : list [int ]) -> list [int ]:
216
+ def _apply_transpose (self , perm : Sequence [int ], on : list [int ]) -> list [int ]:
237
217
assert len (perm ) == len (on ), "length mismatch"
238
218
res = [- 1 for i in on ]
239
219
for i , p in enumerate (perm ):
240
220
res [i ] = on [p ]
241
221
return res
242
222
243
- @classmethod
244
223
def _apply_transposes (
245
- cls , perms : list [tuple [int , ... ]], on : list [int ] | None = None
224
+ self , perms : list [Sequence [int ]], on : list [int ] | None = None
246
225
) -> list [int ]:
247
226
if on is None :
248
227
on = list (range (len (perms [0 ])))
249
228
for p in perms :
250
- on = cls ._apply_transpose (p , on )
229
+ on = self ._apply_transpose (p , on )
251
230
return on
252
231
253
- @classmethod
254
- def rewrite (cls , op , x : ir .Value , perm1 : ir .Attr , perm2 : ir .Attr ):
255
- first = list (range (len (perm1 .value )))
256
- last = cls ._apply_transposes ([perm1 .value , perm2 .value ])
232
+ def rewrite (self , op , x : ir .Value , perm1 : ir .Attr , perm2 : ir .Attr ):
233
+ first = list (range (len (perm1 .as_ints ())))
234
+ last = self ._apply_transposes ([perm1 .as_ints (), perm2 .as_ints ()])
257
235
if first == last :
258
236
return op .Identity (x )
259
237
return op .Transpose (x , perm = last )
260
238
261
239
262
- class UnsqueezeUnsqueeze (orp .RewriteRuleAsClass ):
240
+ class UnsqueezeUnsqueeze (orp .RewriteRuleClassBase ):
263
241
"""Replaces ``Unsqueeze(Unsqueeze(., axes1), axes2)`` with one Unsqueeze."""
264
242
265
- @classmethod
266
- def pattern (cls , op , x , axes1 , axes2 ):
243
+ def pattern (self , op , x , axes1 , axes2 ):
267
244
return op .Unsqueeze (op .Unsqueeze (x , axes1 ), axes2 )
268
245
269
- @classmethod
270
- def rewrite (cls , op , x : ir .Value , axes1 : ir .Value , axes2 : ir .Value ):
246
+ def rewrite (self , op , x : ir .Value , axes1 : ir .Value , axes2 : ir .Value ):
271
247
v1 = ir_utils .get_singleton_value (axes1 )
272
248
v2 = ir_utils .get_singleton_value (axes2 )
273
249
axes = [v1 , v2 ] if v1 < v2 else [v2 , v1 + 1 ]
274
250
return op .Unsqueeze (x , op .Constant (value = ir .tensor (axes , dtype = ir .DataType .INT64 )))
275
251
276
- @classmethod
277
- def check (cls , context , x , axes1 , axes2 ) -> orp .MatchResult :
252
+ def check (self , context , x , axes1 , axes2 ) -> orp .MatchResult :
278
253
check_result = orp .MatchResult ()
279
254
del context # Unused
280
255
del x # Unused
@@ -288,14 +263,14 @@ def check(cls, context, x, axes1, axes2) -> orp.MatchResult:
288
263
return check_result
289
264
290
265
291
- cast_cast_rule = orp . make_rewrite_rule_from_class ( CastCast )
292
- cast_identity_rule = orp . make_rewrite_rule_from_class ( CastIdentity )
293
- expand_identity_rule = orp . make_rewrite_rule_from_class ( ExpandIdentity )
294
- reshape_reshape_rule = orp . make_rewrite_rule_from_class ( ReshapeReshape )
295
- slice_split_rule = orp . make_rewrite_rule_from_class ( SlicesSplit , True )
296
- transpose_identity_rule = orp . make_rewrite_rule_from_class ( TransposeIdentity )
297
- transpose_transpose_rule = orp . make_rewrite_rule_from_class ( TransposeTranspose )
298
- unsqueeze_unsqueeze_rule = orp . make_rewrite_rule_from_class ( UnsqueezeUnsqueeze )
266
+ cast_cast_rule = CastCast . rule ( )
267
+ cast_identity_rule = CastIdentity . rule ( )
268
+ expand_identity_rule = ExpandIdentity . rule ( )
269
+ reshape_reshape_rule = ReshapeReshape . rule ( )
270
+ slice_split_rule = SlicesSplit . rule ( )
271
+ transpose_identity_rule = TransposeIdentity . rule ( )
272
+ transpose_transpose_rule = TransposeTranspose . rule ( )
273
+ unsqueeze_unsqueeze_rule = UnsqueezeUnsqueeze . rule ( )
299
274
squeeze_reshape_1d_rule = SqueezeReshape .rule ()
300
275
301
276
0 commit comments