Skip to content

Conversation

@zero9178
Copy link
Member

Prior to this PR, the default behaviour of a conversion pattern which receives operands of a 1:N is to abort the compilation. This has historically been useful when the 1:N type conversion got merged into the dialect conversion as it allowed us to easily find patterns that should be capable of handling 1:N type conversions but didn't.

However, this behaviour has the disadvantage of being non-composable: While the pattern in question cannot handle the 1:N type conversion, another pattern part of the set might, but doesn't get the chance as compilation is aborted.

This PR fixes this behaviour by failing to match and instead of aborting, giving other patterns the chance to legalize an op. The implementation uses a reusable function called dispatchTo1To1 to allow derived conversion patterns to also implement the behaviour.

Prior to this PR, the default behaviour of a conversion pattern which receives operands of a 1:N is to abort the compilation. This has historically been useful when the 1:N type conversion got merged into the dialect conversion as it allowed us to easily find patterns that should be capable of handling 1:N type conversions but didn't.

However, this behaviour has the disadvantage of being non-composable: While the pattern in question cannot handle the 1:N type conversion, another pattern part of the set might, but doesn't get the chance as compilation is aborted.

This PR fixes this behaviour by failing to match and instead of aborting, giving other patterns the chance to legalize an op. The implementation uses a reusable function called `dispatchTo1To1` to allow derived conversion patterns to also implement the behaviour.
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:llvm mlir labels Aug 14, 2025
@zero9178 zero9178 changed the title [mlir] Do not abort when encountering a 1:1 conversion pattern [mlir] Do not abort when encountering a 1:1 conversion pattern in a 1:N converison Aug 14, 2025
@zero9178 zero9178 changed the title [mlir] Do not abort when encountering a 1:1 conversion pattern in a 1:N converison [mlir] Do not abort when encountering a 1:1 conversion pattern in a 1:N conversion Aug 14, 2025
@llvmbot
Copy link
Member

llvmbot commented Aug 14, 2025

@llvm/pr-subscribers-flang-fir-hlfir
@llvm/pr-subscribers-flang-codegen
@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Markus Böck (zero9178)

Changes

Prior to this PR, the default behaviour of a conversion pattern which receives operands of a 1:N is to abort the compilation. This has historically been useful when the 1:N type conversion got merged into the dialect conversion as it allowed us to easily find patterns that should be capable of handling 1:N type conversions but didn't.

However, this behaviour has the disadvantage of being non-composable: While the pattern in question cannot handle the 1:N type conversion, another pattern part of the set might, but doesn't get the chance as compilation is aborted.

This PR fixes this behaviour by failing to match and instead of aborting, giving other patterns the chance to legalize an op. The implementation uses a reusable function called dispatchTo1To1 to allow derived conversion patterns to also implement the behaviour.


Full diff: https://github.com/llvm/llvm-project/pull/153605.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Conversion/LLVMCommon/Pattern.h (+2-4)
  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+52-10)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+4-4)
  • (modified) mlir/test/Transforms/test-legalizer.mlir (+21)
  • (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+19-2)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index 969154abe8830..19148a5d783f3 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -233,9 +233,7 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
   virtual LogicalResult
   matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const {
-    SmallVector<Value> oneToOneOperands =
-        getOneToOneAdaptorOperands(adaptor.getOperands());
-    return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
+    return dispatchTo1To1(*this, op, adaptor, rewriter);
   }
 
 private:
@@ -276,7 +274,7 @@ class ConvertOpInterfaceToLLVMPattern : public ConvertToLLVMPattern {
   virtual LogicalResult
   matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
                   ConversionPatternRewriter &rewriter) const {
-    return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
+    return dispatchTo1To1(*this, op, operands, rewriter);
   }
 
 private:
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index e601c821e1e4e..80a09b643ce38 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -521,8 +521,8 @@ class ConversionPattern : public RewritePattern {
 
   /// Hook for derived classes to implement combined matching and rewriting.
   /// This overload supports only 1:1 replacements. The 1:N overload is called
-  /// by the driver. By default, it calls this 1:1 overload or reports a fatal
-  /// error if 1:N replacements were found.
+  /// by the driver. By default, it calls this 1:1 overload or fails to match
+  /// if 1:N replacements were found.
   virtual LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const {
@@ -534,7 +534,7 @@ class ConversionPattern : public RewritePattern {
   virtual LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
                   ConversionPatternRewriter &rewriter) const {
-    return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
+    return dispatchTo1To1(*this, op, operands, rewriter);
   }
 
   /// Attempt to match and rewrite the IR root at the specified operation.
@@ -567,11 +567,26 @@ class ConversionPattern : public RewritePattern {
   /// try to extract the single value of each range to construct a the inputs
   /// for a 1:1 adaptor.
   ///
-  /// This function produces a fatal error if at least one range has 0 or
-  /// more than 1 value: "pattern 'name' does not support 1:N conversion"
-  SmallVector<Value>
+  /// Returns failure if at least one range has 0 or more than 1 value.
+  FailureOr<SmallVector<Value>>
   getOneToOneAdaptorOperands(ArrayRef<ValueRange> operands) const;
 
+  /// Overloaded method used to dispatch to the 1:1 'matchAndRewrite' method
+  /// if possible and emit diagnostic with a failure return value otherwise.
+  /// 'self' should be '*this' of the derived-pattern and is used to dispatch
+  /// to the correct 'matchAndRewrite' method in the derived pattern.
+  template <typename Pattern, typename SourceOp>
+  static LogicalResult dispatchTo1To1(const Pattern &self, SourceOp op,
+                                      ArrayRef<ValueRange> operands,
+                                      ConversionPatternRewriter &rewriter);
+
+  /// Same as above, but accepts an adaptor as operand.
+  template <typename Pattern, typename SourceOp>
+  static LogicalResult dispatchTo1To1(
+      const Pattern &self, SourceOp op,
+      typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>> adaptor,
+      ConversionPatternRewriter &rewriter);
+
 protected:
   /// An optional type converter for use by this pattern.
   const TypeConverter *typeConverter = nullptr;
@@ -620,9 +635,7 @@ class OpConversionPattern : public ConversionPattern {
   virtual LogicalResult
   matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const {
-    SmallVector<Value> oneToOneOperands =
-        getOneToOneAdaptorOperands(adaptor.getOperands());
-    return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
+    return dispatchTo1To1(*this, op, adaptor, rewriter);
   }
 
 private:
@@ -666,7 +679,7 @@ class OpInterfaceConversionPattern : public ConversionPattern {
   virtual LogicalResult
   matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
                   ConversionPatternRewriter &rewriter) const {
-    return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
+    return dispatchTo1To1(*this, op, operands, rewriter);
   }
 
 private:
@@ -865,6 +878,35 @@ class ConversionPatternRewriter final : public PatternRewriter {
   std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
 };
 
+template <class Pattern, typename SourceOp>
+LogicalResult
+ConversionPattern::dispatchTo1To1(const Pattern &self, SourceOp op,
+                                  ArrayRef<ValueRange> operands,
+                                  ConversionPatternRewriter &rewriter) {
+  FailureOr<SmallVector<Value>> oneToOneOperands =
+      self.getOneToOneAdaptorOperands(operands);
+  if (failed(oneToOneOperands))
+    return rewriter.notifyMatchFailure(op,
+                                       "pattern '" + self.getDebugName() +
+                                           "' does not support 1:N conversion");
+  return self.matchAndRewrite(op, *oneToOneOperands, rewriter);
+}
+
+template <typename Pattern, typename SourceOp>
+LogicalResult ConversionPattern::dispatchTo1To1(
+    const Pattern &self, SourceOp op,
+    typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>> adaptor,
+    ConversionPatternRewriter &rewriter) {
+  FailureOr<SmallVector<Value>> oneToOneOperands =
+      self.getOneToOneAdaptorOperands(adaptor.getOperands());
+  if (failed(oneToOneOperands))
+    return rewriter.notifyMatchFailure(op,
+                                       "pattern '" + self.getDebugName() +
+                                           "' does not support 1:N conversion");
+  return self.matchAndRewrite(
+      op, typename SourceOp::Adaptor(*oneToOneOperands, adaptor), rewriter);
+}
+
 //===----------------------------------------------------------------------===//
 // ConversionTarget
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 001c13e1ab08c..ff34a58965763 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2244,17 +2244,17 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
 // ConversionPattern
 //===----------------------------------------------------------------------===//
 
-SmallVector<Value> ConversionPattern::getOneToOneAdaptorOperands(
+FailureOr<SmallVector<Value>> ConversionPattern::getOneToOneAdaptorOperands(
     ArrayRef<ValueRange> operands) const {
   SmallVector<Value> oneToOneOperands;
   oneToOneOperands.reserve(operands.size());
   for (ValueRange operand : operands) {
     if (operand.size() != 1)
-      llvm::report_fatal_error("pattern '" + getDebugName() +
-                               "' does not support 1:N conversion");
+      return failure();
+
     oneToOneOperands.push_back(operand.front());
   }
-  return oneToOneOperands;
+  return std::move(oneToOneOperands);
 }
 
 LogicalResult
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 9a04da7904863..55d153db7f4bb 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -439,3 +439,24 @@ func.func @test_lookup_without_converter() {
   // expected-remark@+1 {{op 'func.return' is not legalizable}}
   return
 }
+
+// -----
+// expected-remark@-1 {{applyPartialConversion failed}}
+
+func.func @test_skip_1to1_pattern(%arg0: f32) {
+  // expected-error@+1 {{failed to legalize operation 'test.type_consumer'}}
+  "test.type_consumer"(%arg0) : (f32) -> ()
+  return
+}
+
+// -----
+
+// Demonstrate that the pattern generally works, but only for 1:1 type
+// conversions.
+
+// CHECK-LABEL: @test_working_1to1_pattern(
+func.func @test_working_1to1_pattern(%arg0: f16) {
+  // CHECK-NEXT: "test.return"() : () -> ()
+  "test.type_consumer"(%arg0) : (f16) -> ()
+  "test.return"() : () -> ()
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 657dfd2bac6ec..6300c5b0ca21c 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1386,6 +1386,23 @@ class TestMultiple1ToNReplacement : public ConversionPattern {
   }
 };
 
+/// Pattern that erases 'test.type_consumers' iff the input operand is the
+/// result of a 1:1 type conversion.
+/// Used to test correct skipping of 1:1 patterns in the 1:N case.
+class TestTypeConsumerOpPattern
+    : public OpConversionPattern<TestTypeConsumerOp> {
+public:
+  TestTypeConsumerOpPattern(MLIRContext *ctx, const TypeConverter &converter)
+      : OpConversionPattern<TestTypeConsumerOp>(converter, ctx) {}
+
+  LogicalResult
+  matchAndRewrite(TestTypeConsumerOp op, OpAdaptor operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
 /// Test unambiguous overload resolution of replaceOpWithMultiple. This
 /// function is just to trigger compiler errors. It is never executed.
 [[maybe_unused]] void testReplaceOpWithMultipleOverloads(
@@ -1497,8 +1514,8 @@ struct TestLegalizePatternDriver
         TestRepetitive1ToNConsumer>(&getContext());
     patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
                  TestPassthroughInvalidOp, TestMultiple1ToNReplacement,
-                 TestBlockArgReplace, TestReplaceWithValidConsumer>(
-        &getContext(), converter);
+                 TestBlockArgReplace, TestReplaceWithValidConsumer,
+                 TestTypeConsumerOpPattern>(&getContext(), converter);
     patterns.add<TestConvertBlockArgs>(converter, &getContext());
     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
                                                               converter);

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir flang:codegen labels Aug 14, 2025
@matthias-springer matthias-springer changed the title [mlir] Do not abort when encountering a 1:1 conversion pattern in a 1:N conversion [mlir][Transforms] Turn 1:N -> 1:1 dispatch fatal error into match failure Aug 15, 2025
@zero9178 zero9178 merged commit 8582025 into llvm:main Aug 15, 2025
9 checks passed
@zero9178 zero9178 deleted the 1to1-pattern-composability branch August 15, 2025 09:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

flang:codegen flang:fir-hlfir flang Flang issues not falling into any other category mlir:core MLIR Core Infrastructure mlir:llvm mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants