diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 09ad42364baaf..d4b1c8c7f0a74 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -2257,37 +2257,6 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys, return success(); } -LogicalResult ConversionPatternRewriter::legalize(Region *r) { - // Fast path: If the region is empty, there is nothing to legalize. - if (r->empty()) - return success(); - - // Gather a list of all operations to legalize. This is done before - // converting the entry block signature because unrealized_conversion_cast - // ops should not be included. - SmallVector ops; - for (Block &b : *r) - for (Operation &op : b) - ops.push_back(&op); - - // If the current pattern runs with a type converter, convert the entry block - // signature. - if (const TypeConverter *converter = impl->currentTypeConverter) { - std::optional conversion = - converter->convertBlockSignature(&r->front()); - if (!conversion) - return failure(); - applySignatureConversion(&r->front(), *conversion, converter); - } - - // Legalize all operations in the region. - for (Operation *op : ops) - if (failed(legalize(op))) - return failure(); - - return success(); -} - void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues) { @@ -3287,8 +3256,20 @@ struct OperationConverter { : rewriter(ctx, config, *this), opLegalizer(rewriter, target, patterns), mode(mode) {} - /// Converts the given operations to the conversion target. - LogicalResult convertOperations(ArrayRef ops); + /// Applies the conversion to the given operations (and their nested + /// operations). + LogicalResult applyConversion(ArrayRef ops); + + /// Legalizes the given operations (and their nested operations) to the + /// conversion target. + template + LogicalResult legalizeOperations(ArrayRef ops, Fn onFailure, + bool isRecursiveLegalization = false); + LogicalResult legalizeOperations(ArrayRef ops, + bool isRecursiveLegalization = false) { + return legalizeOperations( + ops, /*onFailure=*/[&]() {}, isRecursiveLegalization); + } /// Converts a single operation. If `isRecursiveLegalization` is "true", the /// conversion is a recursive legalization request, triggered from within a @@ -3297,6 +3278,8 @@ struct OperationConverter { /// legalization mechanism). LogicalResult convert(Operation *op, bool isRecursiveLegalization = false); + const ConversionTarget &getTarget() { return opLegalizer.getTarget(); } + private: /// The rewriter to use when converting operations. ConversionPatternRewriter rewriter; @@ -3309,10 +3292,6 @@ struct OperationConverter { }; } // namespace mlir -LogicalResult ConversionPatternRewriter::legalize(Operation *op) { - return impl->opConverter.convert(op, /*isRecursiveLegalization=*/true); -} - LogicalResult OperationConverter::convert(Operation *op, bool isRecursiveLegalization) { const ConversionConfig &config = rewriter.getConfig(); @@ -3398,12 +3377,15 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter, return failure(); } -LogicalResult OperationConverter::convertOperations(ArrayRef ops) { +template +LogicalResult +OperationConverter::legalizeOperations(ArrayRef ops, Fn onFailure, + bool isRecursiveLegalization) { const ConversionTarget &target = opLegalizer.getTarget(); // Compute the set of operations and blocks to convert. SmallVector toConvert; - for (auto *op : ops) { + for (Operation *op : ops) { op->walk>( [&](Operation *op) { toConvert.push_back(op); @@ -3415,25 +3397,67 @@ LogicalResult OperationConverter::convertOperations(ArrayRef ops) { return WalkResult::advance(); }); } + for (Operation *op : toConvert) { + if (failed(convert(op, isRecursiveLegalization))) { + // Failed to convert an operation. + onFailure(); + return failure(); + } + } + return success(); +} - // Convert each operation and discard rewrites on failure. - ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl(); +LogicalResult ConversionPatternRewriter::legalize(Operation *op) { + return impl->opConverter.legalizeOperations(op, + /*isRecursiveLegalization=*/true); +} - for (auto *op : toConvert) { - if (failed(convert(op))) { - // Dialect conversion failed. - if (rewriterImpl.config.allowPatternRollback) { - // Rollback is allowed: restore the original IR. - rewriterImpl.undoRewrites(); - } else { - // Rollback is not allowed: apply all modifications that have been - // performed so far. - rewriterImpl.applyRewrites(); - } +LogicalResult ConversionPatternRewriter::legalize(Region *r) { + // Fast path: If the region is empty, there is nothing to legalize. + if (r->empty()) + return success(); + + // Gather a list of all operations to legalize. This is done before + // converting the entry block signature because unrealized_conversion_cast + // ops should not be included. + SmallVector ops; + for (Block &b : *r) + for (Operation &op : b) + ops.push_back(&op); + + // If the current pattern runs with a type converter, convert the entry block + // signature. + if (const TypeConverter *converter = impl->currentTypeConverter) { + std::optional conversion = + converter->convertBlockSignature(&r->front()); + if (!conversion) return failure(); - } + applySignatureConversion(&r->front(), *conversion, converter); } + // Legalize all operations in the region. This includes all nested + // operations. + return impl->opConverter.legalizeOperations(ops, + /*isRecursiveLegalization=*/true); +} + +LogicalResult OperationConverter::applyConversion(ArrayRef ops) { + // Convert each operation and discard rewrites on failure. + ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl(); + LogicalResult status = legalizeOperations(ops, /*onFailure=*/[&]() { + // Dialect conversion failed. + if (rewriterImpl.config.allowPatternRollback) { + // Rollback is allowed: restore the original IR. + rewriterImpl.undoRewrites(); + } else { + // Rollback is not allowed: apply all modifications that have been + // performed so far. + rewriterImpl.applyRewrites(); + } + }); + if (failed(status)) + return failure(); + // After a successful conversion, apply rewrites. rewriterImpl.applyRewrites(); @@ -4143,7 +4167,7 @@ static LogicalResult applyConversion(ArrayRef ops, [&] { OperationConverter opConverter(ops.front()->getContext(), target, patterns, config, mode); - status = opConverter.convertOperations(ops); + status = opConverter.applyConversion(ops); }, irUnits); return status; diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir index 88a71cc26ab0c..8d854aff1992f 100644 --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -454,11 +454,16 @@ func.func @test_working_1to1_pattern(%arg0: f16) { // The region of "test.post_order_legalization" is converted before the op. // CHECK: notifyBlockInserted into test.post_order_legalization: was unlinked +// CHECK: notifyOperationInserted: test.remaining_consumer +// CHECK: notifyOperationInserted: test.legal_op // CHECK: notifyOperationInserted: test.invalid // CHECK: notifyBlockErased // CHECK: notifyOperationInserted: test.valid, was unlinked // CHECK: notifyOperationReplaced: test.invalid // CHECK: notifyOperationErased: test.invalid +// CHECK: notifyOperationInserted: test.valid, was unlinked +// CHECK: notifyOperationReplaced: test.invalid +// CHECK: notifyOperationErased: test.invalid // CHECK: notifyOperationModified: test.post_order_legalization // CHECK-LABEL: func @test_preorder_legalization @@ -475,6 +480,9 @@ func.func @test_preorder_legalization() { ^bb0(%arg0: i64): // expected-remark @+1 {{'test.remaining_consumer' is not legalizable}} "test.remaining_consumer"(%arg0) : (i64) -> () + "test.legal_op"() ({ + "test.invalid"(%arg0) : (i64) -> () + }) : () -> () "test.invalid"(%arg0) : (i64) -> () }) : () -> () // expected-remark @+1 {{'func.return' is not legalizable}}