Skip to content

Commit 19a042e

Browse files
committed
Move the pattern into populateUpliftWhileToForPattern.
1 parent ecffe33 commit 19a042e

File tree

4 files changed

+165
-171
lines changed

4 files changed

+165
-171
lines changed

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 1 addition & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -3546,137 +3546,6 @@ LogicalResult scf::WhileOp::verify() {
35463546
}
35473547

35483548
namespace {
3549-
/// Move an scf.if op that is directly before the scf.condition op in the while
3550-
/// before region, and whose condition matches the condition of the
3551-
/// scf.condition op, down into the while after region.
3552-
///
3553-
/// scf.while (..) : (...) -> ... {
3554-
/// %additional_used_values = ...
3555-
/// %cond = ...
3556-
/// ...
3557-
/// %res = scf.if %cond -> (...) {
3558-
/// use(%additional_used_values)
3559-
/// ... // then block
3560-
/// scf.yield %then_value
3561-
/// } else {
3562-
/// scf.yield %else_value
3563-
/// }
3564-
/// scf.condition(%cond) %res, ...
3565-
/// } do {
3566-
/// ^bb0(%res_arg, ...):
3567-
/// use(%res_arg)
3568-
/// ...
3569-
///
3570-
/// becomes
3571-
/// scf.while (..) : (...) -> ... {
3572-
/// %additional_used_values = ...
3573-
/// %cond = ...
3574-
/// ...
3575-
/// scf.condition(%cond) %else_value, ..., %additional_used_values
3576-
/// } do {
3577-
/// ^bb0(%res_arg ..., %additional_args): :
3578-
/// use(%additional_args)
3579-
/// ... // if then block
3580-
/// use(%then_value)
3581-
/// ...
3582-
struct WhileMoveIfDown : public OpRewritePattern<WhileOp> {
3583-
using OpRewritePattern<WhileOp>::OpRewritePattern;
3584-
3585-
LogicalResult matchAndRewrite(WhileOp op,
3586-
PatternRewriter &rewriter) const override {
3587-
auto conditionOp =
3588-
cast<scf::ConditionOp>(op.getBeforeBody()->getTerminator());
3589-
auto ifOp = dyn_cast_or_null<scf::IfOp>(conditionOp->getPrevNode());
3590-
3591-
// Check that the ifOp is directly before the conditionOp and that it
3592-
// matches the condition of the conditionOp. Also ensure that the ifOp has
3593-
// no else block with content, as that would complicate the transformation.
3594-
// TODO: support else blocks with content.
3595-
if (!ifOp || ifOp.getCondition() != conditionOp.getCondition() ||
3596-
(ifOp.elseBlock() && !ifOp.elseBlock()->without_terminator().empty()))
3597-
return failure();
3598-
3599-
assert(ifOp->use_empty() || (llvm::all_equal(ifOp->getUsers()) &&
3600-
*ifOp->user_begin() == conditionOp) &&
3601-
"ifOp has unexpected uses");
3602-
3603-
Location loc = op.getLoc();
3604-
3605-
// Replace uses of ifOp results in the conditionOp with the yielded values
3606-
// from the ifOp branches.
3607-
for (auto [idx, arg] : llvm::enumerate(conditionOp.getArgs())) {
3608-
auto it = llvm::find(ifOp->getResults(), arg);
3609-
if (it != ifOp->getResults().end()) {
3610-
size_t ifOpIdx = it.getIndex();
3611-
Value thenValue = ifOp.thenYield()->getOperand(ifOpIdx);
3612-
Value elseValue = ifOp.elseYield()->getOperand(ifOpIdx);
3613-
3614-
rewriter.replaceAllUsesWith(ifOp->getResults()[ifOpIdx], elseValue);
3615-
rewriter.replaceAllUsesWith(op.getAfterArguments()[idx], thenValue);
3616-
}
3617-
}
3618-
3619-
SmallVector<Value> additionalUsedValues;
3620-
auto isValueUsedInsideIf = [&](Value val) {
3621-
return llvm::any_of(val.getUsers(), [&](Operation *user) {
3622-
return ifOp.getThenRegion().isAncestor(user->getParentRegion());
3623-
});
3624-
};
3625-
3626-
// Collect additional used values from before region.
3627-
for (Operation *it = ifOp->getPrevNode(); it != nullptr;
3628-
it = it->getPrevNode())
3629-
llvm::copy_if(it->getResults(), std::back_inserter(additionalUsedValues),
3630-
isValueUsedInsideIf);
3631-
3632-
llvm::copy_if(op.getBeforeArguments(),
3633-
std::back_inserter(additionalUsedValues),
3634-
isValueUsedInsideIf);
3635-
3636-
// Create new whileOp with additional used values as results.
3637-
auto additionalValueTypes = llvm::map_to_vector(
3638-
additionalUsedValues, [](Value val) { return val.getType(); });
3639-
size_t additionalValueSize = additionalUsedValues.size();
3640-
SmallVector<Type> newResultTypes(op.getResultTypes());
3641-
newResultTypes.append(additionalValueTypes);
3642-
3643-
auto newWhileOp =
3644-
scf::WhileOp::create(rewriter, loc, newResultTypes, op.getInits());
3645-
3646-
newWhileOp.getBefore().takeBody(op.getBefore());
3647-
newWhileOp.getAfter().takeBody(op.getAfter());
3648-
newWhileOp.getAfter().addArguments(
3649-
additionalValueTypes, SmallVector<Location>(additionalValueSize, loc));
3650-
3651-
SmallVector<Value> conditionArgs = conditionOp.getArgs();
3652-
llvm::append_range(conditionArgs, additionalUsedValues);
3653-
3654-
// Update conditionOp inside new whileOp before region.
3655-
rewriter.setInsertionPoint(conditionOp);
3656-
rewriter.replaceOpWithNewOp<scf::ConditionOp>(
3657-
conditionOp, conditionOp.getCondition(), conditionArgs);
3658-
3659-
// Replace uses of additional used values inside the ifOp then region with
3660-
// the whileOp after region arguments.
3661-
rewriter.replaceUsesWithIf(
3662-
additionalUsedValues,
3663-
newWhileOp.getAfterArguments().take_back(additionalValueSize),
3664-
[&](OpOperand &use) {
3665-
return ifOp.getThenRegion().isAncestor(
3666-
use.getOwner()->getParentRegion());
3667-
});
3668-
3669-
// Inline ifOp then region into new whileOp after region.
3670-
rewriter.eraseOp(ifOp.thenYield());
3671-
rewriter.inlineBlockBefore(ifOp.thenBlock(), newWhileOp.getAfterBody(),
3672-
newWhileOp.getAfterBody()->begin());
3673-
rewriter.eraseOp(ifOp);
3674-
rewriter.replaceOp(op,
3675-
newWhileOp->getResults().drop_back(additionalValueSize));
3676-
return success();
3677-
}
3678-
};
3679-
36803549
/// Replace uses of the condition within the do block with true, since otherwise
36813550
/// the block would not be evaluated.
36823551
///
@@ -4389,8 +4258,7 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
43894258
results.add<RemoveLoopInvariantArgsFromBeforeBlock,
43904259
RemoveLoopInvariantValueYielded, WhileConditionTruth,
43914260
WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4392-
WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs, WhileMoveIfDown>(
4393-
context);
4261+
WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
43944262
}
43954263

43964264
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp

Lines changed: 133 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,137 @@
1919
using namespace mlir;
2020

2121
namespace {
22+
/// Move a scf.if op that is directly before the scf.condition op in the while
23+
/// before region, and whose condition matches the condition of the
24+
/// scf.condition op, down into the while after region.
25+
///
26+
/// scf.while (..) : (...) -> ... {
27+
/// %additional_used_values = ...
28+
/// %cond = ...
29+
/// ...
30+
/// %res = scf.if %cond -> (...) {
31+
/// use(%additional_used_values)
32+
/// ... // then block
33+
/// scf.yield %then_value
34+
/// } else {
35+
/// scf.yield %else_value
36+
/// }
37+
/// scf.condition(%cond) %res, ...
38+
/// } do {
39+
/// ^bb0(%res_arg, ...):
40+
/// use(%res_arg)
41+
/// ...
42+
///
43+
/// becomes
44+
/// scf.while (..) : (...) -> ... {
45+
/// %additional_used_values = ...
46+
/// %cond = ...
47+
/// ...
48+
/// scf.condition(%cond) %else_value, ..., %additional_used_values
49+
/// } do {
50+
/// ^bb0(%res_arg ..., %additional_args): :
51+
/// use(%additional_args)
52+
/// ... // if then block
53+
/// use(%then_value)
54+
/// ...
55+
struct WhileMoveIfDown : public OpRewritePattern<scf::WhileOp> {
56+
using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
57+
58+
LogicalResult matchAndRewrite(scf::WhileOp op,
59+
PatternRewriter &rewriter) const override {
60+
auto conditionOp =
61+
cast<scf::ConditionOp>(op.getBeforeBody()->getTerminator());
62+
auto ifOp = dyn_cast_or_null<scf::IfOp>(conditionOp->getPrevNode());
63+
64+
// Check that the ifOp is directly before the conditionOp and that it
65+
// matches the condition of the conditionOp. Also ensure that the ifOp has
66+
// no else block with content, as that would complicate the transformation.
67+
// TODO: support else blocks with content.
68+
if (!ifOp || ifOp.getCondition() != conditionOp.getCondition() ||
69+
(ifOp.elseBlock() && !ifOp.elseBlock()->without_terminator().empty()))
70+
return failure();
71+
72+
assert(ifOp->use_empty() || (llvm::all_equal(ifOp->getUsers()) &&
73+
*ifOp->user_begin() == conditionOp) &&
74+
"ifOp has unexpected uses");
75+
76+
Location loc = op.getLoc();
77+
78+
// Replace uses of ifOp results in the conditionOp with the yielded values
79+
// from the ifOp branches.
80+
for (auto [idx, arg] : llvm::enumerate(conditionOp.getArgs())) {
81+
auto it = llvm::find(ifOp->getResults(), arg);
82+
if (it != ifOp->getResults().end()) {
83+
size_t ifOpIdx = it.getIndex();
84+
Value thenValue = ifOp.thenYield()->getOperand(ifOpIdx);
85+
Value elseValue = ifOp.elseYield()->getOperand(ifOpIdx);
86+
87+
rewriter.replaceAllUsesWith(ifOp->getResults()[ifOpIdx], elseValue);
88+
rewriter.replaceAllUsesWith(op.getAfterArguments()[idx], thenValue);
89+
}
90+
}
91+
92+
SmallVector<Value> additionalUsedValues;
93+
auto isValueUsedInsideIf = [&](Value val) {
94+
return llvm::any_of(val.getUsers(), [&](Operation *user) {
95+
return ifOp.getThenRegion().isAncestor(user->getParentRegion());
96+
});
97+
};
98+
99+
// Collect additional used values from before region.
100+
for (Operation *it = ifOp->getPrevNode(); it != nullptr;
101+
it = it->getPrevNode())
102+
llvm::copy_if(it->getResults(), std::back_inserter(additionalUsedValues),
103+
isValueUsedInsideIf);
104+
105+
llvm::copy_if(op.getBeforeArguments(),
106+
std::back_inserter(additionalUsedValues),
107+
isValueUsedInsideIf);
108+
109+
// Create new whileOp with additional used values as results.
110+
auto additionalValueTypes = llvm::map_to_vector(
111+
additionalUsedValues, [](Value val) { return val.getType(); });
112+
size_t additionalValueSize = additionalUsedValues.size();
113+
SmallVector<Type> newResultTypes(op.getResultTypes());
114+
newResultTypes.append(additionalValueTypes);
115+
116+
auto newWhileOp =
117+
scf::WhileOp::create(rewriter, loc, newResultTypes, op.getInits());
118+
119+
newWhileOp.getBefore().takeBody(op.getBefore());
120+
newWhileOp.getAfter().takeBody(op.getAfter());
121+
newWhileOp.getAfter().addArguments(
122+
additionalValueTypes, SmallVector<Location>(additionalValueSize, loc));
123+
124+
SmallVector<Value> conditionArgs = conditionOp.getArgs();
125+
llvm::append_range(conditionArgs, additionalUsedValues);
126+
127+
// Update conditionOp inside new whileOp before region.
128+
rewriter.setInsertionPoint(conditionOp);
129+
rewriter.replaceOpWithNewOp<scf::ConditionOp>(
130+
conditionOp, conditionOp.getCondition(), conditionArgs);
131+
132+
// Replace uses of additional used values inside the ifOp then region with
133+
// the whileOp after region arguments.
134+
rewriter.replaceUsesWithIf(
135+
additionalUsedValues,
136+
newWhileOp.getAfterArguments().take_back(additionalValueSize),
137+
[&](OpOperand &use) {
138+
return ifOp.getThenRegion().isAncestor(
139+
use.getOwner()->getParentRegion());
140+
});
141+
142+
// Inline ifOp then region into new whileOp after region.
143+
rewriter.eraseOp(ifOp.thenYield());
144+
rewriter.inlineBlockBefore(ifOp.thenBlock(), newWhileOp.getAfterBody(),
145+
newWhileOp.getAfterBody()->begin());
146+
rewriter.eraseOp(ifOp);
147+
rewriter.replaceOp(op,
148+
newWhileOp->getResults().drop_back(additionalValueSize));
149+
return success();
150+
}
151+
};
152+
22153
struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
23154
using OpRewritePattern::OpRewritePattern;
24155

@@ -267,5 +398,6 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
267398
}
268399

269400
void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns) {
270-
patterns.add<UpliftWhileOp>(patterns.getContext());
401+
patterns.add<WhileMoveIfDown, UpliftWhileOp>(patterns.getContext());
402+
scf::WhileOp::getCanonicalizationPatterns(patterns, patterns.getContext());
271403
}

mlir/test/Dialect/SCF/canonicalize.mlir

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -974,43 +974,6 @@ func.func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) {
974974

975975
// -----
976976

977-
// CHECK-LABEL: @while_move_if_down
978-
func.func @while_move_if_down() -> i32 {
979-
%0 = scf.while () : () -> (i32) {
980-
%additional_used_value = "test.get_some_value1" () : () -> (i32)
981-
%else_value = "test.get_some_value2" () : () -> (i32)
982-
%condition = "test.condition"() : () -> i1
983-
%res = scf.if %condition -> (i32) {
984-
"test.use1" (%additional_used_value) : (i32) -> ()
985-
%then_value = "test.get_some_value3" () : () -> (i32)
986-
scf.yield %then_value : i32
987-
} else {
988-
scf.yield %else_value : i32
989-
}
990-
scf.condition(%condition) %res : i32
991-
} do {
992-
^bb0(%res_arg: i32):
993-
"test.use2" (%res_arg) : (i32) -> ()
994-
scf.yield
995-
}
996-
return %0 : i32
997-
}
998-
// CHECK-NEXT: %[[WHILE_0:.*]]:2 = scf.while : () -> (i32, i32) {
999-
// CHECK-NEXT: %[[VAL_0:.*]] = "test.get_some_value1"() : () -> i32
1000-
// CHECK-NEXT: %[[VAL_1:.*]] = "test.get_some_value2"() : () -> i32
1001-
// CHECK-NEXT: %[[VAL_2:.*]] = "test.condition"() : () -> i1
1002-
// CHECK-NEXT: scf.condition(%[[VAL_2]]) %[[VAL_1]], %[[VAL_0]] : i32, i32
1003-
// CHECK-NEXT: } do {
1004-
// CHECK-NEXT: ^bb0(%[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32):
1005-
// CHECK-NEXT: "test.use1"(%[[VAL_4]]) : (i32) -> ()
1006-
// CHECK-NEXT: %[[VAL_5:.*]] = "test.get_some_value3"() : () -> i32
1007-
// CHECK-NEXT: "test.use2"(%[[VAL_5]]) : (i32) -> ()
1008-
// CHECK-NEXT: scf.yield
1009-
// CHECK-NEXT: }
1010-
// CHECK-NEXT: return %[[VAL_6:.*]]#0 : i32
1011-
1012-
// -----
1013-
1014977
// CHECK-LABEL: @while_cond_true
1015978
func.func @while_cond_true() -> i1 {
1016979
%0 = scf.while () : () -> i1 {

mlir/test/Dialect/SCF/uplift-while.mlir

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,34 @@ func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> (i32, f32)
185185
// CHECK: %[[T2:.*]] = "test.test2"(%[[ARG2]]) : (f32) -> f32
186186
// CHECK: scf.yield %[[T1]], %[[T2]] : i32, f32
187187
// CHECK: return %[[RES]]#0, %[[RES]]#1 : i32, f32
188+
189+
// -----
190+
191+
func.func @uplift_while(%low: index, %upper: index, %val : i32) -> i32 {
192+
%c1 = arith.constant 1 : index
193+
%1:2 = scf.while (%iv = %low, %iter = %val) : (index, i32) -> (index, i32) {
194+
%2 = arith.cmpi slt, %iv, %upper : index
195+
%3:2 = scf.if %2 -> (index, i32) {
196+
%4 = "test.test"(%iter) : (i32) -> i32
197+
%5 = arith.addi %iv, %c1 : index
198+
scf.yield %5, %4 : index, i32
199+
} else {
200+
scf.yield %iv, %iter : index, i32
201+
}
202+
scf.condition(%2) %3#0, %3#1 : index, i32
203+
} do {
204+
^bb0(%arg0: index, %arg1: i32):
205+
scf.yield %arg0, %arg1 : index, i32
206+
}
207+
return %1#1 : i32
208+
}
209+
210+
// CHECK-LABEL: func.func @uplift_while(
211+
// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: i32) -> i32 {
212+
// CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : index
213+
// CHECK: %[[FOR_0:.*]] = scf.for %[[VAL_0:.*]] = %[[ARG0]] to %[[ARG1]] step %[[CONSTANT_0]] iter_args(%[[VAL_1:.*]] = %[[ARG2]]) -> (i32) {
214+
// CHECK: %[[VAL_2:.*]] = "test.test"(%[[VAL_1]]) : (i32) -> i32
215+
// CHECK: scf.yield %[[VAL_2]] : i32
216+
// CHECK: }
217+
// CHECK: return %[[FOR_0]] : i32
218+
// CHECK: }

0 commit comments

Comments
 (0)