Skip to content

Commit 0c47363

Browse files
authored
[InstCombine] Simplify nested selects with implied condition (#83739)
This patch does the following simplification: ``` sel1 = select cond1, X, Y sel2 = select cond2, sel1, Z --> sel2 = select cond2, X, Z if cond2 implies cond1 sel2 = select cond2, Y, Z if cond2 implies !cond1 ``` Alive2: https://alive2.llvm.org/ce/z/9A_arU It cannot be done in CVP/SCCP since we should guarantee that `cond2` is not an undef.
1 parent ed62758 commit 0c47363

File tree

3 files changed

+128
-59
lines changed

3 files changed

+128
-59
lines changed

llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp

Lines changed: 33 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -2643,46 +2643,33 @@ static Value *foldSelectWithFrozenICmp(SelectInst &Sel, InstCombiner::BuilderTy
26432643
return nullptr;
26442644
}
26452645

2646+
/// Given that \p CondVal is known to be \p CondIsTrue, try to simplify \p SI.
2647+
static Value *simplifyNestedSelectsUsingImpliedCond(SelectInst &SI,
2648+
Value *CondVal,
2649+
bool CondIsTrue,
2650+
const DataLayout &DL) {
2651+
Value *InnerCondVal = SI.getCondition();
2652+
Value *InnerTrueVal = SI.getTrueValue();
2653+
Value *InnerFalseVal = SI.getFalseValue();
2654+
assert(CondVal->getType() == InnerCondVal->getType() &&
2655+
"The type of inner condition must match with the outer.");
2656+
if (auto Implied = isImpliedCondition(CondVal, InnerCondVal, DL, CondIsTrue))
2657+
return *Implied ? InnerTrueVal : InnerFalseVal;
2658+
return nullptr;
2659+
}
2660+
26462661
Instruction *InstCombinerImpl::foldAndOrOfSelectUsingImpliedCond(Value *Op,
26472662
SelectInst &SI,
26482663
bool IsAnd) {
2649-
Value *CondVal = SI.getCondition();
2650-
Value *A = SI.getTrueValue();
2651-
Value *B = SI.getFalseValue();
2652-
26532664
assert(Op->getType()->isIntOrIntVectorTy(1) &&
26542665
"Op must be either i1 or vector of i1.");
2655-
2656-
std::optional<bool> Res = isImpliedCondition(Op, CondVal, DL, IsAnd);
2657-
if (!Res)
2666+
if (SI.getCondition()->getType() != Op->getType())
26582667
return nullptr;
2659-
2660-
Value *Zero = Constant::getNullValue(A->getType());
2661-
Value *One = Constant::getAllOnesValue(A->getType());
2662-
2663-
if (*Res == true) {
2664-
if (IsAnd)
2665-
// select op, (select cond, A, B), false => select op, A, false
2666-
// and op, (select cond, A, B) => select op, A, false
2667-
// if op = true implies condval = true.
2668-
return SelectInst::Create(Op, A, Zero);
2669-
else
2670-
// select op, true, (select cond, A, B) => select op, true, A
2671-
// or op, (select cond, A, B) => select op, true, A
2672-
// if op = false implies condval = true.
2673-
return SelectInst::Create(Op, One, A);
2674-
} else {
2675-
if (IsAnd)
2676-
// select op, (select cond, A, B), false => select op, B, false
2677-
// and op, (select cond, A, B) => select op, B, false
2678-
// if op = true implies condval = false.
2679-
return SelectInst::Create(Op, B, Zero);
2680-
else
2681-
// select op, true, (select cond, A, B) => select op, true, B
2682-
// or op, (select cond, A, B) => select op, true, B
2683-
// if op = false implies condval = false.
2684-
return SelectInst::Create(Op, One, B);
2685-
}
2668+
if (Value *V = simplifyNestedSelectsUsingImpliedCond(SI, Op, IsAnd, DL))
2669+
return SelectInst::Create(Op,
2670+
IsAnd ? V : ConstantInt::getTrue(Op->getType()),
2671+
IsAnd ? ConstantInt::getFalse(Op->getType()) : V);
2672+
return nullptr;
26862673
}
26872674

26882675
// Canonicalize select with fcmp to fabs(). -0.0 makes this tricky. We need
@@ -3138,11 +3125,6 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) {
31383125
return replaceInstUsesWith(SI, Op1);
31393126
}
31403127

3141-
if (auto *Op1SI = dyn_cast<SelectInst>(Op1))
3142-
if (auto *I = foldAndOrOfSelectUsingImpliedCond(CondVal, *Op1SI,
3143-
/* IsAnd */ IsAnd))
3144-
return I;
3145-
31463128
if (auto *ICmp0 = dyn_cast<ICmpInst>(CondVal))
31473129
if (auto *ICmp1 = dyn_cast<ICmpInst>(Op1))
31483130
if (auto *V = foldAndOrOfICmps(ICmp0, ICmp1, SI, IsAnd,
@@ -3643,12 +3625,12 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
36433625

36443626
if (SelectInst *TrueSI = dyn_cast<SelectInst>(TrueVal)) {
36453627
if (TrueSI->getCondition()->getType() == CondVal->getType()) {
3646-
// select(C, select(C, a, b), c) -> select(C, a, c)
3647-
if (TrueSI->getCondition() == CondVal) {
3648-
if (SI.getTrueValue() == TrueSI->getTrueValue())
3649-
return nullptr;
3650-
return replaceOperand(SI, 1, TrueSI->getTrueValue());
3651-
}
3628+
// Fold nested selects if the inner condition can be implied by the outer
3629+
// condition.
3630+
if (Value *V = simplifyNestedSelectsUsingImpliedCond(
3631+
*TrueSI, CondVal, /*CondIsTrue=*/true, DL))
3632+
return replaceOperand(SI, 1, V);
3633+
36523634
// select(C0, select(C1, a, b), b) -> select(C0&C1, a, b)
36533635
// We choose this as normal form to enable folding on the And and
36543636
// shortening paths for the values (this helps getUnderlyingObjects() for
@@ -3663,12 +3645,12 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
36633645
}
36643646
if (SelectInst *FalseSI = dyn_cast<SelectInst>(FalseVal)) {
36653647
if (FalseSI->getCondition()->getType() == CondVal->getType()) {
3666-
// select(C, a, select(C, b, c)) -> select(C, a, c)
3667-
if (FalseSI->getCondition() == CondVal) {
3668-
if (SI.getFalseValue() == FalseSI->getFalseValue())
3669-
return nullptr;
3670-
return replaceOperand(SI, 2, FalseSI->getFalseValue());
3671-
}
3648+
// Fold nested selects if the inner condition can be implied by the outer
3649+
// condition.
3650+
if (Value *V = simplifyNestedSelectsUsingImpliedCond(
3651+
*FalseSI, CondVal, /*CondIsTrue=*/false, DL))
3652+
return replaceOperand(SI, 2, V);
3653+
36723654
// select(C0, a, select(C1, a, b)) -> select(C0|C1, a, b)
36733655
if (FalseSI->getTrueValue() == TrueVal && FalseSI->hasOneUse()) {
36743656
Value *Or = Builder.CreateLogicalOr(CondVal, FalseSI->getCondition());

llvm/test/Transforms/InstCombine/canonicalize-clamp-like-pattern-between-negative-and-positive-thresholds.ll

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -189,10 +189,8 @@ define i32 @n9_ult_slt_neg17(i32 %x, i32 %replacement_low, i32 %replacement_high
189189
; Regression test for PR53252.
190190
define i32 @n10_ugt_slt(i32 %x, i32 %replacement_low, i32 %replacement_high) {
191191
; CHECK-LABEL: @n10_ugt_slt(
192-
; CHECK-NEXT: [[T0:%.*]] = icmp slt i32 [[X:%.*]], 0
193-
; CHECK-NEXT: [[T1:%.*]] = select i1 [[T0]], i32 [[REPLACEMENT_LOW:%.*]], i32 [[REPLACEMENT_HIGH:%.*]]
194-
; CHECK-NEXT: [[T2:%.*]] = icmp ugt i32 [[X]], 128
195-
; CHECK-NEXT: [[R:%.*]] = select i1 [[T2]], i32 [[X]], i32 [[T1]]
192+
; CHECK-NEXT: [[T2:%.*]] = icmp ugt i32 [[X:%.*]], 128
193+
; CHECK-NEXT: [[R:%.*]] = select i1 [[T2]], i32 [[X]], i32 [[REPLACEMENT_HIGH:%.*]]
196194
; CHECK-NEXT: ret i32 [[R]]
197195
;
198196
%t0 = icmp slt i32 %x, 0
@@ -204,10 +202,8 @@ define i32 @n10_ugt_slt(i32 %x, i32 %replacement_low, i32 %replacement_high) {
204202

205203
define i32 @n11_uge_slt(i32 %x, i32 %replacement_low, i32 %replacement_high) {
206204
; CHECK-LABEL: @n11_uge_slt(
207-
; CHECK-NEXT: [[T0:%.*]] = icmp slt i32 [[X:%.*]], 0
208-
; CHECK-NEXT: [[T1:%.*]] = select i1 [[T0]], i32 [[REPLACEMENT_LOW:%.*]], i32 [[REPLACEMENT_HIGH:%.*]]
209-
; CHECK-NEXT: [[T2:%.*]] = icmp ult i32 [[X]], 129
210-
; CHECK-NEXT: [[R:%.*]] = select i1 [[T2]], i32 [[T1]], i32 [[X]]
205+
; CHECK-NEXT: [[T2:%.*]] = icmp ult i32 [[X:%.*]], 129
206+
; CHECK-NEXT: [[R:%.*]] = select i1 [[T2]], i32 [[REPLACEMENT_HIGH:%.*]], i32 [[X]]
211207
; CHECK-NEXT: ret i32 [[R]]
212208
;
213209
%t0 = icmp slt i32 %x, 0

llvm/test/Transforms/InstCombine/nested-select.ll

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,3 +498,94 @@ define i1 @orcond.111.inv.all.conds(i1 %inner.cond, i1 %alt.cond, i1 %inner.sel.
498498
%outer.sel = select i1 %not.outer.cond, i1 true, i1 %inner.sel
499499
ret i1 %outer.sel
500500
}
501+
502+
define i8 @test_implied_true(i8 %x) {
503+
; CHECK-LABEL: @test_implied_true(
504+
; CHECK-NEXT: [[CMP2:%.*]] = icmp slt i8 [[X:%.*]], 0
505+
; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[CMP2]], i8 0, i8 20
506+
; CHECK-NEXT: ret i8 [[SEL2]]
507+
;
508+
%cmp1 = icmp slt i8 %x, 10
509+
%cmp2 = icmp slt i8 %x, 0
510+
%sel1 = select i1 %cmp1, i8 0, i8 5
511+
%sel2 = select i1 %cmp2, i8 %sel1, i8 20
512+
ret i8 %sel2
513+
}
514+
515+
define <2 x i8> @test_implied_true_vec(<2 x i8> %x) {
516+
; CHECK-LABEL: @test_implied_true_vec(
517+
; CHECK-NEXT: [[CMP2:%.*]] = icmp slt <2 x i8> [[X:%.*]], zeroinitializer
518+
; CHECK-NEXT: [[SEL2:%.*]] = select <2 x i1> [[CMP2]], <2 x i8> zeroinitializer, <2 x i8> <i8 20, i8 20>
519+
; CHECK-NEXT: ret <2 x i8> [[SEL2]]
520+
;
521+
%cmp1 = icmp slt <2 x i8> %x, <i8 10, i8 10>
522+
%cmp2 = icmp slt <2 x i8> %x, zeroinitializer
523+
%sel1 = select <2 x i1> %cmp1, <2 x i8> zeroinitializer, <2 x i8> <i8 5, i8 5>
524+
%sel2 = select <2 x i1> %cmp2, <2 x i8> %sel1, <2 x i8> <i8 20, i8 20>
525+
ret <2 x i8> %sel2
526+
}
527+
528+
define i8 @test_implied_true_falseval(i8 %x) {
529+
; CHECK-LABEL: @test_implied_true_falseval(
530+
; CHECK-NEXT: [[CMP2:%.*]] = icmp sgt i8 [[X:%.*]], 0
531+
; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[CMP2]], i8 20, i8 0
532+
; CHECK-NEXT: ret i8 [[SEL2]]
533+
;
534+
%cmp1 = icmp slt i8 %x, 10
535+
%cmp2 = icmp sgt i8 %x, 0
536+
%sel1 = select i1 %cmp1, i8 0, i8 5
537+
%sel2 = select i1 %cmp2, i8 20, i8 %sel1
538+
ret i8 %sel2
539+
}
540+
541+
define i8 @test_implied_false(i8 %x) {
542+
; CHECK-LABEL: @test_implied_false(
543+
; CHECK-NEXT: [[CMP2:%.*]] = icmp slt i8 [[X:%.*]], 0
544+
; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[CMP2]], i8 5, i8 20
545+
; CHECK-NEXT: ret i8 [[SEL2]]
546+
;
547+
%cmp1 = icmp sgt i8 %x, 10
548+
%cmp2 = icmp slt i8 %x, 0
549+
%sel1 = select i1 %cmp1, i8 0, i8 5
550+
%sel2 = select i1 %cmp2, i8 %sel1, i8 20
551+
ret i8 %sel2
552+
}
553+
554+
; Negative tests
555+
556+
define i8 @test_imply_fail(i8 %x) {
557+
; CHECK-LABEL: @test_imply_fail(
558+
; CHECK-NEXT: [[CMP1:%.*]] = icmp slt i8 [[X:%.*]], -10
559+
; CHECK-NEXT: [[CMP2:%.*]] = icmp slt i8 [[X]], 0
560+
; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[CMP1]], i8 0, i8 5
561+
; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[CMP2]], i8 [[SEL1]], i8 20
562+
; CHECK-NEXT: ret i8 [[SEL2]]
563+
;
564+
%cmp1 = icmp slt i8 %x, -10
565+
%cmp2 = icmp slt i8 %x, 0
566+
%sel1 = select i1 %cmp1, i8 0, i8 5
567+
%sel2 = select i1 %cmp2, i8 %sel1, i8 20
568+
ret i8 %sel2
569+
}
570+
571+
define <2 x i8> @test_imply_type_mismatch(<2 x i8> %x, i8 %y) {
572+
; CHECK-LABEL: @test_imply_type_mismatch(
573+
; CHECK-NEXT: [[CMP1:%.*]] = icmp slt <2 x i8> [[X:%.*]], <i8 10, i8 10>
574+
; CHECK-NEXT: [[CMP2:%.*]] = icmp slt i8 [[Y:%.*]], 0
575+
; CHECK-NEXT: [[SEL1:%.*]] = select <2 x i1> [[CMP1]], <2 x i8> zeroinitializer, <2 x i8> <i8 5, i8 5>
576+
; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[CMP2]], <2 x i8> [[SEL1]], <2 x i8> <i8 20, i8 20>
577+
; CHECK-NEXT: ret <2 x i8> [[SEL2]]
578+
;
579+
%cmp1 = icmp slt <2 x i8> %x, <i8 10, i8 10>
580+
%cmp2 = icmp slt i8 %y, 0
581+
%sel1 = select <2 x i1> %cmp1, <2 x i8> zeroinitializer, <2 x i8> <i8 5, i8 5>
582+
%sel2 = select i1 %cmp2, <2 x i8> %sel1, <2 x i8> <i8 20, i8 20>
583+
ret <2 x i8> %sel2
584+
}
585+
586+
define <4 x i1> @test_dont_crash(i1 %cond, <4 x i1> %a, <4 x i1> %b) {
587+
entry:
588+
%sel = select i1 %cond, <4 x i1> %a, <4 x i1> zeroinitializer
589+
%and = and <4 x i1> %sel, %b
590+
ret <4 x i1> %and
591+
}

0 commit comments

Comments
 (0)