@@ -3546,137 +3546,6 @@ LogicalResult scf::WhileOp::verify() {
35463546}
35473547
35483548namespace {
3549- // / Move an scf.if op that is directly before the scf.condition op in the while
3550- // / before region, and whose condition matches the condition of the
3551- // / scf.condition op, down into the while after region.
3552- // /
3553- // / scf.while (..) : (...) -> ... {
3554- // / %additional_used_values = ...
3555- // / %cond = ...
3556- // / ...
3557- // / %res = scf.if %cond -> (...) {
3558- // / use(%additional_used_values)
3559- // / ... // then block
3560- // / scf.yield %then_value
3561- // / } else {
3562- // / scf.yield %else_value
3563- // / }
3564- // / scf.condition(%cond) %res, ...
3565- // / } do {
3566- // / ^bb0(%res_arg, ...):
3567- // / use(%res_arg)
3568- // / ...
3569- // /
3570- // / becomes
3571- // / scf.while (..) : (...) -> ... {
3572- // / %additional_used_values = ...
3573- // / %cond = ...
3574- // / ...
3575- // / scf.condition(%cond) %else_value, ..., %additional_used_values
3576- // / } do {
3577- // / ^bb0(%res_arg ..., %additional_args): :
3578- // / use(%additional_args)
3579- // / ... // if then block
3580- // / use(%then_value)
3581- // / ...
3582- struct WhileMoveIfDown : public OpRewritePattern <WhileOp> {
3583- using OpRewritePattern<WhileOp>::OpRewritePattern;
3584-
3585- LogicalResult matchAndRewrite (WhileOp op,
3586- PatternRewriter &rewriter) const override {
3587- auto conditionOp =
3588- cast<scf::ConditionOp>(op.getBeforeBody ()->getTerminator ());
3589- auto ifOp = dyn_cast_or_null<scf::IfOp>(conditionOp->getPrevNode ());
3590-
3591- // Check that the ifOp is directly before the conditionOp and that it
3592- // matches the condition of the conditionOp. Also ensure that the ifOp has
3593- // no else block with content, as that would complicate the transformation.
3594- // TODO: support else blocks with content.
3595- if (!ifOp || ifOp.getCondition () != conditionOp.getCondition () ||
3596- (ifOp.elseBlock () && !ifOp.elseBlock ()->without_terminator ().empty ()))
3597- return failure ();
3598-
3599- assert (ifOp->use_empty () || (llvm::all_equal (ifOp->getUsers ()) &&
3600- *ifOp->user_begin () == conditionOp) &&
3601- " ifOp has unexpected uses" );
3602-
3603- Location loc = op.getLoc ();
3604-
3605- // Replace uses of ifOp results in the conditionOp with the yielded values
3606- // from the ifOp branches.
3607- for (auto [idx, arg] : llvm::enumerate (conditionOp.getArgs ())) {
3608- auto it = llvm::find (ifOp->getResults (), arg);
3609- if (it != ifOp->getResults ().end ()) {
3610- size_t ifOpIdx = it.getIndex ();
3611- Value thenValue = ifOp.thenYield ()->getOperand (ifOpIdx);
3612- Value elseValue = ifOp.elseYield ()->getOperand (ifOpIdx);
3613-
3614- rewriter.replaceAllUsesWith (ifOp->getResults ()[ifOpIdx], elseValue);
3615- rewriter.replaceAllUsesWith (op.getAfterArguments ()[idx], thenValue);
3616- }
3617- }
3618-
3619- SmallVector<Value> additionalUsedValues;
3620- auto isValueUsedInsideIf = [&](Value val) {
3621- return llvm::any_of (val.getUsers (), [&](Operation *user) {
3622- return ifOp.getThenRegion ().isAncestor (user->getParentRegion ());
3623- });
3624- };
3625-
3626- // Collect additional used values from before region.
3627- for (Operation *it = ifOp->getPrevNode (); it != nullptr ;
3628- it = it->getPrevNode ())
3629- llvm::copy_if (it->getResults (), std::back_inserter (additionalUsedValues),
3630- isValueUsedInsideIf);
3631-
3632- llvm::copy_if (op.getBeforeArguments (),
3633- std::back_inserter (additionalUsedValues),
3634- isValueUsedInsideIf);
3635-
3636- // Create new whileOp with additional used values as results.
3637- auto additionalValueTypes = llvm::map_to_vector (
3638- additionalUsedValues, [](Value val) { return val.getType (); });
3639- size_t additionalValueSize = additionalUsedValues.size ();
3640- SmallVector<Type> newResultTypes (op.getResultTypes ());
3641- newResultTypes.append (additionalValueTypes);
3642-
3643- auto newWhileOp =
3644- scf::WhileOp::create (rewriter, loc, newResultTypes, op.getInits ());
3645-
3646- newWhileOp.getBefore ().takeBody (op.getBefore ());
3647- newWhileOp.getAfter ().takeBody (op.getAfter ());
3648- newWhileOp.getAfter ().addArguments (
3649- additionalValueTypes, SmallVector<Location>(additionalValueSize, loc));
3650-
3651- SmallVector<Value> conditionArgs = conditionOp.getArgs ();
3652- llvm::append_range (conditionArgs, additionalUsedValues);
3653-
3654- // Update conditionOp inside new whileOp before region.
3655- rewriter.setInsertionPoint (conditionOp);
3656- rewriter.replaceOpWithNewOp <scf::ConditionOp>(
3657- conditionOp, conditionOp.getCondition (), conditionArgs);
3658-
3659- // Replace uses of additional used values inside the ifOp then region with
3660- // the whileOp after region arguments.
3661- rewriter.replaceUsesWithIf (
3662- additionalUsedValues,
3663- newWhileOp.getAfterArguments ().take_back (additionalValueSize),
3664- [&](OpOperand &use) {
3665- return ifOp.getThenRegion ().isAncestor (
3666- use.getOwner ()->getParentRegion ());
3667- });
3668-
3669- // Inline ifOp then region into new whileOp after region.
3670- rewriter.eraseOp (ifOp.thenYield ());
3671- rewriter.inlineBlockBefore (ifOp.thenBlock (), newWhileOp.getAfterBody (),
3672- newWhileOp.getAfterBody ()->begin ());
3673- rewriter.eraseOp (ifOp);
3674- rewriter.replaceOp (op,
3675- newWhileOp->getResults ().drop_back (additionalValueSize));
3676- return success ();
3677- }
3678- };
3679-
36803549// / Replace uses of the condition within the do block with true, since otherwise
36813550// / the block would not be evaluated.
36823551// /
@@ -4389,8 +4258,7 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
43894258 results.add <RemoveLoopInvariantArgsFromBeforeBlock,
43904259 RemoveLoopInvariantValueYielded, WhileConditionTruth,
43914260 WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4392- WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs, WhileMoveIfDown>(
4393- context);
4261+ WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
43944262}
43954263
43964264// ===----------------------------------------------------------------------===//
0 commit comments