Skip to content

Commit 345ca6a

Browse files
[mlir][Transforms] Dialect conversion: extra signature conversion check (#117471)
This commit adds an extra assertion to `applySignatureConversion` to prevent incorrect API usage: The same block cannot be converted multiple times. That would mess with the underlying conversion value mapping. (Mappings would be overwritten.) This is similar to op replacements: The same op cannot be replaced multiple times. To simplify the check, `BlockTypeConversionRewrite::block` now stores the original block. The new block is stored in an extra field. (It used to be the other way around.) This commit is in preparation of adding 1:N support to the conversion value mapping. Before making any further changes to the mapping infrastructure, I'd like to make sure that the code base around it (that uses the mapping) is robust.
1 parent bb5bbe5 commit 345ca6a

File tree

1 file changed

+26
-9
lines changed

1 file changed

+26
-9
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -434,23 +434,25 @@ class MoveBlockRewrite : public BlockRewrite {
434434
class BlockTypeConversionRewrite : public BlockRewrite {
435435
public:
436436
BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
437-
Block *block, Block *origBlock)
438-
: BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
439-
origBlock(origBlock) {}
437+
Block *origBlock, Block *newBlock)
438+
: BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, origBlock),
439+
newBlock(newBlock) {}
440440

441441
static bool classof(const IRRewrite *rewrite) {
442442
return rewrite->getKind() == Kind::BlockTypeConversion;
443443
}
444444

445-
Block *getOrigBlock() const { return origBlock; }
445+
Block *getOrigBlock() const { return block; }
446+
447+
Block *getNewBlock() const { return newBlock; }
446448

447449
void commit(RewriterBase &rewriter) override;
448450

449451
void rollback() override;
450452

451453
private:
452-
/// The original block that was requested to have its signature converted.
453-
Block *origBlock;
454+
/// The new block that was created as part of this signature conversion.
455+
Block *newBlock;
454456
};
455457

456458
/// Replacing a block argument. This rewrite is not immediately reflected in the
@@ -721,6 +723,18 @@ static bool hasRewrite(R &&rewrites, Operation *op) {
721723
});
722724
}
723725

726+
#ifndef NDEBUG
727+
/// Return "true" if there is a block rewrite that matches the specified
728+
/// rewrite type and block among the given rewrites.
729+
template <typename RewriteTy, typename R>
730+
static bool hasRewrite(R &&rewrites, Block *block) {
731+
return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
732+
auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
733+
return rewriteTy && rewriteTy->getBlock() == block;
734+
});
735+
}
736+
#endif // NDEBUG
737+
724738
//===----------------------------------------------------------------------===//
725739
// ConversionPatternRewriterImpl
726740
//===----------------------------------------------------------------------===//
@@ -966,12 +980,12 @@ void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
966980
// block.
967981
if (auto *listener =
968982
dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
969-
for (Operation *op : block->getUsers())
983+
for (Operation *op : getNewBlock()->getUsers())
970984
listener->notifyOperationModified(op);
971985
}
972986

973987
void BlockTypeConversionRewrite::rollback() {
974-
block->replaceAllUsesWith(origBlock);
988+
getNewBlock()->replaceAllUsesWith(getOrigBlock());
975989
}
976990

977991
void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
@@ -1223,6 +1237,9 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
12231237
ConversionPatternRewriter &rewriter, Block *block,
12241238
const TypeConverter *converter,
12251239
TypeConverter::SignatureConversion &signatureConversion) {
1240+
// A block cannot be converted multiple times.
1241+
assert(!hasRewrite<BlockTypeConversionRewrite>(rewrites, block) &&
1242+
"block was already converted");
12261243
OpBuilder::InsertionGuard g(rewriter);
12271244

12281245
// If no arguments are being changed or added, there is nothing to do.
@@ -1308,7 +1325,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13081325
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
13091326
}
13101327

1311-
appendRewrite<BlockTypeConversionRewrite>(newBlock, block);
1328+
appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
13121329

13131330
// Erase the old block. (It is just unlinked for now and will be erased during
13141331
// cleanup.)

0 commit comments

Comments
 (0)