@@ -3546,6 +3546,137 @@ LogicalResult scf::WhileOp::verify() {
35463546}
35473547
35483548namespace {
3549+ // / Move an scf.if op that is directly before the scf.condition op in the while
3550+ // / before region, and whose condition matches the condition of the
3551+ // / scf.condition op, down into the while after region.
3552+ // /
3553+ // / scf.while (..) : (...) -> ... {
3554+ // / %additional_used_values = ...
3555+ // / %cond = ...
3556+ // / ...
3557+ // / %res = scf.if %cond -> (...) {
3558+ // / use(%additional_used_values)
3559+ // / ... // then block
3560+ // / scf.yield %then_value
3561+ // / } else {
3562+ // / scf.yield %else_value
3563+ // / }
3564+ // / scf.condition(%cond) %res, ...
3565+ // / } do {
3566+ // / ^bb0(%res_arg, ...):
3567+ // / use(%res_arg)
3568+ // / ...
3569+ // /
3570+ // / becomes
3571+ // / scf.while (..) : (...) -> ... {
3572+ // / %additional_used_values = ...
3573+ // / %cond = ...
3574+ // / ...
3575+ // / scf.condition(%cond) %else_value, ..., %additional_used_values
3576+ // / } do {
3577+ // / ^bb0(%res_arg ..., %additional_args): :
3578+ // / use(%additional_args)
3579+ // / ... // if then block
3580+ // / use(%then_value)
3581+ // / ...
3582+ struct WhileMoveIfDown : public OpRewritePattern <WhileOp> {
3583+ using OpRewritePattern<WhileOp>::OpRewritePattern;
3584+
3585+ LogicalResult matchAndRewrite (WhileOp op,
3586+ PatternRewriter &rewriter) const override {
3587+ auto conditionOp =
3588+ cast<scf::ConditionOp>(op.getBeforeBody ()->getTerminator ());
3589+ auto ifOp = dyn_cast_or_null<scf::IfOp>(conditionOp->getPrevNode ());
3590+
3591+ // Check that the ifOp is directly before the conditionOp and that it
3592+ // matches the condition of the conditionOp. Also ensure that the ifOp has
3593+ // no else block with content, as that would complicate the transformation.
3594+ // TODO: support else blocks with content.
3595+ if (!ifOp || ifOp.getCondition () != conditionOp.getCondition () ||
3596+ (ifOp.elseBlock () && !ifOp.elseBlock ()->without_terminator ().empty ()))
3597+ return failure ();
3598+
3599+ assert (ifOp->use_empty () || (llvm::all_equal (ifOp->getUsers ()) &&
3600+ *ifOp->user_begin () == conditionOp) &&
3601+ " ifOp has unexpected uses" );
3602+
3603+ Location loc = op.getLoc ();
3604+
3605+ // Replace uses of ifOp results in the conditionOp with the yielded values
3606+ // from the ifOp branches.
3607+ for (auto [idx, arg] : llvm::enumerate (conditionOp.getArgs ())) {
3608+ auto it = llvm::find (ifOp->getResults (), arg);
3609+ if (it != ifOp->getResults ().end ()) {
3610+ size_t ifOpIdx = it.getIndex ();
3611+ Value thenValue = ifOp.thenYield ()->getOperand (ifOpIdx);
3612+ Value elseValue = ifOp.elseYield ()->getOperand (ifOpIdx);
3613+
3614+ rewriter.replaceAllUsesWith (ifOp->getResults ()[ifOpIdx], elseValue);
3615+ rewriter.replaceAllUsesWith (op.getAfterArguments ()[idx], thenValue);
3616+ }
3617+ }
3618+
3619+ SmallVector<Value> additionalUsedValues;
3620+ auto isValueUsedInsideIf = [&](Value val) {
3621+ return llvm::any_of (val.getUsers (), [&](Operation *user) {
3622+ return ifOp.getThenRegion ().isAncestor (user->getParentRegion ());
3623+ });
3624+ };
3625+
3626+ // Collect additional used values from before region.
3627+ for (Operation *it = ifOp->getPrevNode (); it != nullptr ;
3628+ it = it->getPrevNode ())
3629+ llvm::copy_if (it->getResults (), std::back_inserter (additionalUsedValues),
3630+ isValueUsedInsideIf);
3631+
3632+ llvm::copy_if (op.getBeforeArguments (),
3633+ std::back_inserter (additionalUsedValues),
3634+ isValueUsedInsideIf);
3635+
3636+ // Create new whileOp with additional used values as results.
3637+ auto additionalValueTypes = llvm::map_to_vector (
3638+ additionalUsedValues, [](Value val) { return val.getType (); });
3639+ size_t additionalValueSize = additionalUsedValues.size ();
3640+ SmallVector<Type> newResultTypes (op.getResultTypes ());
3641+ newResultTypes.append (additionalValueTypes);
3642+
3643+ auto newWhileOp =
3644+ scf::WhileOp::create (rewriter, loc, newResultTypes, op.getInits ());
3645+
3646+ newWhileOp.getBefore ().takeBody (op.getBefore ());
3647+ newWhileOp.getAfter ().takeBody (op.getAfter ());
3648+ newWhileOp.getAfter ().addArguments (
3649+ additionalValueTypes, SmallVector<Location>(additionalValueSize, loc));
3650+
3651+ SmallVector<Value> conditionArgs = conditionOp.getArgs ();
3652+ llvm::append_range (conditionArgs, additionalUsedValues);
3653+
3654+ // Update conditionOp inside new whileOp before region.
3655+ rewriter.setInsertionPoint (conditionOp);
3656+ rewriter.replaceOpWithNewOp <scf::ConditionOp>(
3657+ conditionOp, conditionOp.getCondition (), conditionArgs);
3658+
3659+ // Replace uses of additional used values inside the ifOp then region with
3660+ // the whileOp after region arguments.
3661+ rewriter.replaceUsesWithIf (
3662+ additionalUsedValues,
3663+ newWhileOp.getAfterArguments ().take_back (additionalValueSize),
3664+ [&](OpOperand &use) {
3665+ return ifOp.getThenRegion ().isAncestor (
3666+ use.getOwner ()->getParentRegion ());
3667+ });
3668+
3669+ // Inline ifOp then region into new whileOp after region.
3670+ rewriter.eraseOp (ifOp.thenYield ());
3671+ rewriter.inlineBlockBefore (ifOp.thenBlock (), newWhileOp.getAfterBody (),
3672+ newWhileOp.getAfterBody ()->begin ());
3673+ rewriter.eraseOp (ifOp);
3674+ rewriter.replaceOp (op,
3675+ newWhileOp->getResults ().drop_back (additionalValueSize));
3676+ return success ();
3677+ }
3678+ };
3679+
35493680// / Replace uses of the condition within the do block with true, since otherwise
35503681// / the block would not be evaluated.
35513682// /
@@ -4258,7 +4389,8 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
42584389 results.add <RemoveLoopInvariantArgsFromBeforeBlock,
42594390 RemoveLoopInvariantValueYielded, WhileConditionTruth,
42604391 WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
4261- WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context);
4392+ WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs, WhileMoveIfDown>(
4393+ context);
42624394}
42634395
42644396// ===----------------------------------------------------------------------===//
0 commit comments