Skip to content

Conversation

@matthias-springer
Copy link
Member

Add 1:N support for the arith.select lowering. Only cases where the entire true/false value is selected are supported.

@llvmbot
Copy link
Member

llvmbot commented Aug 16, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Matthias Springer (matthias-springer)

Changes

Add 1:N support for the arith.select lowering. Only cases where the entire true/false value is selected are supported.


Full diff: https://github.com/llvm/llvm-project/pull/153944.diff

3 Files Affected:

  • (modified) mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp (+41)
  • (modified) mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir (+15)
  • (modified) mlir/test/lib/Dialect/LLVM/TestPatterns.cpp (+2)
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 18e857c81af8d..3d759e0fb6361 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -238,6 +238,16 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+struct SelectOpOneToNLowering : public ConvertOpToLLVMPattern<arith::SelectOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+  using Adaptor =
+      typename ConvertOpToLLVMPattern<arith::SelectOp>::OneToNOpAdaptor;
+
+  LogicalResult
+  matchAndRewrite(arith::SelectOp op, Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -479,6 +489,36 @@ CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
       rewriter);
 }
 
+//===----------------------------------------------------------------------===//
+// SelectOpOneToNLowering
+//===----------------------------------------------------------------------===//
+
+/// Pattern for arith.select where the true/false values lower to multiple
+/// SSA values (1:N conversion). This pattern generates multiple arith.select
+/// than can be lowered by the 1:1 arith.select pattern.
+LogicalResult SelectOpOneToNLowering::matchAndRewrite(
+    arith::SelectOp op, Adaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  // In case of a 1:1 conversion, the 1:1 pattern will match.
+  if (llvm::hasSingleElement(adaptor.getTrueValue()))
+    return rewriter.notifyMatchFailure(
+        op, "not a 1:N conversion, 1:1 pattern will match");
+  if (!op.getCondition().getType().isInteger(1))
+    return rewriter.notifyMatchFailure(op,
+                                       "non-i1 conditions are not supported");
+  if (!llvm::hasSingleElement(adaptor.getCondition()))
+    return rewriter.notifyMatchFailure(
+        op, "1:N condition conversion is not supported");
+  SmallVector<Value> results;
+  for (auto [trueValue, falseValue] :
+       llvm::zip_equal(adaptor.getTrueValue(), adaptor.getFalseValue()))
+    results.push_back(arith::SelectOp::create(
+        rewriter, op.getLoc(), llvm::getSingleElement(adaptor.getCondition()),
+        trueValue, falseValue));
+  rewriter.replaceOpWithMultiple(op, {results});
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Pass Definition
 //===----------------------------------------------------------------------===//
@@ -587,6 +627,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
     RemSIOpLowering,
     RemUIOpLowering,
     SelectOpLowering,
+    SelectOpOneToNLowering,
     ShLIOpLowering,
     ShRSIOpLowering,
     ShRUIOpLowering,
diff --git a/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir b/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir
index c1751f282b002..2e050887cc1d3 100644
--- a/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir
@@ -138,3 +138,18 @@ func.func @caller(%arg0: i1, %arg1: i17) -> (i17, i1, i17) {
   %res:2 = func.call @multi_return(%arg1, %arg0) : (i17, i1) -> (i17, i1)
   return %res#0, %res#1, %res#0 : i17, i1, i17
 }
+
+// -----
+
+// CHECK-LABEL: llvm.func @arith_select(
+//  CHECK-SAME:     %[[arg0:.*]]: i1, %[[arg1:.*]]: i18, %[[arg2:.*]]: i18, %[[arg3:.*]]: i18, %[[arg4:.*]]: i18) -> !llvm.struct<(i18, i18)>
+//       CHECK:   %[[select0:.*]] = llvm.select %[[arg0]], %[[arg1]], %[[arg3]] : i1, i18
+//       CHECK:   %[[select1:.*]] = llvm.select %[[arg0]], %[[arg2]], %[[arg4]] : i1, i18
+//       CHECK:   %[[i0:.*]] = llvm.mlir.poison : !llvm.struct<(i18, i18)>
+//       CHECK:   %[[i1:.*]] = llvm.insertvalue %[[select0]], %[[i0]][0] : !llvm.struct<(i18, i18)>
+//       CHECK:   %[[i2:.*]] = llvm.insertvalue %[[select1]], %[[i1]][1] : !llvm.struct<(i18, i18)>
+//       CHECK:   llvm.return %[[i2]]
+func.func @arith_select(%arg0: i1, %arg1: i17, %arg2: i17) -> (i17) {
+  %0 = arith.select %arg0, %arg1, %arg2 : i17
+  return %0 : i17
+}
diff --git a/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp b/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
index fe9aa0f2a9902..c2a75836b77b9 100644
--- a/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
@@ -6,6 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -69,6 +70,7 @@ struct TestLLVMLegalizePatternsPass
     // Populate patterns.
     mlir::RewritePatternSet patterns(ctx);
     patterns.add<TestDirectReplacementOp>(ctx, converter);
+    arith::populateArithToLLVMConversionPatterns(converter, patterns);
     populateFuncToLLVMConversionPatterns(converter, patterns);
 
     // Define the conversion target used for the test.

Copy link
Member Author

@matthias-springer matthias-springer Aug 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: This utilizes #153605. If it is a 1:N conversion, the 1:1 pattern will not match.

@matthias-springer matthias-springer force-pushed the users/matthias-springer/arith_1n branch from ae9611f to c891df2 Compare August 16, 2025 10:53
@matthias-springer matthias-springer force-pushed the users/matthias-springer/arith_1n branch from c891df2 to fb7422f Compare August 18, 2025 06:58
@matthias-springer matthias-springer force-pushed the users/matthias-springer/arith_1n branch from fb7422f to 94b91ff Compare August 18, 2025 07:04
Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@matthias-springer matthias-springer merged commit f7b09ad into main Aug 18, 2025
9 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/arith_1n branch August 18, 2025 07:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants