From 90f1a0e50c7fde08a66d25ac7b3ca665fd92e0ab Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Fri, 30 Aug 2024 23:06:11 +0200 Subject: [PATCH 1/2] [mlir][Transforms] Dialect conversion: Align handling of dropped values Handle dropped block arguments and dropped op results in the same way: build a source materialization (that may fold away if unused). This simplifies the code base a bit and makes it possible to merge `legalizeConvertedArgumentTypes` and `legalizeConvertedOpResultTypes` in a future commit. These two functions are almost doing the same thing now. This commit also fixes a bug where circular materializations were built, e.g.: ``` %0 = "builtin.unrealized_conversion_cast"(%1) : (!a) -> !b %1 = "builtin.unrealized_conversion_cast"(%0) : (!b) -> !a // No further uses of %0, %1. ``` This happened when: 1. An op was erased. (No replacement values provided.) 2. A conversion pattern for another op builds a replacement value (first cast op) during `remapValues`, but that SSA value is not used during the pattern application. 3. During the finalization phase, `legalizeConvertedOpResultTypes` thinks that the erased op is alive because of the cast op that was built in Step 2. It builds a cast from that replacement value to the original type. 4. During the commit phase, all uses of the original op are repalced with the casted value produced in Step 3. We have generated circular IR. --- .../Transforms/Utils/DialectConversion.cpp | 140 ++++-------------- .../test-legalize-erased-op-with-uses.mlir | 4 +- 2 files changed, 28 insertions(+), 116 deletions(-) diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index f288c7fc2cb77..26f7ae3b2cbf1 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -941,6 +941,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// to modify/access them is invalid rewriter API usage. SetVector replacedOps; + DenseSet unresolvedMaterializations; + /// The current type converter, or nullptr if no type converter is currently /// active. const TypeConverter *currentTypeConverter = nullptr; @@ -1066,6 +1068,7 @@ void UnresolvedMaterializationRewrite::rollback() { for (Value input : op->getOperands()) rewriterImpl.mapping.erase(input); } + rewriterImpl.unresolvedMaterializations.erase(op); op->erase(); } @@ -1347,6 +1350,7 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization( builder.setInsertionPoint(ip.getBlock(), ip.getPoint()); auto convertOp = builder.create(loc, outputType, inputs); + unresolvedMaterializations.insert(convertOp); appendRewrite(convertOp, converter, kind); return convertOp.getResult(0); } @@ -1385,9 +1389,21 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op, // Create mappings for each of the new result values. for (auto [newValue, result] : llvm::zip(newValues, op->getResults())) { if (!newValue) { - resultChanged = true; - continue; + // This result was dropped and no replacement value was provided. + if (unresolvedMaterializations.contains(op)) { + // Do not create another materializations if we are erasing a + // materialization. + resultChanged = true; + continue; + } + + // Materialize a replacement value "out of thin air". + newValue = buildUnresolvedMaterialization( + MaterializationKind::Source, computeInsertPoint(result), + result.getLoc(), /*inputs=*/ValueRange(), + /*outputType=*/result.getType(), currentTypeConverter); } + // Remap, and check for any result type changes. mapping.map(result, newValue); resultChanged |= (newValue.getType() != result.getType()); @@ -2359,11 +2375,6 @@ struct OperationConverter { ConversionPatternRewriterImpl &rewriterImpl, DenseMap> &inverseMapping); - /// Legalize an operation result that was marked as "erased". - LogicalResult - legalizeErasedResult(Operation *op, OpResult result, - ConversionPatternRewriterImpl &rewriterImpl); - /// Dialect conversion configuration. ConversionConfig config; @@ -2455,77 +2466,6 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter, return failure(); } -/// Erase all dead unrealized_conversion_cast ops. An op is dead if its results -/// are not used (transitively) by any op that is not in the given list of -/// cast ops. -/// -/// In particular, this function erases cyclic casts that may be inserted -/// during the dialect conversion process. E.g.: -/// %0 = unrealized_conversion_cast(%1) -/// %1 = unrealized_conversion_cast(%0) -// Note: This step will become unnecessary when -// https://github.com/llvm/llvm-project/pull/106760 has been merged. -static void eraseDeadUnrealizedCasts( - ArrayRef castOps, - SmallVectorImpl *remainingCastOps) { - // Ops that have already been visited or are currently being visited. - DenseSet visited; - // Set of all cast ops for faster lookups. - DenseSet castOpSet; - // Set of all cast ops that have been determined to be alive. - DenseSet live; - - for (UnrealizedConversionCastOp op : castOps) - castOpSet.insert(op); - - // Visit a cast operation. Return "true" if the operation is live. - std::function visit = [&](Operation *op) -> bool { - // No need to traverse any IR if the op was already marked as live. - if (live.contains(op)) - return true; - - // Do not visit ops multiple times. If we find a circle, no live user was - // found on the current path. - if (!visited.insert(op).second) - return false; - - // Visit all users. - for (Operation *user : op->getUsers()) { - // If the user is not an unrealized_conversion_cast op, then the given op - // is live. - if (!castOpSet.contains(user)) { - live.insert(op); - return true; - } - // Otherwise, it is live if a live op can be reached from one of its - // users (which must all be unrealized_conversion_cast ops). - if (visit(user)) { - live.insert(op); - return true; - } - } - - return false; - }; - - // Visit all cast ops. - for (UnrealizedConversionCastOp op : castOps) { - visit(op); - visited.clear(); - } - - // Erase all cast ops that are dead. - for (UnrealizedConversionCastOp op : castOps) { - if (live.contains(op)) { - if (remainingCastOps) - remainingCastOps->push_back(op); - continue; - } - op->dropAllUses(); - op->erase(); - } -} - LogicalResult OperationConverter::convertOperations(ArrayRef ops) { if (ops.empty()) return success(); @@ -2584,14 +2524,13 @@ LogicalResult OperationConverter::convertOperations(ArrayRef ops) { // Reconcile all UnrealizedConversionCastOps that were inserted by the // dialect conversion frameworks. (Not the one that were inserted by // patterns.) - SmallVector remainingCastOps1, remainingCastOps2; - eraseDeadUnrealizedCasts(allCastOps, &remainingCastOps1); - reconcileUnrealizedCasts(remainingCastOps1, &remainingCastOps2); + SmallVector remainingCastOps; + reconcileUnrealizedCasts(allCastOps, &remainingCastOps); // Try to legalize all unresolved materializations. if (config.buildMaterializations) { IRRewriter rewriter(rewriterImpl.context, config.listener); - for (UnrealizedConversionCastOp castOp : remainingCastOps2) { + for (UnrealizedConversionCastOp castOp : remainingCastOps) { auto it = rewriteMap.find(castOp.getOperation()); assert(it != rewriteMap.end() && "inconsistent state"); if (failed(legalizeUnresolvedMaterialization(rewriter, it->second))) @@ -2650,26 +2589,18 @@ LogicalResult OperationConverter::legalizeConvertedOpResultTypes( continue; Operation *op = opReplacement->getOperation(); for (OpResult result : op->getResults()) { - Value newValue = rewriterImpl.mapping.lookupOrNull(result); - - // If the operation result was replaced with null, all of the uses of this - // value should be replaced. - if (!newValue) { - if (failed(legalizeErasedResult(op, result, rewriterImpl))) - return failure(); - continue; - } - - // Otherwise, check to see if the type of the result changed. - if (result.getType() == newValue.getType()) + // If the type of this op result changed and the result is still live, + // we need to materialize a conversion. + if (rewriterImpl.mapping.lookupOrNull(result, result.getType())) continue; - Operation *liveUser = findLiveUserOfReplaced(result, rewriterImpl, inverseMapping); if (!liveUser) continue; // Legalize this result. + Value newValue = rewriterImpl.mapping.lookupOrNull(result); + assert(newValue && "replacement value not found"); Value castValue = rewriterImpl.buildUnresolvedMaterialization( MaterializationKind::Source, computeInsertPoint(result), op->getLoc(), /*inputs=*/newValue, /*outputType=*/result.getType(), @@ -2727,25 +2658,6 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes( return success(); } -LogicalResult OperationConverter::legalizeErasedResult( - Operation *op, OpResult result, - ConversionPatternRewriterImpl &rewriterImpl) { - // If the operation result was replaced with null, all of the uses of this - // value should be replaced. - auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) { - return rewriterImpl.isOpIgnored(user); - }); - if (liveUserIt != result.user_end()) { - InFlightDiagnostic diag = op->emitError("failed to legalize operation '") - << op->getName() << "' marked as erased"; - diag.attachNote(liveUserIt->getLoc()) - << "found live user of result #" << result.getResultNumber() << ": " - << *liveUserIt; - return failure(); - } - return success(); -} - //===----------------------------------------------------------------------===// // Reconcile Unrealized Casts //===----------------------------------------------------------------------===// diff --git a/mlir/test/Transforms/test-legalize-erased-op-with-uses.mlir b/mlir/test/Transforms/test-legalize-erased-op-with-uses.mlir index 49275e8008e74..6e8f0162e505d 100644 --- a/mlir/test/Transforms/test-legalize-erased-op-with-uses.mlir +++ b/mlir/test/Transforms/test-legalize-erased-op-with-uses.mlir @@ -3,8 +3,8 @@ // Test that an error is emitted when an operation is marked as "erased", but // has users that live across the conversion. func.func @remove_all_ops(%arg0: i32) -> i32 { - // expected-error@below {{failed to legalize operation 'test.illegal_op_a' marked as erased}} + // expected-error@below {{failed to legalize unresolved materialization from () to 'i32' that remained live after conversion}} %0 = "test.illegal_op_a"() : () -> i32 - // expected-note@below {{found live user of result #0: func.return %0 : i32}} + // expected-note@below {{see existing live user here}} return %0 : i32 } From e12b6848df4db721ce5f0cab24373733c4d5eaeb Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Fri, 30 Aug 2024 23:40:23 +0200 Subject: [PATCH 2/2] Remove `changedResults` --- .../Transforms/Utils/DialectConversion.cpp | 21 +++++-------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 26f7ae3b2cbf1..b58a95c3baf70 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -624,10 +624,9 @@ class ModifyOperationRewrite : public OperationRewrite { class ReplaceOperationRewrite : public OperationRewrite { public: ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl, - Operation *op, const TypeConverter *converter, - bool changedResults) + Operation *op, const TypeConverter *converter) : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op), - converter(converter), changedResults(changedResults) {} + converter(converter) {} static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() == Kind::ReplaceOperation; @@ -641,15 +640,10 @@ class ReplaceOperationRewrite : public OperationRewrite { const TypeConverter *getConverter() const { return converter; } - bool hasChangedResults() const { return changedResults; } - private: /// An optional type converter that can be used to materialize conversions /// between the new and old values if necessary. const TypeConverter *converter; - - /// A boolean flag that indicates whether result types have changed or not. - bool changedResults; }; class CreateOperationRewrite : public OperationRewrite { @@ -941,6 +935,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// to modify/access them is invalid rewriter API usage. SetVector replacedOps; + /// A set of all unresolved materializations. DenseSet unresolvedMaterializations; /// The current type converter, or nullptr if no type converter is currently @@ -1383,9 +1378,6 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op, assert(newValues.size() == op->getNumResults()); assert(!ignoredOps.contains(op) && "operation was already replaced"); - // Track if any of the results changed, e.g. erased and replaced with null. - bool resultChanged = false; - // Create mappings for each of the new result values. for (auto [newValue, result] : llvm::zip(newValues, op->getResults())) { if (!newValue) { @@ -1393,7 +1385,6 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op, if (unresolvedMaterializations.contains(op)) { // Do not create another materializations if we are erasing a // materialization. - resultChanged = true; continue; } @@ -1406,11 +1397,9 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op, // Remap, and check for any result type changes. mapping.map(result, newValue); - resultChanged |= (newValue.getType() != result.getType()); } - appendRewrite(op, currentTypeConverter, - resultChanged); + appendRewrite(op, currentTypeConverter); // Mark this operation and all nested ops as replaced. op->walk([&](Operation *op) { replacedOps.insert(op); }); @@ -2585,7 +2574,7 @@ LogicalResult OperationConverter::legalizeConvertedOpResultTypes( for (unsigned i = 0; i < rewriterImpl.rewrites.size(); ++i) { auto *opReplacement = dyn_cast(rewriterImpl.rewrites[i].get()); - if (!opReplacement || !opReplacement->hasChangedResults()) + if (!opReplacement) continue; Operation *op = opReplacement->getOperation(); for (OpResult result : op->getResults()) {