Skip to content

Commit f2d500c

Browse files
[mlir][Transforms] Dialect conversion: Fix bug in UnresolvedMaterializationRewrite rollback (#105949)
When an unresolved materialization (`unrealized_conversion_cast` op) is rolled back, the mapping should be rolled back as well, regardless of whether it is a source, target or argument materialization. Otherwise, we accumulate pointers to erased IR in the `mapping`. This is harmless in most cases, but can cause issues when a new operation is allocated at the same memory location and the pointer is "reused". It is not possible to write a test case for this because I cannot trigger the pointer reuse programmatically.
1 parent d106a39 commit f2d500c

File tree

1 file changed

+42
-32
lines changed

1 file changed

+42
-32
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 42 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -676,7 +676,8 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
676676
UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
677677
UnrealizedConversionCastOp op,
678678
const TypeConverter *converter,
679-
MaterializationKind kind, Type originalType);
679+
MaterializationKind kind, Type originalType,
680+
Value mappedValue);
680681

681682
static bool classof(const IRRewrite *rewrite) {
682683
return rewrite->getKind() == Kind::UnresolvedMaterialization;
@@ -710,6 +711,10 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
710711
/// The original type of the SSA value. Only used for target
711712
/// materializations.
712713
Type originalType;
714+
715+
/// The value in the conversion value mapping that is being replaced by the
716+
/// results of this unresolved materialization.
717+
Value mappedValue;
713718
};
714719
} // namespace
715720

@@ -814,10 +819,14 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
814819

815820
/// Build an unresolved materialization operation given an output type and set
816821
/// of input operands.
822+
///
823+
/// If `valueToMap` is set to a non-null Value, then that value is mapped to
824+
/// the result of the unresolved materialization in the conversion value
825+
/// mapping.
817826
Value buildUnresolvedMaterialization(MaterializationKind kind,
818827
OpBuilder::InsertPoint ip, Location loc,
819-
ValueRange inputs, Type outputType,
820-
Type originalType,
828+
Value valueToMap, ValueRange inputs,
829+
Type outputType, Type originalType,
821830
const TypeConverter *converter);
822831

823832
/// Build an N:1 materialization for the given original value that was
@@ -1068,19 +1077,19 @@ void CreateOperationRewrite::rollback() {
10681077

10691078
UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
10701079
ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
1071-
const TypeConverter *converter, MaterializationKind kind, Type originalType)
1080+
const TypeConverter *converter, MaterializationKind kind, Type originalType,
1081+
Value mappedValue)
10721082
: OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
1073-
converterAndKind(converter, kind), originalType(originalType) {
1083+
converterAndKind(converter, kind), originalType(originalType),
1084+
mappedValue(mappedValue) {
10741085
assert((!originalType || kind == MaterializationKind::Target) &&
10751086
"original type is valid only for target materializations");
10761087
rewriterImpl.unresolvedMaterializations[op] = this;
10771088
}
10781089

10791090
void UnresolvedMaterializationRewrite::rollback() {
1080-
if (getMaterializationKind() == MaterializationKind::Target) {
1081-
for (Value input : op->getOperands())
1082-
rewriterImpl.mapping.erase(input);
1083-
}
1091+
if (mappedValue)
1092+
rewriterImpl.mapping.erase(mappedValue);
10841093
rewriterImpl.unresolvedMaterializations.erase(getOperation());
10851094
op->erase();
10861095
}
@@ -1176,10 +1185,9 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
11761185
// source materialization was created yet.
11771186
Value castValue = buildUnresolvedMaterialization(
11781187
MaterializationKind::Target, computeInsertPoint(newOperand),
1179-
operandLoc,
1180-
/*inputs=*/newOperand, /*outputType=*/desiredType,
1181-
/*originalType=*/origType, currentTypeConverter);
1182-
mapping.map(newOperand, castValue);
1188+
operandLoc, /*valueToMap=*/newOperand, /*inputs=*/newOperand,
1189+
/*outputType=*/desiredType, /*originalType=*/origType,
1190+
currentTypeConverter);
11831191
newOperand = castValue;
11841192
}
11851193
remapped.push_back(newOperand);
@@ -1293,12 +1301,11 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
12931301
if (!inputMap) {
12941302
// This block argument was dropped and no replacement value was provided.
12951303
// Materialize a replacement value "out of thin air".
1296-
Value repl = buildUnresolvedMaterialization(
1304+
buildUnresolvedMaterialization(
12971305
MaterializationKind::Source,
12981306
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
1299-
/*inputs=*/ValueRange(),
1307+
/*valueToMap=*/origArg, /*inputs=*/ValueRange(),
13001308
/*outputType=*/origArgType, /*originalType=*/Type(), converter);
1301-
mapping.map(origArg, repl);
13021309
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
13031310
continue;
13041311
}
@@ -1342,23 +1349,28 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13421349
/// of input operands.
13431350
Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
13441351
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1345-
ValueRange inputs, Type outputType, Type originalType,
1352+
Value valueToMap, ValueRange inputs, Type outputType, Type originalType,
13461353
const TypeConverter *converter) {
13471354
assert((!originalType || kind == MaterializationKind::Target) &&
13481355
"original type is valid only for target materializations");
13491356

13501357
// Avoid materializing an unnecessary cast.
1351-
if (inputs.size() == 1 && inputs.front().getType() == outputType)
1358+
if (inputs.size() == 1 && inputs.front().getType() == outputType) {
1359+
if (valueToMap)
1360+
mapping.map(valueToMap, inputs.front());
13521361
return inputs.front();
1362+
}
13531363

13541364
// Create an unresolved materialization. We use a new OpBuilder to avoid
13551365
// tracking the materialization like we do for other operations.
13561366
OpBuilder builder(outputType.getContext());
13571367
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
13581368
auto convertOp =
13591369
builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
1370+
if (valueToMap)
1371+
mapping.map(valueToMap, convertOp.getResult(0));
13601372
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
1361-
originalType);
1373+
originalType, valueToMap);
13621374
return convertOp.getResult(0);
13631375
}
13641376

@@ -1367,11 +1379,10 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization(
13671379
Value originalValue, const TypeConverter *converter) {
13681380
// Insert argument materialization back to the original type.
13691381
Type originalType = originalValue.getType();
1370-
Value argMat =
1371-
buildUnresolvedMaterialization(MaterializationKind::Argument, ip, loc,
1372-
/*inputs=*/replacements, originalType,
1373-
/*originalType=*/Type(), converter);
1374-
mapping.map(originalValue, argMat);
1382+
Value argMat = buildUnresolvedMaterialization(
1383+
MaterializationKind::Argument, ip, loc, /*valueToMap=*/originalValue,
1384+
/*inputs=*/replacements, originalType, /*originalType=*/Type(),
1385+
converter);
13751386

13761387
// Insert target materialization to the legalized type.
13771388
Type legalOutputType;
@@ -1387,11 +1398,11 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization(
13871398
legalOutputType = replacements[0].getType();
13881399
}
13891400
if (legalOutputType && legalOutputType != originalType) {
1390-
Value targetMat = buildUnresolvedMaterialization(
1391-
MaterializationKind::Target, computeInsertPoint(argMat), loc,
1392-
/*inputs=*/argMat, /*outputType=*/legalOutputType,
1393-
/*originalType=*/originalType, converter);
1394-
mapping.map(argMat, targetMat);
1401+
buildUnresolvedMaterialization(MaterializationKind::Target,
1402+
computeInsertPoint(argMat), loc,
1403+
/*valueToMap=*/argMat, /*inputs=*/argMat,
1404+
/*outputType=*/legalOutputType,
1405+
/*originalType=*/originalType, converter);
13951406
}
13961407
}
13971408

@@ -1425,9 +1436,8 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
14251436
}
14261437
Value castValue = buildUnresolvedMaterialization(
14271438
MaterializationKind::Source, computeInsertPoint(repl), value.getLoc(),
1428-
/*inputs=*/repl, /*outputType=*/value.getType(),
1439+
/*valueToMap=*/value, /*inputs=*/repl, /*outputType=*/value.getType(),
14291440
/*originalType=*/Type(), converter);
1430-
mapping.map(value, castValue);
14311441
return castValue;
14321442
}
14331443

@@ -1480,7 +1490,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(
14801490
// Materialize a replacement value "out of thin air".
14811491
Value sourceMat = buildUnresolvedMaterialization(
14821492
MaterializationKind::Source, computeInsertPoint(result),
1483-
result.getLoc(), /*inputs=*/ValueRange(),
1493+
result.getLoc(), /*valueToMap=*/Value(), /*inputs=*/ValueRange(),
14841494
/*outputType=*/result.getType(), /*originalType=*/Type(),
14851495
currentTypeConverter);
14861496
repl.push_back(sourceMat);

0 commit comments

Comments
 (0)