Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -463,10 +463,15 @@ struct BufferDeallocationSimplificationPass
SplitDeallocWhenNotAliasingAnyOther,
RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
analysis);
// We don't want that the block structure changes invalidating the
// `BufferOriginAnalysis` so we apply the rewrites witha `Normal` level of
// region simplification
GreedyRewriteConfig config;
config.enableRegionSimplification = GreedySimplifyRegionLevel::Normal;
populateDeallocOpCanonicalizationPatterns(patterns, &getContext());

if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
config)))
signalPassFailure();
}
};
Expand Down
139 changes: 127 additions & 12 deletions mlir/lib/Transforms/Utils/RegionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,23 @@
#include "mlir/Transforms/RegionUtils.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/RegionGraphTraits.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Support/LogicalResult.h"

#include "llvm/ADT/DepthFirstIterator.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"

#include <deque>
#include <iterator>

using namespace mlir;

Expand Down Expand Up @@ -699,9 +704,8 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
blockIterators.push_back(mergeBlock->begin());

// Update each of the predecessor terminators with the new arguments.
SmallVector<SmallVector<Value, 8>, 2> newArguments(
1 + blocksToMerge.size(),
SmallVector<Value, 8>(operandsToMerge.size()));
SmallVector<SmallVector<Value, 8>, 2> newArguments(1 + blocksToMerge.size(),
SmallVector<Value, 8>());
unsigned curOpIndex = 0;
for (const auto &it : llvm::enumerate(operandsToMerge)) {
unsigned nextOpOffset = it.value().first - curOpIndex;
Expand All @@ -712,13 +716,22 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
Block::iterator &blockIter = blockIterators[i];
std::advance(blockIter, nextOpOffset);
auto &operand = blockIter->getOpOperand(it.value().second);
newArguments[i][it.index()] = operand.get();

// Update the operand and insert an argument if this is the leader.
if (i == 0) {
Value operandVal = operand.get();
operand.set(leaderBlock->addArgument(operandVal.getType(),
operandVal.getLoc()));
Value operandVal = operand.get();
Value *it = std::find(newArguments[i].begin(), newArguments[i].end(),
operandVal);
if (it == newArguments[i].end()) {
newArguments[i].push_back(operandVal);
// Update the operand and insert an argument if this is the leader.
if (i == 0) {
operand.set(leaderBlock->addArgument(operandVal.getType(),
operandVal.getLoc()));
}
} else if (i == 0) {
// If this is the leader, update the operand but do not insert a new
// argument. Instead, the opearand should point to one of the
// arguments we already passed (and that contained `operandVal`)
operand.set(leaderBlock->getArgument(
std::distance(newArguments[i].begin(), it)));
}
}
}
Expand Down Expand Up @@ -818,6 +831,104 @@ static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
return success(anyChanged);
}

static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
Block &block) {
SmallVector<size_t> argsToErase;

// Go through the arguments of the block
for (size_t argIdx = 0; argIdx < block.getNumArguments(); argIdx++) {
bool sameArg = true;
Value commonValue;

// Go through the block predecessor and flag if they pass to the block
// different values for the same argument
for (auto predIt = block.pred_begin(), predE = block.pred_end();
predIt != predE; ++predIt) {
auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator());
if (!branch) {
sameArg = false;
break;
}
unsigned succIndex = predIt.getSuccessorIndex();
SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
auto operands = succOperands.getForwardedOperands();
if (!commonValue) {
commonValue = operands[argIdx];
} else {
if (operands[argIdx] != commonValue) {
sameArg = false;
break;
}
}
}

// If they are passing the same value, drop the argument
if (commonValue && sameArg) {
argsToErase.push_back(argIdx);

// Remove the argument from the block
Value argVal = block.getArgument(argIdx);
rewriter.replaceAllUsesWith(argVal, commonValue);
}
}

// Remove the arguments
for (auto argIdx : llvm::reverse(argsToErase)) {
block.eraseArgument(argIdx);

// Remove the argument from the branch ops
for (auto predIt = block.pred_begin(), predE = block.pred_end();
predIt != predE; ++predIt) {
auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
unsigned succIndex = predIt.getSuccessorIndex();
SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
succOperands.erase(argIdx);
}
}
return success(!argsToErase.empty());
}

/// This optimization drops redundant argument to blocks. I.e., if a given
/// argument to a block receives the same value from each of the block
/// predecessors, we can remove the argument from the block and use directly the
/// original value. This is a simple example:
///
/// %cond = llvm.call @rand() : () -> i1
/// %val0 = llvm.mlir.constant(1 : i64) : i64
/// %val1 = llvm.mlir.constant(2 : i64) : i64
/// %val2 = llvm.mlir.constant(2 : i64) : i64
/// llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
/// : i64) ^bb1(%arg0 : i64, %arg1 : i64):
/// llvm.call @foo(%arg0, %arg1)
///
/// The previous IR can be rewritten as:
/// %cond = llvm.call @rand() : () -> i1
/// %val = llvm.mlir.constant(1 : i64) : i64
/// llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
/// ^bb1(%arg0 : i64):
/// llvm.call @foo(%val0, %arg1)
///
static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
MutableArrayRef<Region> regions) {
llvm::SmallSetVector<Region *, 1> worklist;
for (auto &region : regions)
worklist.insert(&region);
bool anyChanged = false;
while (!worklist.empty()) {
Region *region = worklist.pop_back_val();

// Add any nested regions to the worklist.
for (Block &block : *region) {
anyChanged = succeeded(dropRedundantArguments(rewriter, block));

for (auto &op : block)
for (auto &nestedRegion : op.getRegions())
worklist.insert(&nestedRegion);
}
}
return success(anyChanged);
}

//===----------------------------------------------------------------------===//
// Region Simplification
//===----------------------------------------------------------------------===//
Expand All @@ -832,8 +943,12 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions));
bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions));
bool mergedIdenticalBlocks = false;
if (mergeBlocks)
bool droppedRedundantArguments = false;
if (mergeBlocks) {
mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions));
droppedRedundantArguments =
succeeded(dropRedundantArguments(rewriter, regions));
}
return success(eliminatedBlocks || eliminatedOpsOrArgs ||
mergedIdenticalBlocks);
mergedIdenticalBlocks || droppedRedundantArguments);
}
6 changes: 3 additions & 3 deletions mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func.func @main(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK-LABEL: @main
// CHECK-SAME: (%[[ARG0:.+]]: tensor<f32>) -> tensor<f32>
// CHECK: %[[EXTRACTED:.+]] = tensor.extract %[[ARG0]][] : tensor<f32>
// CHECK: cf.br ^{{.*}}(%[[EXTRACTED]] : f32)
// CHECK: ^{{.*}}(%[[ARG1:.+]]: f32):
// CHECK: %[[ELEMENTS:.+]] = tensor.from_elements %[[ARG1]] : tensor<f32>
// CHECK: cf.br ^{{.*}}
// CHECK: ^{{.*}}:
// CHECK: %[[ELEMENTS:.+]] = tensor.from_elements %[[EXTRACTED]] : tensor<f32>
// CHECK: return %[[ELEMENTS]] : tensor<f32>