From a7073f8cc92008917a9c65c42840894f9df685bc Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Sun, 24 Nov 2024 09:33:18 +0100 Subject: [PATCH] [mlir][Transforms] Dialect conversion: extra signature conversion checks --- .../Transforms/Utils/DialectConversion.cpp | 35 ++++++++++++++----- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 5acd095da8e38..710c976281dc3 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -434,23 +434,25 @@ class MoveBlockRewrite : public BlockRewrite { class BlockTypeConversionRewrite : public BlockRewrite { public: BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl, - Block *block, Block *origBlock) - : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block), - origBlock(origBlock) {} + Block *origBlock, Block *newBlock) + : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, origBlock), + newBlock(newBlock) {} static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() == Kind::BlockTypeConversion; } - Block *getOrigBlock() const { return origBlock; } + Block *getOrigBlock() const { return block; } + + Block *getNewBlock() const { return newBlock; } void commit(RewriterBase &rewriter) override; void rollback() override; private: - /// The original block that was requested to have its signature converted. - Block *origBlock; + /// The new block that was created as part of this signature conversion. + Block *newBlock; }; /// Replacing a block argument. This rewrite is not immediately reflected in the @@ -721,6 +723,18 @@ static bool hasRewrite(R &&rewrites, Operation *op) { }); } +#ifndef NDEBUG +/// Return "true" if there is a block rewrite that matches the specified +/// rewrite type and block among the given rewrites. +template +static bool hasRewrite(R &&rewrites, Block *block) { + return any_of(std::forward(rewrites), [&](auto &rewrite) { + auto *rewriteTy = dyn_cast(rewrite.get()); + return rewriteTy && rewriteTy->getBlock() == block; + }); +} +#endif // NDEBUG + //===----------------------------------------------------------------------===// // ConversionPatternRewriterImpl //===----------------------------------------------------------------------===// @@ -966,12 +980,12 @@ void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) { // block. if (auto *listener = dyn_cast_or_null(rewriter.getListener())) - for (Operation *op : block->getUsers()) + for (Operation *op : getNewBlock()->getUsers()) listener->notifyOperationModified(op); } void BlockTypeConversionRewrite::rollback() { - block->replaceAllUsesWith(origBlock); + getNewBlock()->replaceAllUsesWith(getOrigBlock()); } void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) { @@ -1223,6 +1237,9 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( ConversionPatternRewriter &rewriter, Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion) { + // A block cannot be converted multiple times. + assert(!hasRewrite(rewrites, block) && + "block was already converted"); OpBuilder::InsertionGuard g(rewriter); // If no arguments are being changed or added, there is nothing to do. @@ -1308,7 +1325,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( appendRewrite(block, origArg, converter); } - appendRewrite(newBlock, block); + appendRewrite(/*origBlock=*/block, newBlock); // Erase the old block. (It is just unlinked for now and will be erased during // cleanup.)