diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 71fa9b9ba41eb..c47bc33df0706 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -2643,46 +2643,33 @@ static Value *foldSelectWithFrozenICmp(SelectInst &Sel, InstCombiner::BuilderTy return nullptr; } +/// Given that \p CondVal is known to be \p CondIsTrue, try to simplify \p SI. +static Value *simplifyNestedSelectsUsingImpliedCond(SelectInst &SI, + Value *CondVal, + bool CondIsTrue, + const DataLayout &DL) { + Value *InnerCondVal = SI.getCondition(); + Value *InnerTrueVal = SI.getTrueValue(); + Value *InnerFalseVal = SI.getFalseValue(); + assert(CondVal->getType() == InnerCondVal->getType() && + "The type of inner condition must match with the outer."); + if (auto Implied = isImpliedCondition(CondVal, InnerCondVal, DL, CondIsTrue)) + return *Implied ? InnerTrueVal : InnerFalseVal; + return nullptr; +} + Instruction *InstCombinerImpl::foldAndOrOfSelectUsingImpliedCond(Value *Op, SelectInst &SI, bool IsAnd) { - Value *CondVal = SI.getCondition(); - Value *A = SI.getTrueValue(); - Value *B = SI.getFalseValue(); - assert(Op->getType()->isIntOrIntVectorTy(1) && "Op must be either i1 or vector of i1."); - - std::optional Res = isImpliedCondition(Op, CondVal, DL, IsAnd); - if (!Res) + if (SI.getCondition()->getType() != Op->getType()) return nullptr; - - Value *Zero = Constant::getNullValue(A->getType()); - Value *One = Constant::getAllOnesValue(A->getType()); - - if (*Res == true) { - if (IsAnd) - // select op, (select cond, A, B), false => select op, A, false - // and op, (select cond, A, B) => select op, A, false - // if op = true implies condval = true. - return SelectInst::Create(Op, A, Zero); - else - // select op, true, (select cond, A, B) => select op, true, A - // or op, (select cond, A, B) => select op, true, A - // if op = false implies condval = true. - return SelectInst::Create(Op, One, A); - } else { - if (IsAnd) - // select op, (select cond, A, B), false => select op, B, false - // and op, (select cond, A, B) => select op, B, false - // if op = true implies condval = false. - return SelectInst::Create(Op, B, Zero); - else - // select op, true, (select cond, A, B) => select op, true, B - // or op, (select cond, A, B) => select op, true, B - // if op = false implies condval = false. - return SelectInst::Create(Op, One, B); - } + if (Value *V = simplifyNestedSelectsUsingImpliedCond(SI, Op, IsAnd, DL)) + return SelectInst::Create(Op, + IsAnd ? V : ConstantInt::getTrue(Op->getType()), + IsAnd ? ConstantInt::getFalse(Op->getType()) : V); + return nullptr; } // Canonicalize select with fcmp to fabs(). -0.0 makes this tricky. We need @@ -3138,11 +3125,6 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) { return replaceInstUsesWith(SI, Op1); } - if (auto *Op1SI = dyn_cast(Op1)) - if (auto *I = foldAndOrOfSelectUsingImpliedCond(CondVal, *Op1SI, - /* IsAnd */ IsAnd)) - return I; - if (auto *ICmp0 = dyn_cast(CondVal)) if (auto *ICmp1 = dyn_cast(Op1)) if (auto *V = foldAndOrOfICmps(ICmp0, ICmp1, SI, IsAnd, @@ -3643,12 +3625,12 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { if (SelectInst *TrueSI = dyn_cast(TrueVal)) { if (TrueSI->getCondition()->getType() == CondVal->getType()) { - // select(C, select(C, a, b), c) -> select(C, a, c) - if (TrueSI->getCondition() == CondVal) { - if (SI.getTrueValue() == TrueSI->getTrueValue()) - return nullptr; - return replaceOperand(SI, 1, TrueSI->getTrueValue()); - } + // Fold nested selects if the inner condition can be implied by the outer + // condition. + if (Value *V = simplifyNestedSelectsUsingImpliedCond( + *TrueSI, CondVal, /*CondIsTrue=*/true, DL)) + return replaceOperand(SI, 1, V); + // select(C0, select(C1, a, b), b) -> select(C0&C1, a, b) // We choose this as normal form to enable folding on the And and // shortening paths for the values (this helps getUnderlyingObjects() for @@ -3663,12 +3645,12 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { } if (SelectInst *FalseSI = dyn_cast(FalseVal)) { if (FalseSI->getCondition()->getType() == CondVal->getType()) { - // select(C, a, select(C, b, c)) -> select(C, a, c) - if (FalseSI->getCondition() == CondVal) { - if (SI.getFalseValue() == FalseSI->getFalseValue()) - return nullptr; - return replaceOperand(SI, 2, FalseSI->getFalseValue()); - } + // Fold nested selects if the inner condition can be implied by the outer + // condition. + if (Value *V = simplifyNestedSelectsUsingImpliedCond( + *FalseSI, CondVal, /*CondIsTrue=*/false, DL)) + return replaceOperand(SI, 2, V); + // select(C0, a, select(C1, a, b)) -> select(C0|C1, a, b) if (FalseSI->getTrueValue() == TrueVal && FalseSI->hasOneUse()) { Value *Or = Builder.CreateLogicalOr(CondVal, FalseSI->getCondition()); diff --git a/llvm/test/Transforms/InstCombine/canonicalize-clamp-like-pattern-between-negative-and-positive-thresholds.ll b/llvm/test/Transforms/InstCombine/canonicalize-clamp-like-pattern-between-negative-and-positive-thresholds.ll index d03e22bc4c9fb..b5ef1f466958d 100644 --- a/llvm/test/Transforms/InstCombine/canonicalize-clamp-like-pattern-between-negative-and-positive-thresholds.ll +++ b/llvm/test/Transforms/InstCombine/canonicalize-clamp-like-pattern-between-negative-and-positive-thresholds.ll @@ -189,10 +189,8 @@ define i32 @n9_ult_slt_neg17(i32 %x, i32 %replacement_low, i32 %replacement_high ; Regression test for PR53252. define i32 @n10_ugt_slt(i32 %x, i32 %replacement_low, i32 %replacement_high) { ; CHECK-LABEL: @n10_ugt_slt( -; CHECK-NEXT: [[T0:%.*]] = icmp slt i32 [[X:%.*]], 0 -; CHECK-NEXT: [[T1:%.*]] = select i1 [[T0]], i32 [[REPLACEMENT_LOW:%.*]], i32 [[REPLACEMENT_HIGH:%.*]] -; CHECK-NEXT: [[T2:%.*]] = icmp ugt i32 [[X]], 128 -; CHECK-NEXT: [[R:%.*]] = select i1 [[T2]], i32 [[X]], i32 [[T1]] +; CHECK-NEXT: [[T2:%.*]] = icmp ugt i32 [[X:%.*]], 128 +; CHECK-NEXT: [[R:%.*]] = select i1 [[T2]], i32 [[X]], i32 [[REPLACEMENT_HIGH:%.*]] ; CHECK-NEXT: ret i32 [[R]] ; %t0 = icmp slt i32 %x, 0 @@ -204,10 +202,8 @@ define i32 @n10_ugt_slt(i32 %x, i32 %replacement_low, i32 %replacement_high) { define i32 @n11_uge_slt(i32 %x, i32 %replacement_low, i32 %replacement_high) { ; CHECK-LABEL: @n11_uge_slt( -; CHECK-NEXT: [[T0:%.*]] = icmp slt i32 [[X:%.*]], 0 -; CHECK-NEXT: [[T1:%.*]] = select i1 [[T0]], i32 [[REPLACEMENT_LOW:%.*]], i32 [[REPLACEMENT_HIGH:%.*]] -; CHECK-NEXT: [[T2:%.*]] = icmp ult i32 [[X]], 129 -; CHECK-NEXT: [[R:%.*]] = select i1 [[T2]], i32 [[T1]], i32 [[X]] +; CHECK-NEXT: [[T2:%.*]] = icmp ult i32 [[X:%.*]], 129 +; CHECK-NEXT: [[R:%.*]] = select i1 [[T2]], i32 [[REPLACEMENT_HIGH:%.*]], i32 [[X]] ; CHECK-NEXT: ret i32 [[R]] ; %t0 = icmp slt i32 %x, 0 diff --git a/llvm/test/Transforms/InstCombine/nested-select.ll b/llvm/test/Transforms/InstCombine/nested-select.ll index 42a0f81e7b85a..d01dcf0793ade 100644 --- a/llvm/test/Transforms/InstCombine/nested-select.ll +++ b/llvm/test/Transforms/InstCombine/nested-select.ll @@ -498,3 +498,94 @@ define i1 @orcond.111.inv.all.conds(i1 %inner.cond, i1 %alt.cond, i1 %inner.sel. %outer.sel = select i1 %not.outer.cond, i1 true, i1 %inner.sel ret i1 %outer.sel } + +define i8 @test_implied_true(i8 %x) { +; CHECK-LABEL: @test_implied_true( +; CHECK-NEXT: [[CMP2:%.*]] = icmp slt i8 [[X:%.*]], 0 +; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[CMP2]], i8 0, i8 20 +; CHECK-NEXT: ret i8 [[SEL2]] +; + %cmp1 = icmp slt i8 %x, 10 + %cmp2 = icmp slt i8 %x, 0 + %sel1 = select i1 %cmp1, i8 0, i8 5 + %sel2 = select i1 %cmp2, i8 %sel1, i8 20 + ret i8 %sel2 +} + +define <2 x i8> @test_implied_true_vec(<2 x i8> %x) { +; CHECK-LABEL: @test_implied_true_vec( +; CHECK-NEXT: [[CMP2:%.*]] = icmp slt <2 x i8> [[X:%.*]], zeroinitializer +; CHECK-NEXT: [[SEL2:%.*]] = select <2 x i1> [[CMP2]], <2 x i8> zeroinitializer, <2 x i8> +; CHECK-NEXT: ret <2 x i8> [[SEL2]] +; + %cmp1 = icmp slt <2 x i8> %x, + %cmp2 = icmp slt <2 x i8> %x, zeroinitializer + %sel1 = select <2 x i1> %cmp1, <2 x i8> zeroinitializer, <2 x i8> + %sel2 = select <2 x i1> %cmp2, <2 x i8> %sel1, <2 x i8> + ret <2 x i8> %sel2 +} + +define i8 @test_implied_true_falseval(i8 %x) { +; CHECK-LABEL: @test_implied_true_falseval( +; CHECK-NEXT: [[CMP2:%.*]] = icmp sgt i8 [[X:%.*]], 0 +; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[CMP2]], i8 20, i8 0 +; CHECK-NEXT: ret i8 [[SEL2]] +; + %cmp1 = icmp slt i8 %x, 10 + %cmp2 = icmp sgt i8 %x, 0 + %sel1 = select i1 %cmp1, i8 0, i8 5 + %sel2 = select i1 %cmp2, i8 20, i8 %sel1 + ret i8 %sel2 +} + +define i8 @test_implied_false(i8 %x) { +; CHECK-LABEL: @test_implied_false( +; CHECK-NEXT: [[CMP2:%.*]] = icmp slt i8 [[X:%.*]], 0 +; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[CMP2]], i8 5, i8 20 +; CHECK-NEXT: ret i8 [[SEL2]] +; + %cmp1 = icmp sgt i8 %x, 10 + %cmp2 = icmp slt i8 %x, 0 + %sel1 = select i1 %cmp1, i8 0, i8 5 + %sel2 = select i1 %cmp2, i8 %sel1, i8 20 + ret i8 %sel2 +} + +; Negative tests + +define i8 @test_imply_fail(i8 %x) { +; CHECK-LABEL: @test_imply_fail( +; CHECK-NEXT: [[CMP1:%.*]] = icmp slt i8 [[X:%.*]], -10 +; CHECK-NEXT: [[CMP2:%.*]] = icmp slt i8 [[X]], 0 +; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[CMP1]], i8 0, i8 5 +; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[CMP2]], i8 [[SEL1]], i8 20 +; CHECK-NEXT: ret i8 [[SEL2]] +; + %cmp1 = icmp slt i8 %x, -10 + %cmp2 = icmp slt i8 %x, 0 + %sel1 = select i1 %cmp1, i8 0, i8 5 + %sel2 = select i1 %cmp2, i8 %sel1, i8 20 + ret i8 %sel2 +} + +define <2 x i8> @test_imply_type_mismatch(<2 x i8> %x, i8 %y) { +; CHECK-LABEL: @test_imply_type_mismatch( +; CHECK-NEXT: [[CMP1:%.*]] = icmp slt <2 x i8> [[X:%.*]], +; CHECK-NEXT: [[CMP2:%.*]] = icmp slt i8 [[Y:%.*]], 0 +; CHECK-NEXT: [[SEL1:%.*]] = select <2 x i1> [[CMP1]], <2 x i8> zeroinitializer, <2 x i8> +; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[CMP2]], <2 x i8> [[SEL1]], <2 x i8> +; CHECK-NEXT: ret <2 x i8> [[SEL2]] +; + %cmp1 = icmp slt <2 x i8> %x, + %cmp2 = icmp slt i8 %y, 0 + %sel1 = select <2 x i1> %cmp1, <2 x i8> zeroinitializer, <2 x i8> + %sel2 = select i1 %cmp2, <2 x i8> %sel1, <2 x i8> + ret <2 x i8> %sel2 +} + +define <4 x i1> @test_dont_crash(i1 %cond, <4 x i1> %a, <4 x i1> %b) { +entry: + %sel = select i1 %cond, <4 x i1> %a, <4 x i1> zeroinitializer + %and = and <4 x i1> %sel, %b + ret <4 x i1> %and +}