99#include " mlir/Transforms/RegionUtils.h"
1010#include " mlir/Analysis/TopologicalSortUtils.h"
1111#include " mlir/IR/Block.h"
12- #include " mlir/IR/BuiltinOps.h"
1312#include " mlir/IR/IRMapping.h"
1413#include " mlir/IR/Operation.h"
1514#include " mlir/IR/PatternMatch.h"
1615#include " mlir/IR/RegionGraphTraits.h"
1716#include " mlir/IR/Value.h"
1817#include " mlir/Interfaces/ControlFlowInterfaces.h"
1918#include " mlir/Interfaces/SideEffectInterfaces.h"
20- #include " mlir/Support/LogicalResult.h"
2119
2220#include " llvm/ADT/DepthFirstIterator.h"
2321#include " llvm/ADT/PostOrderIterator.h"
24- #include " llvm/ADT/STLExtras.h"
25- #include " llvm/ADT/SmallSet.h"
2622
2723#include < deque>
28- #include < iterator>
2924
3025using namespace mlir ;
3126
@@ -704,8 +699,9 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
704699 blockIterators.push_back (mergeBlock->begin ());
705700
706701 // Update each of the predecessor terminators with the new arguments.
707- SmallVector<SmallVector<Value, 8 >, 2 > newArguments (1 + blocksToMerge.size (),
708- SmallVector<Value, 8 >());
702+ SmallVector<SmallVector<Value, 8 >, 2 > newArguments (
703+ 1 + blocksToMerge.size (),
704+ SmallVector<Value, 8 >(operandsToMerge.size ()));
709705 unsigned curOpIndex = 0 ;
710706 for (const auto &it : llvm::enumerate (operandsToMerge)) {
711707 unsigned nextOpOffset = it.value ().first - curOpIndex;
@@ -716,22 +712,13 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
716712 Block::iterator &blockIter = blockIterators[i];
717713 std::advance (blockIter, nextOpOffset);
718714 auto &operand = blockIter->getOpOperand (it.value ().second );
719- Value operandVal = operand.get ();
720- Value *it = std::find (newArguments[i].begin (), newArguments[i].end (),
721- operandVal);
722- if (it == newArguments[i].end ()) {
723- newArguments[i].push_back (operandVal);
724- // Update the operand and insert an argument if this is the leader.
725- if (i == 0 ) {
726- operand.set (leaderBlock->addArgument (operandVal.getType (),
727- operandVal.getLoc ()));
728- }
729- } else if (i == 0 ) {
730- // If this is the leader, update the operand but do not insert a new
731- // argument. Instead, the opearand should point to one of the
732- // arguments we already passed (and that contained `operandVal`)
733- operand.set (leaderBlock->getArgument (
734- std::distance (newArguments[i].begin (), it)));
715+ newArguments[i][it.index ()] = operand.get ();
716+
717+ // Update the operand and insert an argument if this is the leader.
718+ if (i == 0 ) {
719+ Value operandVal = operand.get ();
720+ operand.set (leaderBlock->addArgument (operandVal.getType (),
721+ operandVal.getLoc ()));
735722 }
736723 }
737724 }
@@ -831,109 +818,6 @@ static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
831818 return success (anyChanged);
832819}
833820
834- static LogicalResult dropRedundantArguments (RewriterBase &rewriter,
835- Block &block) {
836- SmallVector<size_t > argsToErase;
837-
838- // Go through the arguments of the block
839- for (size_t argIdx = 0 ; argIdx < block.getNumArguments (); argIdx++) {
840- bool sameArg = true ;
841- Value commonValue;
842-
843- // Go through the block predecessor and flag if they pass to the block
844- // different values for the same argument
845- for (auto predIt = block.pred_begin (), predE = block.pred_end ();
846- predIt != predE; ++predIt) {
847- auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator ());
848- if (!branch) {
849- sameArg = false ;
850- break ;
851- }
852- unsigned succIndex = predIt.getSuccessorIndex ();
853- SuccessorOperands succOperands = branch.getSuccessorOperands (succIndex);
854- auto operands = succOperands.getForwardedOperands ();
855- if (!commonValue) {
856- commonValue = operands[argIdx];
857- } else {
858- if (operands[argIdx] != commonValue) {
859- sameArg = false ;
860- break ;
861- }
862- }
863- }
864-
865- // If they are passing the same value, drop the argument
866- if (commonValue && sameArg) {
867- argsToErase.push_back (argIdx);
868-
869- // Remove the argument from the block
870- Value argVal = block.getArgument (argIdx);
871- rewriter.replaceAllUsesWith (argVal, commonValue);
872- }
873- }
874-
875- // Remove the arguments
876- for (auto argIdx : llvm::reverse (argsToErase)) {
877- block.eraseArgument (argIdx);
878-
879- // Remove the argument from the branch ops
880- for (auto predIt = block.pred_begin (), predE = block.pred_end ();
881- predIt != predE; ++predIt) {
882- auto branch = cast<BranchOpInterface>((*predIt)->getTerminator ());
883- unsigned succIndex = predIt.getSuccessorIndex ();
884- SuccessorOperands succOperands = branch.getSuccessorOperands (succIndex);
885- succOperands.erase (argIdx);
886- }
887- }
888- return success (!argsToErase.empty ());
889- }
890-
891- // / This optimization drops redundant argument to blocks. I.e., if a given
892- // / argument to a block receives the same value from each of the block
893- // / predecessors, we can remove the argument from the block and use directly the
894- // / original value. This is a simple example:
895- // /
896- // / %cond = llvm.call @rand() : () -> i1
897- // / %val0 = llvm.mlir.constant(1 : i64) : i64
898- // / %val1 = llvm.mlir.constant(2 : i64) : i64
899- // / %val2 = llvm.mlir.constant(3 : i64) : i64
900- // / llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
901- // / : i64)
902- // /
903- // / ^bb1(%arg0 : i64, %arg1 : i64):
904- // / llvm.call @foo(%arg0, %arg1)
905- // /
906- // / The previous IR can be rewritten as:
907- // / %cond = llvm.call @rand() : () -> i1
908- // / %val0 = llvm.mlir.constant(1 : i64) : i64
909- // / %val1 = llvm.mlir.constant(2 : i64) : i64
910- // / %val2 = llvm.mlir.constant(3 : i64) : i64
911- // / llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
912- // /
913- // / ^bb1(%arg0 : i64):
914- // / llvm.call @foo(%val0, %arg0)
915- // /
916- static LogicalResult dropRedundantArguments (RewriterBase &rewriter,
917- MutableArrayRef<Region> regions) {
918- llvm::SmallSetVector<Region *, 1 > worklist;
919- for (auto ®ion : regions)
920- worklist.insert (®ion);
921- bool anyChanged = false ;
922- while (!worklist.empty ()) {
923- Region *region = worklist.pop_back_val ();
924-
925- // Add any nested regions to the worklist.
926- for (Block &block : *region) {
927- anyChanged = succeeded (dropRedundantArguments (rewriter, block));
928-
929- for (auto &op : block)
930- for (auto &nestedRegion : op.getRegions ())
931- worklist.insert (&nestedRegion);
932- }
933- }
934- return success (anyChanged);
935- }
936-
937821// ===----------------------------------------------------------------------===//
938822// Region Simplification
939823// ===----------------------------------------------------------------------===//
@@ -948,12 +832,8 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
948832 bool eliminatedBlocks = succeeded (eraseUnreachableBlocks (rewriter, regions));
949833 bool eliminatedOpsOrArgs = succeeded (runRegionDCE (rewriter, regions));
950834 bool mergedIdenticalBlocks = false ;
951- bool droppedRedundantArguments = false ;
952- if (mergeBlocks) {
835+ if (mergeBlocks)
953836 mergedIdenticalBlocks = succeeded (mergeIdenticalBlocks (rewriter, regions));
954- droppedRedundantArguments =
955- succeeded (dropRedundantArguments (rewriter, regions));
956- }
957837 return success (eliminatedBlocks || eliminatedOpsOrArgs ||
958- mergedIdenticalBlocks || droppedRedundantArguments );
838+ mergedIdenticalBlocks);
959839}
0 commit comments