@@ -110,8 +110,29 @@ def rewrite_eye(g, ops):
110
110
OpTypePattern ("Const" , name = "fill_value" ),
111
111
]), "*"
112
112
])
113
+ pattern7 = \
114
+ OpTypePattern ("MatrixDiag" , name = "output_eye_matrix" , inputs = [
115
+ OpTypePattern ("Fill" , inputs = [
116
+ OpTypePattern ("Reshape" , inputs = [
117
+ OpTypePattern ("Minimum|Cast" , name = "min_or_cast" ),
118
+ "*" ,
119
+ ]),
120
+ OpTypePattern ("Const" , name = "fill_value" ),
121
+ ])
122
+ ])
123
+ pattern8 = \
124
+ OpTypePattern ("MatrixSetDiag" , name = "output_eye_matrix" , inputs = [
125
+ OpTypePattern ("Fill" ),
126
+ OpTypePattern ("Fill" , inputs = [
127
+ OpTypePattern ("Reshape" , inputs = [
128
+ OpTypePattern ("Minimum|Cast" , name = "min_or_cast" ),
129
+ "*" ,
130
+ ]),
131
+ OpTypePattern ("Const" , name = "fill_value" ),
132
+ ])
133
+ ])
113
134
114
- for pattern in [pattern1 , pattern2 , pattern3 , pattern4 , pattern5 , pattern6 ]:
135
+ for pattern in [pattern1 , pattern2 , pattern3 , pattern4 , pattern5 , pattern6 , pattern7 , pattern8 ]:
115
136
matcher = GraphMatcher (pattern , allow_reorder = True )
116
137
match_results = list (matcher .match_ops (ops ))
117
138
for match_result in match_results :
@@ -146,6 +167,6 @@ def rewrite_eye(g, ops):
146
167
zero_matrix = g .make_node ("ConstantOfShape" , matrix_shape_int64 .output )
147
168
148
169
g .make_node ("EyeLike" , zero_matrix .output , attr = {"dtype" : output_dtypes [0 ]},
149
- name = old_output .name , shapes = output_shapes , dtypes = output_dtypes )
170
+ name = old_output .name , shapes = output_shapes , dtypes = output_dtypes , outputs = [ old_output . output [ 0 ]] )
150
171
151
172
return g .get_nodes ()
0 commit comments