Skip to content

Commit 82afd9d

Browse files
[mlir][Transforms][NFC] Dialect conversion: Eagerly build reverse mapping
The "inverse mapping" is an inverse IRMapping that points from replaced values to their original values. This inverse mapping is needed when legalizing unresolved materializations, to figure out if a value has any uses. (It is not sufficient to examine the IR, because some IR changes have not been materialized yet.) There was a check in `OperationConverter::finalize` that computed the inverse mapping only when needed. This check is not needed. `legalizeUnresolvedMaterializations` always computes the inverse mapping, so we can just do that in `OperationConverter::finalize` before calling `legalizeUnresolvedMaterializations`.
1 parent 2fc71e4 commit 82afd9d

File tree

1 file changed

+12
-14
lines changed

1 file changed

+12
-14
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2352,7 +2352,7 @@ struct OperationConverter {
23522352
LogicalResult legalizeUnresolvedMaterializations(
23532353
ConversionPatternRewriter &rewriter,
23542354
ConversionPatternRewriterImpl &rewriterImpl,
2355-
std::optional<DenseMap<Value, SmallVector<Value>>> &inverseMapping);
2355+
DenseMap<Value, SmallVector<Value>> &inverseMapping);
23562356

23572357
/// Legalize an operation result that was marked as "erased".
23582358
LogicalResult
@@ -2454,10 +2454,12 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
24542454

24552455
LogicalResult
24562456
OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
2457-
std::optional<DenseMap<Value, SmallVector<Value>>> inverseMapping;
24582457
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2459-
if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)) ||
2460-
failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
2458+
if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
2459+
return failure();
2460+
DenseMap<Value, SmallVector<Value>> inverseMapping =
2461+
rewriterImpl.mapping.getInverse();
2462+
if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
24612463
inverseMapping)))
24622464
return failure();
24632465

@@ -2483,15 +2485,11 @@ OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
24832485
if (result.getType() == newValue.getType())
24842486
continue;
24852487

2486-
// Compute the inverse mapping only if it is really needed.
2487-
if (!inverseMapping)
2488-
inverseMapping = rewriterImpl.mapping.getInverse();
2489-
24902488
// Legalize this result.
24912489
rewriter.setInsertionPoint(op);
24922490
if (failed(legalizeChangedResultType(
24932491
op, result, newValue, opReplacement->getConverter(), rewriter,
2494-
rewriterImpl, *inverseMapping)))
2492+
rewriterImpl, inverseMapping)))
24952493
return failure();
24962494
}
24972495
}
@@ -2503,6 +2501,8 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
25032501
ConversionPatternRewriterImpl &rewriterImpl) {
25042502
// Functor used to check if all users of a value will be dead after
25052503
// conversion.
2504+
// TODO: This should probably query the inverse mapping, same as in
2505+
// `legalizeChangedResultType`.
25062506
auto findLiveUser = [&](Value val) {
25072507
auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](Operation *user) {
25082508
return rewriterImpl.isOpIgnored(user);
@@ -2796,20 +2796,18 @@ static LogicalResult legalizeUnresolvedMaterialization(
27962796
LogicalResult OperationConverter::legalizeUnresolvedMaterializations(
27972797
ConversionPatternRewriter &rewriter,
27982798
ConversionPatternRewriterImpl &rewriterImpl,
2799-
std::optional<DenseMap<Value, SmallVector<Value>>> &inverseMapping) {
2800-
inverseMapping = rewriterImpl.mapping.getInverse();
2801-
2799+
DenseMap<Value, SmallVector<Value>> &inverseMapping) {
28022800
// As an initial step, compute all of the inserted materializations that we
28032801
// expect to persist beyond the conversion process.
28042802
DenseMap<Operation *, UnresolvedMaterializationRewrite *> materializationOps;
28052803
SetVector<UnresolvedMaterializationRewrite *> necessaryMaterializations;
28062804
computeNecessaryMaterializations(materializationOps, rewriter, rewriterImpl,
2807-
*inverseMapping, necessaryMaterializations);
2805+
inverseMapping, necessaryMaterializations);
28082806

28092807
// Once computed, legalize any necessary materializations.
28102808
for (auto *mat : necessaryMaterializations) {
28112809
if (failed(legalizeUnresolvedMaterialization(
2812-
*mat, materializationOps, rewriter, rewriterImpl, *inverseMapping)))
2810+
*mat, materializationOps, rewriter, rewriterImpl, inverseMapping)))
28132811
return failure();
28142812
}
28152813
return success();

0 commit comments

Comments
 (0)