-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][Transforms] Dialect Conversion: No target mat. for 1:N replacement #117513
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Transforms] Dialect Conversion: No target mat. for 1:N replacement #117513
Conversation
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesDuring a 1:N replacement (block signature conversion or
The target materialization is unnecessary. Subsequent patterns receive the replacement values via their adaptors. These patterns have their own type converter. When they see a replacement value of type Special case: If a subsequent pattern does not have a type converter, it does not insert any target materializations. That's because the absence of a type converter indicates the pattern does not care about type legality. Therefore, it is correct to pass an SSA value of type This commit is in preparation of adding 1:N support to the conversion value mapping. Before making any further changes to the mapping infrastructure, I'd like to make sure that the code base around it (that uses the mapping) is robust. Full diff: https://github.com/llvm/llvm-project/pull/117513.diff 3 Files Affected:
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 710c976281dc3d..c2a70bfa54b1b0 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -830,8 +830,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// function will be deleted when full 1:N support has been added.
///
/// This function inserts an argument materialization back to the original
- /// type, followed by a target materialization to the legalized type (if
- /// applicable).
+ /// type.
void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc,
ValueRange replacements, Value originalValue,
const TypeConverter *converter);
@@ -1372,27 +1371,6 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization(
/*inputs=*/replacements, originalType,
/*originalType=*/Type(), converter);
mapping.map(originalValue, argMat);
-
- // Insert target materialization to the legalized type.
- Type legalOutputType;
- if (converter) {
- legalOutputType = converter->convertType(originalType);
- } else if (replacements.size() == 1) {
- // When there is no type converter, assume that the replacement value
- // types are legal. This is reasonable to assume because they were
- // specified by the user.
- // FIXME: This won't work for 1->N conversions because multiple output
- // types are not supported in parts of the dialect conversion. In such a
- // case, we currently use the original value type.
- legalOutputType = replacements[0].getType();
- }
- if (legalOutputType && legalOutputType != originalType) {
- Value targetMat = buildUnresolvedMaterialization(
- MaterializationKind::Target, computeInsertPoint(argMat), loc,
- /*inputs=*/argMat, /*outputType=*/legalOutputType,
- /*originalType=*/originalType, converter);
- mapping.map(argMat, targetMat);
- }
}
Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index e05f444afa68f0..624add08846a28 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -124,10 +124,10 @@ func.func @no_remap_nested() {
// CHECK-NEXT: "foo.region"
// expected-remark@+1 {{op 'foo.region' is not legalizable}}
"foo.region"() ({
- // CHECK-NEXT: ^bb0(%{{.*}}: i64, %{{.*}}: i16, %{{.*}}: i64):
- ^bb0(%i0: i64, %unused: i16, %i1: i64):
- // CHECK-NEXT: "test.valid"{{.*}} : (i64, i64)
- "test.invalid"(%i0, %i1) : (i64, i64) -> ()
+ // CHECK-NEXT: ^bb0(%{{.*}}: f64, %{{.*}}: i16, %{{.*}}: f64):
+ ^bb0(%i0: f64, %unused: i16, %i1: f64):
+ // CHECK-NEXT: "test.valid"{{.*}} : (f64, f64)
+ "test.invalid"(%i0, %i1) : (f64, f64) -> ()
}) : () -> ()
// expected-remark@+1 {{op 'func.return' is not legalizable}}
return
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index bbd55938718fe7..e931b394c86210 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -979,8 +979,8 @@ struct TestDropOpSignatureConversion : public ConversionPattern {
};
/// This pattern simply updates the operands of the given operation.
struct TestPassthroughInvalidOp : public ConversionPattern {
- TestPassthroughInvalidOp(MLIRContext *ctx)
- : ConversionPattern("test.invalid", 1, ctx) {}
+ TestPassthroughInvalidOp(MLIRContext *ctx, const TypeConverter &converter)
+ : ConversionPattern(converter, "test.invalid", 1, ctx) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
@@ -1254,18 +1254,18 @@ struct TestLegalizePatternDriver
TestTypeConverter converter;
mlir::RewritePatternSet patterns(&getContext());
populateWithGenerated(patterns);
- patterns.add<
- TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
- TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
- TestUndoBlockArgReplace, TestUndoBlockErase, TestPassthroughInvalidOp,
- TestSplitReturnType, TestChangeProducerTypeI32ToF32,
- TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid,
- TestUpdateConsumerType, TestNonRootReplacement,
- TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
- TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
- TestUndoPropertiesModification, TestEraseOp>(&getContext());
- patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp>(
- &getContext(), converter);
+ patterns
+ .add<TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
+ TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
+ TestUndoBlockArgReplace, TestUndoBlockErase, TestSplitReturnType,
+ TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
+ TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
+ TestNonRootReplacement, TestBoundedRecursiveRewrite,
+ TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
+ TestCreateUnregisteredOp, TestUndoMoveOpBefore,
+ TestUndoPropertiesModification, TestEraseOp>(&getContext());
+ patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
+ TestPassthroughInvalidOp>(&getContext(), converter);
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
mlir::populateCallOpTypeConversionPattern(patterns, converter);
@@ -1697,8 +1697,9 @@ struct TestTypeConversionAnotherProducer
};
struct TestReplaceWithLegalOp : public ConversionPattern {
- TestReplaceWithLegalOp(MLIRContext *ctx)
- : ConversionPattern("test.replace_with_legal_op", /*benefit=*/1, ctx) {}
+ TestReplaceWithLegalOp(const TypeConverter &converter, MLIRContext *ctx)
+ : ConversionPattern(converter, "test.replace_with_legal_op",
+ /*benefit=*/1, ctx) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
@@ -1820,12 +1821,12 @@ struct TestTypeConversionDriver
// Initialize the set of rewrite patterns.
RewritePatternSet patterns(&getContext());
- patterns.add<TestTypeConsumerForward, TestTypeConversionProducer,
- TestSignatureConversionUndo,
- TestTestSignatureConversionNoConverter>(converter,
- &getContext());
- patterns.add<TestTypeConversionAnotherProducer, TestReplaceWithLegalOp>(
- &getContext());
+ patterns
+ .add<TestTypeConsumerForward, TestTypeConversionProducer,
+ TestSignatureConversionUndo,
+ TestTestSignatureConversionNoConverter, TestReplaceWithLegalOp>(
+ converter, &getContext());
+ patterns.add<TestTypeConversionAnotherProducer>(&getContext());
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
|
@llvm/pr-subscribers-mlir-core Author: Matthias Springer (matthias-springer) ChangesDuring a 1:N replacement (block signature conversion or
The target materialization is unnecessary. Subsequent patterns receive the replacement values via their adaptors. These patterns have their own type converter. When they see a replacement value of type Special case: If a subsequent pattern does not have a type converter, it does not insert any target materializations. That's because the absence of a type converter indicates the pattern does not care about type legality. Therefore, it is correct to pass an SSA value of type This commit is in preparation of adding 1:N support to the conversion value mapping. Before making any further changes to the mapping infrastructure, I'd like to make sure that the code base around it (that uses the mapping) is robust. Full diff: https://github.com/llvm/llvm-project/pull/117513.diff 3 Files Affected:
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 710c976281dc3d..c2a70bfa54b1b0 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -830,8 +830,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// function will be deleted when full 1:N support has been added.
///
/// This function inserts an argument materialization back to the original
- /// type, followed by a target materialization to the legalized type (if
- /// applicable).
+ /// type.
void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc,
ValueRange replacements, Value originalValue,
const TypeConverter *converter);
@@ -1372,27 +1371,6 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization(
/*inputs=*/replacements, originalType,
/*originalType=*/Type(), converter);
mapping.map(originalValue, argMat);
-
- // Insert target materialization to the legalized type.
- Type legalOutputType;
- if (converter) {
- legalOutputType = converter->convertType(originalType);
- } else if (replacements.size() == 1) {
- // When there is no type converter, assume that the replacement value
- // types are legal. This is reasonable to assume because they were
- // specified by the user.
- // FIXME: This won't work for 1->N conversions because multiple output
- // types are not supported in parts of the dialect conversion. In such a
- // case, we currently use the original value type.
- legalOutputType = replacements[0].getType();
- }
- if (legalOutputType && legalOutputType != originalType) {
- Value targetMat = buildUnresolvedMaterialization(
- MaterializationKind::Target, computeInsertPoint(argMat), loc,
- /*inputs=*/argMat, /*outputType=*/legalOutputType,
- /*originalType=*/originalType, converter);
- mapping.map(argMat, targetMat);
- }
}
Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index e05f444afa68f0..624add08846a28 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -124,10 +124,10 @@ func.func @no_remap_nested() {
// CHECK-NEXT: "foo.region"
// expected-remark@+1 {{op 'foo.region' is not legalizable}}
"foo.region"() ({
- // CHECK-NEXT: ^bb0(%{{.*}}: i64, %{{.*}}: i16, %{{.*}}: i64):
- ^bb0(%i0: i64, %unused: i16, %i1: i64):
- // CHECK-NEXT: "test.valid"{{.*}} : (i64, i64)
- "test.invalid"(%i0, %i1) : (i64, i64) -> ()
+ // CHECK-NEXT: ^bb0(%{{.*}}: f64, %{{.*}}: i16, %{{.*}}: f64):
+ ^bb0(%i0: f64, %unused: i16, %i1: f64):
+ // CHECK-NEXT: "test.valid"{{.*}} : (f64, f64)
+ "test.invalid"(%i0, %i1) : (f64, f64) -> ()
}) : () -> ()
// expected-remark@+1 {{op 'func.return' is not legalizable}}
return
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index bbd55938718fe7..e931b394c86210 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -979,8 +979,8 @@ struct TestDropOpSignatureConversion : public ConversionPattern {
};
/// This pattern simply updates the operands of the given operation.
struct TestPassthroughInvalidOp : public ConversionPattern {
- TestPassthroughInvalidOp(MLIRContext *ctx)
- : ConversionPattern("test.invalid", 1, ctx) {}
+ TestPassthroughInvalidOp(MLIRContext *ctx, const TypeConverter &converter)
+ : ConversionPattern(converter, "test.invalid", 1, ctx) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
@@ -1254,18 +1254,18 @@ struct TestLegalizePatternDriver
TestTypeConverter converter;
mlir::RewritePatternSet patterns(&getContext());
populateWithGenerated(patterns);
- patterns.add<
- TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
- TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
- TestUndoBlockArgReplace, TestUndoBlockErase, TestPassthroughInvalidOp,
- TestSplitReturnType, TestChangeProducerTypeI32ToF32,
- TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid,
- TestUpdateConsumerType, TestNonRootReplacement,
- TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
- TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
- TestUndoPropertiesModification, TestEraseOp>(&getContext());
- patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp>(
- &getContext(), converter);
+ patterns
+ .add<TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
+ TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
+ TestUndoBlockArgReplace, TestUndoBlockErase, TestSplitReturnType,
+ TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
+ TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
+ TestNonRootReplacement, TestBoundedRecursiveRewrite,
+ TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
+ TestCreateUnregisteredOp, TestUndoMoveOpBefore,
+ TestUndoPropertiesModification, TestEraseOp>(&getContext());
+ patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
+ TestPassthroughInvalidOp>(&getContext(), converter);
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
mlir::populateCallOpTypeConversionPattern(patterns, converter);
@@ -1697,8 +1697,9 @@ struct TestTypeConversionAnotherProducer
};
struct TestReplaceWithLegalOp : public ConversionPattern {
- TestReplaceWithLegalOp(MLIRContext *ctx)
- : ConversionPattern("test.replace_with_legal_op", /*benefit=*/1, ctx) {}
+ TestReplaceWithLegalOp(const TypeConverter &converter, MLIRContext *ctx)
+ : ConversionPattern(converter, "test.replace_with_legal_op",
+ /*benefit=*/1, ctx) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
@@ -1820,12 +1821,12 @@ struct TestTypeConversionDriver
// Initialize the set of rewrite patterns.
RewritePatternSet patterns(&getContext());
- patterns.add<TestTypeConsumerForward, TestTypeConversionProducer,
- TestSignatureConversionUndo,
- TestTestSignatureConversionNoConverter>(converter,
- &getContext());
- patterns.add<TestTypeConversionAnotherProducer, TestReplaceWithLegalOp>(
- &getContext());
+ patterns
+ .add<TestTypeConsumerForward, TestTypeConversionProducer,
+ TestSignatureConversionUndo,
+ TestTestSignatureConversionNoConverter, TestReplaceWithLegalOp>(
+ converter, &getContext());
+ patterns.add<TestTypeConversionAnotherProducer>(&getContext());
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
ec18101
to
d2d3eb9
Compare
7330a8b
to
4369986
Compare
4369986
to
2ec040d
Compare
f9e8758
to
ddc9294
Compare
e5926b6
to
03fcf33
Compare
ddc9294
to
6041464
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM as discussed in the comments, thank you! I think it is better to land this as early as reasonable compared to the larger PR. The more we can split out and phase in gradually the better probably
@@ -2864,6 +2843,9 @@ void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo, | |||
|
|||
LogicalResult TypeConverter::convertType(Type t, | |||
SmallVectorImpl<Type> &results) const { | |||
assert(this && "expected non-null type converter"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert(this && "expected non-null type converter"); |
This assert only protects against a case that is already in UB as far as C++ is concerned and would e.g. get optimized out in release builds.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know this is not really needed, but I spent an hour debugging patterns that run without type converters (running in debug mode). (Can also be helpful for finding bugs in the dialect conversion framework.) When convertType
is called on a null type converter, the stack trace looks really odd and makes it look like something is wrong with the caching logic / mutex synchronization inside of TypeConverter::convertType
.
…g 1:N replacement fix test experiement
6041464
to
759edf6
Compare
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/169/builds/6738 Here is the relevant piece of the build log for the reference
|
materialization. When replacing a block argument, previously to llvm#117513, we would automatically insert a N->1 argument materialization. After llvm#117513, this is no longer the case for 1->1 mappings. As a result, no materialization is added until `ReplaceBlockArgRewrite` is committed, where `findOrBuildReplacementValue` inserts a source materialization. The switch from an argument materialization to a source materialization causes legalization to fail.
During a 1:N replacement (
applySignatureConversion
orreplaceOpWithMultiple
), the dialect conversion driver used to insert two materializations:S
.T
.The target materialization is unnecessary. Subsequent patterns receive the replacement values via their adaptors. These patterns have their own type converter. When they see a replacement value of type
S
, they will automatically insert a target materialization to typeT
. There is no reason to do this already during the 1:N replacement. (The functionality used to be duplicated inremapValues
andinsertNTo1Materialization
.)Special case: If a subsequent pattern does not have a type converter, it does not insert any target materializations. That's because the absence of a type converter indicates that the pattern does not care about type legality. Therefore, it is correct to pass an SSA value of type
S
(or any other type) to the pattern.Note: Most patterns in
TestPatterns.cpp
run without a type converter. To make sure that the tests still behave the same, some of these patterns now have a type converter.This commit is in preparation of adding 1:N support to the conversion value mapping. Before making any further changes to the mapping infrastructure, I'd like to make sure that the code base around it (that uses the mapping) is robust.