-
Notifications
You must be signed in to change notification settings - Fork 15.6k
[mlir][LLVM] ArithToLLVM: Add 1:N support for arith.select lowering
#153944
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm Author: Matthias Springer (matthias-springer) ChangesAdd 1:N support for the Full diff: https://github.com/llvm/llvm-project/pull/153944.diff 3 Files Affected:
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 18e857c81af8d..3d759e0fb6361 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -238,6 +238,16 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
ConversionPatternRewriter &rewriter) const override;
};
+struct SelectOpOneToNLowering : public ConvertOpToLLVMPattern<arith::SelectOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+ using Adaptor =
+ typename ConvertOpToLLVMPattern<arith::SelectOp>::OneToNOpAdaptor;
+
+ LogicalResult
+ matchAndRewrite(arith::SelectOp op, Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -479,6 +489,36 @@ CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
rewriter);
}
+//===----------------------------------------------------------------------===//
+// SelectOpOneToNLowering
+//===----------------------------------------------------------------------===//
+
+/// Pattern for arith.select where the true/false values lower to multiple
+/// SSA values (1:N conversion). This pattern generates multiple arith.select
+/// than can be lowered by the 1:1 arith.select pattern.
+LogicalResult SelectOpOneToNLowering::matchAndRewrite(
+ arith::SelectOp op, Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ // In case of a 1:1 conversion, the 1:1 pattern will match.
+ if (llvm::hasSingleElement(adaptor.getTrueValue()))
+ return rewriter.notifyMatchFailure(
+ op, "not a 1:N conversion, 1:1 pattern will match");
+ if (!op.getCondition().getType().isInteger(1))
+ return rewriter.notifyMatchFailure(op,
+ "non-i1 conditions are not supported");
+ if (!llvm::hasSingleElement(adaptor.getCondition()))
+ return rewriter.notifyMatchFailure(
+ op, "1:N condition conversion is not supported");
+ SmallVector<Value> results;
+ for (auto [trueValue, falseValue] :
+ llvm::zip_equal(adaptor.getTrueValue(), adaptor.getFalseValue()))
+ results.push_back(arith::SelectOp::create(
+ rewriter, op.getLoc(), llvm::getSingleElement(adaptor.getCondition()),
+ trueValue, falseValue));
+ rewriter.replaceOpWithMultiple(op, {results});
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Pass Definition
//===----------------------------------------------------------------------===//
@@ -587,6 +627,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
RemSIOpLowering,
RemUIOpLowering,
SelectOpLowering,
+ SelectOpOneToNLowering,
ShLIOpLowering,
ShRSIOpLowering,
ShRUIOpLowering,
diff --git a/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir b/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir
index c1751f282b002..2e050887cc1d3 100644
--- a/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/type-conversion.mlir
@@ -138,3 +138,18 @@ func.func @caller(%arg0: i1, %arg1: i17) -> (i17, i1, i17) {
%res:2 = func.call @multi_return(%arg1, %arg0) : (i17, i1) -> (i17, i1)
return %res#0, %res#1, %res#0 : i17, i1, i17
}
+
+// -----
+
+// CHECK-LABEL: llvm.func @arith_select(
+// CHECK-SAME: %[[arg0:.*]]: i1, %[[arg1:.*]]: i18, %[[arg2:.*]]: i18, %[[arg3:.*]]: i18, %[[arg4:.*]]: i18) -> !llvm.struct<(i18, i18)>
+// CHECK: %[[select0:.*]] = llvm.select %[[arg0]], %[[arg1]], %[[arg3]] : i1, i18
+// CHECK: %[[select1:.*]] = llvm.select %[[arg0]], %[[arg2]], %[[arg4]] : i1, i18
+// CHECK: %[[i0:.*]] = llvm.mlir.poison : !llvm.struct<(i18, i18)>
+// CHECK: %[[i1:.*]] = llvm.insertvalue %[[select0]], %[[i0]][0] : !llvm.struct<(i18, i18)>
+// CHECK: %[[i2:.*]] = llvm.insertvalue %[[select1]], %[[i1]][1] : !llvm.struct<(i18, i18)>
+// CHECK: llvm.return %[[i2]]
+func.func @arith_select(%arg0: i1, %arg1: i17, %arg2: i17) -> (i17) {
+ %0 = arith.select %arg0, %arg1, %arg2 : i17
+ return %0 : i17
+}
diff --git a/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp b/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
index fe9aa0f2a9902..c2a75836b77b9 100644
--- a/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
@@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -69,6 +70,7 @@ struct TestLLVMLegalizePatternsPass
// Populate patterns.
mlir::RewritePatternSet patterns(ctx);
patterns.add<TestDirectReplacementOp>(ctx, converter);
+ arith::populateArithToLLVMConversionPatterns(converter, patterns);
populateFuncToLLVMConversionPatterns(converter, patterns);
// Define the conversion target used for the test.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: This utilizes #153605. If it is a 1:N conversion, the 1:1 pattern will not match.
ae9611f to
c891df2
Compare
c891df2 to
fb7422f
Compare
fb7422f to
94b91ff
Compare
zero9178
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Add 1:N support for the
arith.selectlowering. Only cases where the entire true/false value is selected are supported.