-
Notifications
You must be signed in to change notification settings - Fork 13.7k
[mlir] Fix block merging #102038
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
[mlir] Fix block merging #102038
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
e4aee4f
Fix block merging
giuseros 100cf69
Add test case
giuseros 6facb30
Correcting remaining tests
giuseros 5307460
Fix integration tests
giuseros 8b5efca
Address review feedbacks
giuseros 0614d89
Address review feedbacks - 2
giuseros 5b2a018
Use a O(N) algorithm with a conservative tradeoff
giuseros 7a52f8f
Offset the new arguments by the number of old arguments
giuseros e576c6b
Address review feedbacks - 3
giuseros File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,18 +9,23 @@ | |
#include "mlir/Transforms/RegionUtils.h" | ||
#include "mlir/Analysis/TopologicalSortUtils.h" | ||
#include "mlir/IR/Block.h" | ||
#include "mlir/IR/BuiltinOps.h" | ||
#include "mlir/IR/IRMapping.h" | ||
#include "mlir/IR/Operation.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/IR/RegionGraphTraits.h" | ||
#include "mlir/IR/Value.h" | ||
#include "mlir/Interfaces/ControlFlowInterfaces.h" | ||
#include "mlir/Interfaces/SideEffectInterfaces.h" | ||
#include "mlir/Support/LogicalResult.h" | ||
|
||
#include "llvm/ADT/DepthFirstIterator.h" | ||
#include "llvm/ADT/PostOrderIterator.h" | ||
#include "llvm/ADT/STLExtras.h" | ||
#include "llvm/ADT/SmallSet.h" | ||
|
||
#include <deque> | ||
#include <iterator> | ||
|
||
using namespace mlir; | ||
|
||
|
@@ -674,6 +679,94 @@ static bool ableToUpdatePredOperands(Block *block) { | |
return true; | ||
} | ||
|
||
/// Prunes the redundant list of new arguments. E.g., if we are passing an | ||
/// argument list like [x, y, z, x] this would return [x, y, z] and it would | ||
/// update the `block` (to whom the argument are passed to) accordingly. The new | ||
/// arguments are passed as arguments at the back of the block, hence we need to | ||
/// know how many `numOldArguments` were before, in order to correctly replace | ||
/// the new arguments in the block | ||
static SmallVector<SmallVector<Value, 8>, 2> pruneRedundantArguments( | ||
const SmallVector<SmallVector<Value, 8>, 2> &newArguments, | ||
RewriterBase &rewriter, unsigned numOldArguments, Block *block) { | ||
|
||
SmallVector<SmallVector<Value, 8>, 2> newArgumentsPruned( | ||
newArguments.size(), SmallVector<Value, 8>()); | ||
|
||
if (newArguments.empty()) | ||
return newArguments; | ||
|
||
// `newArguments` is a 2D array of size `numLists` x `numArgs` | ||
unsigned numLists = newArguments.size(); | ||
unsigned numArgs = newArguments[0].size(); | ||
|
||
// Map that for each arg index contains the index that we can use in place of | ||
// the original index. E.g., if we have newArgs = [x, y, z, x], we will have | ||
// idxToReplacement[3] = 0 | ||
llvm::DenseMap<unsigned, unsigned> idxToReplacement; | ||
|
||
// This is a useful data structure to track the first appearance of a Value | ||
// on a given list of arguments | ||
DenseMap<Value, unsigned> firstValueToIdx; | ||
for (unsigned j = 0; j < numArgs; ++j) { | ||
Value newArg = newArguments[0][j]; | ||
if (!firstValueToIdx.contains(newArg)) | ||
firstValueToIdx[newArg] = j; | ||
} | ||
|
||
// Go through the first list of arguments (list 0). | ||
for (unsigned j = 0; j < numArgs; ++j) { | ||
// Look back to see if there are possible redundancies in list 0. Please | ||
// note that we are using a map to annotate when an argument was seen first | ||
// to avoid a O(N^2) algorithm. This has the drawback that if we have two | ||
// lists like: | ||
// list0: [%a, %a, %a] | ||
// list1: [%c, %b, %b] | ||
// We cannot simplify it, because firstValueToIdx[%a] = 0, but we cannot | ||
// point list1[1](==%b) or list1[2](==%b) to list1[0](==%c). However, since | ||
// the number of arguments can be potentially unbounded we cannot afford a | ||
// O(N^2) algorithm (to search to all the possible pairs) and we need to | ||
// accept the trade-off. | ||
unsigned k = firstValueToIdx[newArguments[0][j]]; | ||
if (k == j) | ||
continue; | ||
|
||
bool shouldReplaceJ = true; | ||
unsigned replacement = k; | ||
// If a possible redundancy is found, then scan the other lists: we | ||
// can prune the arguments if and only if they are redundant in every | ||
// list. | ||
for (unsigned i = 1; i < numLists; ++i) | ||
shouldReplaceJ = | ||
shouldReplaceJ && (newArguments[i][k] == newArguments[i][j]); | ||
// Save the replacement. | ||
if (shouldReplaceJ) | ||
idxToReplacement[j] = replacement; | ||
} | ||
|
||
// Populate the pruned argument list. | ||
for (unsigned i = 0; i < numLists; ++i) | ||
for (unsigned j = 0; j < numArgs; ++j) | ||
if (!idxToReplacement.contains(j)) | ||
newArgumentsPruned[i].push_back(newArguments[i][j]); | ||
|
||
// Replace the block's redundant arguments. | ||
SmallVector<unsigned> toErase; | ||
for (auto [idx, arg] : llvm::enumerate(block->getArguments())) { | ||
if (idxToReplacement.contains(idx)) { | ||
Value oldArg = block->getArgument(numOldArguments + idx); | ||
Value newArg = | ||
block->getArgument(numOldArguments + idxToReplacement[idx]); | ||
rewriter.replaceAllUsesWith(oldArg, newArg); | ||
toErase.push_back(numOldArguments + idx); | ||
} | ||
} | ||
|
||
// Erase the block's redundant arguments. | ||
for (unsigned idxToErase : llvm::reverse(toErase)) | ||
block->eraseArgument(idxToErase); | ||
return newArgumentsPruned; | ||
} | ||
|
||
LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) { | ||
// Don't consider clusters that don't have blocks to merge. | ||
if (blocksToMerge.empty()) | ||
|
@@ -703,6 +796,7 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) { | |
1 + blocksToMerge.size(), | ||
SmallVector<Value, 8>(operandsToMerge.size())); | ||
unsigned curOpIndex = 0; | ||
unsigned numOldArguments = leaderBlock->getNumArguments(); | ||
for (const auto &it : llvm::enumerate(operandsToMerge)) { | ||
unsigned nextOpOffset = it.value().first - curOpIndex; | ||
curOpIndex = it.value().first; | ||
|
@@ -722,6 +816,11 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) { | |
} | ||
} | ||
} | ||
|
||
// Prune redundant arguments and update the leader block argument list | ||
newArguments = pruneRedundantArguments(newArguments, rewriter, | ||
numOldArguments, leaderBlock); | ||
|
||
// Update the predecessors for each of the blocks. | ||
auto updatePredecessors = [&](Block *block, unsigned clusterIndex) { | ||
for (auto predIt = block->pred_begin(), predE = block->pred_end(); | ||
|
@@ -818,6 +917,111 @@ static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter, | |
return success(anyChanged); | ||
} | ||
|
||
/// If a block's argument is always the same across different invocations, then | ||
/// drop the argument and use the value directly inside the block | ||
static LogicalResult dropRedundantArguments(RewriterBase &rewriter, | ||
Block &block) { | ||
SmallVector<size_t> argsToErase; | ||
|
||
// Go through the arguments of the block. | ||
for (auto [argIdx, blockOperand] : llvm::enumerate(block.getArguments())) { | ||
bool sameArg = true; | ||
Value commonValue; | ||
|
||
// Go through the block predecessor and flag if they pass to the block | ||
// different values for the same argument. | ||
for (Block::pred_iterator predIt = block.pred_begin(), | ||
predE = block.pred_end(); | ||
predIt != predE; ++predIt) { | ||
auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator()); | ||
if (!branch) { | ||
sameArg = false; | ||
break; | ||
} | ||
unsigned succIndex = predIt.getSuccessorIndex(); | ||
SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex); | ||
auto branchOperands = succOperands.getForwardedOperands(); | ||
if (!commonValue) { | ||
commonValue = branchOperands[argIdx]; | ||
continue; | ||
} | ||
if (branchOperands[argIdx] != commonValue) { | ||
sameArg = false; | ||
break; | ||
} | ||
} | ||
|
||
// If they are passing the same value, drop the argument. | ||
if (commonValue && sameArg) { | ||
argsToErase.push_back(argIdx); | ||
|
||
// Remove the argument from the block. | ||
rewriter.replaceAllUsesWith(blockOperand, commonValue); | ||
} | ||
} | ||
|
||
// Remove the arguments. | ||
for (size_t argIdx : llvm::reverse(argsToErase)) { | ||
block.eraseArgument(argIdx); | ||
|
||
// Remove the argument from the branch ops. | ||
for (auto predIt = block.pred_begin(), predE = block.pred_end(); | ||
predIt != predE; ++predIt) { | ||
giuseros marked this conversation as resolved.
Show resolved
Hide resolved
|
||
auto branch = cast<BranchOpInterface>((*predIt)->getTerminator()); | ||
unsigned succIndex = predIt.getSuccessorIndex(); | ||
SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex); | ||
succOperands.erase(argIdx); | ||
} | ||
} | ||
return success(!argsToErase.empty()); | ||
} | ||
|
||
/// This optimization drops redundant argument to blocks. I.e., if a given | ||
/// argument to a block receives the same value from each of the block | ||
/// predecessors, we can remove the argument from the block and use directly the | ||
/// original value. This is a simple example: | ||
/// | ||
/// %cond = llvm.call @rand() : () -> i1 | ||
/// %val0 = llvm.mlir.constant(1 : i64) : i64 | ||
/// %val1 = llvm.mlir.constant(2 : i64) : i64 | ||
/// %val2 = llvm.mlir.constant(3 : i64) : i64 | ||
/// llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2 | ||
/// : i64) | ||
/// | ||
/// ^bb1(%arg0 : i64, %arg1 : i64): | ||
/// llvm.call @foo(%arg0, %arg1) | ||
/// | ||
/// The previous IR can be rewritten as: | ||
/// %cond = llvm.call @rand() : () -> i1 | ||
/// %val0 = llvm.mlir.constant(1 : i64) : i64 | ||
/// %val1 = llvm.mlir.constant(2 : i64) : i64 | ||
/// %val2 = llvm.mlir.constant(3 : i64) : i64 | ||
/// llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64) | ||
/// | ||
/// ^bb1(%arg0 : i64): | ||
/// llvm.call @foo(%val0, %arg0) | ||
/// | ||
static LogicalResult dropRedundantArguments(RewriterBase &rewriter, | ||
MutableArrayRef<Region> regions) { | ||
llvm::SmallSetVector<Region *, 1> worklist; | ||
for (Region ®ion : regions) | ||
worklist.insert(®ion); | ||
bool anyChanged = false; | ||
while (!worklist.empty()) { | ||
Region *region = worklist.pop_back_val(); | ||
|
||
// Add any nested regions to the worklist. | ||
for (Block &block : *region) { | ||
anyChanged = succeeded(dropRedundantArguments(rewriter, block)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
It should be |
||
|
||
for (Operation &op : block) | ||
for (Region &nestedRegion : op.getRegions()) | ||
worklist.insert(&nestedRegion); | ||
} | ||
} | ||
return success(anyChanged); | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
// Region Simplification | ||
//===----------------------------------------------------------------------===// | ||
|
@@ -832,8 +1036,12 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter, | |
bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions)); | ||
bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions)); | ||
bool mergedIdenticalBlocks = false; | ||
if (mergeBlocks) | ||
bool droppedRedundantArguments = false; | ||
if (mergeBlocks) { | ||
mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions)); | ||
droppedRedundantArguments = | ||
succeeded(dropRedundantArguments(rewriter, regions)); | ||
} | ||
return success(eliminatedBlocks || eliminatedOpsOrArgs || | ||
mergedIdenticalBlocks); | ||
mergedIdenticalBlocks || droppedRedundantArguments); | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.