Skip to content

Commit 08be950

Browse files
[mlir][IR] Add helper functions to compute insertion point
1 parent a230a52 commit 08be950

File tree

3 files changed

+91
-15
lines changed

3 files changed

+91
-15
lines changed

mlir/include/mlir/IR/Builders.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,18 @@ class OpBuilder : public Builder {
334334
/// This class represents a saved insertion point.
335335
class InsertPoint {
336336
public:
337+
/// Finds the closest insertion point where all given values are defined
338+
/// and can be used without a dominance violation. (Does not account for
339+
/// IsolatedFromAbove regions.) Returns an "empty" insertion point if there
340+
/// is no such insertion point.
341+
static InsertPoint findClosest(ArrayRef<Value> values);
342+
343+
/// Creates a new insertion point right after the given value. If the given
344+
/// value is a block argument, the insertion point points to the beginning
345+
/// of the block. Otherwise, it points to the place right after the op that
346+
/// defines the value.
347+
static InsertPoint after(Value value);
348+
337349
/// Creates a new insertion point which doesn't point to anything.
338350
InsertPoint() = default;
339351

mlir/lib/IR/Builders.cpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,3 +641,76 @@ void OpBuilder::cloneRegionBefore(Region &region, Region &parent,
641641
void OpBuilder::cloneRegionBefore(Region &region, Block *before) {
642642
cloneRegionBefore(region, *before->getParent(), before->getIterator());
643643
}
644+
645+
//===----------------------------------------------------------------------===//
646+
// InsertPoint
647+
//===----------------------------------------------------------------------===//
648+
649+
OpBuilder::InsertPoint OpBuilder::InsertPoint::after(Value value) {
650+
if (auto blockArg = dyn_cast<BlockArgument>(value))
651+
return InsertPoint(blockArg.getOwner(), blockArg.getOwner()->begin());
652+
Operation *op = value.getDefiningOp();
653+
return InsertPoint(op->getBlock(), ++op->getIterator());
654+
}
655+
656+
/// Helper function that returns "true" if:
657+
/// - `a` is a proper ancestor of `b`
658+
/// - or: there is a path from `a` to `b`
659+
static bool isAncestorOrBefore(Block *a, Block *b) {
660+
if (a->getParentOp()->isProperAncestor(b->getParentOp()))
661+
return true;
662+
if (a->getParent() != b->getParent())
663+
return false;
664+
return a->isReachable(b);
665+
}
666+
667+
OpBuilder::InsertPoint
668+
OpBuilder::InsertPoint::findClosest(ArrayRef<Value> values) {
669+
// Compute the insertion point after the first value.
670+
assert(!values.empty() && "expected at least one value");
671+
InsertPoint result = InsertPoint::after(values.front());
672+
673+
// Check all other values and update the insertion point as needed.
674+
for (Value v : values.drop_front()) {
675+
InsertPoint pt = InsertPoint::after(v);
676+
677+
if (pt.getBlock() == result.getBlock()) {
678+
// Both values belong to the same block. Modify the iterator (but keep
679+
// the block) if needed: take the later one of the two insertion points.
680+
Block *block = pt.getBlock();
681+
if (pt.getPoint() == block->end()) {
682+
// `pt` points to the end of the block: take `pt`.
683+
result = pt;
684+
continue;
685+
} else if (result.getPoint() == block->end()) {
686+
// `result` points to the end of the block: nothing to do.
687+
continue;
688+
}
689+
// Neither `pt` nor `result` point to the end of the block, so both
690+
// iterators point to an operation. Set `result` to the later one of the
691+
// two insertion point.
692+
if (result.getPoint()->isBeforeInBlock(&*pt.getPoint()))
693+
result = pt;
694+
continue;
695+
}
696+
697+
if (isAncestorOrBefore(result.getBlock(), pt.getBlock())) {
698+
// `result` is an ancestor of `pt`. Therefore, `pt` is a valid insertion
699+
// point for `v` and all previous values.
700+
result = pt;
701+
continue;
702+
}
703+
704+
if (isAncestorOrBefore(pt.getBlock(), result.getBlock())) {
705+
// `pt` is an ancestor of `result`. Therefore, `result` is a valid
706+
// insertion point for `v` and all previous values.
707+
continue;
708+
}
709+
710+
// `pt` and `result` are in different subtrees: neither is an ancestor of
711+
// the other. In that case, there is no suitable insertion point.
712+
return InsertPoint();
713+
}
714+
715+
return result;
716+
}

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,6 @@ static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
5353
});
5454
}
5555

56-
/// Helper function that computes an insertion point where the given value is
57-
/// defined and can be used without a dominance violation.
58-
static OpBuilder::InsertPoint computeInsertPoint(Value value) {
59-
Block *insertBlock = value.getParentBlock();
60-
Block::iterator insertPt = insertBlock->begin();
61-
if (OpResult inputRes = dyn_cast<OpResult>(value))
62-
insertPt = ++inputRes.getOwner()->getIterator();
63-
return OpBuilder::InsertPoint(insertBlock, insertPt);
64-
}
65-
6656
//===----------------------------------------------------------------------===//
6757
// ConversionValueMapping
6858
//===----------------------------------------------------------------------===//
@@ -1147,8 +1137,9 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
11471137
// that the value was replaced with a value of different type and no
11481138
// source materialization was created yet.
11491139
Value castValue = buildUnresolvedMaterialization(
1150-
MaterializationKind::Target, computeInsertPoint(newOperand),
1151-
operandLoc, /*inputs=*/newOperand, /*outputType=*/desiredType,
1140+
MaterializationKind::Target,
1141+
OpBuilder::InsertPoint::after(newOperand), operandLoc,
1142+
/*inputs=*/newOperand, /*outputType=*/desiredType,
11521143
/*originalType=*/origType, currentTypeConverter);
11531144
mapping.map(newOperand, castValue);
11541145
newOperand = castValue;
@@ -1309,7 +1300,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13091300
}
13101301
if (legalOutputType && legalOutputType != origArgType) {
13111302
Value targetMat = buildUnresolvedMaterialization(
1312-
MaterializationKind::Target, computeInsertPoint(argMat),
1303+
MaterializationKind::Target, OpBuilder::InsertPoint::after(argMat),
13131304
origArg.getLoc(), /*inputs=*/argMat, /*outputType=*/legalOutputType,
13141305
/*originalType=*/origArgType, converter);
13151306
mapping.map(argMat, targetMat);
@@ -1401,7 +1392,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
14011392

14021393
// Materialize a replacement value "out of thin air".
14031394
newValue = buildUnresolvedMaterialization(
1404-
MaterializationKind::Source, computeInsertPoint(result),
1395+
MaterializationKind::Source, OpBuilder::InsertPoint::after(result),
14051396
result.getLoc(), /*inputs=*/ValueRange(),
14061397
/*outputType=*/result.getType(), /*originalType=*/Type(),
14071398
currentTypeConverter);
@@ -2596,7 +2587,7 @@ void OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
25962587
Value newValue = rewriterImpl.mapping.lookupOrNull(originalValue);
25972588
assert(newValue && "replacement value not found");
25982589
Value castValue = rewriterImpl.buildUnresolvedMaterialization(
2599-
MaterializationKind::Source, computeInsertPoint(newValue),
2590+
MaterializationKind::Source, OpBuilder::InsertPoint::after(newValue),
26002591
originalValue.getLoc(),
26012592
/*inputs=*/newValue, /*outputType=*/originalValue.getType(),
26022593
/*originalType=*/Type(), converter);

0 commit comments

Comments
 (0)