Skip to content

[InstCombine] Simplify nested selects with implied condition #83739

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 33 additions & 51 deletions llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2643,46 +2643,33 @@ static Value *foldSelectWithFrozenICmp(SelectInst &Sel, InstCombiner::BuilderTy
return nullptr;
}

/// Given that \p CondVal is known to be \p CondIsTrue, try to simplify \p SI.
static Value *simplifyNestedSelectsUsingImpliedCond(SelectInst &SI,
Value *CondVal,
bool CondIsTrue,
const DataLayout &DL) {
Value *InnerCondVal = SI.getCondition();
Value *InnerTrueVal = SI.getTrueValue();
Value *InnerFalseVal = SI.getFalseValue();
assert(CondVal->getType() == InnerCondVal->getType() &&
"The type of inner condition must match with the outer.");
if (auto Implied = isImpliedCondition(CondVal, InnerCondVal, DL, CondIsTrue))
return *Implied ? InnerTrueVal : InnerFalseVal;
return nullptr;
}

Instruction *InstCombinerImpl::foldAndOrOfSelectUsingImpliedCond(Value *Op,
SelectInst &SI,
bool IsAnd) {
Value *CondVal = SI.getCondition();
Value *A = SI.getTrueValue();
Value *B = SI.getFalseValue();

assert(Op->getType()->isIntOrIntVectorTy(1) &&
"Op must be either i1 or vector of i1.");

std::optional<bool> Res = isImpliedCondition(Op, CondVal, DL, IsAnd);
if (!Res)
if (SI.getCondition()->getType() != Op->getType())
return nullptr;

Value *Zero = Constant::getNullValue(A->getType());
Value *One = Constant::getAllOnesValue(A->getType());

if (*Res == true) {
if (IsAnd)
// select op, (select cond, A, B), false => select op, A, false
// and op, (select cond, A, B) => select op, A, false
// if op = true implies condval = true.
return SelectInst::Create(Op, A, Zero);
else
// select op, true, (select cond, A, B) => select op, true, A
// or op, (select cond, A, B) => select op, true, A
// if op = false implies condval = true.
return SelectInst::Create(Op, One, A);
} else {
if (IsAnd)
// select op, (select cond, A, B), false => select op, B, false
// and op, (select cond, A, B) => select op, B, false
// if op = true implies condval = false.
return SelectInst::Create(Op, B, Zero);
else
// select op, true, (select cond, A, B) => select op, true, B
// or op, (select cond, A, B) => select op, true, B
// if op = false implies condval = false.
return SelectInst::Create(Op, One, B);
}
if (Value *V = simplifyNestedSelectsUsingImpliedCond(SI, Op, IsAnd, DL))
return SelectInst::Create(Op,
IsAnd ? V : ConstantInt::getTrue(Op->getType()),
IsAnd ? ConstantInt::getFalse(Op->getType()) : V);
return nullptr;
}

// Canonicalize select with fcmp to fabs(). -0.0 makes this tricky. We need
Expand Down Expand Up @@ -3138,11 +3125,6 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) {
return replaceInstUsesWith(SI, Op1);
}

if (auto *Op1SI = dyn_cast<SelectInst>(Op1))
if (auto *I = foldAndOrOfSelectUsingImpliedCond(CondVal, *Op1SI,
/* IsAnd */ IsAnd))
return I;

if (auto *ICmp0 = dyn_cast<ICmpInst>(CondVal))
if (auto *ICmp1 = dyn_cast<ICmpInst>(Op1))
if (auto *V = foldAndOrOfICmps(ICmp0, ICmp1, SI, IsAnd,
Expand Down Expand Up @@ -3643,12 +3625,12 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {

if (SelectInst *TrueSI = dyn_cast<SelectInst>(TrueVal)) {
if (TrueSI->getCondition()->getType() == CondVal->getType()) {
// select(C, select(C, a, b), c) -> select(C, a, c)
if (TrueSI->getCondition() == CondVal) {
if (SI.getTrueValue() == TrueSI->getTrueValue())
return nullptr;
return replaceOperand(SI, 1, TrueSI->getTrueValue());
}
// Fold nested selects if the inner condition can be implied by the outer
// condition.
if (Value *V = simplifyNestedSelectsUsingImpliedCond(
*TrueSI, CondVal, /*CondIsTrue=*/true, DL))
return replaceOperand(SI, 1, V);

// select(C0, select(C1, a, b), b) -> select(C0&C1, a, b)
// We choose this as normal form to enable folding on the And and
// shortening paths for the values (this helps getUnderlyingObjects() for
Expand All @@ -3663,12 +3645,12 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
}
if (SelectInst *FalseSI = dyn_cast<SelectInst>(FalseVal)) {
if (FalseSI->getCondition()->getType() == CondVal->getType()) {
// select(C, a, select(C, b, c)) -> select(C, a, c)
if (FalseSI->getCondition() == CondVal) {
if (SI.getFalseValue() == FalseSI->getFalseValue())
return nullptr;
return replaceOperand(SI, 2, FalseSI->getFalseValue());
}
// Fold nested selects if the inner condition can be implied by the outer
// condition.
if (Value *V = simplifyNestedSelectsUsingImpliedCond(
*FalseSI, CondVal, /*CondIsTrue=*/false, DL))
return replaceOperand(SI, 2, V);

// select(C0, a, select(C1, a, b)) -> select(C0|C1, a, b)
if (FalseSI->getTrueValue() == TrueVal && FalseSI->hasOneUse()) {
Value *Or = Builder.CreateLogicalOr(CondVal, FalseSI->getCondition());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,8 @@ define i32 @n9_ult_slt_neg17(i32 %x, i32 %replacement_low, i32 %replacement_high
; Regression test for PR53252.
define i32 @n10_ugt_slt(i32 %x, i32 %replacement_low, i32 %replacement_high) {
; CHECK-LABEL: @n10_ugt_slt(
; CHECK-NEXT: [[T0:%.*]] = icmp slt i32 [[X:%.*]], 0
; CHECK-NEXT: [[T1:%.*]] = select i1 [[T0]], i32 [[REPLACEMENT_LOW:%.*]], i32 [[REPLACEMENT_HIGH:%.*]]
; CHECK-NEXT: [[T2:%.*]] = icmp ugt i32 [[X]], 128
; CHECK-NEXT: [[R:%.*]] = select i1 [[T2]], i32 [[X]], i32 [[T1]]
; CHECK-NEXT: [[T2:%.*]] = icmp ugt i32 [[X:%.*]], 128
; CHECK-NEXT: [[R:%.*]] = select i1 [[T2]], i32 [[X]], i32 [[REPLACEMENT_HIGH:%.*]]
; CHECK-NEXT: ret i32 [[R]]
;
%t0 = icmp slt i32 %x, 0
Expand All @@ -204,10 +202,8 @@ define i32 @n10_ugt_slt(i32 %x, i32 %replacement_low, i32 %replacement_high) {

define i32 @n11_uge_slt(i32 %x, i32 %replacement_low, i32 %replacement_high) {
; CHECK-LABEL: @n11_uge_slt(
; CHECK-NEXT: [[T0:%.*]] = icmp slt i32 [[X:%.*]], 0
; CHECK-NEXT: [[T1:%.*]] = select i1 [[T0]], i32 [[REPLACEMENT_LOW:%.*]], i32 [[REPLACEMENT_HIGH:%.*]]
; CHECK-NEXT: [[T2:%.*]] = icmp ult i32 [[X]], 129
; CHECK-NEXT: [[R:%.*]] = select i1 [[T2]], i32 [[T1]], i32 [[X]]
; CHECK-NEXT: [[T2:%.*]] = icmp ult i32 [[X:%.*]], 129
; CHECK-NEXT: [[R:%.*]] = select i1 [[T2]], i32 [[REPLACEMENT_HIGH:%.*]], i32 [[X]]
; CHECK-NEXT: ret i32 [[R]]
;
%t0 = icmp slt i32 %x, 0
Expand Down
91 changes: 91 additions & 0 deletions llvm/test/Transforms/InstCombine/nested-select.ll
Original file line number Diff line number Diff line change
Expand Up @@ -498,3 +498,94 @@ define i1 @orcond.111.inv.all.conds(i1 %inner.cond, i1 %alt.cond, i1 %inner.sel.
%outer.sel = select i1 %not.outer.cond, i1 true, i1 %inner.sel
ret i1 %outer.sel
}

define i8 @test_implied_true(i8 %x) {
; CHECK-LABEL: @test_implied_true(
; CHECK-NEXT: [[CMP2:%.*]] = icmp slt i8 [[X:%.*]], 0
; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[CMP2]], i8 0, i8 20
; CHECK-NEXT: ret i8 [[SEL2]]
;
%cmp1 = icmp slt i8 %x, 10
%cmp2 = icmp slt i8 %x, 0
%sel1 = select i1 %cmp1, i8 0, i8 5
%sel2 = select i1 %cmp2, i8 %sel1, i8 20
ret i8 %sel2
}

define <2 x i8> @test_implied_true_vec(<2 x i8> %x) {
; CHECK-LABEL: @test_implied_true_vec(
; CHECK-NEXT: [[CMP2:%.*]] = icmp slt <2 x i8> [[X:%.*]], zeroinitializer
; CHECK-NEXT: [[SEL2:%.*]] = select <2 x i1> [[CMP2]], <2 x i8> zeroinitializer, <2 x i8> <i8 20, i8 20>
; CHECK-NEXT: ret <2 x i8> [[SEL2]]
;
%cmp1 = icmp slt <2 x i8> %x, <i8 10, i8 10>
%cmp2 = icmp slt <2 x i8> %x, zeroinitializer
%sel1 = select <2 x i1> %cmp1, <2 x i8> zeroinitializer, <2 x i8> <i8 5, i8 5>
%sel2 = select <2 x i1> %cmp2, <2 x i8> %sel1, <2 x i8> <i8 20, i8 20>
ret <2 x i8> %sel2
}

define i8 @test_implied_true_falseval(i8 %x) {
; CHECK-LABEL: @test_implied_true_falseval(
; CHECK-NEXT: [[CMP2:%.*]] = icmp sgt i8 [[X:%.*]], 0
; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[CMP2]], i8 20, i8 0
; CHECK-NEXT: ret i8 [[SEL2]]
;
%cmp1 = icmp slt i8 %x, 10
%cmp2 = icmp sgt i8 %x, 0
%sel1 = select i1 %cmp1, i8 0, i8 5
%sel2 = select i1 %cmp2, i8 20, i8 %sel1
ret i8 %sel2
}

define i8 @test_implied_false(i8 %x) {
; CHECK-LABEL: @test_implied_false(
; CHECK-NEXT: [[CMP2:%.*]] = icmp slt i8 [[X:%.*]], 0
; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[CMP2]], i8 5, i8 20
; CHECK-NEXT: ret i8 [[SEL2]]
;
%cmp1 = icmp sgt i8 %x, 10
%cmp2 = icmp slt i8 %x, 0
%sel1 = select i1 %cmp1, i8 0, i8 5
%sel2 = select i1 %cmp2, i8 %sel1, i8 20
ret i8 %sel2
}

; Negative tests

define i8 @test_imply_fail(i8 %x) {
; CHECK-LABEL: @test_imply_fail(
; CHECK-NEXT: [[CMP1:%.*]] = icmp slt i8 [[X:%.*]], -10
; CHECK-NEXT: [[CMP2:%.*]] = icmp slt i8 [[X]], 0
; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[CMP1]], i8 0, i8 5
; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[CMP2]], i8 [[SEL1]], i8 20
; CHECK-NEXT: ret i8 [[SEL2]]
;
%cmp1 = icmp slt i8 %x, -10
%cmp2 = icmp slt i8 %x, 0
%sel1 = select i1 %cmp1, i8 0, i8 5
%sel2 = select i1 %cmp2, i8 %sel1, i8 20
ret i8 %sel2
}

define <2 x i8> @test_imply_type_mismatch(<2 x i8> %x, i8 %y) {
; CHECK-LABEL: @test_imply_type_mismatch(
; CHECK-NEXT: [[CMP1:%.*]] = icmp slt <2 x i8> [[X:%.*]], <i8 10, i8 10>
; CHECK-NEXT: [[CMP2:%.*]] = icmp slt i8 [[Y:%.*]], 0
; CHECK-NEXT: [[SEL1:%.*]] = select <2 x i1> [[CMP1]], <2 x i8> zeroinitializer, <2 x i8> <i8 5, i8 5>
; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[CMP2]], <2 x i8> [[SEL1]], <2 x i8> <i8 20, i8 20>
; CHECK-NEXT: ret <2 x i8> [[SEL2]]
;
%cmp1 = icmp slt <2 x i8> %x, <i8 10, i8 10>
%cmp2 = icmp slt i8 %y, 0
%sel1 = select <2 x i1> %cmp1, <2 x i8> zeroinitializer, <2 x i8> <i8 5, i8 5>
%sel2 = select i1 %cmp2, <2 x i8> %sel1, <2 x i8> <i8 20, i8 20>
ret <2 x i8> %sel2
}

define <4 x i1> @test_dont_crash(i1 %cond, <4 x i1> %a, <4 x i1> %b) {
entry:
%sel = select i1 %cond, <4 x i1> %a, <4 x i1> zeroinitializer
%and = and <4 x i1> %sel, %b
ret <4 x i1> %and
}