Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,16 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
ConversionPatternRewriter &rewriter) const override;
};

struct SelectOpOneToNLowering : public ConvertOpToLLVMPattern<arith::SelectOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
using Adaptor =
typename ConvertOpToLLVMPattern<arith::SelectOp>::OneToNOpAdaptor;

LogicalResult
matchAndRewrite(arith::SelectOp op, Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

} // namespace

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -479,6 +489,32 @@ CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
rewriter);
}

//===----------------------------------------------------------------------===//
// SelectOpOneToNLowering
//===----------------------------------------------------------------------===//

/// Pattern for arith.select where the true/false values lower to multiple
/// SSA values (1:N conversion). This pattern generates multiple arith.select
/// than can be lowered by the 1:1 arith.select pattern.
LogicalResult SelectOpOneToNLowering::matchAndRewrite(
arith::SelectOp op, Adaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// In case of a 1:1 conversion, the 1:1 pattern will match.
Copy link
Member Author

@matthias-springer matthias-springer Aug 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: This utilizes #153605. If it is a 1:N conversion, the 1:1 pattern will not match.

if (llvm::hasSingleElement(adaptor.getTrueValue()))
return rewriter.notifyMatchFailure(
op, "not a 1:N conversion, 1:1 pattern will match");
if (!op.getCondition().getType().isInteger(1))
return rewriter.notifyMatchFailure(op,
"non-i1 conditions are not supported");
SmallVector<Value> results;
for (auto [trueValue, falseValue] :
llvm::zip_equal(adaptor.getTrueValue(), adaptor.getFalseValue()))
results.push_back(arith::SelectOp::create(
rewriter, op.getLoc(), op.getCondition(), trueValue, falseValue));
rewriter.replaceOpWithMultiple(op, {results});
return success();
}

//===----------------------------------------------------------------------===//
// Pass Definition
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -587,6 +623,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
RemSIOpLowering,
RemUIOpLowering,
SelectOpLowering,
SelectOpOneToNLowering,
ShLIOpLowering,
ShRSIOpLowering,
ShRUIOpLowering,
Expand Down
15 changes: 15 additions & 0 deletions mlir/test/Conversion/ArithToLLVM/type-conversion.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: mlir-opt %s -test-llvm-legalize-patterns -split-input-file | FileCheck %s
// RUN: mlir-opt %s -test-llvm-legalize-patterns="allow-pattern-rollback=0" -split-input-file | FileCheck %s

// CHECK-LABEL: llvm.func @arith_select(
// CHECK-SAME: %[[arg0:.*]]: i1, %[[arg1:.*]]: i18, %[[arg2:.*]]: i18, %[[arg3:.*]]: i18, %[[arg4:.*]]: i18) -> !llvm.struct<(i18, i18)>
// CHECK: %[[select0:.*]] = llvm.select %[[arg0]], %[[arg1]], %[[arg3]] : i1, i18
// CHECK: %[[select1:.*]] = llvm.select %[[arg0]], %[[arg2]], %[[arg4]] : i1, i18
// CHECK: %[[i0:.*]] = llvm.mlir.poison : !llvm.struct<(i18, i18)>
// CHECK: %[[i1:.*]] = llvm.insertvalue %[[select0]], %[[i0]][0] : !llvm.struct<(i18, i18)>
// CHECK: %[[i2:.*]] = llvm.insertvalue %[[select1]], %[[i1]][1] : !llvm.struct<(i18, i18)>
// CHECK: llvm.return %[[i2]]
func.func @arith_select(%arg0: i1, %arg1: i17, %arg2: i17) -> (i17) {
%0 = arith.select %arg0, %arg1, %arg2 : i17
return %0 : i17
}
2 changes: 2 additions & 0 deletions mlir/test/lib/Dialect/LLVM/TestPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
Expand Down Expand Up @@ -70,6 +71,7 @@ struct TestLLVMLegalizePatternsPass
// Populate patterns.
mlir::RewritePatternSet patterns(ctx);
patterns.add<TestDirectReplacementOp>(ctx, converter);
arith::populateArithToLLVMConversionPatterns(converter, patterns);
populateFuncToLLVMConversionPatterns(converter, patterns);
cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);

Expand Down