Skip to content

[mlir][Transforms] Dialect conversion: Fix bug in UnresolvedMaterializationRewrite rollback #105949

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 42 additions & 32 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,8 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
UnrealizedConversionCastOp op,
const TypeConverter *converter,
MaterializationKind kind, Type originalType);
MaterializationKind kind, Type originalType,
Value mappedValue);

static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::UnresolvedMaterialization;
Expand Down Expand Up @@ -710,6 +711,10 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
/// The original type of the SSA value. Only used for target
/// materializations.
Type originalType;

/// The value in the conversion value mapping that is being replaced by the
/// results of this unresolved materialization.
Value mappedValue;
};
} // namespace

Expand Down Expand Up @@ -814,10 +819,14 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {

/// Build an unresolved materialization operation given an output type and set
/// of input operands.
///
/// If `valueToMap` is set to a non-null Value, then that value is mapped to
/// the result of the unresolved materialization in the conversion value
/// mapping.
Value buildUnresolvedMaterialization(MaterializationKind kind,
OpBuilder::InsertPoint ip, Location loc,
ValueRange inputs, Type outputType,
Type originalType,
Value valueToMap, ValueRange inputs,
Type outputType, Type originalType,
const TypeConverter *converter);

/// Build an N:1 materialization for the given original value that was
Expand Down Expand Up @@ -1068,19 +1077,19 @@ void CreateOperationRewrite::rollback() {

UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
const TypeConverter *converter, MaterializationKind kind, Type originalType)
const TypeConverter *converter, MaterializationKind kind, Type originalType,
Value mappedValue)
: OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
converterAndKind(converter, kind), originalType(originalType) {
converterAndKind(converter, kind), originalType(originalType),
mappedValue(mappedValue) {
assert((!originalType || kind == MaterializationKind::Target) &&
"original type is valid only for target materializations");
rewriterImpl.unresolvedMaterializations[op] = this;
}

void UnresolvedMaterializationRewrite::rollback() {
if (getMaterializationKind() == MaterializationKind::Target) {
for (Value input : op->getOperands())
rewriterImpl.mapping.erase(input);
}
if (mappedValue)
rewriterImpl.mapping.erase(mappedValue);
rewriterImpl.unresolvedMaterializations.erase(getOperation());
op->erase();
}
Expand Down Expand Up @@ -1176,10 +1185,9 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
// source materialization was created yet.
Value castValue = buildUnresolvedMaterialization(
MaterializationKind::Target, computeInsertPoint(newOperand),
operandLoc,
/*inputs=*/newOperand, /*outputType=*/desiredType,
/*originalType=*/origType, currentTypeConverter);
mapping.map(newOperand, castValue);
operandLoc, /*valueToMap=*/newOperand, /*inputs=*/newOperand,
/*outputType=*/desiredType, /*originalType=*/origType,
currentTypeConverter);
newOperand = castValue;
}
remapped.push_back(newOperand);
Expand Down Expand Up @@ -1293,12 +1301,11 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
if (!inputMap) {
// This block argument was dropped and no replacement value was provided.
// Materialize a replacement value "out of thin air".
Value repl = buildUnresolvedMaterialization(
buildUnresolvedMaterialization(
MaterializationKind::Source,
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
/*inputs=*/ValueRange(),
/*valueToMap=*/origArg, /*inputs=*/ValueRange(),
/*outputType=*/origArgType, /*originalType=*/Type(), converter);
mapping.map(origArg, repl);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
continue;
}
Expand Down Expand Up @@ -1342,23 +1349,28 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
/// of input operands.
Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
ValueRange inputs, Type outputType, Type originalType,
Value valueToMap, ValueRange inputs, Type outputType, Type originalType,
const TypeConverter *converter) {
assert((!originalType || kind == MaterializationKind::Target) &&
"original type is valid only for target materializations");

// Avoid materializing an unnecessary cast.
if (inputs.size() == 1 && inputs.front().getType() == outputType)
if (inputs.size() == 1 && inputs.front().getType() == outputType) {
if (valueToMap)
mapping.map(valueToMap, inputs.front());
return inputs.front();
}

// Create an unresolved materialization. We use a new OpBuilder to avoid
// tracking the materialization like we do for other operations.
OpBuilder builder(outputType.getContext());
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
auto convertOp =
builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
if (valueToMap)
mapping.map(valueToMap, convertOp.getResult(0));
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
originalType);
originalType, valueToMap);
return convertOp.getResult(0);
}

Expand All @@ -1367,11 +1379,10 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization(
Value originalValue, const TypeConverter *converter) {
// Insert argument materialization back to the original type.
Type originalType = originalValue.getType();
Value argMat =
buildUnresolvedMaterialization(MaterializationKind::Argument, ip, loc,
/*inputs=*/replacements, originalType,
/*originalType=*/Type(), converter);
mapping.map(originalValue, argMat);
Value argMat = buildUnresolvedMaterialization(
MaterializationKind::Argument, ip, loc, /*valueToMap=*/originalValue,
/*inputs=*/replacements, originalType, /*originalType=*/Type(),
converter);

// Insert target materialization to the legalized type.
Type legalOutputType;
Expand All @@ -1387,11 +1398,11 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization(
legalOutputType = replacements[0].getType();
}
if (legalOutputType && legalOutputType != originalType) {
Value targetMat = buildUnresolvedMaterialization(
MaterializationKind::Target, computeInsertPoint(argMat), loc,
/*inputs=*/argMat, /*outputType=*/legalOutputType,
/*originalType=*/originalType, converter);
mapping.map(argMat, targetMat);
buildUnresolvedMaterialization(MaterializationKind::Target,
computeInsertPoint(argMat), loc,
/*valueToMap=*/argMat, /*inputs=*/argMat,
/*outputType=*/legalOutputType,
/*originalType=*/originalType, converter);
}
}

Expand Down Expand Up @@ -1425,9 +1436,8 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
}
Value castValue = buildUnresolvedMaterialization(
MaterializationKind::Source, computeInsertPoint(repl), value.getLoc(),
/*inputs=*/repl, /*outputType=*/value.getType(),
/*valueToMap=*/value, /*inputs=*/repl, /*outputType=*/value.getType(),
/*originalType=*/Type(), converter);
mapping.map(value, castValue);
return castValue;
}

Expand Down Expand Up @@ -1480,7 +1490,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(
// Materialize a replacement value "out of thin air".
Value sourceMat = buildUnresolvedMaterialization(
MaterializationKind::Source, computeInsertPoint(result),
result.getLoc(), /*inputs=*/ValueRange(),
result.getLoc(), /*valueToMap=*/Value(), /*inputs=*/ValueRange(),
/*outputType=*/result.getType(), /*originalType=*/Type(),
currentTypeConverter);
repl.push_back(sourceMat);
Expand Down