diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index f967e8352bf4c..5d399ce1eb9cf 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -798,13 +798,12 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { PatternRewriter &rewriter, ValueRange values, SmallVectorImpl &remapped); - /// Returns true if the given operation is ignored, and does not need to be + /// Return "true" if the given operation is ignored, and does not need to be /// converted. bool isOpIgnored(Operation *op) const; - /// Recursively marks the nested operations under 'op' as ignored. This - /// removes them from being considered for legalization. - void markNestedOpsIgnored(Operation *op); + /// Return "true" if the given operation was replaced or erased. + bool wasOpReplaced(Operation *op) const; //===--------------------------------------------------------------------===// // Type Conversion @@ -946,18 +945,15 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// Ordered list of block operations (creations, splits, motions). SmallVector> rewrites; - /// A set of operations that should no longer be considered for legalization, - /// but were not directly replace/erased/etc. by a pattern. These are - /// generally child operations of other operations who were - /// replaced/erased/etc. This is not meant to be an exhaustive list of all - /// operations, but the minimal set that can be used to detect if a given - /// operation should be `ignored`. For example, we may add the operations that - /// define non-empty regions to the set, but not any of the others. This - /// simplifies the amount of memory needed as we can query if the parent - /// operation was ignored. + /// A set of operations that should no longer be considered for legalization. + /// E.g., ops that are recursively legal. Ops that were replaced/erased are + /// tracked separately. SetVector ignoredOps; - // A set of operations that were erased. + /// A set of operations that were replaced/erased. Such ops are not erased + /// immediately but only when the dialect conversion succeeds. In the mean + /// time, they should no longer be considered for legalization and any attempt + /// to modify/access them is invalid rewriter API usage. SetVector replacedOps; /// The current type converter, or nullptr if no type converter is currently @@ -1237,24 +1233,14 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( return success(); } -// TODO: This function is a misnomer. It does not actually check if `op` is in -// `ignoredOps`. bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const { - // Check to see if this operation or the parent operation is ignored. - return ignoredOps.count(op->getParentOp()) || replacedOps.count(op); + // Check to see if this operation is ignored or was replaced. + return replacedOps.count(op) || ignoredOps.count(op); } -void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) { - // Walk this operation and collect nested operations that define non-empty - // regions. We mark such operations as 'ignored' so that we know we don't have - // to convert them, or their nested ops. - if (op->getNumRegions() == 0) - return; - op->walk([&](Operation *op) { - if (llvm::any_of(op->getRegions(), - [](Region ®ion) { return !region.empty(); })) - ignoredOps.insert(op); - }); +bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const { + // Check to see if this operation was replaced. + return replacedOps.count(op); } //===----------------------------------------------------------------------===// @@ -1476,6 +1462,9 @@ void ConversionPatternRewriterImpl::notifyOperationInserted( logger.startLine() << "** Insert : '" << op->getName() << "'(" << op << ")\n"; }); + assert(!wasOpReplaced(op->getParentOp()) && + "attempting to insert into a block within a replaced/erased op"); + if (!previous.isSet()) { // This is a newly created op. appendRewrite(op); @@ -1490,7 +1479,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted( void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op, ValueRange newValues) { assert(newValues.size() == op->getNumResults()); - assert(!replacedOps.contains(op) && "operation was already replaced"); + assert(!ignoredOps.contains(op) && "operation was already replaced"); // Track if any of the results changed, e.g. erased and replaced with null. bool resultChanged = false; @@ -1509,10 +1498,8 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op, appendRewrite(op, currentTypeConverter, resultChanged); - // Mark this operation as recursively ignored so that we don't need to - // convert any nested operations. - replacedOps.insert(op); - markNestedOpsIgnored(op); + // Mark this operation and all nested ops as replaced. + op->walk([&](Operation *op) { replacedOps.insert(op); }); } void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) { @@ -1523,6 +1510,9 @@ void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) { void ConversionPatternRewriterImpl::notifyBlockInserted( Block *block, Region *previous, Region::iterator previousIt) { + assert(!wasOpReplaced(block->getParentOp()) && + "attempting to insert into a region within a replaced/erased op"); + if (!previous) { // This is a newly created block. appendRewrite(block); @@ -1604,6 +1594,9 @@ void ConversionPatternRewriter::eraseOp(Operation *op) { } void ConversionPatternRewriter::eraseBlock(Block *block) { + assert(!impl->wasOpReplaced(block->getParentOp()) && + "attempting to erase a block within a replaced/erased op"); + // Mark all ops for erasure. for (Operation &op : *block) eraseOp(&op); @@ -1619,18 +1612,27 @@ void ConversionPatternRewriter::eraseBlock(Block *block) { Block *ConversionPatternRewriter::applySignatureConversion( Region *region, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter) { + assert(!impl->wasOpReplaced(region->getParentOp()) && + "attempting to apply a signature conversion to a block within a " + "replaced/erased op"); return impl->applySignatureConversion(region, conversion, converter); } FailureOr ConversionPatternRewriter::convertRegionTypes( Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion) { + assert(!impl->wasOpReplaced(region->getParentOp()) && + "attempting to apply a signature conversion to a block within a " + "replaced/erased op"); return impl->convertRegionTypes(region, converter, entryConversion); } LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes( Region *region, const TypeConverter &converter, ArrayRef blockConversions) { + assert(!impl->wasOpReplaced(region->getParentOp()) && + "attempting to apply a signature conversion to a block within a " + "replaced/erased op"); return impl->convertNonEntryRegionTypes(region, converter, blockConversions); } @@ -1665,6 +1667,8 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys, Block *ConversionPatternRewriter::splitBlock(Block *block, Block::iterator before) { + assert(!impl->wasOpReplaced(block->getParentOp()) && + "attempting to split a block within a replaced/erased op"); auto *continuation = block->splitBlock(before); impl->notifySplitBlock(block, continuation); return continuation; @@ -1673,15 +1677,19 @@ Block *ConversionPatternRewriter::splitBlock(Block *block, void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues) { +#ifndef NDEBUG assert(argValues.size() == source->getNumArguments() && "incorrect # of argument replacement values"); -#ifndef NDEBUG + assert(!impl->wasOpReplaced(source->getParentOp()) && + "attempting to inline a block from a replaced/erased op"); + assert(!impl->wasOpReplaced(dest->getParentOp()) && + "attempting to inline a block into a replaced/erased op"); auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); }; -#endif // NDEBUG // The source block will be deleted, so it should not have any users (i.e., // there should be no predecessors). assert(llvm::all_of(source->getUsers(), opIgnored) && "expected 'source' to have no predecessors"); +#endif // NDEBUG impl->notifyBlockBeingInlined(dest, source, before); for (auto it : llvm::zip(source->getArguments(), argValues)) @@ -1691,6 +1699,8 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, } void ConversionPatternRewriter::startOpModification(Operation *op) { + assert(!impl->wasOpReplaced(op) && + "attempting to modify a replaced/erased op"); #ifndef NDEBUG impl->pendingRootUpdates.insert(op); #endif @@ -1698,6 +1708,8 @@ void ConversionPatternRewriter::startOpModification(Operation *op) { } void ConversionPatternRewriter::finalizeOpModification(Operation *op) { + assert(!impl->wasOpReplaced(op) && + "attempting to modify a replaced/erased op"); PatternRewriter::finalizeOpModification(op); // There is nothing to do here, we only need to track the operation at the // start of the update. @@ -1912,8 +1924,13 @@ OperationLegalizer::legalize(Operation *op, // If this operation is recursively legal, mark its children as ignored so // that we don't consider them for legalization. - if (legalityInfo->isRecursivelyLegal) - rewriter.getImpl().markNestedOpsIgnored(op); + if (legalityInfo->isRecursivelyLegal) { + op->walk([&](Operation *nested) { + if (op != nested) + rewriter.getImpl().ignoredOps.insert(nested); + }); + } + return success(); } diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index bde4255ee4b36..abc0e43c7b7f2 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -1768,7 +1768,6 @@ struct TestMergeSingleBlockOps rewriter.inlineBlockBefore(&innerBlock, op); rewriter.eraseOp(innerTerminator); rewriter.eraseOp(op); - rewriter.modifyOpInPlace(op, [] {}); return success(); } };