From 92cf7322559542e66b03d5099d582e57f2c746dd Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Sat, 28 Dec 2024 17:24:41 +0100 Subject: [PATCH] test double replacement --- mlir/test/Transforms/test-legalizer.mlir | 24 +++++++++- mlir/test/lib/Dialect/Test/TestPatterns.cpp | 51 +++++++++++++++++++-- 2 files changed, 70 insertions(+), 5 deletions(-) diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir index 2ca5f49637523..297eb5acef21b 100644 --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -450,7 +450,7 @@ func.func @fold_legalization() -> i32 { // ----- // CHECK-LABEL: func @convert_detached_signature() -// CHECK: "test.legal_op_with_region"() ({ +// CHECK: "test.legal_op"() ({ // CHECK: ^bb0(%arg0: f64): // CHECK: "test.return"() : () -> () // CHECK: }) : () -> () @@ -483,3 +483,25 @@ func.func @test_1_to_n_block_signature_conversion() { "test.return"() : () -> () } +// ----- + +// CHECK: notifyOperationInserted: test.step_1 +// CHECK: notifyOperationReplaced: test.multiple_1_to_n_replacement +// CHECK: notifyOperationErased: test.multiple_1_to_n_replacement +// CHECK: notifyOperationInserted: test.legal_op +// CHECK: notifyOperationReplaced: test.step_1 +// CHECK: notifyOperationErased: test.step_1 + +// CHECK-LABEL: func @test_multiple_1_to_n_replacement() +// CHECK: %[[legal_op:.*]]:4 = "test.legal_op"() : () -> (f16, f16, f16, f16) +// TODO: There should be a single cast (i.e., a single target materialization). +// This is currently not possible due to 1:N limitations of the conversion +// mapping. Instead, we have 3 argument materializations. +// CHECK: %[[cast1:.*]] = "test.cast"(%[[legal_op]]#2, %[[legal_op]]#3) : (f16, f16) -> f16 +// CHECK: %[[cast2:.*]] = "test.cast"(%[[legal_op]]#0, %[[legal_op]]#1) : (f16, f16) -> f16 +// CHECK: %[[cast3:.*]] = "test.cast"(%[[cast2]], %[[cast1]]) : (f16, f16) -> f16 +// CHECK: "test.valid"(%[[cast3]]) : (f16) -> () +func.func @test_multiple_1_to_n_replacement() { + %0 = "test.multiple_1_to_n_replacement"() : () -> (f16) + "test.invalid"(%0) : (f16) -> () +} diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index a470497fdbb56..826c222990be4 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -785,7 +785,7 @@ struct TestDetachedSignatureConversion : public ConversionPattern { ConversionPatternRewriter &rewriter) const final { if (op->getNumRegions() != 1) return failure(); - OperationState state(op->getLoc(), "test.legal_op_with_region", operands, + OperationState state(op->getLoc(), "test.legal_op", operands, op->getResultTypes(), {}, BlockRange()); Region *newRegion = state.addRegion(); rewriter.inlineRegionBefore(op->getRegion(0), *newRegion, @@ -1234,6 +1234,49 @@ class TestRepetitive1ToNConsumer : public ConversionPattern { } }; +/// A pattern that tests two back-to-back 1 -> 2 op replacements. +class TestMultiple1ToNReplacement : public ConversionPattern { +public: + TestMultiple1ToNReplacement(MLIRContext *ctx, const TypeConverter &converter) + : ConversionPattern(converter, "test.multiple_1_to_n_replacement", 1, + ctx) {} + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + // Helper function that replaces the given op with a new op of the given + // name and doubles each result (1 -> 2 replacement of each result). + auto replaceWithDoubleResults = [&](Operation *op, StringRef name) { + SmallVector types; + for (Type t : op->getResultTypes()) { + types.push_back(t); + types.push_back(t); + } + OperationState state(op->getLoc(), name, + /*operands=*/{}, types, op->getAttrs()); + auto *newOp = rewriter.create(state); + SmallVector repls; + for (size_t i = 0, e = op->getNumResults(); i < e; ++i) + repls.push_back(newOp->getResults().slice(2 * i, 2)); + rewriter.replaceOpWithMultiple(op, repls); + return newOp; + }; + + // Replace test.multiple_1_to_n_replacement with test.step_1. + Operation *repl1 = replaceWithDoubleResults(op, "test.step_1"); + // Now replace test.step_1 with test.legal_op. + // TODO: Ideally, it should not be necessary to reset the insertion point + // here. Based on the API calls, it looks like test.step_1 is entirely + // erased. But that's not the case: an argument materialization will + // survive. And that argument materialization will be used by the users of + // `op`. If we don't reset the insertion point here, we get dominance + // errors. This will be fixed when we have 1:N support in the conversion + // value mapping. + rewriter.setInsertionPoint(repl1); + replaceWithDoubleResults(repl1, "test.legal_op"); + return success(); + } +}; + } // namespace namespace { @@ -1319,7 +1362,8 @@ struct TestLegalizePatternDriver TestUndoPropertiesModification, TestEraseOp, TestRepetitive1ToNConsumer>(&getContext()); patterns.add(&getContext(), converter); + TestPassthroughInvalidOp, TestMultiple1ToNReplacement>( + &getContext(), converter); patterns.add(converter, &getContext()); mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, converter); @@ -1330,8 +1374,7 @@ struct TestLegalizePatternDriver target.addLegalOp(); target.addLegalOp(); - target.addLegalOp( - OperationName("test.legal_op_with_region", &getContext())); + target.addLegalOp(OperationName("test.legal_op", &getContext())); target .addIllegalOp(); target.addDynamicallyLegalOp([](TestReturnOp op) {