18
18
#include " mlir/Dialect/Tosa/IR/TosaOps.h"
19
19
#include " mlir/Dialect/Tosa/Utils/ShapeUtils.h"
20
20
#include " mlir/IR/Builders.h"
21
- #include " mlir/IR/BuiltinOps.h"
22
- #include " mlir/IR/IRMapping.h"
23
- #include " mlir/IR/Matchers.h"
24
21
#include " mlir/Interfaces/InferTypeOpInterface.h"
25
22
#include " mlir/Pass/Pass.h"
26
23
#include " mlir/Transforms/DialectConversion.h"
27
- #include " mlir/Transforms/GreedyPatternRewriteDriver.h"
28
- #include " llvm/Support/FormatVariadic.h"
29
24
30
25
namespace mlir {
31
26
namespace tosa {
@@ -39,9 +34,87 @@ using namespace mlir::tosa;
39
34
40
35
namespace {
41
36
42
- void propagateShapesInRegion (Region ®ion);
37
+ // Check whether this use case is replaceable. We define an op as
38
+ // being replaceable if it is used by a TosaOp, or an op with a
39
+ // type-inference related interface.
40
+ // When a non-replaceable use is encountered, the value is wrapped in a
41
+ // cast back to the original type after inference.
42
+ bool isReplaceableUser (Operation *user) {
43
+ // Handle unregistered dialects.
44
+ if (!user->getDialect ())
45
+ return false ;
46
+
47
+ return user->getDialect ()->getNamespace () ==
48
+ TosaDialect::getDialectNamespace () ||
49
+ isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
50
+ }
51
+
52
+ // During type propagation, the types of values in the operator graph are
53
+ // updated. For the tosa.while_loop operation, types are speculatively updated
54
+ // within the body region to determine the output type of the while_loop. This
55
+ // process is performed until a fixed point is reached, then the types are
56
+ // reverted.
57
+ //
58
+ // This class encapsulates the state information needed to perform the reversion
59
+ // process or to commit to the final changes.
60
+ class TypeModificationState {
61
+ public:
62
+ TypeModificationState () = default ;
63
+
64
+ ~TypeModificationState () {
65
+ // Ensure the recorded modifications are either committed or reverted.
66
+ assert (oldTypes.empty () && " unhandled type modifications" );
67
+ }
68
+
69
+ // Update the state of the value and record the old type.
70
+ void setType (Value value, Type type) {
71
+ if (value.getType () != type) {
72
+ oldTypes.emplace_back (value, value.getType ());
73
+ value.setType (type);
74
+ }
75
+ }
43
76
44
- void propagateShapesToTosaIf (Operation &op) {
77
+ // Revert changes made to the types in the IR by setting all the affected
78
+ // values to their old types.
79
+ void revert () {
80
+ // Otherwise revert the changes.
81
+ for (auto [value, type] : oldTypes)
82
+ value.setType (type);
83
+
84
+ oldTypes.clear ();
85
+ }
86
+
87
+ // Commit the changes to the types in the IR.
88
+ // This requires inserting tensor.cast operations to mediate the newly
89
+ // inferred result types with users that do not support type inference.
90
+ void commit () {
91
+ // For each use whose type changed, cast the value with the new type back to
92
+ // the old type.
93
+ for (auto [value, oldType] : oldTypes) {
94
+ for (auto &use : value.getUses ()) {
95
+ if (isReplaceableUser (use.getOwner ()))
96
+ continue ;
97
+
98
+ OpBuilder builder (value.getContext ());
99
+ builder.setInsertionPoint (use.getOwner ());
100
+
101
+ Location loc = value.getLoc ();
102
+ use.set (builder.create <tensor::CastOp>(loc, oldType, value));
103
+ }
104
+ }
105
+
106
+ oldTypes.clear ();
107
+ }
108
+
109
+ private:
110
+ // A record of each value whose type was updated along with that value's
111
+ // previous type.
112
+ llvm::SmallVector<std::pair<Value, Type>> oldTypes;
113
+ };
114
+
115
+ void propagateShapesInRegion (Region ®ion, TypeModificationState &state);
116
+
117
+ void propagateShapesToTosaIf (Operation &op, TypeModificationState &state) {
45
118
IfOp ifOp = dyn_cast<IfOp>(op);
46
119
if (!ifOp)
47
120
return ;
@@ -58,7 +131,7 @@ void propagateShapesToTosaIf(Operation &op) {
58
131
59
132
if (inferredTy.hasRank ()) {
60
133
Type newType = oldType.clone (inferredTy.getShape ());
61
- blockArg .setType (newType);
134
+ state .setType (blockArg, newType);
62
135
}
63
136
}
64
137
@@ -71,64 +144,44 @@ void propagateShapesToTosaIf(Operation &op) {
71
144
ValueKnowledge::join (operandKnowledge, blockKnowledge);
72
145
if (!joinedKnowledge)
73
146
continue ;
74
- frontBlock.getArgument (i). setType ( joinedKnowledge.getType ());
147
+ state. setType ( frontBlock.getArgument (i), joinedKnowledge.getType ());
75
148
}
76
149
77
- propagateShapesInRegion (region);
150
+ propagateShapesInRegion (region, state );
78
151
}
79
152
}
80
153
81
- void propagateShapesToTosaWhile (Operation &op) {
154
+ void propagateShapesToTosaWhile (Operation &op, TypeModificationState &state ) {
82
155
WhileOp whileOp = dyn_cast<WhileOp>(op);
83
156
if (!whileOp)
84
157
return ;
85
158
86
159
// Determine what the expected argument types are to the cond/body blocks.
87
160
// The expected arguments should be compatible with ever iteration of the
88
161
// loop body / condition for tosa.while.
89
- llvm::SmallVector<Type> argTypes;
90
- for (auto operand : op.getOperands ()) {
91
- auto operandTy = cast<ShapedType>(operand.getType ());
92
- if (operandTy.hasRank ()) {
93
- auto newTy = operandTy.clone (operandTy.getShape ());
94
- argTypes.push_back (newTy);
95
- } else {
96
- argTypes.push_back (operand.getType ());
97
- }
98
- }
99
-
100
- // Save out the type information so we can restore at the end.
101
- llvm::DenseMap<Value, Type> originalTypeMap;
102
- for (auto &block : op.getRegion (1 )) {
103
- for (auto arg : block.getArguments ())
104
- originalTypeMap[arg] = arg.getType ();
105
- for (auto &op : block)
106
- for (auto result : op.getResults ())
107
- originalTypeMap[result] = result.getType ();
108
- }
162
+ SmallVector<Type> argTypes = llvm::to_vector (op.getOperandTypes ());
109
163
110
164
bool hasNewTypes = true ;
111
165
while (hasNewTypes) {
166
+ TypeModificationState localState;
112
167
113
168
// Set types on the block args.
114
169
Region &bodyRegion = op.getRegion (1 );
115
170
Block &block = bodyRegion.front ();
116
171
for (int i = 0 , s = argTypes.size (); i < s; i++) {
117
- block.getArgument (i). setType ( argTypes[i]);
172
+ localState. setType ( block.getArgument (i), argTypes[i]);
118
173
}
119
174
120
175
// Propagate to the end.
121
- propagateShapesInRegion (bodyRegion);
176
+ propagateShapesInRegion (bodyRegion, localState );
122
177
123
- // Find all the tosa yield types and verify there is atleast one.
178
+ // Find all the tosa yield types and verify there is a single one.
124
179
llvm::SmallVector<YieldOp> yieldOps;
125
180
for (auto &block : bodyRegion)
126
181
if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator ()))
127
182
yieldOps.push_back (yieldOp);
128
183
129
- if (yieldOps.empty ())
130
- return ;
131
-
184
+ assert (yieldOps.size () == 1 && " missing or non-unique yield op" );
132
185
// Using the new tosa.yield operand types, infer the new subtypes.
133
186
llvm::SmallVector<ValueKnowledge> yieldTypeInfo;
134
187
for (auto ty : argTypes) {
@@ -158,59 +211,31 @@ void propagateShapesToTosaWhile(Operation &op) {
158
211
argTypes[i] = newType;
159
212
}
160
213
161
- // The types inferred in the block assume the operand types specified for
162
- // this iteration. We need to restore the original types to ensure that
163
- // future iterations only use the already specified types, not possible
164
- // types from previous iterations.
165
- for (auto &block : bodyRegion) {
166
- for (auto arg : block.getArguments ())
167
- arg.setType (originalTypeMap[arg]);
168
- for (auto &op : block)
169
- for (auto result : op.getResults ())
170
- result.setType (originalTypeMap[result]);
171
- }
214
+ // Revert all changes made during the speculative part of the algorithm.
215
+ localState.revert ();
172
216
}
173
217
174
218
// We now set the block arguments according to the most recent shape
175
219
// inference results. This gives us the block arg types for the next
176
220
// iteration.
177
221
for (auto ®ion : op.getRegions ()) {
178
222
for (unsigned int i = 0 , s = argTypes.size (); i < s; i++) {
179
- region.front ().getArgument (i). setType ( argTypes[i]);
223
+ state. setType ( region.front ().getArgument (i), argTypes[i]);
180
224
}
181
225
182
- propagateShapesInRegion (region);
226
+ propagateShapesInRegion (region, state );
183
227
}
184
228
}
185
229
186
- // Track the old type for each operand whose type was updated
187
- // during inference. This information is used to introduce casts
188
- // back to the type expected by the operand after inference.
189
- struct TypeRewriteInfo {
190
- OpOperand *operand;
191
- Type oldType;
192
- };
193
-
194
- void propagateShapesInRegion (Region ®ion) {
195
- // Check whether this use case is replaceable. We define an op as
196
- // being replaceable if it is used by a TosaOp, or an op with a
197
- // type-inference related interface.
198
- // When a non-replaceable use is encountered, the value is wrapped in a
199
- // cast back to the original type after inference.
200
- auto isReplaceableUser = [](Operation *user) -> bool {
201
- return user->getDialect ()->getNamespace () ==
202
- TosaDialect::getDialectNamespace () ||
203
- isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
204
- };
205
-
206
- llvm::SmallVector<TypeRewriteInfo> requiresUpdate;
230
+ void propagateShapesInRegion (Region ®ion, TypeModificationState &state) {
207
231
for (auto &block : region) {
208
232
for (Operation &op : block) {
209
- if (op.getDialect ()->getNamespace () != TosaDialect::getDialectNamespace ())
233
+ if (!op.getDialect () ||
234
+ op.getDialect ()->getNamespace () != TosaDialect::getDialectNamespace ())
210
235
continue ;
211
236
212
- propagateShapesToTosaIf (op);
213
- propagateShapesToTosaWhile (op);
237
+ propagateShapesToTosaIf (op, state );
238
+ propagateShapesToTosaWhile (op, state );
214
239
215
240
InferShapedTypeOpInterface shapeInterface =
216
241
dyn_cast<InferShapedTypeOpInterface>(op);
@@ -252,30 +277,11 @@ void propagateShapesInRegion(Region ®ion) {
252
277
continue ;
253
278
254
279
// Set new type
255
- result.setType (newKnowledge.getType ());
256
-
257
- // Collect all uses of the operation which require update.
258
- for (auto &user : result.getUses ()) {
259
- if (!isReplaceableUser (user.getOwner ()))
260
- requiresUpdate.push_back ({&user, resultTy});
261
- }
280
+ state.setType (result, newKnowledge.getType ());
262
281
}
263
282
}
264
283
}
265
284
}
266
-
267
- // For each use whose type changed, cast the value with the new type back to
268
- // the old type.
269
- IRRewriter rewriter (region.getContext ());
270
- for (auto [operand, oldType] : requiresUpdate) {
271
- rewriter.setInsertionPoint (operand->getOwner ());
272
-
273
- auto oldValue = operand->get ();
274
-
275
- auto loc = oldValue.getLoc ();
276
- auto castOp = rewriter.create <tensor::CastOp>(loc, oldType, oldValue);
277
- operand->set (castOp);
278
- }
279
285
}
280
286
281
287
// / Pass that performs shape propagation across TOSA operations. This includes
@@ -285,7 +291,9 @@ struct TosaInferShapes
285
291
public:
286
292
void runOnOperation () override {
287
293
func::FuncOp func = getOperation ();
288
- propagateShapesInRegion (func.getBody ());
294
+ TypeModificationState state;
295
+ propagateShapesInRegion (func.getBody (), state);
296
+ state.commit ();
289
297
}
290
298
};
291
299
} // namespace
0 commit comments