Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 22 additions & 57 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,23 +343,6 @@ struct ArgConverter {
const TypeConverter *converter;
};

/// Return if the signature of the given block has already been converted.
bool hasBeenConverted(Block *block) const {
return conversionInfo.count(block) || convertedBlocks.count(block);
}

/// Set the type converter to use for the given region.
void setConverter(Region *region, const TypeConverter *typeConverter) {
assert(typeConverter && "expected valid type converter");
regionToConverter[region] = typeConverter;
}

/// Return the type converter to use for the given region, or null if there
/// isn't one.
const TypeConverter *getConverter(Region *region) {
return regionToConverter.lookup(region);
}

//===--------------------------------------------------------------------===//
// Rewrite Application
//===--------------------------------------------------------------------===//
Expand Down Expand Up @@ -409,24 +392,10 @@ struct ArgConverter {
ConversionValueMapping &mapping,
SmallVectorImpl<BlockArgument> &argReplacements);

/// Insert a new conversion into the cache.
void insertConversion(Block *newBlock, ConvertedBlockInfo &&info);

/// A collection of blocks that have had their arguments converted. This is a
/// map from the new replacement block, back to the original block.
llvm::MapVector<Block *, ConvertedBlockInfo> conversionInfo;

/// The set of original blocks that were converted.
DenseSet<Block *> convertedBlocks;

/// A mapping from valid regions, to those containing the original blocks of a
/// conversion.
DenseMap<Region *, std::unique_ptr<Region>> regionMapping;

/// A mapping of regions to type converters that should be used when
/// converting the arguments of blocks within that region.
DenseMap<Region *, const TypeConverter *> regionToConverter;

/// The pattern rewriter to use when materializing conversions.
PatternRewriter &rewriter;

Expand Down Expand Up @@ -474,12 +443,12 @@ void ArgConverter::discardRewrites(Block *block) {
block->getArgument(i).dropAllUses();
block->replaceAllUsesWith(origBlock);

// Move the operations back the original block and the delete the new block.
// Move the operations back the original block, move the original block back
// into its original location and the delete the new block.
origBlock->getOperations().splice(origBlock->end(), block->getOperations());
origBlock->moveBefore(block);
block->getParent()->getBlocks().insert(Region::iterator(block), origBlock);
block->erase();

convertedBlocks.erase(origBlock);
conversionInfo.erase(it);
}

Expand Down Expand Up @@ -510,6 +479,9 @@ void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
mapping.lookupOrDefault(castValue, origArg.getType()));
}
}

delete origBlock;
blockInfo.origBlock = nullptr;
}
}

Expand Down Expand Up @@ -572,9 +544,11 @@ FailureOr<Block *> ArgConverter::convertSignature(
Block *block, const TypeConverter *converter,
ConversionValueMapping &mapping,
SmallVectorImpl<BlockArgument> &argReplacements) {
// Check if the block was already converted. If the block is detached,
// conservatively assume it is going to be deleted.
if (hasBeenConverted(block) || !block->getParent())
// Check if the block was already converted.
// * If the block is mapped in `conversionInfo`, it is a converted block.
// * If the block is detached, conservatively assume that it is going to be
// deleted; it is likely the old block (before it was converted).
if (conversionInfo.count(block) || !block->getParent())
return block;
// If a converter wasn't provided, and the block wasn't already converted,
// there is nothing we can do.
Expand Down Expand Up @@ -603,6 +577,9 @@ Block *ArgConverter::applySignatureConversion(
// signature.
Block *newBlock = block->splitBlock(block->begin());
block->replaceAllUsesWith(newBlock);
// Unlink the block, but do not erase it yet, so that the change can be rolled
// back.
block->getParent()->getBlocks().remove(block);

// Map all new arguments to the location of the argument they originate from.
SmallVector<Location> newLocs(convertedTypes.size(),
Expand Down Expand Up @@ -679,24 +656,8 @@ Block *ArgConverter::applySignatureConversion(
ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
}

// Remove the original block from the region and return the new one.
insertConversion(newBlock, std::move(info));
return newBlock;
}

void ArgConverter::insertConversion(Block *newBlock,
ConvertedBlockInfo &&info) {
// Get a region to insert the old block.
Region *region = newBlock->getParent();
std::unique_ptr<Region> &mappedRegion = regionMapping[region];
if (!mappedRegion)
mappedRegion = std::make_unique<Region>(region->getParentOp());

// Move the original block to the mapped region and emplace the conversion.
mappedRegion->getBlocks().splice(mappedRegion->end(), region->getBlocks(),
info.origBlock->getIterator());
convertedBlocks.insert(info.origBlock);
conversionInfo.insert({newBlock, std::move(info)});
return newBlock;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1227,6 +1188,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// active.
const TypeConverter *currentTypeConverter = nullptr;

/// A mapping of regions to type converters that should be used when
/// converting the arguments of blocks within that region.
DenseMap<Region *, const TypeConverter *> regionToConverter;

/// This allows the user to collect the match failure message.
function_ref<void(Diagnostic &)> notifyCallback;

Expand Down Expand Up @@ -1504,7 +1469,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
Region *region, const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion) {
argConverter.setConverter(region, &converter);
regionToConverter[region] = &converter;
if (region->empty())
return nullptr;

Expand All @@ -1519,7 +1484,7 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
Region *region, const TypeConverter &converter,
ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
argConverter.setConverter(region, &converter);
regionToConverter[region] = &converter;
if (region->empty())
return success();

Expand Down Expand Up @@ -2195,7 +2160,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(

// If the region of the block has a type converter, try to convert the block
// directly.
if (auto *converter = impl.argConverter.getConverter(block->getParent())) {
if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
if (failed(impl.convertBlockSignature(block, converter))) {
LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
"block"));
Expand Down