diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h index 96fa16970584d..c1ee4c02e0108 100644 --- a/llvm/include/llvm/Analysis/ValueTracking.h +++ b/llvm/include/llvm/Analysis/ValueTracking.h @@ -94,6 +94,10 @@ void computeKnownBitsFromRangeMetadata(const MDNode &Ranges, KnownBits &Known); void computeKnownBitsFromContext(const Value *V, KnownBits &Known, unsigned Depth, const SimplifyQuery &Q); +void computeKnownBitsFromCond(const Value *V, Value *Cond, KnownBits &Known, + unsigned Depth, const SimplifyQuery &SQ, + bool Invert); + /// Using KnownBits LHS/RHS produce the known bits for logic op (and/xor/or). KnownBits analyzeKnownBitsFromAndXorOr(const Operator *I, const KnownBits &KnownLHS, diff --git a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h index ebcbd5d9e8880..27bcaad49e5b4 100644 --- a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h +++ b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h @@ -438,6 +438,13 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner { return llvm::computeKnownBits(V, Depth, SQ.getWithInstruction(CxtI)); } + void computeKnownBitsFromCond(const Value *V, Value *Cmp, KnownBits &Known, + unsigned Depth, const Instruction *CxtI, + bool Invert) const { + llvm::computeKnownBitsFromCond(V, Cmp, Known, Depth, + SQ.getWithInstruction(CxtI), Invert); + } + bool isKnownToBeAPowerOfTwo(const Value *V, bool OrZero = false, unsigned Depth = 0, const Instruction *CxtI = nullptr) { diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp index 285284dc27071..0d2d2d3bbbdbf 100644 --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -752,9 +752,9 @@ static void computeKnownBitsFromICmpCond(const Value *V, ICmpInst *Cmp, computeKnownBitsFromCmp(V, Pred, LHS, RHS, Known, SQ); } -static void computeKnownBitsFromCond(const Value *V, Value *Cond, - KnownBits &Known, unsigned Depth, - const SimplifyQuery &SQ, bool Invert) { +void llvm::computeKnownBitsFromCond(const Value *V, Value *Cond, + KnownBits &Known, unsigned Depth, + const SimplifyQuery &SQ, bool Invert) { Value *A, *B; if (Depth < MaxAnalysisRecursionDepth && match(Cond, m_LogicalOp(m_Value(A), m_Value(B)))) { diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index a22ee1de0ac21..eaa8faaa2db0f 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -1078,6 +1078,62 @@ static Value *foldAbsDiff(ICmpInst *Cmp, Value *TVal, Value *FVal, return nullptr; } +/// Attempts to fold (AND %A constant) --> %A +/// if all bits that are zero in the negated constant +/// are also zero in A's known zero bits. +static Value *foldAndMaskPattern(Value *V, Value *Cmp, SelectInst &SI, + InstCombinerImpl &IC, unsigned Depth = 0) { + + Value *A; + const APInt *MaskedConstant; + + if (match(V, m_And(m_Value(A), m_APInt(MaskedConstant))) && + isGuaranteedNotToBeUndef(A)) { + KnownBits Known = IC.computeKnownBits(A, 0, &SI); + IC.computeKnownBitsFromCond(A, Cmp, Known, 0, &SI, false); + if ((~(*MaskedConstant)).isSubsetOf(Known.Zero)) + return A; + } + + auto *I = dyn_cast(V); + if (!I || !isSafeToSpeculativelyExecute(I) || Depth >= 2) + return nullptr; + + bool Changed = false; + for (unsigned i = 0; i < I->getNumOperands(); ++i) { + llvm::Value *Operand = I->getOperand(i); + + if (std::any_of(Operand->user_begin(), Operand->user_end(), + [I](const User *User) { return User != I; })) + break; + + Value *NewOp = foldAndMaskPattern(Operand, Cmp, SI, IC, Depth + 1); + if (NewOp) { + IC.replaceOperand(*I, i, NewOp); + Changed = true; + } + } + + return Changed ? I : nullptr; +} + +/// Attmpts to fold expressions in both branches of a select instruction +/// based on KnownBits implied by the condition +// static Instruction *foldSelectWithIcmpEqAndPattern(Value *TVal, Value *FVal, +// Value *CondVal, +// SelectInst &SI, +// InstCombinerImpl &IC) { +// if (TVal->hasOneUse()) +// if (Value *newTrueOp = simplifyAndMaskPattern(TVal, CondVal, SI, IC)) +// return IC.replaceOperand(SI, 1, newTrueOp); + +// if (FVal->hasOneUse()) +// if (Value *newFalseOp = simplifyAndMaskPattern(FVal, CondVal, SI, IC)) +// return IC.replaceOperand(SI, 2, newFalseOp); + +// return nullptr; +// } + /// Fold the following code sequence: /// \code /// int a = ctlz(x & -x); @@ -4110,5 +4166,12 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { } } + // Attempts to recursively identify and fold (AND A constant) --> A + // in the true branch of the select if all bits + // that are zero in the negated constant are also zero in A's known zero bits. + if (TrueVal->hasOneUse()) + if (Value *newTrueOp = foldAndMaskPattern(TrueVal, CondVal, SI, *this)) + return replaceOperand(SI, 1, newTrueOp); + return nullptr; } diff --git a/llvm/test/Transforms/InstCombine/select-known-bits.ll b/llvm/test/Transforms/InstCombine/select-known-bits.ll new file mode 100644 index 0000000000000..52c56bed429a0 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/select-known-bits.ll @@ -0,0 +1,121 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; RUN: opt < %s -passes=instcombine -S | FileCheck %s + +define i8 @select_icmp_eq_mul_and(i8 noundef %a, i8 %b) { +; CHECK-LABEL: define i8 @select_icmp_eq_mul_and( +; CHECK-SAME: i8 noundef [[A:%.*]], i8 [[B:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[A]], 1 +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[TMP1]], 0 +; CHECK-NEXT: [[MUL:%.*]] = mul i8 [[A]], [[A]] +; CHECK-NEXT: [[RETVAL:%.*]] = select i1 [[CMP]], i8 [[MUL]], i8 [[B]] +; CHECK-NEXT: ret i8 [[RETVAL]] +; + %1 = and i8 %a, 1 + %cmp = icmp eq i8 %1, 0 + %div = and i8 %a, -2 + %mul = mul i8 %div, %div + %retval = select i1 %cmp, i8 %mul, i8 %b + ret i8 %retval +} + +define i8 @select_icmp_eq_mul_and_inv(i8 noundef %a, i8 %b) { +; CHECK-LABEL: define i8 @select_icmp_eq_mul_and_inv( +; CHECK-SAME: i8 noundef [[A:%.*]], i8 [[B:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[A]], 1 +; CHECK-NEXT: [[CMP_NOT:%.*]] = icmp eq i8 [[TMP1]], 0 +; CHECK-NEXT: [[MUL:%.*]] = mul i8 [[A]], [[A]] +; CHECK-NEXT: [[RETVAL:%.*]] = select i1 [[CMP_NOT]], i8 [[MUL]], i8 [[B]] +; CHECK-NEXT: ret i8 [[RETVAL]] +; + %1 = and i8 %a, 1 + %cmp = icmp eq i8 %1, 1 + %div = and i8 %a, -2 + %mul = mul i8 %div, %div + %retval = select i1 %cmp, i8 %b, i8 %mul + ret i8 %retval +} + +define i8 @select_icmp_eq_and(i8 noundef %a, i8 %b) { +; CHECK-LABEL: define i8 @select_icmp_eq_and( +; CHECK-SAME: i8 noundef [[A:%.*]], i8 [[B:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[A]], 1 +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[TMP1]], 0 +; CHECK-NEXT: [[RETVAL:%.*]] = select i1 [[CMP]], i8 [[A]], i8 [[B]] +; CHECK-NEXT: ret i8 [[RETVAL]] +; + %1 = and i8 %a, 1 + %cmp = icmp eq i8 %1, 0 + %div = and i8 %a, -2 + %retval = select i1 %cmp, i8 %div, i8 %b + ret i8 %retval +} + +define i8 @select_icmp_eq_and_inv(i8 noundef %a, i8 %b) { +; CHECK-LABEL: define i8 @select_icmp_eq_and_inv( +; CHECK-SAME: i8 noundef [[A:%.*]], i8 [[B:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[A]], 1 +; CHECK-NEXT: [[CMP_NOT:%.*]] = icmp eq i8 [[TMP1]], 0 +; CHECK-NEXT: [[RETVAL:%.*]] = select i1 [[CMP_NOT]], i8 [[A]], i8 [[B]] +; CHECK-NEXT: ret i8 [[RETVAL]] +; + %1 = and i8 %a, 1 + %cmp = icmp eq i8 %1, 1 + %div = and i8 %a, -2 + %retval = select i1 %cmp, i8 %b, i8 %div + ret i8 %retval +} + +;negative test +define i8 @select_icmp_eq_and_undef(i8 %a, i8 %b) { +; CHECK-LABEL: define i8 @select_icmp_eq_and_undef( +; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[A]], 1 +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[TMP1]], 0 +; CHECK-NEXT: [[DIV:%.*]] = and i8 [[A]], -2 +; CHECK-NEXT: [[RETVAL:%.*]] = select i1 [[CMP]], i8 [[DIV]], i8 [[B]] +; CHECK-NEXT: ret i8 [[RETVAL]] +; + %1 = and i8 %a, 1 + %cmp = icmp eq i8 %1, 0 + %div = and i8 %a, -2 + %retval = select i1 %cmp, i8 %div, i8 %b + ret i8 %retval +} + +;negative test +define i8 @select_icmp_eq_and_diff(i8 noundef %a, i8 %b, i8 %c) { +; CHECK-LABEL: define i8 @select_icmp_eq_and_diff( +; CHECK-SAME: i8 noundef [[A:%.*]], i8 [[B:%.*]], i8 [[C:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[A]], 1 +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[TMP1]], 0 +; CHECK-NEXT: [[DIV:%.*]] = and i8 [[C]], -2 +; CHECK-NEXT: [[RETVAL:%.*]] = select i1 [[CMP]], i8 [[DIV]], i8 [[B]] +; CHECK-NEXT: ret i8 [[RETVAL]] +; + %1 = and i8 %a, 1 + %cmp = icmp eq i8 %1, 0 + %div = and i8 %c, -2 + %retval = select i1 %cmp, i8 %div, i8 %b + ret i8 %retval +} + +;negative test +define i8 @select_icmp_eq_mul_and_extra_use(i8 noundef %a, i8 %b) { +; CHECK-LABEL: define i8 @select_icmp_eq_mul_and_extra_use( +; CHECK-SAME: i8 noundef [[A:%.*]], i8 [[B:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[A]], 1 +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[TMP1]], 0 +; CHECK-NEXT: [[DIV:%.*]] = and i8 [[A]], -2 +; CHECK-NEXT: [[MUL:%.*]] = mul i8 [[DIV]], [[DIV]] +; CHECK-NEXT: [[RETVAL:%.*]] = select i1 [[CMP]], i8 [[MUL]], i8 [[B]] +; CHECK-NEXT: [[SUM:%.*]] = add i8 [[MUL]], [[RETVAL]] +; CHECK-NEXT: ret i8 [[SUM]] +; + %1 = and i8 %a, 1 + %cmp = icmp eq i8 %1, 0 + %div = and i8 %a, -2 + %mul = mul i8 %div, %div + %retval = select i1 %cmp, i8 %mul, i8 %b + %sum = add i8 %mul, %retval + ret i8 %sum +} diff --git a/llvm/test/Transforms/InstCombine/select.ll b/llvm/test/Transforms/InstCombine/select.ll index 1369be305ec13..1c7247ec6a8b3 100644 --- a/llvm/test/Transforms/InstCombine/select.ll +++ b/llvm/test/Transforms/InstCombine/select.ll @@ -2989,9 +2989,8 @@ define i8 @select_replacement_loop3(i32 noundef %x) { define i16 @select_replacement_loop4(i16 noundef %p_12) { ; CHECK-LABEL: @select_replacement_loop4( -; CHECK-NEXT: [[AND1:%.*]] = and i16 [[P_12:%.*]], 1 -; CHECK-NEXT: [[CMP21:%.*]] = icmp ult i16 [[P_12]], 2 -; CHECK-NEXT: [[AND3:%.*]] = select i1 [[CMP21]], i16 [[AND1]], i16 0 +; CHECK-NEXT: [[CMP21:%.*]] = icmp ult i16 [[P_12:%.*]], 2 +; CHECK-NEXT: [[AND3:%.*]] = select i1 [[CMP21]], i16 [[P_12]], i16 0 ; CHECK-NEXT: ret i16 [[AND3]] ; %cmp1 = icmp ult i16 %p_12, 2 @@ -4671,8 +4670,7 @@ define i8 @select_knownbits_simplify(i8 noundef %x) { ; CHECK-LABEL: @select_knownbits_simplify( ; CHECK-NEXT: [[X_LO:%.*]] = and i8 [[X:%.*]], 1 ; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[X_LO]], 0 -; CHECK-NEXT: [[AND:%.*]] = and i8 [[X]], -2 -; CHECK-NEXT: [[RES:%.*]] = select i1 [[CMP]], i8 [[AND]], i8 0 +; CHECK-NEXT: [[RES:%.*]] = select i1 [[CMP]], i8 [[X]], i8 0 ; CHECK-NEXT: ret i8 [[RES]] ; %x.lo = and i8 %x, 1 @@ -4686,8 +4684,7 @@ define i8 @select_knownbits_simplify_nested(i8 noundef %x) { ; CHECK-LABEL: @select_knownbits_simplify_nested( ; CHECK-NEXT: [[X_LO:%.*]] = and i8 [[X:%.*]], 1 ; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[X_LO]], 0 -; CHECK-NEXT: [[AND:%.*]] = and i8 [[X]], -2 -; CHECK-NEXT: [[MUL:%.*]] = mul i8 [[AND]], [[AND]] +; CHECK-NEXT: [[MUL:%.*]] = mul i8 [[X]], [[X]] ; CHECK-NEXT: [[RES:%.*]] = select i1 [[CMP]], i8 [[MUL]], i8 0 ; CHECK-NEXT: ret i8 [[RES]] ;