Skip to content

Commit 2fc71e4

Browse files
[mlir][Transforms][NFC] Dialect Conversion: Move argument materialization logic (#98805)
This commit moves the argument materialization logic from `legalizeConvertedArgumentTypes` to `legalizeUnresolvedMaterializations`. Before this change: - Argument materializations were created in `legalizeConvertedArgumentTypes` (which used to call `materializeLiveConversions`). After this change: - `legalizeConvertedArgumentTypes` creates a "placeholder" `unrealized_conversion_cast`. - The placeholder `unrealized_conversion_cast` is replaced with an argument materialization (using the type converter) in `legalizeUnresolvedMaterializations`. - All argument and target materializations now take place in the same location (`legalizeUnresolvedMaterializations`). This commit brings us closer towards creating all source/target/argument materializations in one central step, which can then be made optional (and delegated to the user) in the future. (There is one more source materialization step that has not been moved yet.) This commit also consolidates all `build*UnresolvedMaterialization` functions into a single `buildUnresolvedMaterialization` function. This is a re-upload of #96329.
1 parent 3eaca31 commit 2fc71e4

File tree

1 file changed

+54
-84
lines changed

1 file changed

+54
-84
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 54 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,16 @@ static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
5353
});
5454
}
5555

56+
/// Helper function that computes an insertion point where the given value is
57+
/// defined and can be used without a dominance violation.
58+
static OpBuilder::InsertPoint computeInsertPoint(Value value) {
59+
Block *insertBlock = value.getParentBlock();
60+
Block::iterator insertPt = insertBlock->begin();
61+
if (OpResult inputRes = dyn_cast<OpResult>(value))
62+
insertPt = ++inputRes.getOwner()->getIterator();
63+
return OpBuilder::InsertPoint(insertBlock, insertPt);
64+
}
65+
5666
//===----------------------------------------------------------------------===//
5767
// ConversionValueMapping
5868
//===----------------------------------------------------------------------===//
@@ -444,11 +454,9 @@ class BlockTypeConversionRewrite : public BlockRewrite {
444454
return rewrite->getKind() == Kind::BlockTypeConversion;
445455
}
446456

447-
/// Materialize any necessary conversions for converted arguments that have
448-
/// live users, using the provided `findLiveUser` to search for a user that
449-
/// survives the conversion process.
450-
LogicalResult
451-
materializeLiveConversions(function_ref<Operation *(Value)> findLiveUser);
457+
Block *getOrigBlock() const { return origBlock; }
458+
459+
const TypeConverter *getConverter() const { return converter; }
452460

453461
void commit(RewriterBase &rewriter) override;
454462

@@ -829,15 +837,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
829837
/// Build an unresolved materialization operation given an output type and set
830838
/// of input operands.
831839
Value buildUnresolvedMaterialization(MaterializationKind kind,
832-
Block *insertBlock,
833-
Block::iterator insertPt, Location loc,
840+
OpBuilder::InsertPoint ip, Location loc,
834841
ValueRange inputs, Type outputType,
835842
const TypeConverter *converter);
836843

837-
Value buildUnresolvedTargetMaterialization(Location loc, Value input,
838-
Type outputType,
839-
const TypeConverter *converter);
840-
841844
//===--------------------------------------------------------------------===//
842845
// Rewriter Notification Hooks
843846
//===--------------------------------------------------------------------===//
@@ -969,49 +972,6 @@ void BlockTypeConversionRewrite::rollback() {
969972
block->replaceAllUsesWith(origBlock);
970973
}
971974

972-
LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
973-
function_ref<Operation *(Value)> findLiveUser) {
974-
// Process the remapping for each of the original arguments.
975-
for (auto it : llvm::enumerate(origBlock->getArguments())) {
976-
BlockArgument origArg = it.value();
977-
// Note: `block` may be detached, so OpBuilder::atBlockBegin cannot be used.
978-
OpBuilder builder(it.value().getContext(), /*listener=*/&rewriterImpl);
979-
builder.setInsertionPointToStart(block);
980-
981-
// If the type of this argument changed and the argument is still live, we
982-
// need to materialize a conversion.
983-
if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
984-
continue;
985-
Operation *liveUser = findLiveUser(origArg);
986-
if (!liveUser)
987-
continue;
988-
989-
Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg);
990-
assert(replacementValue && "replacement value not found");
991-
Value newArg;
992-
if (converter) {
993-
builder.setInsertionPointAfterValue(replacementValue);
994-
newArg = converter->materializeSourceConversion(
995-
builder, origArg.getLoc(), origArg.getType(), replacementValue);
996-
assert((!newArg || newArg.getType() == origArg.getType()) &&
997-
"materialization hook did not provide a value of the expected "
998-
"type");
999-
}
1000-
if (!newArg) {
1001-
InFlightDiagnostic diag =
1002-
emitError(origArg.getLoc())
1003-
<< "failed to materialize conversion for block argument #"
1004-
<< it.index() << " that remained live after conversion, type was "
1005-
<< origArg.getType();
1006-
diag.attachNote(liveUser->getLoc())
1007-
<< "see existing live user here: " << *liveUser;
1008-
return failure();
1009-
}
1010-
rewriterImpl.mapping.map(origArg, newArg);
1011-
}
1012-
return success();
1013-
}
1014-
1015975
void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
1016976
Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType());
1017977
if (!repl)
@@ -1184,8 +1144,10 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
11841144
Type newOperandType = newOperand.getType();
11851145
if (currentTypeConverter && desiredType && newOperandType != desiredType) {
11861146
Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1187-
Value castValue = buildUnresolvedTargetMaterialization(
1188-
operandLoc, newOperand, desiredType, currentTypeConverter);
1147+
Value castValue = buildUnresolvedMaterialization(
1148+
MaterializationKind::Target, computeInsertPoint(newOperand),
1149+
operandLoc, /*inputs=*/newOperand, /*outputType=*/desiredType,
1150+
currentTypeConverter);
11891151
mapping.map(mapping.lookupOrDefault(newOperand), castValue);
11901152
newOperand = castValue;
11911153
}
@@ -1298,8 +1260,9 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
12981260
// This block argument was dropped and no replacement value was provided.
12991261
// Materialize a replacement value "out of thin air".
13001262
Value repl = buildUnresolvedMaterialization(
1301-
MaterializationKind::Source, newBlock, newBlock->begin(),
1302-
origArg.getLoc(), /*inputs=*/ValueRange(),
1263+
MaterializationKind::Source,
1264+
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
1265+
/*inputs=*/ValueRange(),
13031266
/*outputType=*/origArgType, converter);
13041267
mapping.map(origArg, repl);
13051268
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
@@ -1323,8 +1286,9 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13231286
auto replArgs =
13241287
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
13251288
Value argMat = buildUnresolvedMaterialization(
1326-
MaterializationKind::Argument, newBlock, newBlock->begin(),
1327-
origArg.getLoc(), /*inputs=*/replArgs, origArgType, converter);
1289+
MaterializationKind::Argument,
1290+
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
1291+
/*inputs=*/replArgs, origArgType, converter);
13281292
mapping.map(origArg, argMat);
13291293
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
13301294

@@ -1342,7 +1306,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13421306
legalOutputType = replArgs[0].getType();
13431307
}
13441308
if (legalOutputType && legalOutputType != origArgType) {
1345-
Value targetMat = buildUnresolvedTargetMaterialization(
1309+
Value targetMat = buildUnresolvedMaterialization(
1310+
MaterializationKind::Target, computeInsertPoint(argMat),
13461311
origArg.getLoc(), argMat, legalOutputType, converter);
13471312
mapping.map(argMat, targetMat);
13481313
}
@@ -1365,34 +1330,21 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13651330
/// Build an unresolved materialization operation given an output type and set
13661331
/// of input operands.
13671332
Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
1368-
MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
1369-
Location loc, ValueRange inputs, Type outputType,
1370-
const TypeConverter *converter) {
1333+
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1334+
ValueRange inputs, Type outputType, const TypeConverter *converter) {
13711335
// Avoid materializing an unnecessary cast.
13721336
if (inputs.size() == 1 && inputs.front().getType() == outputType)
13731337
return inputs.front();
13741338

13751339
// Create an unresolved materialization. We use a new OpBuilder to avoid
13761340
// tracking the materialization like we do for other operations.
13771341
OpBuilder builder(outputType.getContext());
1378-
builder.setInsertionPoint(insertBlock, insertPt);
1342+
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
13791343
auto convertOp =
13801344
builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
13811345
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
13821346
return convertOp.getResult(0);
13831347
}
1384-
Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
1385-
Location loc, Value input, Type outputType,
1386-
const TypeConverter *converter) {
1387-
Block *insertBlock = input.getParentBlock();
1388-
Block::iterator insertPt = insertBlock->begin();
1389-
if (OpResult inputRes = dyn_cast<OpResult>(input))
1390-
insertPt = ++inputRes.getOwner()->getIterator();
1391-
1392-
return buildUnresolvedMaterialization(MaterializationKind::Target,
1393-
insertBlock, insertPt, loc, input,
1394-
outputType, converter);
1395-
}
13961348

13971349
//===----------------------------------------------------------------------===//
13981350
// Rewriter Notification Hooks
@@ -2504,9 +2456,9 @@ LogicalResult
25042456
OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
25052457
std::optional<DenseMap<Value, SmallVector<Value>>> inverseMapping;
25062458
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2507-
if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
2508-
inverseMapping)) ||
2509-
failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
2459+
if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)) ||
2460+
failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
2461+
inverseMapping)))
25102462
return failure();
25112463

25122464
// Process requested operation replacements.
@@ -2562,10 +2514,28 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
25622514
++i) {
25632515
auto &rewrite = rewriterImpl.rewrites[i];
25642516
if (auto *blockTypeConversionRewrite =
2565-
dyn_cast<BlockTypeConversionRewrite>(rewrite.get()))
2566-
if (failed(blockTypeConversionRewrite->materializeLiveConversions(
2567-
findLiveUser)))
2568-
return failure();
2517+
dyn_cast<BlockTypeConversionRewrite>(rewrite.get())) {
2518+
// Process the remapping for each of the original arguments.
2519+
for (Value origArg :
2520+
blockTypeConversionRewrite->getOrigBlock()->getArguments()) {
2521+
// If the type of this argument changed and the argument is still live,
2522+
// we need to materialize a conversion.
2523+
if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
2524+
continue;
2525+
Operation *liveUser = findLiveUser(origArg);
2526+
if (!liveUser)
2527+
continue;
2528+
2529+
Value replacementValue = rewriterImpl.mapping.lookupOrNull(origArg);
2530+
assert(replacementValue && "replacement value not found");
2531+
Value repl = rewriterImpl.buildUnresolvedMaterialization(
2532+
MaterializationKind::Source, computeInsertPoint(replacementValue),
2533+
origArg.getLoc(), /*inputs=*/replacementValue,
2534+
/*outputType=*/origArg.getType(),
2535+
blockTypeConversionRewrite->getConverter());
2536+
rewriterImpl.mapping.map(origArg, repl);
2537+
}
2538+
}
25692539
}
25702540
return success();
25712541
}

0 commit comments

Comments
 (0)