Skip to content

Commit 7330a8b

Browse files
[mlir][Transforms] Dialect Conversion: Do not build target mat. during 1:N replacement
fix test experiement
1 parent 3a11527 commit 7330a8b

File tree

3 files changed

+28
-49
lines changed

3 files changed

+28
-49
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -830,8 +830,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
830830
/// function will be deleted when full 1:N support has been added.
831831
///
832832
/// This function inserts an argument materialization back to the original
833-
/// type, followed by a target materialization to the legalized type (if
834-
/// applicable).
833+
/// type.
835834
void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc,
836835
ValueRange replacements, Value originalValue,
837836
const TypeConverter *converter);
@@ -1372,27 +1371,6 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization(
13721371
/*inputs=*/replacements, originalType,
13731372
/*originalType=*/Type(), converter);
13741373
mapping.map(originalValue, argMat);
1375-
1376-
// Insert target materialization to the legalized type.
1377-
Type legalOutputType;
1378-
if (converter) {
1379-
legalOutputType = converter->convertType(originalType);
1380-
} else if (replacements.size() == 1) {
1381-
// When there is no type converter, assume that the replacement value
1382-
// types are legal. This is reasonable to assume because they were
1383-
// specified by the user.
1384-
// FIXME: This won't work for 1->N conversions because multiple output
1385-
// types are not supported in parts of the dialect conversion. In such a
1386-
// case, we currently use the original value type.
1387-
legalOutputType = replacements[0].getType();
1388-
}
1389-
if (legalOutputType && legalOutputType != originalType) {
1390-
Value targetMat = buildUnresolvedMaterialization(
1391-
MaterializationKind::Target, computeInsertPoint(argMat), loc,
1392-
/*inputs=*/argMat, /*outputType=*/legalOutputType,
1393-
/*originalType=*/originalType, converter);
1394-
mapping.map(argMat, targetMat);
1395-
}
13961374
}
13971375

13981376
Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(

mlir/test/Transforms/test-legalizer.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,10 @@ func.func @no_remap_nested() {
124124
// CHECK-NEXT: "foo.region"
125125
// expected-remark@+1 {{op 'foo.region' is not legalizable}}
126126
"foo.region"() ({
127-
// CHECK-NEXT: ^bb0(%{{.*}}: i64, %{{.*}}: i16, %{{.*}}: i64):
128-
^bb0(%i0: i64, %unused: i16, %i1: i64):
129-
// CHECK-NEXT: "test.valid"{{.*}} : (i64, i64)
130-
"test.invalid"(%i0, %i1) : (i64, i64) -> ()
127+
// CHECK-NEXT: ^bb0(%{{.*}}: f64, %{{.*}}: i16, %{{.*}}: f64):
128+
^bb0(%i0: f64, %unused: i16, %i1: f64):
129+
// CHECK-NEXT: "test.valid"{{.*}} : (f64, f64)
130+
"test.invalid"(%i0, %i1) : (f64, f64) -> ()
131131
}) : () -> ()
132132
// expected-remark@+1 {{op 'func.return' is not legalizable}}
133133
return

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -979,8 +979,8 @@ struct TestDropOpSignatureConversion : public ConversionPattern {
979979
};
980980
/// This pattern simply updates the operands of the given operation.
981981
struct TestPassthroughInvalidOp : public ConversionPattern {
982-
TestPassthroughInvalidOp(MLIRContext *ctx)
983-
: ConversionPattern("test.invalid", 1, ctx) {}
982+
TestPassthroughInvalidOp(MLIRContext *ctx, const TypeConverter &converter)
983+
: ConversionPattern(converter, "test.invalid", 1, ctx) {}
984984
LogicalResult
985985
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
986986
ConversionPatternRewriter &rewriter) const final {
@@ -1254,18 +1254,18 @@ struct TestLegalizePatternDriver
12541254
TestTypeConverter converter;
12551255
mlir::RewritePatternSet patterns(&getContext());
12561256
populateWithGenerated(patterns);
1257-
patterns.add<
1258-
TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
1259-
TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
1260-
TestUndoBlockArgReplace, TestUndoBlockErase, TestPassthroughInvalidOp,
1261-
TestSplitReturnType, TestChangeProducerTypeI32ToF32,
1262-
TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid,
1263-
TestUpdateConsumerType, TestNonRootReplacement,
1264-
TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
1265-
TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
1266-
TestUndoPropertiesModification, TestEraseOp>(&getContext());
1267-
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp>(
1268-
&getContext(), converter);
1257+
patterns
1258+
.add<TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
1259+
TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
1260+
TestUndoBlockArgReplace, TestUndoBlockErase, TestSplitReturnType,
1261+
TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
1262+
TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
1263+
TestNonRootReplacement, TestBoundedRecursiveRewrite,
1264+
TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
1265+
TestCreateUnregisteredOp, TestUndoMoveOpBefore,
1266+
TestUndoPropertiesModification, TestEraseOp>(&getContext());
1267+
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
1268+
TestPassthroughInvalidOp>(&getContext(), converter);
12691269
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
12701270
converter);
12711271
mlir::populateCallOpTypeConversionPattern(patterns, converter);
@@ -1697,8 +1697,9 @@ struct TestTypeConversionAnotherProducer
16971697
};
16981698

16991699
struct TestReplaceWithLegalOp : public ConversionPattern {
1700-
TestReplaceWithLegalOp(MLIRContext *ctx)
1701-
: ConversionPattern("test.replace_with_legal_op", /*benefit=*/1, ctx) {}
1700+
TestReplaceWithLegalOp(const TypeConverter &converter, MLIRContext *ctx)
1701+
: ConversionPattern(converter, "test.replace_with_legal_op",
1702+
/*benefit=*/1, ctx) {}
17021703
LogicalResult
17031704
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
17041705
ConversionPatternRewriter &rewriter) const final {
@@ -1820,12 +1821,12 @@ struct TestTypeConversionDriver
18201821

18211822
// Initialize the set of rewrite patterns.
18221823
RewritePatternSet patterns(&getContext());
1823-
patterns.add<TestTypeConsumerForward, TestTypeConversionProducer,
1824-
TestSignatureConversionUndo,
1825-
TestTestSignatureConversionNoConverter>(converter,
1826-
&getContext());
1827-
patterns.add<TestTypeConversionAnotherProducer, TestReplaceWithLegalOp>(
1828-
&getContext());
1824+
patterns
1825+
.add<TestTypeConsumerForward, TestTypeConversionProducer,
1826+
TestSignatureConversionUndo,
1827+
TestTestSignatureConversionNoConverter, TestReplaceWithLegalOp>(
1828+
converter, &getContext());
1829+
patterns.add<TestTypeConversionAnotherProducer>(&getContext());
18291830
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
18301831
converter);
18311832

0 commit comments

Comments
 (0)