Skip to content

Commit 450c6b0

Browse files
authored
[MLIR][SCFToEmitC] Convert types while converting from SCF to EmitC (#118940)
Switch from rewrite patterns to conversion patterns. This allows to perform type conversions together with other parts of the IR. For example, this allows to convert from index to emit.size_t types.
1 parent 328ff04 commit 450c6b0

File tree

4 files changed

+228
-80
lines changed

4 files changed

+228
-80
lines changed

mlir/include/mlir/Conversion/SCFToEmitC/SCFToEmitC.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#ifndef MLIR_CONVERSION_SCFTOEMITC_SCFTOEMITC_H
1010
#define MLIR_CONVERSION_SCFTOEMITC_SCFTOEMITC_H
1111

12+
#include "mlir/Transforms/DialectConversion.h"
1213
#include <memory>
1314

1415
namespace mlir {
@@ -19,7 +20,8 @@ class RewritePatternSet;
1920
#include "mlir/Conversion/Passes.h.inc"
2021

2122
/// Collect a set of patterns to convert SCF operations to the EmitC dialect.
22-
void populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns);
23+
void populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns,
24+
TypeConverter &typeConverter);
2325
} // namespace mlir
2426

2527
#endif // MLIR_CONVERSION_SCFTOEMITC_SCFTOEMITC_H

mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp

Lines changed: 140 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "mlir/Dialect/Arith/IR/Arith.h"
1616
#include "mlir/Dialect/EmitC/IR/EmitC.h"
17+
#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h"
1718
#include "mlir/Dialect/SCF/IR/SCF.h"
1819
#include "mlir/IR/Builders.h"
1920
#include "mlir/IR/BuiltinOps.h"
@@ -39,21 +40,22 @@ struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> {
3940

4041
// Lower scf::for to emitc::for, implementing result values using
4142
// emitc::variable's updated within the loop body.
42-
struct ForLowering : public OpRewritePattern<ForOp> {
43-
using OpRewritePattern<ForOp>::OpRewritePattern;
43+
struct ForLowering : public OpConversionPattern<ForOp> {
44+
using OpConversionPattern<ForOp>::OpConversionPattern;
4445

45-
LogicalResult matchAndRewrite(ForOp forOp,
46-
PatternRewriter &rewriter) const override;
46+
LogicalResult
47+
matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
48+
ConversionPatternRewriter &rewriter) const override;
4749
};
4850

4951
// Create an uninitialized emitc::variable op for each result of the given op.
5052
template <typename T>
51-
static SmallVector<Value> createVariablesForResults(T op,
52-
PatternRewriter &rewriter) {
53-
SmallVector<Value> resultVariables;
54-
53+
static LogicalResult
54+
createVariablesForResults(T op, const TypeConverter *typeConverter,
55+
ConversionPatternRewriter &rewriter,
56+
SmallVector<Value> &resultVariables) {
5557
if (!op.getNumResults())
56-
return resultVariables;
58+
return success();
5759

5860
Location loc = op->getLoc();
5961
MLIRContext *context = op.getContext();
@@ -62,21 +64,23 @@ static SmallVector<Value> createVariablesForResults(T op,
6264
rewriter.setInsertionPoint(op);
6365

6466
for (OpResult result : op.getResults()) {
65-
Type resultType = result.getType();
67+
Type resultType = typeConverter->convertType(result.getType());
68+
if (!resultType)
69+
return rewriter.notifyMatchFailure(op, "result type conversion failed");
6670
Type varType = emitc::LValueType::get(resultType);
6771
emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
6872
emitc::VariableOp var =
6973
rewriter.create<emitc::VariableOp>(loc, varType, noInit);
7074
resultVariables.push_back(var);
7175
}
7276

73-
return resultVariables;
77+
return success();
7478
}
7579

7680
// Create a series of assign ops assigning given values to given variables at
7781
// the current insertion point of given rewriter.
78-
static void assignValues(ValueRange values, SmallVector<Value> &variables,
79-
PatternRewriter &rewriter, Location loc) {
82+
static void assignValues(ValueRange values, ValueRange variables,
83+
ConversionPatternRewriter &rewriter, Location loc) {
8084
for (auto [value, var] : llvm::zip(values, variables))
8185
rewriter.create<emitc::AssignOp>(loc, var, value);
8286
}
@@ -89,46 +93,58 @@ SmallVector<Value> loadValues(const SmallVector<Value> &variables,
8993
});
9094
}
9195

92-
static void lowerYield(SmallVector<Value> &resultVariables,
93-
PatternRewriter &rewriter, scf::YieldOp yield) {
96+
static LogicalResult lowerYield(Operation *op, ValueRange resultVariables,
97+
ConversionPatternRewriter &rewriter,
98+
scf::YieldOp yield) {
9499
Location loc = yield.getLoc();
95-
ValueRange operands = yield.getOperands();
96100

97101
OpBuilder::InsertionGuard guard(rewriter);
98102
rewriter.setInsertionPoint(yield);
99103

100-
assignValues(operands, resultVariables, rewriter, loc);
104+
SmallVector<Value> yieldOperands;
105+
if (failed(rewriter.getRemappedValues(yield.getOperands(), yieldOperands))) {
106+
return rewriter.notifyMatchFailure(op, "failed to lower yield operands");
107+
}
108+
109+
assignValues(yieldOperands, resultVariables, rewriter, loc);
101110

102111
rewriter.create<emitc::YieldOp>(loc);
103112
rewriter.eraseOp(yield);
113+
114+
return success();
104115
}
105116

106117
// Lower the contents of an scf::if/scf::index_switch regions to an
107118
// emitc::if/emitc::switch region. The contents of the lowering region is
108119
// moved into the respective lowered region, but the scf::yield is replaced not
109120
// only with an emitc::yield, but also with a sequence of emitc::assign ops that
110121
// set the yielded values into the result variables.
111-
static void lowerRegion(SmallVector<Value> &resultVariables,
112-
PatternRewriter &rewriter, Region &region,
113-
Region &loweredRegion) {
122+
static LogicalResult lowerRegion(Operation *op, ValueRange resultVariables,
123+
ConversionPatternRewriter &rewriter,
124+
Region &region, Region &loweredRegion) {
114125
rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
115126
Operation *terminator = loweredRegion.back().getTerminator();
116-
lowerYield(resultVariables, rewriter, cast<scf::YieldOp>(terminator));
127+
return lowerYield(op, resultVariables, rewriter,
128+
cast<scf::YieldOp>(terminator));
117129
}
118130

119-
LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
120-
PatternRewriter &rewriter) const {
131+
LogicalResult
132+
ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
133+
ConversionPatternRewriter &rewriter) const {
121134
Location loc = forOp.getLoc();
122135

123136
// Create an emitc::variable op for each result. These variables will be
124137
// assigned to by emitc::assign ops within the loop body.
125-
SmallVector<Value> resultVariables =
126-
createVariablesForResults(forOp, rewriter);
138+
SmallVector<Value> resultVariables;
139+
if (failed(createVariablesForResults(forOp, getTypeConverter(), rewriter,
140+
resultVariables)))
141+
return rewriter.notifyMatchFailure(forOp,
142+
"create variables for results failed");
127143

128-
assignValues(forOp.getInits(), resultVariables, rewriter, loc);
144+
assignValues(adaptor.getInitArgs(), resultVariables, rewriter, loc);
129145

130146
emitc::ForOp loweredFor = rewriter.create<emitc::ForOp>(
131-
loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep());
147+
loc, adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep());
132148

133149
Block *loweredBody = loweredFor.getBody();
134150

@@ -143,13 +159,27 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
143159

144160
rewriter.restoreInsertionPoint(ip);
145161

162+
// Convert the original region types into the new types by adding unrealized
163+
// casts in the beginning of the loop. This performs the conversion in place.
164+
if (failed(rewriter.convertRegionTypes(&forOp.getRegion(),
165+
*getTypeConverter(), nullptr))) {
166+
return rewriter.notifyMatchFailure(forOp, "region types conversion failed");
167+
}
168+
169+
// Register the replacements for the block arguments and inline the body of
170+
// the scf.for loop into the body of the emitc::for loop.
171+
Block *scfBody = &(forOp.getRegion().front());
146172
SmallVector<Value> replacingValues;
147173
replacingValues.push_back(loweredFor.getInductionVar());
148174
replacingValues.append(iterArgsValues.begin(), iterArgsValues.end());
175+
rewriter.mergeBlocks(scfBody, loweredBody, replacingValues);
149176

150-
rewriter.mergeBlocks(forOp.getBody(), loweredBody, replacingValues);
151-
lowerYield(resultVariables, rewriter,
152-
cast<scf::YieldOp>(loweredBody->getTerminator()));
177+
auto result = lowerYield(forOp, resultVariables, rewriter,
178+
cast<scf::YieldOp>(loweredBody->getTerminator()));
179+
180+
if (failed(result)) {
181+
return result;
182+
}
153183

154184
// Load variables into SSA values after the for loop.
155185
SmallVector<Value> resultValues = loadValues(resultVariables, rewriter, loc);
@@ -160,38 +190,66 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
160190

161191
// Lower scf::if to emitc::if, implementing result values as emitc::variable's
162192
// updated within the then and else regions.
163-
struct IfLowering : public OpRewritePattern<IfOp> {
164-
using OpRewritePattern<IfOp>::OpRewritePattern;
193+
struct IfLowering : public OpConversionPattern<IfOp> {
194+
using OpConversionPattern<IfOp>::OpConversionPattern;
165195

166-
LogicalResult matchAndRewrite(IfOp ifOp,
167-
PatternRewriter &rewriter) const override;
196+
LogicalResult
197+
matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
198+
ConversionPatternRewriter &rewriter) const override;
168199
};
169200

170201
} // namespace
171202

172-
LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
173-
PatternRewriter &rewriter) const {
203+
LogicalResult
204+
IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
205+
ConversionPatternRewriter &rewriter) const {
174206
Location loc = ifOp.getLoc();
175207

176208
// Create an emitc::variable op for each result. These variables will be
177209
// assigned to by emitc::assign ops within the then & else regions.
178-
SmallVector<Value> resultVariables =
179-
createVariablesForResults(ifOp, rewriter);
180-
181-
Region &thenRegion = ifOp.getThenRegion();
182-
Region &elseRegion = ifOp.getElseRegion();
210+
SmallVector<Value> resultVariables;
211+
if (failed(createVariablesForResults(ifOp, getTypeConverter(), rewriter,
212+
resultVariables)))
213+
return rewriter.notifyMatchFailure(ifOp,
214+
"create variables for results failed");
215+
216+
// Utility function to lower the contents of an scf::if region to an emitc::if
217+
// region. The contents of the scf::if regions is moved into the respective
218+
// emitc::if regions, but the scf::yield is replaced not only with an
219+
// emitc::yield, but also with a sequence of emitc::assign ops that set the
220+
// yielded values into the result variables.
221+
auto lowerRegion = [&resultVariables, &rewriter,
222+
&ifOp](Region &region, Region &loweredRegion) {
223+
rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
224+
Operation *terminator = loweredRegion.back().getTerminator();
225+
auto result = lowerYield(ifOp, resultVariables, rewriter,
226+
cast<scf::YieldOp>(terminator));
227+
if (failed(result)) {
228+
return result;
229+
}
230+
return success();
231+
};
232+
233+
Region &thenRegion = adaptor.getThenRegion();
234+
Region &elseRegion = adaptor.getElseRegion();
183235

184236
bool hasElseBlock = !elseRegion.empty();
185237

186238
auto loweredIf =
187-
rewriter.create<emitc::IfOp>(loc, ifOp.getCondition(), false, false);
239+
rewriter.create<emitc::IfOp>(loc, adaptor.getCondition(), false, false);
188240

189241
Region &loweredThenRegion = loweredIf.getThenRegion();
190-
lowerRegion(resultVariables, rewriter, thenRegion, loweredThenRegion);
242+
auto result = lowerRegion(thenRegion, loweredThenRegion);
243+
if (failed(result)) {
244+
return result;
245+
}
191246

192247
if (hasElseBlock) {
193248
Region &loweredElseRegion = loweredIf.getElseRegion();
194-
lowerRegion(resultVariables, rewriter, elseRegion, loweredElseRegion);
249+
auto result = lowerRegion(elseRegion, loweredElseRegion);
250+
if (failed(result)) {
251+
return result;
252+
}
195253
}
196254

197255
rewriter.setInsertionPointAfter(ifOp);
@@ -203,37 +261,46 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
203261

204262
// Lower scf::index_switch to emitc::switch, implementing result values as
205263
// emitc::variable's updated within the case and default regions.
206-
struct IndexSwitchOpLowering : public OpRewritePattern<IndexSwitchOp> {
207-
using OpRewritePattern<IndexSwitchOp>::OpRewritePattern;
264+
struct IndexSwitchOpLowering : public OpConversionPattern<IndexSwitchOp> {
265+
using OpConversionPattern::OpConversionPattern;
208266

209-
LogicalResult matchAndRewrite(IndexSwitchOp indexSwitchOp,
210-
PatternRewriter &rewriter) const override;
267+
LogicalResult
268+
matchAndRewrite(IndexSwitchOp indexSwitchOp, OpAdaptor adaptor,
269+
ConversionPatternRewriter &rewriter) const override;
211270
};
212271

213-
LogicalResult
214-
IndexSwitchOpLowering::matchAndRewrite(IndexSwitchOp indexSwitchOp,
215-
PatternRewriter &rewriter) const {
272+
LogicalResult IndexSwitchOpLowering::matchAndRewrite(
273+
IndexSwitchOp indexSwitchOp, OpAdaptor adaptor,
274+
ConversionPatternRewriter &rewriter) const {
216275
Location loc = indexSwitchOp.getLoc();
217276

218277
// Create an emitc::variable op for each result. These variables will be
219278
// assigned to by emitc::assign ops within the case and default regions.
220-
SmallVector<Value> resultVariables =
221-
createVariablesForResults(indexSwitchOp, rewriter);
279+
SmallVector<Value> resultVariables;
280+
if (failed(createVariablesForResults(indexSwitchOp, getTypeConverter(),
281+
rewriter, resultVariables))) {
282+
return rewriter.notifyMatchFailure(indexSwitchOp,
283+
"create variables for results failed");
284+
}
222285

223286
auto loweredSwitch = rewriter.create<emitc::SwitchOp>(
224-
loc, indexSwitchOp.getArg(), indexSwitchOp.getCases(),
225-
indexSwitchOp.getNumCases());
287+
loc, adaptor.getArg(), adaptor.getCases(), indexSwitchOp.getNumCases());
226288

227289
// Lowering all case regions.
228-
for (auto pair : llvm::zip(indexSwitchOp.getCaseRegions(),
229-
loweredSwitch.getCaseRegions())) {
230-
lowerRegion(resultVariables, rewriter, std::get<0>(pair),
231-
std::get<1>(pair));
290+
for (auto pair :
291+
llvm::zip(adaptor.getCaseRegions(), loweredSwitch.getCaseRegions())) {
292+
if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter,
293+
*std::get<0>(pair), std::get<1>(pair)))) {
294+
return failure();
295+
}
232296
}
233297

234298
// Lowering default region.
235-
lowerRegion(resultVariables, rewriter, indexSwitchOp.getDefaultRegion(),
236-
loweredSwitch.getDefaultRegion());
299+
if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter,
300+
adaptor.getDefaultRegion(),
301+
loweredSwitch.getDefaultRegion()))) {
302+
return failure();
303+
}
237304

238305
rewriter.setInsertionPointAfter(indexSwitchOp);
239306
SmallVector<Value> results = loadValues(resultVariables, rewriter, loc);
@@ -242,15 +309,22 @@ IndexSwitchOpLowering::matchAndRewrite(IndexSwitchOp indexSwitchOp,
242309
return success();
243310
}
244311

245-
void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns) {
246-
patterns.add<ForLowering>(patterns.getContext());
247-
patterns.add<IfLowering>(patterns.getContext());
248-
patterns.add<IndexSwitchOpLowering>(patterns.getContext());
312+
void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns,
313+
TypeConverter &typeConverter) {
314+
patterns.add<ForLowering>(typeConverter, patterns.getContext());
315+
patterns.add<IfLowering>(typeConverter, patterns.getContext());
316+
patterns.add<IndexSwitchOpLowering>(typeConverter, patterns.getContext());
249317
}
250318

251319
void SCFToEmitCPass::runOnOperation() {
252320
RewritePatternSet patterns(&getContext());
253-
populateSCFToEmitCConversionPatterns(patterns);
321+
TypeConverter typeConverter;
322+
// Fallback converter
323+
// See note https://mlir.llvm.org/docs/DialectConversion/#type-converter
324+
// Type converters are called most to least recently inserted
325+
typeConverter.addConversion([](Type t) { return t; });
326+
populateEmitCSizeTTypeConversions(typeConverter);
327+
populateSCFToEmitCConversionPatterns(patterns, typeConverter);
254328

255329
// Configure conversion to lower out SCF operations.
256330
ConversionTarget target(getContext());

0 commit comments

Comments
 (0)