diff --git a/mlir/include/mlir/Conversion/SCFToEmitC/SCFToEmitC.h b/mlir/include/mlir/Conversion/SCFToEmitC/SCFToEmitC.h index 22df7f1c5dcf2..acc39e6acf726 100644 --- a/mlir/include/mlir/Conversion/SCFToEmitC/SCFToEmitC.h +++ b/mlir/include/mlir/Conversion/SCFToEmitC/SCFToEmitC.h @@ -9,6 +9,7 @@ #ifndef MLIR_CONVERSION_SCFTOEMITC_SCFTOEMITC_H #define MLIR_CONVERSION_SCFTOEMITC_SCFTOEMITC_H +#include "mlir/Transforms/DialectConversion.h" #include namespace mlir { @@ -19,7 +20,8 @@ class RewritePatternSet; #include "mlir/Conversion/Passes.h.inc" /// Collect a set of patterns to convert SCF operations to the EmitC dialect. -void populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns); +void populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter); } // namespace mlir #endif // MLIR_CONVERSION_SCFTOEMITC_SCFTOEMITC_H diff --git a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp index 67a43c43d608b..92523ca4f12b2 100644 --- a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp +++ b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" @@ -39,21 +40,22 @@ struct SCFToEmitCPass : public impl::SCFToEmitCBase { // Lower scf::for to emitc::for, implementing result values using // emitc::variable's updated within the loop body. -struct ForLowering : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct ForLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(ForOp forOp, - PatternRewriter &rewriter) const override; + LogicalResult + matchAndRewrite(ForOp forOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; }; // Create an uninitialized emitc::variable op for each result of the given op. template -static SmallVector createVariablesForResults(T op, - PatternRewriter &rewriter) { - SmallVector resultVariables; - +static LogicalResult +createVariablesForResults(T op, const TypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + SmallVector &resultVariables) { if (!op.getNumResults()) - return resultVariables; + return success(); Location loc = op->getLoc(); MLIRContext *context = op.getContext(); @@ -62,7 +64,9 @@ static SmallVector createVariablesForResults(T op, rewriter.setInsertionPoint(op); for (OpResult result : op.getResults()) { - Type resultType = result.getType(); + Type resultType = typeConverter->convertType(result.getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, "result type conversion failed"); Type varType = emitc::LValueType::get(resultType); emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, ""); emitc::VariableOp var = @@ -70,13 +74,13 @@ static SmallVector createVariablesForResults(T op, resultVariables.push_back(var); } - return resultVariables; + return success(); } // Create a series of assign ops assigning given values to given variables at // the current insertion point of given rewriter. -static void assignValues(ValueRange values, SmallVector &variables, - PatternRewriter &rewriter, Location loc) { +static void assignValues(ValueRange values, ValueRange variables, + ConversionPatternRewriter &rewriter, Location loc) { for (auto [value, var] : llvm::zip(values, variables)) rewriter.create(loc, var, value); } @@ -89,18 +93,25 @@ SmallVector loadValues(const SmallVector &variables, }); } -static void lowerYield(SmallVector &resultVariables, - PatternRewriter &rewriter, scf::YieldOp yield) { +static LogicalResult lowerYield(Operation *op, ValueRange resultVariables, + ConversionPatternRewriter &rewriter, + scf::YieldOp yield) { Location loc = yield.getLoc(); - ValueRange operands = yield.getOperands(); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(yield); - assignValues(operands, resultVariables, rewriter, loc); + SmallVector yieldOperands; + if (failed(rewriter.getRemappedValues(yield.getOperands(), yieldOperands))) { + return rewriter.notifyMatchFailure(op, "failed to lower yield operands"); + } + + assignValues(yieldOperands, resultVariables, rewriter, loc); rewriter.create(loc); rewriter.eraseOp(yield); + + return success(); } // Lower the contents of an scf::if/scf::index_switch regions to an @@ -108,27 +119,32 @@ static void lowerYield(SmallVector &resultVariables, // moved into the respective lowered region, but the scf::yield is replaced not // only with an emitc::yield, but also with a sequence of emitc::assign ops that // set the yielded values into the result variables. -static void lowerRegion(SmallVector &resultVariables, - PatternRewriter &rewriter, Region ®ion, - Region &loweredRegion) { +static LogicalResult lowerRegion(Operation *op, ValueRange resultVariables, + ConversionPatternRewriter &rewriter, + Region ®ion, Region &loweredRegion) { rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end()); Operation *terminator = loweredRegion.back().getTerminator(); - lowerYield(resultVariables, rewriter, cast(terminator)); + return lowerYield(op, resultVariables, rewriter, + cast(terminator)); } -LogicalResult ForLowering::matchAndRewrite(ForOp forOp, - PatternRewriter &rewriter) const { +LogicalResult +ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { Location loc = forOp.getLoc(); // Create an emitc::variable op for each result. These variables will be // assigned to by emitc::assign ops within the loop body. - SmallVector resultVariables = - createVariablesForResults(forOp, rewriter); + SmallVector resultVariables; + if (failed(createVariablesForResults(forOp, getTypeConverter(), rewriter, + resultVariables))) + return rewriter.notifyMatchFailure(forOp, + "create variables for results failed"); - assignValues(forOp.getInits(), resultVariables, rewriter, loc); + assignValues(adaptor.getInitArgs(), resultVariables, rewriter, loc); emitc::ForOp loweredFor = rewriter.create( - loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep()); + loc, adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep()); Block *loweredBody = loweredFor.getBody(); @@ -143,13 +159,27 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp, rewriter.restoreInsertionPoint(ip); + // Convert the original region types into the new types by adding unrealized + // casts in the beginning of the loop. This performs the conversion in place. + if (failed(rewriter.convertRegionTypes(&forOp.getRegion(), + *getTypeConverter(), nullptr))) { + return rewriter.notifyMatchFailure(forOp, "region types conversion failed"); + } + + // Register the replacements for the block arguments and inline the body of + // the scf.for loop into the body of the emitc::for loop. + Block *scfBody = &(forOp.getRegion().front()); SmallVector replacingValues; replacingValues.push_back(loweredFor.getInductionVar()); replacingValues.append(iterArgsValues.begin(), iterArgsValues.end()); + rewriter.mergeBlocks(scfBody, loweredBody, replacingValues); - rewriter.mergeBlocks(forOp.getBody(), loweredBody, replacingValues); - lowerYield(resultVariables, rewriter, - cast(loweredBody->getTerminator())); + auto result = lowerYield(forOp, resultVariables, rewriter, + cast(loweredBody->getTerminator())); + + if (failed(result)) { + return result; + } // Load variables into SSA values after the for loop. SmallVector resultValues = loadValues(resultVariables, rewriter, loc); @@ -160,38 +190,66 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp, // Lower scf::if to emitc::if, implementing result values as emitc::variable's // updated within the then and else regions. -struct IfLowering : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct IfLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(IfOp ifOp, - PatternRewriter &rewriter) const override; + LogicalResult + matchAndRewrite(IfOp ifOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; }; } // namespace -LogicalResult IfLowering::matchAndRewrite(IfOp ifOp, - PatternRewriter &rewriter) const { +LogicalResult +IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { Location loc = ifOp.getLoc(); // Create an emitc::variable op for each result. These variables will be // assigned to by emitc::assign ops within the then & else regions. - SmallVector resultVariables = - createVariablesForResults(ifOp, rewriter); - - Region &thenRegion = ifOp.getThenRegion(); - Region &elseRegion = ifOp.getElseRegion(); + SmallVector resultVariables; + if (failed(createVariablesForResults(ifOp, getTypeConverter(), rewriter, + resultVariables))) + return rewriter.notifyMatchFailure(ifOp, + "create variables for results failed"); + + // Utility function to lower the contents of an scf::if region to an emitc::if + // region. The contents of the scf::if regions is moved into the respective + // emitc::if regions, but the scf::yield is replaced not only with an + // emitc::yield, but also with a sequence of emitc::assign ops that set the + // yielded values into the result variables. + auto lowerRegion = [&resultVariables, &rewriter, + &ifOp](Region ®ion, Region &loweredRegion) { + rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end()); + Operation *terminator = loweredRegion.back().getTerminator(); + auto result = lowerYield(ifOp, resultVariables, rewriter, + cast(terminator)); + if (failed(result)) { + return result; + } + return success(); + }; + + Region &thenRegion = adaptor.getThenRegion(); + Region &elseRegion = adaptor.getElseRegion(); bool hasElseBlock = !elseRegion.empty(); auto loweredIf = - rewriter.create(loc, ifOp.getCondition(), false, false); + rewriter.create(loc, adaptor.getCondition(), false, false); Region &loweredThenRegion = loweredIf.getThenRegion(); - lowerRegion(resultVariables, rewriter, thenRegion, loweredThenRegion); + auto result = lowerRegion(thenRegion, loweredThenRegion); + if (failed(result)) { + return result; + } if (hasElseBlock) { Region &loweredElseRegion = loweredIf.getElseRegion(); - lowerRegion(resultVariables, rewriter, elseRegion, loweredElseRegion); + auto result = lowerRegion(elseRegion, loweredElseRegion); + if (failed(result)) { + return result; + } } rewriter.setInsertionPointAfter(ifOp); @@ -203,37 +261,46 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp, // Lower scf::index_switch to emitc::switch, implementing result values as // emitc::variable's updated within the case and default regions. -struct IndexSwitchOpLowering : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct IndexSwitchOpLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(IndexSwitchOp indexSwitchOp, - PatternRewriter &rewriter) const override; + LogicalResult + matchAndRewrite(IndexSwitchOp indexSwitchOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; }; -LogicalResult -IndexSwitchOpLowering::matchAndRewrite(IndexSwitchOp indexSwitchOp, - PatternRewriter &rewriter) const { +LogicalResult IndexSwitchOpLowering::matchAndRewrite( + IndexSwitchOp indexSwitchOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { Location loc = indexSwitchOp.getLoc(); // Create an emitc::variable op for each result. These variables will be // assigned to by emitc::assign ops within the case and default regions. - SmallVector resultVariables = - createVariablesForResults(indexSwitchOp, rewriter); + SmallVector resultVariables; + if (failed(createVariablesForResults(indexSwitchOp, getTypeConverter(), + rewriter, resultVariables))) { + return rewriter.notifyMatchFailure(indexSwitchOp, + "create variables for results failed"); + } auto loweredSwitch = rewriter.create( - loc, indexSwitchOp.getArg(), indexSwitchOp.getCases(), - indexSwitchOp.getNumCases()); + loc, adaptor.getArg(), adaptor.getCases(), indexSwitchOp.getNumCases()); // Lowering all case regions. - for (auto pair : llvm::zip(indexSwitchOp.getCaseRegions(), - loweredSwitch.getCaseRegions())) { - lowerRegion(resultVariables, rewriter, std::get<0>(pair), - std::get<1>(pair)); + for (auto pair : + llvm::zip(adaptor.getCaseRegions(), loweredSwitch.getCaseRegions())) { + if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter, + *std::get<0>(pair), std::get<1>(pair)))) { + return failure(); + } } // Lowering default region. - lowerRegion(resultVariables, rewriter, indexSwitchOp.getDefaultRegion(), - loweredSwitch.getDefaultRegion()); + if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter, + adaptor.getDefaultRegion(), + loweredSwitch.getDefaultRegion()))) { + return failure(); + } rewriter.setInsertionPointAfter(indexSwitchOp); SmallVector results = loadValues(resultVariables, rewriter, loc); @@ -242,15 +309,22 @@ IndexSwitchOpLowering::matchAndRewrite(IndexSwitchOp indexSwitchOp, return success(); } -void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); +void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter) { + patterns.add(typeConverter, patterns.getContext()); + patterns.add(typeConverter, patterns.getContext()); + patterns.add(typeConverter, patterns.getContext()); } void SCFToEmitCPass::runOnOperation() { RewritePatternSet patterns(&getContext()); - populateSCFToEmitCConversionPatterns(patterns); + TypeConverter typeConverter; + // Fallback converter + // See note https://mlir.llvm.org/docs/DialectConversion/#type-converter + // Type converters are called most to least recently inserted + typeConverter.addConversion([](Type t) { return t; }); + populateEmitCSizeTTypeConversions(typeConverter); + populateSCFToEmitCConversionPatterns(patterns, typeConverter); // Configure conversion to lower out SCF operations. ConversionTarget target(getContext()); diff --git a/mlir/test/Conversion/SCFToEmitC/for.mlir b/mlir/test/Conversion/SCFToEmitC/for.mlir index 83592187a9b68..7f41e636936b8 100644 --- a/mlir/test/Conversion/SCFToEmitC/for.mlir +++ b/mlir/test/Conversion/SCFToEmitC/for.mlir @@ -7,8 +7,11 @@ func.func @simple_std_for_loop(%arg0 : index, %arg1 : index, %arg2 : index) { return } // CHECK-LABEL: func.func @simple_std_for_loop( -// CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index) { -// CHECK-NEXT: emitc.for %[[VAL_3:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] { +// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) { +// CHECK-NEXT: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[ARG_2]] : index to !emitc.size_t +// CHECK-NEXT: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_1]] : index to !emitc.size_t +// CHECK-NEXT: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t +// CHECK-NEXT: emitc.for %[[VAL_3:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] : !emitc.size_t { // CHECK-NEXT: %[[VAL_4:.*]] = arith.constant 1 : index // CHECK-NEXT: } // CHECK-NEXT: return @@ -24,10 +27,13 @@ func.func @simple_std_2_for_loops(%arg0 : index, %arg1 : index, %arg2 : index) { return } // CHECK-LABEL: func.func @simple_std_2_for_loops( -// CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index) { -// CHECK-NEXT: emitc.for %[[VAL_3:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] { +// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) { +// CHECK-NEXT: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[ARG_2]] : index to !emitc.size_t +// CHECK-NEXT: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_1]] : index to !emitc.size_t +// CHECK-NEXT: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t +// CHECK-NEXT: emitc.for %[[VAL_3:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] : !emitc.size_t { // CHECK-NEXT: %[[VAL_4:.*]] = arith.constant 1 : index -// CHECK-NEXT: emitc.for %[[VAL_5:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] { +// CHECK-NEXT: emitc.for %[[VAL_5:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] : !emitc.size_t { // CHECK-NEXT: %[[VAL_6:.*]] = arith.constant 1 : index // CHECK-NEXT: } // CHECK-NEXT: } @@ -44,14 +50,17 @@ func.func @for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> (f32, f32) return %result#0, %result#1 : f32, f32 } // CHECK-LABEL: func.func @for_yield( -// CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index) -> (f32, f32) { +// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) -> (f32, f32) { +// CHECK-NEXT: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[ARG_2]] : index to !emitc.size_t +// CHECK-NEXT: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_1]] : index to !emitc.size_t +// CHECK-NEXT: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t // CHECK-NEXT: %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32 // CHECK-NEXT: %[[VAL_4:.*]] = arith.constant 1.000000e+00 : f32 // CHECK-NEXT: %[[VAL_5:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue // CHECK-NEXT: %[[VAL_6:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue // CHECK-NEXT: emitc.assign %[[VAL_3]] : f32 to %[[VAL_5]] : // CHECK-NEXT: emitc.assign %[[VAL_4]] : f32 to %[[VAL_6]] : -// CHECK-NEXT: emitc.for %[[VAL_7:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] { +// CHECK-NEXT: emitc.for %[[VAL_7:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] : !emitc.size_t { // CHECK-NEXT: %[[VAL_8:.*]] = emitc.load %[[VAL_5]] : // CHECK-NEXT: %[[VAL_9:.*]] = emitc.load %[[VAL_6]] : // CHECK-NEXT: %[[VAL_10:.*]] = arith.addf %[[VAL_8]], %[[VAL_9]] : f32 @@ -75,15 +84,18 @@ func.func @nested_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> f32 return %r : f32 } // CHECK-LABEL: func.func @nested_for_yield( -// CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index) -> f32 { +// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) -> f32 { +// CHECK-NEXT: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[ARG_2]] : index to !emitc.size_t +// CHECK-NEXT: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_1]] : index to !emitc.size_t +// CHECK-NEXT: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t // CHECK-NEXT: %[[VAL_3:.*]] = arith.constant 1.000000e+00 : f32 // CHECK-NEXT: %[[VAL_4:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue // CHECK-NEXT: emitc.assign %[[VAL_3]] : f32 to %[[VAL_4]] : -// CHECK-NEXT: emitc.for %[[VAL_5:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] { +// CHECK-NEXT: emitc.for %[[VAL_5:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] : !emitc.size_t { // CHECK-NEXT: %[[VAL_6:.*]] = emitc.load %[[VAL_4]] : // CHECK-NEXT: %[[VAL_7:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue // CHECK-NEXT: emitc.assign %[[VAL_6]] : f32 to %[[VAL_7]] : -// CHECK-NEXT: emitc.for %[[VAL_8:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] { +// CHECK-NEXT: emitc.for %[[VAL_8:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] : !emitc.size_t { // CHECK-NEXT: %[[VAL_9:.*]] = emitc.load %[[VAL_7]] : // CHECK-NEXT: %[[VAL_10:.*]] = arith.addf %[[VAL_9]], %[[VAL_9]] : f32 // CHECK-NEXT: emitc.assign %[[VAL_10]] : f32 to %[[VAL_7]] : @@ -94,3 +106,60 @@ func.func @nested_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> f32 // CHECK-NEXT: %[[VAL_12:.*]] = emitc.load %[[VAL_4]] : // CHECK-NEXT: return %[[VAL_12]] : f32 // CHECK-NEXT: } + +func.func @for_yield_index(%arg0 : index, %arg1 : index, %arg2 : index) -> index { + %zero = arith.constant 0 : index + %r = scf.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%acc = %zero) -> index { + scf.yield %acc : index + } + return %r : index +} + +// CHECK-LABEL: func.func @for_yield_index( +// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) -> index { +// CHECK: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_2]] : index to !emitc.size_t +// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_1]] : index to !emitc.size_t +// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_3:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to !emitc.size_t +// CHECK: %[[VAL_4:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue +// CHECK: emitc.assign %[[VAL_3]] : !emitc.size_t to %[[VAL_4]] : +// CHECK: emitc.for %[[VAL_5:.*]] = %[[VAL_2]] to %[[VAL_1]] step %[[VAL_0]] : !emitc.size_t { +// CHECK: %[[V:.*]] = emitc.load %[[VAL_4]] : +// CHECK: emitc.assign %[[V]] : !emitc.size_t to %[[VAL_4]] : +// CHECK: } +// CHECK: %[[V2:.*]] = emitc.load %[[VAL_4]] : +// CHECK: %[[VAL_8:.*]] = builtin.unrealized_conversion_cast %[[V2]] : !emitc.size_t to index +// CHECK: return %[[VAL_8]] : index +// CHECK: } + + +func.func @for_yield_update_loop_carried_var(%arg0 : index, %arg1 : index, %arg2 : index) -> index { + %zero = arith.constant 0 : index + %r = scf.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%acc = %zero) -> index { + %sn = arith.addi %acc, %acc : index + scf.yield %sn: index + } + return %r : index + } + +// CHECK-LABEL: func.func @for_yield_update_loop_carried_var( +// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) -> index { +// CHECK: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_2]] : index to !emitc.size_t +// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_1]] : index to !emitc.size_t +// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_3:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to !emitc.size_t +// CHECK: %[[VAL_4:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue +// CHECK: emitc.assign %[[VAL_3]] : !emitc.size_t to %[[VAL_4]] : +// CHECK: emitc.for %[[ARG_3:.*]] = %[[VAL_2]] to %[[VAL_1]] step %[[VAL_0]] : !emitc.size_t { +// CHECK: %[[V:.*]] = emitc.load %[[VAL_4]] : +// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[V]] : !emitc.size_t to index +// CHECK: %[[VAL_6:.*]] = arith.addi %[[VAL_5]], %[[VAL_5]] : index +// CHECK: %[[VAL_8:.*]] = builtin.unrealized_conversion_cast %[[VAL_6]] : index to !emitc.size_t +// CHECK: emitc.assign %[[VAL_8]] : !emitc.size_t to %[[VAL_4]] : +// CHECK: } +// CHECK: %[[V2:.*]] = emitc.load %[[VAL_4]] : +// CHECK: %[[VAL_9:.*]] = builtin.unrealized_conversion_cast %[[V2]] : !emitc.size_t to index +// CHECK: return %[[VAL_9]] : index +// CHECK: } diff --git a/mlir/test/Conversion/SCFToEmitC/switch.mlir b/mlir/test/Conversion/SCFToEmitC/switch.mlir index 86d96ed21f1b5..61015b0ae483b 100644 --- a/mlir/test/Conversion/SCFToEmitC/switch.mlir +++ b/mlir/test/Conversion/SCFToEmitC/switch.mlir @@ -1,7 +1,8 @@ // RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-emitc %s | FileCheck %s // CHECK-LABEL: func.func @switch_no_result( -// CHECK-SAME: %[[VAL_0:.*]]: index) { +// CHECK-SAME: %[[ARG_0:.*]]: index) { +// CHECK: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t // CHECK: emitc.switch %[[VAL_0]] // CHECK: case 2 { // CHECK: %[[VAL_1:.*]] = arith.constant 10 : i32 @@ -33,7 +34,8 @@ func.func @switch_no_result(%arg0 : index) { } // CHECK-LABEL: func.func @switch_one_result( -// CHECK-SAME: %[[VAL_0:.*]]: index) { +// CHECK-SAME: %[[ARG_0:.*]]: index) { +// CHECK: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t // CHECK: %[[VAL_1:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue // CHECK: emitc.switch %[[VAL_0]] // CHECK: case 2 { @@ -70,7 +72,8 @@ func.func @switch_one_result(%arg0 : index) { } // CHECK-LABEL: func.func @switch_two_results( -// CHECK-SAME: %[[VAL_0:.*]]: index) -> (i32, f32) { +// CHECK-SAME: %[[ARG_0:.*]]: index) -> (i32, f32) { +// CHECK: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t // CHECK: %[[VAL_1:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue // CHECK: %[[VAL_2:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue // CHECK: emitc.switch %[[VAL_0]]