Skip to content

Conversation

@matthias-springer
Copy link
Member

This commit align the implementation of ConversionPatternRewriter::legalize with its documentation:

  /// Attempt to legalize the given region. This can be used within
  ...
  LogicalResult legalize(Region *r);

This function now legalizes the entire region, including nested ops. The implementation follows the same logic as the "main" traversal: pre-order, forward-dominance.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Dec 13, 2025
@llvmbot
Copy link
Member

llvmbot commented Dec 13, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

This commit align the implementation of ConversionPatternRewriter::legalize with its documentation:

  /// Attempt to legalize the given region. This can be used within
  ...
  LogicalResult legalize(Region *r);

This function now legalizes the entire region, including nested ops. The implementation follows the same logic as the "main" traversal: pre-order, forward-dominance.


Full diff: https://github.com/llvm/llvm-project/pull/172158.diff

2 Files Affected:

  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+78-54)
  • (modified) mlir/test/Transforms/test-legalizer.mlir (+8)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 09ad42364baaf..d4b1c8c7f0a74 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2257,37 +2257,6 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
   return success();
 }
 
-LogicalResult ConversionPatternRewriter::legalize(Region *r) {
-  // Fast path: If the region is empty, there is nothing to legalize.
-  if (r->empty())
-    return success();
-
-  // Gather a list of all operations to legalize. This is done before
-  // converting the entry block signature because unrealized_conversion_cast
-  // ops should not be included.
-  SmallVector<Operation *> ops;
-  for (Block &b : *r)
-    for (Operation &op : b)
-      ops.push_back(&op);
-
-  // If the current pattern runs with a type converter, convert the entry block
-  // signature.
-  if (const TypeConverter *converter = impl->currentTypeConverter) {
-    std::optional<TypeConverter::SignatureConversion> conversion =
-        converter->convertBlockSignature(&r->front());
-    if (!conversion)
-      return failure();
-    applySignatureConversion(&r->front(), *conversion, converter);
-  }
-
-  // Legalize all operations in the region.
-  for (Operation *op : ops)
-    if (failed(legalize(op)))
-      return failure();
-
-  return success();
-}
-
 void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
                                                   Block::iterator before,
                                                   ValueRange argValues) {
@@ -3287,8 +3256,20 @@ struct OperationConverter {
       : rewriter(ctx, config, *this), opLegalizer(rewriter, target, patterns),
         mode(mode) {}
 
-  /// Converts the given operations to the conversion target.
-  LogicalResult convertOperations(ArrayRef<Operation *> ops);
+  /// Applies the conversion to the given operations (and their nested
+  /// operations).
+  LogicalResult applyConversion(ArrayRef<Operation *> ops);
+
+  /// Legalizes the given operations (and their nested operations) to the
+  /// conversion target.
+  template <typename Fn>
+  LogicalResult legalizeOperations(ArrayRef<Operation *> ops, Fn onFailure,
+                                   bool isRecursiveLegalization = false);
+  LogicalResult legalizeOperations(ArrayRef<Operation *> ops,
+                                   bool isRecursiveLegalization = false) {
+    return legalizeOperations(
+        ops, /*onFailure=*/[&]() {}, isRecursiveLegalization);
+  }
 
   /// Converts a single operation. If `isRecursiveLegalization` is "true", the
   /// conversion is a recursive legalization request, triggered from within a
@@ -3297,6 +3278,8 @@ struct OperationConverter {
   /// legalization mechanism).
   LogicalResult convert(Operation *op, bool isRecursiveLegalization = false);
 
+  const ConversionTarget &getTarget() { return opLegalizer.getTarget(); }
+
 private:
   /// The rewriter to use when converting operations.
   ConversionPatternRewriter rewriter;
@@ -3309,10 +3292,6 @@ struct OperationConverter {
 };
 } // namespace mlir
 
-LogicalResult ConversionPatternRewriter::legalize(Operation *op) {
-  return impl->opConverter.convert(op, /*isRecursiveLegalization=*/true);
-}
-
 LogicalResult OperationConverter::convert(Operation *op,
                                           bool isRecursiveLegalization) {
   const ConversionConfig &config = rewriter.getConfig();
@@ -3398,12 +3377,15 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
   return failure();
 }
 
-LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
+template <typename Fn>
+LogicalResult
+OperationConverter::legalizeOperations(ArrayRef<Operation *> ops, Fn onFailure,
+                                       bool isRecursiveLegalization) {
   const ConversionTarget &target = opLegalizer.getTarget();
 
   // Compute the set of operations and blocks to convert.
   SmallVector<Operation *> toConvert;
-  for (auto *op : ops) {
+  for (Operation *op : ops) {
     op->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>(
         [&](Operation *op) {
           toConvert.push_back(op);
@@ -3415,25 +3397,67 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
           return WalkResult::advance();
         });
   }
+  for (Operation *op : toConvert) {
+    if (failed(convert(op, isRecursiveLegalization))) {
+      // Failed to convert an operation.
+      onFailure();
+      return failure();
+    }
+  }
+  return success();
+}
 
-  // Convert each operation and discard rewrites on failure.
-  ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
+LogicalResult ConversionPatternRewriter::legalize(Operation *op) {
+  return impl->opConverter.legalizeOperations(op,
+                                              /*isRecursiveLegalization=*/true);
+}
 
-  for (auto *op : toConvert) {
-    if (failed(convert(op))) {
-      // Dialect conversion failed.
-      if (rewriterImpl.config.allowPatternRollback) {
-        // Rollback is allowed: restore the original IR.
-        rewriterImpl.undoRewrites();
-      } else {
-        // Rollback is not allowed: apply all modifications that have been
-        // performed so far.
-        rewriterImpl.applyRewrites();
-      }
+LogicalResult ConversionPatternRewriter::legalize(Region *r) {
+  // Fast path: If the region is empty, there is nothing to legalize.
+  if (r->empty())
+    return success();
+
+  // Gather a list of all operations to legalize. This is done before
+  // converting the entry block signature because unrealized_conversion_cast
+  // ops should not be included.
+  SmallVector<Operation *> ops;
+  for (Block &b : *r)
+    for (Operation &op : b)
+      ops.push_back(&op);
+
+  // If the current pattern runs with a type converter, convert the entry block
+  // signature.
+  if (const TypeConverter *converter = impl->currentTypeConverter) {
+    std::optional<TypeConverter::SignatureConversion> conversion =
+        converter->convertBlockSignature(&r->front());
+    if (!conversion)
       return failure();
-    }
+    applySignatureConversion(&r->front(), *conversion, converter);
   }
 
+  // Legalize all operations in the region. This includes all nested
+  // operations.
+  return impl->opConverter.legalizeOperations(ops,
+                                              /*isRecursiveLegalization=*/true);
+}
+
+LogicalResult OperationConverter::applyConversion(ArrayRef<Operation *> ops) {
+  // Convert each operation and discard rewrites on failure.
+  ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
+  LogicalResult status = legalizeOperations(ops, /*onFailure=*/[&]() {
+    // Dialect conversion failed.
+    if (rewriterImpl.config.allowPatternRollback) {
+      // Rollback is allowed: restore the original IR.
+      rewriterImpl.undoRewrites();
+    } else {
+      // Rollback is not allowed: apply all modifications that have been
+      // performed so far.
+      rewriterImpl.applyRewrites();
+    }
+  });
+  if (failed(status))
+    return failure();
+
   // After a successful conversion, apply rewrites.
   rewriterImpl.applyRewrites();
 
@@ -4143,7 +4167,7 @@ static LogicalResult applyConversion(ArrayRef<Operation *> ops,
       [&] {
         OperationConverter opConverter(ops.front()->getContext(), target,
                                        patterns, config, mode);
-        status = opConverter.convertOperations(ops);
+        status = opConverter.applyConversion(ops);
       },
       irUnits);
   return status;
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 88a71cc26ab0c..8d854aff1992f 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -454,11 +454,16 @@ func.func @test_working_1to1_pattern(%arg0: f16) {
 // The region of "test.post_order_legalization" is converted before the op.
 
 // CHECK: notifyBlockInserted into test.post_order_legalization: was unlinked
+// CHECK: notifyOperationInserted: test.remaining_consumer
+// CHECK: notifyOperationInserted: test.legal_op
 // CHECK: notifyOperationInserted: test.invalid
 // CHECK: notifyBlockErased
 // CHECK: notifyOperationInserted: test.valid, was unlinked
 // CHECK: notifyOperationReplaced: test.invalid
 // CHECK: notifyOperationErased: test.invalid
+// CHECK: notifyOperationInserted: test.valid, was unlinked
+// CHECK: notifyOperationReplaced: test.invalid
+// CHECK: notifyOperationErased: test.invalid
 // CHECK: notifyOperationModified: test.post_order_legalization
 
 // CHECK-LABEL: func @test_preorder_legalization
@@ -475,6 +480,9 @@ func.func @test_preorder_legalization() {
   ^bb0(%arg0: i64):
     // expected-remark @+1 {{'test.remaining_consumer' is not legalizable}}
     "test.remaining_consumer"(%arg0) : (i64) -> ()
+    "test.legal_op"() ({
+      "test.invalid"(%arg0) : (i64) -> ()
+    }) : () -> ()
     "test.invalid"(%arg0) : (i64) -> ()
   }) : () -> ()
   // expected-remark @+1 {{'func.return' is not legalizable}}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:core MLIR Core Infrastructure mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants