14
14
15
15
#include " mlir/Dialect/Arith/IR/Arith.h"
16
16
#include " mlir/Dialect/EmitC/IR/EmitC.h"
17
+ #include " mlir/Dialect/EmitC/Transforms/TypeConversions.h"
17
18
#include " mlir/Dialect/SCF/IR/SCF.h"
18
19
#include " mlir/IR/Builders.h"
19
20
#include " mlir/IR/BuiltinOps.h"
@@ -39,21 +40,22 @@ struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> {
39
40
40
41
// Lower scf::for to emitc::for, implementing result values using
41
42
// emitc::variable's updated within the loop body.
42
- struct ForLowering : public OpRewritePattern <ForOp> {
43
- using OpRewritePattern <ForOp>::OpRewritePattern ;
43
+ struct ForLowering : public OpConversionPattern <ForOp> {
44
+ using OpConversionPattern <ForOp>::OpConversionPattern ;
44
45
45
- LogicalResult matchAndRewrite (ForOp forOp,
46
- PatternRewriter &rewriter) const override ;
46
+ LogicalResult
47
+ matchAndRewrite (ForOp forOp, OpAdaptor adaptor,
48
+ ConversionPatternRewriter &rewriter) const override ;
47
49
};
48
50
49
51
// Create an uninitialized emitc::variable op for each result of the given op.
50
52
template <typename T>
51
- static SmallVector<Value> createVariablesForResults (T op,
52
- PatternRewriter &rewriter) {
53
- SmallVector<Value> resultVariables;
54
-
53
+ static LogicalResult
54
+ createVariablesForResults (T op, const TypeConverter *typeConverter,
55
+ ConversionPatternRewriter &rewriter,
56
+ SmallVector<Value> &resultVariables) {
55
57
if (!op.getNumResults ())
56
- return resultVariables ;
58
+ return success () ;
57
59
58
60
Location loc = op->getLoc ();
59
61
MLIRContext *context = op.getContext ();
@@ -62,21 +64,23 @@ static SmallVector<Value> createVariablesForResults(T op,
62
64
rewriter.setInsertionPoint (op);
63
65
64
66
for (OpResult result : op.getResults ()) {
65
- Type resultType = result.getType ();
67
+ Type resultType = typeConverter->convertType (result.getType ());
68
+ if (!resultType)
69
+ return rewriter.notifyMatchFailure (op, " result type conversion failed" );
66
70
Type varType = emitc::LValueType::get (resultType);
67
71
emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get (context, " " );
68
72
emitc::VariableOp var =
69
73
rewriter.create <emitc::VariableOp>(loc, varType, noInit);
70
74
resultVariables.push_back (var);
71
75
}
72
76
73
- return resultVariables ;
77
+ return success () ;
74
78
}
75
79
76
80
// Create a series of assign ops assigning given values to given variables at
77
81
// the current insertion point of given rewriter.
78
- static void assignValues (ValueRange values, SmallVector<Value> & variables,
79
- PatternRewriter &rewriter, Location loc) {
82
+ static void assignValues (ValueRange values, ValueRange variables,
83
+ ConversionPatternRewriter &rewriter, Location loc) {
80
84
for (auto [value, var] : llvm::zip (values, variables))
81
85
rewriter.create <emitc::AssignOp>(loc, var, value);
82
86
}
@@ -89,46 +93,58 @@ SmallVector<Value> loadValues(const SmallVector<Value> &variables,
89
93
});
90
94
}
91
95
92
- static void lowerYield (SmallVector<Value> &resultVariables,
93
- PatternRewriter &rewriter, scf::YieldOp yield) {
96
+ static LogicalResult lowerYield (Operation *op, ValueRange resultVariables,
97
+ ConversionPatternRewriter &rewriter,
98
+ scf::YieldOp yield) {
94
99
Location loc = yield.getLoc ();
95
- ValueRange operands = yield.getOperands ();
96
100
97
101
OpBuilder::InsertionGuard guard (rewriter);
98
102
rewriter.setInsertionPoint (yield);
99
103
100
- assignValues (operands, resultVariables, rewriter, loc);
104
+ SmallVector<Value> yieldOperands;
105
+ if (failed (rewriter.getRemappedValues (yield.getOperands (), yieldOperands))) {
106
+ return rewriter.notifyMatchFailure (op, " failed to lower yield operands" );
107
+ }
108
+
109
+ assignValues (yieldOperands, resultVariables, rewriter, loc);
101
110
102
111
rewriter.create <emitc::YieldOp>(loc);
103
112
rewriter.eraseOp (yield);
113
+
114
+ return success ();
104
115
}
105
116
106
117
// Lower the contents of an scf::if/scf::index_switch regions to an
107
118
// emitc::if/emitc::switch region. The contents of the lowering region is
108
119
// moved into the respective lowered region, but the scf::yield is replaced not
109
120
// only with an emitc::yield, but also with a sequence of emitc::assign ops that
110
121
// set the yielded values into the result variables.
111
- static void lowerRegion (SmallVector<Value> & resultVariables,
112
- PatternRewriter &rewriter, Region ®ion ,
113
- Region &loweredRegion) {
122
+ static LogicalResult lowerRegion (Operation *op, ValueRange resultVariables,
123
+ ConversionPatternRewriter &rewriter ,
124
+ Region ®ion, Region &loweredRegion) {
114
125
rewriter.inlineRegionBefore (region, loweredRegion, loweredRegion.end ());
115
126
Operation *terminator = loweredRegion.back ().getTerminator ();
116
- lowerYield (resultVariables, rewriter, cast<scf::YieldOp>(terminator));
127
+ return lowerYield (op, resultVariables, rewriter,
128
+ cast<scf::YieldOp>(terminator));
117
129
}
118
130
119
- LogicalResult ForLowering::matchAndRewrite (ForOp forOp,
120
- PatternRewriter &rewriter) const {
131
+ LogicalResult
132
+ ForLowering::matchAndRewrite (ForOp forOp, OpAdaptor adaptor,
133
+ ConversionPatternRewriter &rewriter) const {
121
134
Location loc = forOp.getLoc ();
122
135
123
136
// Create an emitc::variable op for each result. These variables will be
124
137
// assigned to by emitc::assign ops within the loop body.
125
- SmallVector<Value> resultVariables =
126
- createVariablesForResults (forOp, rewriter);
138
+ SmallVector<Value> resultVariables;
139
+ if (failed (createVariablesForResults (forOp, getTypeConverter (), rewriter,
140
+ resultVariables)))
141
+ return rewriter.notifyMatchFailure (forOp,
142
+ " create variables for results failed" );
127
143
128
- assignValues (forOp. getInits (), resultVariables, rewriter, loc);
144
+ assignValues (adaptor. getInitArgs (), resultVariables, rewriter, loc);
129
145
130
146
emitc::ForOp loweredFor = rewriter.create <emitc::ForOp>(
131
- loc, forOp .getLowerBound (), forOp .getUpperBound (), forOp .getStep ());
147
+ loc, adaptor .getLowerBound (), adaptor .getUpperBound (), adaptor .getStep ());
132
148
133
149
Block *loweredBody = loweredFor.getBody ();
134
150
@@ -143,13 +159,27 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
143
159
144
160
rewriter.restoreInsertionPoint (ip);
145
161
162
+ // Convert the original region types into the new types by adding unrealized
163
+ // casts in the beginning of the loop. This performs the conversion in place.
164
+ if (failed (rewriter.convertRegionTypes (&forOp.getRegion (),
165
+ *getTypeConverter (), nullptr ))) {
166
+ return rewriter.notifyMatchFailure (forOp, " region types conversion failed" );
167
+ }
168
+
169
+ // Register the replacements for the block arguments and inline the body of
170
+ // the scf.for loop into the body of the emitc::for loop.
171
+ Block *scfBody = &(forOp.getRegion ().front ());
146
172
SmallVector<Value> replacingValues;
147
173
replacingValues.push_back (loweredFor.getInductionVar ());
148
174
replacingValues.append (iterArgsValues.begin (), iterArgsValues.end ());
175
+ rewriter.mergeBlocks (scfBody, loweredBody, replacingValues);
149
176
150
- rewriter.mergeBlocks (forOp.getBody (), loweredBody, replacingValues);
151
- lowerYield (resultVariables, rewriter,
152
- cast<scf::YieldOp>(loweredBody->getTerminator ()));
177
+ auto result = lowerYield (forOp, resultVariables, rewriter,
178
+ cast<scf::YieldOp>(loweredBody->getTerminator ()));
179
+
180
+ if (failed (result)) {
181
+ return result;
182
+ }
153
183
154
184
// Load variables into SSA values after the for loop.
155
185
SmallVector<Value> resultValues = loadValues (resultVariables, rewriter, loc);
@@ -160,38 +190,66 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
160
190
161
191
// Lower scf::if to emitc::if, implementing result values as emitc::variable's
162
192
// updated within the then and else regions.
163
- struct IfLowering : public OpRewritePattern <IfOp> {
164
- using OpRewritePattern <IfOp>::OpRewritePattern ;
193
+ struct IfLowering : public OpConversionPattern <IfOp> {
194
+ using OpConversionPattern <IfOp>::OpConversionPattern ;
165
195
166
- LogicalResult matchAndRewrite (IfOp ifOp,
167
- PatternRewriter &rewriter) const override ;
196
+ LogicalResult
197
+ matchAndRewrite (IfOp ifOp, OpAdaptor adaptor,
198
+ ConversionPatternRewriter &rewriter) const override ;
168
199
};
169
200
170
201
} // namespace
171
202
172
- LogicalResult IfLowering::matchAndRewrite (IfOp ifOp,
173
- PatternRewriter &rewriter) const {
203
+ LogicalResult
204
+ IfLowering::matchAndRewrite (IfOp ifOp, OpAdaptor adaptor,
205
+ ConversionPatternRewriter &rewriter) const {
174
206
Location loc = ifOp.getLoc ();
175
207
176
208
// Create an emitc::variable op for each result. These variables will be
177
209
// assigned to by emitc::assign ops within the then & else regions.
178
- SmallVector<Value> resultVariables =
179
- createVariablesForResults (ifOp, rewriter);
180
-
181
- Region &thenRegion = ifOp.getThenRegion ();
182
- Region &elseRegion = ifOp.getElseRegion ();
210
+ SmallVector<Value> resultVariables;
211
+ if (failed (createVariablesForResults (ifOp, getTypeConverter (), rewriter,
212
+ resultVariables)))
213
+ return rewriter.notifyMatchFailure (ifOp,
214
+ " create variables for results failed" );
215
+
216
+ // Utility function to lower the contents of an scf::if region to an emitc::if
217
+ // region. The contents of the scf::if regions is moved into the respective
218
+ // emitc::if regions, but the scf::yield is replaced not only with an
219
+ // emitc::yield, but also with a sequence of emitc::assign ops that set the
220
+ // yielded values into the result variables.
221
+ auto lowerRegion = [&resultVariables, &rewriter,
222
+ &ifOp](Region ®ion, Region &loweredRegion) {
223
+ rewriter.inlineRegionBefore (region, loweredRegion, loweredRegion.end ());
224
+ Operation *terminator = loweredRegion.back ().getTerminator ();
225
+ auto result = lowerYield (ifOp, resultVariables, rewriter,
226
+ cast<scf::YieldOp>(terminator));
227
+ if (failed (result)) {
228
+ return result;
229
+ }
230
+ return success ();
231
+ };
232
+
233
+ Region &thenRegion = adaptor.getThenRegion ();
234
+ Region &elseRegion = adaptor.getElseRegion ();
183
235
184
236
bool hasElseBlock = !elseRegion.empty ();
185
237
186
238
auto loweredIf =
187
- rewriter.create <emitc::IfOp>(loc, ifOp .getCondition (), false , false );
239
+ rewriter.create <emitc::IfOp>(loc, adaptor .getCondition (), false , false );
188
240
189
241
Region &loweredThenRegion = loweredIf.getThenRegion ();
190
- lowerRegion (resultVariables, rewriter, thenRegion, loweredThenRegion);
242
+ auto result = lowerRegion (thenRegion, loweredThenRegion);
243
+ if (failed (result)) {
244
+ return result;
245
+ }
191
246
192
247
if (hasElseBlock) {
193
248
Region &loweredElseRegion = loweredIf.getElseRegion ();
194
- lowerRegion (resultVariables, rewriter, elseRegion, loweredElseRegion);
249
+ auto result = lowerRegion (elseRegion, loweredElseRegion);
250
+ if (failed (result)) {
251
+ return result;
252
+ }
195
253
}
196
254
197
255
rewriter.setInsertionPointAfter (ifOp);
@@ -203,37 +261,46 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
203
261
204
262
// Lower scf::index_switch to emitc::switch, implementing result values as
205
263
// emitc::variable's updated within the case and default regions.
206
- struct IndexSwitchOpLowering : public OpRewritePattern <IndexSwitchOp> {
207
- using OpRewritePattern<IndexSwitchOp>::OpRewritePattern ;
264
+ struct IndexSwitchOpLowering : public OpConversionPattern <IndexSwitchOp> {
265
+ using OpConversionPattern::OpConversionPattern ;
208
266
209
- LogicalResult matchAndRewrite (IndexSwitchOp indexSwitchOp,
210
- PatternRewriter &rewriter) const override ;
267
+ LogicalResult
268
+ matchAndRewrite (IndexSwitchOp indexSwitchOp, OpAdaptor adaptor,
269
+ ConversionPatternRewriter &rewriter) const override ;
211
270
};
212
271
213
- LogicalResult
214
- IndexSwitchOpLowering::matchAndRewrite ( IndexSwitchOp indexSwitchOp,
215
- PatternRewriter &rewriter) const {
272
+ LogicalResult IndexSwitchOpLowering::matchAndRewrite (
273
+ IndexSwitchOp indexSwitchOp, OpAdaptor adaptor ,
274
+ ConversionPatternRewriter &rewriter) const {
216
275
Location loc = indexSwitchOp.getLoc ();
217
276
218
277
// Create an emitc::variable op for each result. These variables will be
219
278
// assigned to by emitc::assign ops within the case and default regions.
220
- SmallVector<Value> resultVariables =
221
- createVariablesForResults (indexSwitchOp, rewriter);
279
+ SmallVector<Value> resultVariables;
280
+ if (failed (createVariablesForResults (indexSwitchOp, getTypeConverter (),
281
+ rewriter, resultVariables))) {
282
+ return rewriter.notifyMatchFailure (indexSwitchOp,
283
+ " create variables for results failed" );
284
+ }
222
285
223
286
auto loweredSwitch = rewriter.create <emitc::SwitchOp>(
224
- loc, indexSwitchOp.getArg (), indexSwitchOp.getCases (),
225
- indexSwitchOp.getNumCases ());
287
+ loc, adaptor.getArg (), adaptor.getCases (), indexSwitchOp.getNumCases ());
226
288
227
289
// Lowering all case regions.
228
- for (auto pair : llvm::zip (indexSwitchOp.getCaseRegions (),
229
- loweredSwitch.getCaseRegions ())) {
230
- lowerRegion (resultVariables, rewriter, std::get<0 >(pair),
231
- std::get<1 >(pair));
290
+ for (auto pair :
291
+ llvm::zip (adaptor.getCaseRegions (), loweredSwitch.getCaseRegions ())) {
292
+ if (failed (lowerRegion (indexSwitchOp, resultVariables, rewriter,
293
+ *std::get<0 >(pair), std::get<1 >(pair)))) {
294
+ return failure ();
295
+ }
232
296
}
233
297
234
298
// Lowering default region.
235
- lowerRegion (resultVariables, rewriter, indexSwitchOp.getDefaultRegion (),
236
- loweredSwitch.getDefaultRegion ());
299
+ if (failed (lowerRegion (indexSwitchOp, resultVariables, rewriter,
300
+ adaptor.getDefaultRegion (),
301
+ loweredSwitch.getDefaultRegion ()))) {
302
+ return failure ();
303
+ }
237
304
238
305
rewriter.setInsertionPointAfter (indexSwitchOp);
239
306
SmallVector<Value> results = loadValues (resultVariables, rewriter, loc);
@@ -242,15 +309,22 @@ IndexSwitchOpLowering::matchAndRewrite(IndexSwitchOp indexSwitchOp,
242
309
return success ();
243
310
}
244
311
245
- void mlir::populateSCFToEmitCConversionPatterns (RewritePatternSet &patterns) {
246
- patterns.add <ForLowering>(patterns.getContext ());
247
- patterns.add <IfLowering>(patterns.getContext ());
248
- patterns.add <IndexSwitchOpLowering>(patterns.getContext ());
312
+ void mlir::populateSCFToEmitCConversionPatterns (RewritePatternSet &patterns,
313
+ TypeConverter &typeConverter) {
314
+ patterns.add <ForLowering>(typeConverter, patterns.getContext ());
315
+ patterns.add <IfLowering>(typeConverter, patterns.getContext ());
316
+ patterns.add <IndexSwitchOpLowering>(typeConverter, patterns.getContext ());
249
317
}
250
318
251
319
void SCFToEmitCPass::runOnOperation () {
252
320
RewritePatternSet patterns (&getContext ());
253
- populateSCFToEmitCConversionPatterns (patterns);
321
+ TypeConverter typeConverter;
322
+ // Fallback converter
323
+ // See note https://mlir.llvm.org/docs/DialectConversion/#type-converter
324
+ // Type converters are called most to least recently inserted
325
+ typeConverter.addConversion ([](Type t) { return t; });
326
+ populateEmitCSizeTTypeConversions (typeConverter);
327
+ populateSCFToEmitCConversionPatterns (patterns, typeConverter);
254
328
255
329
// Configure conversion to lower out SCF operations.
256
330
ConversionTarget target (getContext ());
0 commit comments