diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index cc61bc6b6260c..88709bb261874 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -343,23 +343,6 @@ struct ArgConverter { const TypeConverter *converter; }; - /// Return if the signature of the given block has already been converted. - bool hasBeenConverted(Block *block) const { - return conversionInfo.count(block) || convertedBlocks.count(block); - } - - /// Set the type converter to use for the given region. - void setConverter(Region *region, const TypeConverter *typeConverter) { - assert(typeConverter && "expected valid type converter"); - regionToConverter[region] = typeConverter; - } - - /// Return the type converter to use for the given region, or null if there - /// isn't one. - const TypeConverter *getConverter(Region *region) { - return regionToConverter.lookup(region); - } - //===--------------------------------------------------------------------===// // Rewrite Application //===--------------------------------------------------------------------===// @@ -409,24 +392,10 @@ struct ArgConverter { ConversionValueMapping &mapping, SmallVectorImpl &argReplacements); - /// Insert a new conversion into the cache. - void insertConversion(Block *newBlock, ConvertedBlockInfo &&info); - /// A collection of blocks that have had their arguments converted. This is a /// map from the new replacement block, back to the original block. llvm::MapVector conversionInfo; - /// The set of original blocks that were converted. - DenseSet convertedBlocks; - - /// A mapping from valid regions, to those containing the original blocks of a - /// conversion. - DenseMap> regionMapping; - - /// A mapping of regions to type converters that should be used when - /// converting the arguments of blocks within that region. - DenseMap regionToConverter; - /// The pattern rewriter to use when materializing conversions. PatternRewriter &rewriter; @@ -474,12 +443,12 @@ void ArgConverter::discardRewrites(Block *block) { block->getArgument(i).dropAllUses(); block->replaceAllUsesWith(origBlock); - // Move the operations back the original block and the delete the new block. + // Move the operations back the original block, move the original block back + // into its original location and the delete the new block. origBlock->getOperations().splice(origBlock->end(), block->getOperations()); - origBlock->moveBefore(block); + block->getParent()->getBlocks().insert(Region::iterator(block), origBlock); block->erase(); - convertedBlocks.erase(origBlock); conversionInfo.erase(it); } @@ -510,6 +479,9 @@ void ArgConverter::applyRewrites(ConversionValueMapping &mapping) { mapping.lookupOrDefault(castValue, origArg.getType())); } } + + delete origBlock; + blockInfo.origBlock = nullptr; } } @@ -572,9 +544,11 @@ FailureOr ArgConverter::convertSignature( Block *block, const TypeConverter *converter, ConversionValueMapping &mapping, SmallVectorImpl &argReplacements) { - // Check if the block was already converted. If the block is detached, - // conservatively assume it is going to be deleted. - if (hasBeenConverted(block) || !block->getParent()) + // Check if the block was already converted. + // * If the block is mapped in `conversionInfo`, it is a converted block. + // * If the block is detached, conservatively assume that it is going to be + // deleted; it is likely the old block (before it was converted). + if (conversionInfo.count(block) || !block->getParent()) return block; // If a converter wasn't provided, and the block wasn't already converted, // there is nothing we can do. @@ -603,6 +577,9 @@ Block *ArgConverter::applySignatureConversion( // signature. Block *newBlock = block->splitBlock(block->begin()); block->replaceAllUsesWith(newBlock); + // Unlink the block, but do not erase it yet, so that the change can be rolled + // back. + block->getParent()->getBlocks().remove(block); // Map all new arguments to the location of the argument they originate from. SmallVector newLocs(convertedTypes.size(), @@ -679,24 +656,8 @@ Block *ArgConverter::applySignatureConversion( ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg); } - // Remove the original block from the region and return the new one. - insertConversion(newBlock, std::move(info)); - return newBlock; -} - -void ArgConverter::insertConversion(Block *newBlock, - ConvertedBlockInfo &&info) { - // Get a region to insert the old block. - Region *region = newBlock->getParent(); - std::unique_ptr &mappedRegion = regionMapping[region]; - if (!mappedRegion) - mappedRegion = std::make_unique(region->getParentOp()); - - // Move the original block to the mapped region and emplace the conversion. - mappedRegion->getBlocks().splice(mappedRegion->end(), region->getBlocks(), - info.origBlock->getIterator()); - convertedBlocks.insert(info.origBlock); conversionInfo.insert({newBlock, std::move(info)}); + return newBlock; } //===----------------------------------------------------------------------===// @@ -1227,6 +1188,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// active. const TypeConverter *currentTypeConverter = nullptr; + /// A mapping of regions to type converters that should be used when + /// converting the arguments of blocks within that region. + DenseMap regionToConverter; + /// This allows the user to collect the match failure message. function_ref notifyCallback; @@ -1504,7 +1469,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( FailureOr ConversionPatternRewriterImpl::convertRegionTypes( Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion) { - argConverter.setConverter(region, &converter); + regionToConverter[region] = &converter; if (region->empty()) return nullptr; @@ -1519,7 +1484,7 @@ FailureOr ConversionPatternRewriterImpl::convertRegionTypes( LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes( Region *region, const TypeConverter &converter, ArrayRef blockConversions) { - argConverter.setConverter(region, &converter); + regionToConverter[region] = &converter; if (region->empty()) return success(); @@ -2195,7 +2160,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites( // If the region of the block has a type converter, try to convert the block // directly. - if (auto *converter = impl.argConverter.getConverter(block->getParent())) { + if (auto *converter = impl.regionToConverter.lookup(block->getParent())) { if (failed(impl.convertBlockSignature(block, converter))) { LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved " "block"));