From 4ea392009450abc5a6c91efcdbf5e99d061ba224 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Sat, 8 Jun 2024 21:15:08 +0200 Subject: [PATCH 1/6] [mlir][Transforms] Dialect Conversion: Simplify block conversion API This commit simplifies and improves documentation for the part of the `ConversionPatternRewriter` API that deals with signature conversions. There are now two public functions for signature conversion: * `applySignatureConversion` converts a single block signature. * `convertRegionTypes` converts all block signatures of a region. Note: `convertRegionTypes` could be renamed to `applySignatureConversion` (overload) in the future. Also clarify when a type converter and/or signature conversion object is needed and for what purpose. From a functional perspective, this change is NFC. However, the public API changes, thus not marking as NFC. --- mlir/docs/DialectConversion.md | 30 +++-- .../mlir/Transforms/DialectConversion.h | 43 +++--- mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp | 2 +- .../Dialect/Linalg/Transforms/Detensorize.cpp | 20 ++- .../Transforms/Utils/DialectConversion.cpp | 123 ++++-------------- mlir/test/lib/Dialect/Test/TestPatterns.cpp | 5 +- 6 files changed, 76 insertions(+), 147 deletions(-) diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md index a355d5a90e4d1..8338109eb97c3 100644 --- a/mlir/docs/DialectConversion.md +++ b/mlir/docs/DialectConversion.md @@ -372,19 +372,23 @@ class TypeConverter { From the perspective of type conversion, the types of block arguments are a bit special. Throughout the conversion process, blocks may move between regions of different operations. Given this, the conversion of the types for blocks must be -done explicitly via a conversion pattern. To convert the types of block -arguments within a Region, a custom hook on the `ConversionPatternRewriter` must -be invoked; `convertRegionTypes`. This hook uses a provided type converter to -apply type conversions to all blocks within a given region, and all blocks that -move into that region. As noted above, the conversions performed by this method -use the argument materialization hook on the `TypeConverter`. This hook also -takes an optional `TypeConverter::SignatureConversion` parameter that applies a -custom conversion to the entry block of the region. The types of the entry block -arguments are often tied semantically to details on the operation, e.g. func::FuncOp, -AffineForOp, etc. To convert the signature of just the region entry block, and -not any other blocks within the region, the `applySignatureConversion` hook may -be used instead. A signature conversion, `TypeConverter::SignatureConversion`, -can be built programmatically: +done explicitly via a conversion pattern. + +To convert the types of block arguments within a Region, a custom hook on the +`ConversionPatternRewriter` must be invoked; `convertRegionTypes`. This hook +uses a provided type converter to apply type conversions to all blocks of a +given region. As noted above, the conversions performed by this method use the +argument materialization hook on the `TypeConverter`. This hook also takes an +optional `TypeConverter::SignatureConversion` parameter that applies a custom +conversion to the entry block of the region. The types of the entry block +arguments are often tied semantically to details on the operation, e.g., +`func::FuncOp`, `AffineForOp`, etc. + +To convert the signature of just one given block, the +`applySignatureConversion` hook can be used. + +A signature conversion, `TypeConverter::SignatureConversion`, can be built +programmatically: ```c++ class SignatureConversion { diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 83198c9b0db54..5f4a972748ffc 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -247,7 +247,8 @@ class TypeConverter { /// Attempts a 1-1 type conversion, expecting the result type to be /// `TargetType`. Returns the converted type cast to `TargetType` on success, /// and a null type on conversion or cast failure. - template TargetType convertType(Type t) const { + template + TargetType convertType(Type t) const { return dyn_cast_or_null(convertType(t)); } @@ -661,42 +662,38 @@ class ConversionPatternRewriter final : public PatternRewriter { public: ~ConversionPatternRewriter() override; - /// Apply a signature conversion to the entry block of the given region. This - /// replaces the entry block with a new block containing the updated - /// signature. The new entry block to the region is returned for convenience. + /// Apply a signature conversion to given block. This replaces the block with + /// a new block containing the updated signature. The operations of the given + /// block are inlined into the newly-created block, which is returned. + /// /// If no block argument types are changing, the entry original block will be /// left in place and returned. /// - /// If provided, `converter` will be used for any materializations. + /// A signature converison must be provided. (Type converters can construct + /// signature conversion with `convertBlockSignature`.) Optionally, a type + /// converter can be provided to build materializations. Block * - applySignatureConversion(Region *region, + applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter = nullptr); - /// Convert the types of block arguments within the given region. This + /// Apply a signature conversion to each block in the given region. This /// replaces each block with a new block containing the updated signature. If /// an updated signature would match the current signature, the respective - /// block is left in place as is. + /// block is left in place as is. (See `applySignatureConversion` for + /// details.) The new entry block of the region is returned. + /// + /// SignatureConversions are computed with the specified type converter. + /// This function returns "failure" if the type converter failed to compute + /// a SignatureConversion for at least one block. /// - /// The entry block may have a special conversion if `entryConversion` is - /// provided. On success, the new entry block to the region is returned for - /// convenience. Otherwise, failure is returned. + /// Optionally, a special SignatureConversion can be specified for the entry + /// block. This is because the types of the entry block arguments are often + /// tied semantically to details on the operation. FailureOr convertRegionTypes( Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion = nullptr); - /// Convert the types of block arguments within the given region except for - /// the entry region. This replaces each non-entry block with a new block - /// containing the updated signature. If an updated signature would match the - /// current signature, the respective block is left in place as is. - /// - /// If special conversion behavior is needed for the non-entry blocks (for - /// example, we need to convert only a subset of a BB arguments), such - /// behavior can be specified in blockConversions. - LogicalResult convertNonEntryRegionTypes( - Region *region, const TypeConverter &converter, - ArrayRef blockConversions); - /// Replace all the uses of the block argument `from` with value `to`. void replaceUsesOfBlockArgument(BlockArgument from, Value to); diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp index d90cf931385fc..f62de1f17a666 100644 --- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp +++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp @@ -162,7 +162,7 @@ struct ForOpConversion final : SCFToSPIRVPattern { signatureConverter.remapInput(0, newIndVar); for (unsigned i = 1, e = body->getNumArguments(); i < e; i++) signatureConverter.remapInput(i, header->getArgument(i)); - body = rewriter.applySignatureConversion(&forOp.getRegion(), + body = rewriter.applySignatureConversion(&forOp.getRegion().front(), signatureConverter); // Move the blocks from the forOp into the loopOp. This is the body of the diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp index 22968096a6891..af38485291182 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -106,27 +106,23 @@ struct FunctionNonEntryBlockConversion ConversionPatternRewriter &rewriter) const override { rewriter.startOpModification(op); Region ®ion = op.getFunctionBody(); - SmallVector conversions; - for (Block &block : llvm::drop_begin(region, 1)) { - conversions.emplace_back(block.getNumArguments()); - TypeConverter::SignatureConversion &back = conversions.back(); + for (Block &block : + llvm::make_early_inc_range(llvm::drop_begin(region, 1))) { + TypeConverter::SignatureConversion conversion( + /*numOrigInputs=*/block.getNumArguments()); for (BlockArgument blockArgument : block.getArguments()) { int idx = blockArgument.getArgNumber(); if (blockArgsToDetensor.count(blockArgument)) - back.addInputs(idx, {getTypeConverter()->convertType( - block.getArgumentTypes()[idx])}); + conversion.addInputs(idx, {getTypeConverter()->convertType( + block.getArgumentTypes()[idx])}); else - back.addInputs(idx, {block.getArgumentTypes()[idx]}); + conversion.addInputs(idx, {block.getArgumentTypes()[idx]}); } - } - if (failed(rewriter.convertNonEntryRegionTypes(®ion, *typeConverter, - conversions))) { - rewriter.cancelOpModification(op); - return failure(); + rewriter.applySignatureConversion(&block, conversion, getTypeConverter()); } rewriter.finalizeOpModification(op); diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index d407d60334c70..2f0efe1b1e454 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -839,27 +839,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { // Type Conversion //===--------------------------------------------------------------------===// - /// Attempt to convert the signature of the given block, if successful a new - /// block is returned containing the new arguments. Returns `block` if it did - /// not require conversion. - FailureOr convertBlockSignature( - ConversionPatternRewriter &rewriter, Block *block, - const TypeConverter *converter, - TypeConverter::SignatureConversion *conversion = nullptr); - - /// Convert the types of non-entry block arguments within the given region. - LogicalResult convertNonEntryRegionTypes( - ConversionPatternRewriter &rewriter, Region *region, - const TypeConverter &converter, - ArrayRef blockConversions = {}); - - /// Apply a signature conversion on the given region, using `converter` for - /// materializations if not null. - Block * - applySignatureConversion(ConversionPatternRewriter &rewriter, Region *region, - TypeConverter::SignatureConversion &conversion, - const TypeConverter *converter); - /// Convert the types of block arguments within the given region. FailureOr convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region, @@ -1294,34 +1273,6 @@ bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const { //===----------------------------------------------------------------------===// // Type Conversion -FailureOr ConversionPatternRewriterImpl::convertBlockSignature( - ConversionPatternRewriter &rewriter, Block *block, - const TypeConverter *converter, - TypeConverter::SignatureConversion *conversion) { - if (conversion) - return applySignatureConversion(rewriter, block, converter, *conversion); - - // If a converter wasn't provided, and the block wasn't already converted, - // there is nothing we can do. - if (!converter) - return failure(); - - // Try to convert the signature for the block with the provided converter. - if (auto conversion = converter->convertBlockSignature(block)) - return applySignatureConversion(rewriter, block, converter, *conversion); - return failure(); -} - -Block *ConversionPatternRewriterImpl::applySignatureConversion( - ConversionPatternRewriter &rewriter, Region *region, - TypeConverter::SignatureConversion &conversion, - const TypeConverter *converter) { - if (!region->empty()) - return *convertBlockSignature(rewriter, ®ion->front(), converter, - &conversion); - return nullptr; -} - FailureOr ConversionPatternRewriterImpl::convertRegionTypes( ConversionPatternRewriter &rewriter, Region *region, const TypeConverter &converter, @@ -1330,42 +1281,29 @@ FailureOr ConversionPatternRewriterImpl::convertRegionTypes( if (region->empty()) return nullptr; - if (failed(convertNonEntryRegionTypes(rewriter, region, converter))) - return failure(); - - FailureOr newEntry = convertBlockSignature( - rewriter, ®ion->front(), &converter, entryConversion); - return newEntry; -} - -LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes( - ConversionPatternRewriter &rewriter, Region *region, - const TypeConverter &converter, - ArrayRef blockConversions) { - regionToConverter[region] = &converter; - if (region->empty()) - return success(); - - // Convert the arguments of each block within the region. - int blockIdx = 0; - assert((blockConversions.empty() || - blockConversions.size() == region->getBlocks().size() - 1) && - "expected either to provide no SignatureConversions at all or to " - "provide a SignatureConversion for each non-entry block"); - + // Convert the arguments of each non-entry block within the region. for (Block &block : llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) { - TypeConverter::SignatureConversion *blockConversion = - blockConversions.empty() - ? nullptr - : const_cast( - &blockConversions[blockIdx++]); - - if (failed(convertBlockSignature(rewriter, &block, &converter, - blockConversion))) + // Compute the signature for the block with the provided converter. + std::optional conversion = + converter.convertBlockSignature(&block); + if (!conversion) return failure(); - } - return success(); + // Convert the block with the computed signature. + applySignatureConversion(rewriter, &block, &converter, *conversion); + } + + // Convert the entry block. If an entry signature conversion was provided, + // use that one. Otherwise, compute the signature with the type converter. + if (entryConversion) + return applySignatureConversion(rewriter, ®ion->front(), &converter, + *entryConversion); + std::optional conversion = + converter.convertBlockSignature(®ion->front()); + if (!conversion) + return failure(); + return applySignatureConversion(rewriter, ®ion->front(), &converter, + *conversion); } Block *ConversionPatternRewriterImpl::applySignatureConversion( @@ -1676,12 +1614,12 @@ void ConversionPatternRewriter::eraseBlock(Block *block) { } Block *ConversionPatternRewriter::applySignatureConversion( - Region *region, TypeConverter::SignatureConversion &conversion, + Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter) { - assert(!impl->wasOpReplaced(region->getParentOp()) && + assert(!impl->wasOpReplaced(block->getParentOp()) && "attempting to apply a signature conversion to a block within a " "replaced/erased op"); - return impl->applySignatureConversion(*this, region, conversion, converter); + return impl->applySignatureConversion(*this, block, converter, conversion); } FailureOr ConversionPatternRewriter::convertRegionTypes( @@ -1693,16 +1631,6 @@ FailureOr ConversionPatternRewriter::convertRegionTypes( return impl->convertRegionTypes(*this, region, converter, entryConversion); } -LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes( - Region *region, const TypeConverter &converter, - ArrayRef blockConversions) { - assert(!impl->wasOpReplaced(region->getParentOp()) && - "attempting to apply a signature conversion to a block within a " - "replaced/erased op"); - return impl->convertNonEntryRegionTypes(*this, region, converter, - blockConversions); -} - void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, Value to) { LLVM_DEBUG({ @@ -2231,11 +2159,14 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites( // If the region of the block has a type converter, try to convert the block // directly. if (auto *converter = impl.regionToConverter.lookup(block->getParent())) { - if (failed(impl.convertBlockSignature(rewriter, block, converter))) { + std::optional conversion = + converter->convertBlockSignature(block); + if (!conversion) { LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved " "block")); return failure(); } + impl.applySignatureConversion(rewriter, block, converter, *conversion); continue; } diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index f9f7d4eacf948..a14a5da341098 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -1516,8 +1516,9 @@ struct TestTestSignatureConversionNoConverter if (failed( converter.convertSignatureArgs(entry->getArgumentTypes(), result))) return failure(); - rewriter.modifyOpInPlace( - op, [&] { rewriter.applySignatureConversion(®ion, result); }); + rewriter.modifyOpInPlace(op, [&] { + rewriter.applySignatureConversion(®ion.front(), result); + }); return success(); } From 8bc081d53a098c5510c66eb7547d5bd8e07a08e9 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Sun, 9 Jun 2024 12:17:51 +0200 Subject: [PATCH 2/6] Update mlir/docs/DialectConversion.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Markus Böck --- mlir/docs/DialectConversion.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md index 8338109eb97c3..69781bb868bbf 100644 --- a/mlir/docs/DialectConversion.md +++ b/mlir/docs/DialectConversion.md @@ -381,7 +381,7 @@ given region. As noted above, the conversions performed by this method use the argument materialization hook on the `TypeConverter`. This hook also takes an optional `TypeConverter::SignatureConversion` parameter that applies a custom conversion to the entry block of the region. The types of the entry block -arguments are often tied semantically to details on the operation, e.g., +arguments are often tied semantically to the operation, e.g., `func::FuncOp`, `AffineForOp`, etc. To convert the signature of just one given block, the From e6cca99f5e5d3628792660ffd2112308940b1f7d Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Sun, 9 Jun 2024 12:18:00 +0200 Subject: [PATCH 3/6] Update mlir/include/mlir/Transforms/DialectConversion.h MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Markus Böck --- mlir/include/mlir/Transforms/DialectConversion.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 5f4a972748ffc..6f6e494919d60 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -666,7 +666,7 @@ class ConversionPatternRewriter final : public PatternRewriter { /// a new block containing the updated signature. The operations of the given /// block are inlined into the newly-created block, which is returned. /// - /// If no block argument types are changing, the entry original block will be + /// If no block argument types are changing, the original block will be /// left in place and returned. /// /// A signature converison must be provided. (Type converters can construct From 49511919a998aec3a9dad166d3ea881a00ac123b Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Sun, 9 Jun 2024 12:18:13 +0200 Subject: [PATCH 4/6] Update mlir/include/mlir/Transforms/DialectConversion.h MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Markus Böck --- mlir/include/mlir/Transforms/DialectConversion.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 6f6e494919d60..8a06fad7c5114 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -670,7 +670,7 @@ class ConversionPatternRewriter final : public PatternRewriter { /// left in place and returned. /// /// A signature converison must be provided. (Type converters can construct - /// signature conversion with `convertBlockSignature`.) Optionally, a type + /// a signature conversion with `convertBlockSignature`.) Optionally, a type /// converter can be provided to build materializations. Block * applySignatureConversion(Block *block, From 4187ddf3070995cbabaedf908c72915cdf01bdcc Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Sun, 9 Jun 2024 12:21:44 +0200 Subject: [PATCH 5/6] Update mlir/include/mlir/Transforms/DialectConversion.h MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Markus Böck --- mlir/include/mlir/Transforms/DialectConversion.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 8a06fad7c5114..d944796d1ba4a 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -689,7 +689,7 @@ class ConversionPatternRewriter final : public PatternRewriter { /// /// Optionally, a special SignatureConversion can be specified for the entry /// block. This is because the types of the entry block arguments are often - /// tied semantically to details on the operation. + /// tied semantically to the operation. FailureOr convertRegionTypes( Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion = nullptr); From 78bd094e0186346e774c7be374a2e0e813ff54e1 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Sun, 9 Jun 2024 13:09:47 +0200 Subject: [PATCH 6/6] Expand comment --- mlir/include/mlir/Transforms/DialectConversion.h | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index d944796d1ba4a..a8488af9dcb2d 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -670,8 +670,12 @@ class ConversionPatternRewriter final : public PatternRewriter { /// left in place and returned. /// /// A signature converison must be provided. (Type converters can construct - /// a signature conversion with `convertBlockSignature`.) Optionally, a type - /// converter can be provided to build materializations. + /// a signature conversion with `convertBlockSignature`.) + /// + /// Optionally, a type converter can be provided to build materializations. + /// Note: If no type converter was provided or the type converter does not + /// specify any suitable argument/target materialization rules, the dialect + /// conversion may fail to legalize unresolved materializations. Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion,