diff --git a/llvm/include/llvm/Analysis/CmpInstAnalysis.h b/llvm/include/llvm/Analysis/CmpInstAnalysis.h index c7862a6d39d07..aeda58ac7535d 100644 --- a/llvm/include/llvm/Analysis/CmpInstAnalysis.h +++ b/llvm/include/llvm/Analysis/CmpInstAnalysis.h @@ -108,6 +108,12 @@ namespace llvm { bool LookThroughTrunc = true, bool AllowNonZeroC = false); + /// Decompose an icmp into the form ((X & Mask) pred C) if + /// possible. Unless \p AllowNonZeroC is true, C will always be 0. + std::optional + decomposeBitTest(Value *Cond, bool LookThroughTrunc = true, + bool AllowNonZeroC = false); + } // end namespace llvm #endif diff --git a/llvm/lib/Analysis/CmpInstAnalysis.cpp b/llvm/lib/Analysis/CmpInstAnalysis.cpp index 2580ea7e97248..3599428c5ff41 100644 --- a/llvm/lib/Analysis/CmpInstAnalysis.cpp +++ b/llvm/lib/Analysis/CmpInstAnalysis.cpp @@ -165,3 +165,17 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred, return Result; } + +std::optional +llvm::decomposeBitTest(Value *Cond, bool LookThruTrunc, bool AllowNonZeroC) { + if (auto *ICmp = dyn_cast(Cond)) { + // Don't allow pointers. Splat vectors are fine. + if (!ICmp->getOperand(0)->getType()->isIntOrIntVectorTy()) + return std::nullopt; + return decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1), + ICmp->getPredicate(), LookThruTrunc, + AllowNonZeroC); + } + + return std::nullopt; +} diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index f82a557e5760c..f7d17b1aa3865 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -179,10 +179,10 @@ static unsigned conjugateICmpMask(unsigned Mask) { } // Adapts the external decomposeBitTestICmp for local use. -static bool decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pred, +static bool decomposeBitTestICmp(Value *Cond, CmpInst::Predicate &Pred, Value *&X, Value *&Y, Value *&Z) { - auto Res = llvm::decomposeBitTestICmp( - LHS, RHS, Pred, /*LookThroughTrunc=*/true, /*AllowNonZeroC=*/true); + auto Res = llvm::decomposeBitTest(Cond, /*LookThroughTrunc=*/true, + /*AllowNonZeroC=*/true); if (!Res) return false; @@ -198,13 +198,10 @@ static bool decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pre /// the right hand side as a pair. /// LHS and RHS are the left hand side and the right hand side ICmps and PredL /// and PredR are their predicates, respectively. -static std::optional> getMaskedTypeForICmpPair( - Value *&A, Value *&B, Value *&C, Value *&D, Value *&E, ICmpInst *LHS, - ICmpInst *RHS, ICmpInst::Predicate &PredL, ICmpInst::Predicate &PredR) { - // Don't allow pointers. Splat vectors are fine. - if (!LHS->getOperand(0)->getType()->isIntOrIntVectorTy() || - !RHS->getOperand(0)->getType()->isIntOrIntVectorTy()) - return std::nullopt; +static std::optional> +getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, Value *&D, Value *&E, + Value *LHS, Value *RHS, ICmpInst::Predicate &PredL, + ICmpInst::Predicate &PredR) { // Here comes the tricky part: // LHS might be of the form L11 & L12 == X, X == L21 & L22, @@ -212,13 +209,23 @@ static std::optional> getMaskedTypeForICmpPair( // Now we must find those components L** and R**, that are equal, so // that we can extract the parameters A, B, C, D, and E for the canonical // above. - Value *L1 = LHS->getOperand(0); - Value *L2 = LHS->getOperand(1); - Value *L11, *L12, *L21, *L22; + // Check whether the icmp can be decomposed into a bit test. - if (decomposeBitTestICmp(L1, L2, PredL, L11, L12, L2)) { + Value *L1, *L11, *L12, *L2, *L21, *L22; + if (decomposeBitTestICmp(LHS, PredL, L11, L12, L2)) { L21 = L22 = L1 = nullptr; } else { + auto *LHSCMP = dyn_cast(LHS); + if (!LHSCMP) + return std::nullopt; + + // Don't allow pointers. Splat vectors are fine. + if (!LHSCMP->getOperand(0)->getType()->isIntOrIntVectorTy()) + return std::nullopt; + + PredL = LHSCMP->getPredicate(); + L1 = LHSCMP->getOperand(0); + L2 = LHSCMP->getOperand(1); // Look for ANDs in the LHS icmp. if (!match(L1, m_And(m_Value(L11), m_Value(L12)))) { // Any icmp can be viewed as being trivially masked; if it allows us to @@ -237,11 +244,8 @@ static std::optional> getMaskedTypeForICmpPair( if (!ICmpInst::isEquality(PredL)) return std::nullopt; - Value *R1 = RHS->getOperand(0); - Value *R2 = RHS->getOperand(1); - Value *R11, *R12; - bool Ok = false; - if (decomposeBitTestICmp(R1, R2, PredR, R11, R12, R2)) { + Value *R11, *R12, *R2; + if (decomposeBitTestICmp(RHS, PredR, R11, R12, R2)) { if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) { A = R11; D = R12; @@ -252,9 +256,19 @@ static std::optional> getMaskedTypeForICmpPair( return std::nullopt; } E = R2; - R1 = nullptr; - Ok = true; } else { + auto *RHSCMP = dyn_cast(RHS); + if (!RHSCMP) + return std::nullopt; + // Don't allow pointers. Splat vectors are fine. + if (!RHSCMP->getOperand(0)->getType()->isIntOrIntVectorTy()) + return std::nullopt; + + PredR = RHSCMP->getPredicate(); + + Value *R1 = RHSCMP->getOperand(0); + R2 = RHSCMP->getOperand(1); + bool Ok = false; if (!match(R1, m_And(m_Value(R11), m_Value(R12)))) { // As before, model no mask as a trivial mask if it'll let us do an // optimization. @@ -277,36 +291,32 @@ static std::optional> getMaskedTypeForICmpPair( // Avoid matching against the -1 value we created for unmasked operand. if (Ok && match(A, m_AllOnes())) Ok = false; + + // Look for ANDs on the right side of the RHS icmp. + if (!Ok) { + if (!match(R2, m_And(m_Value(R11), m_Value(R12)))) { + R11 = R2; + R12 = Constant::getAllOnesValue(R2->getType()); + } + + if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) { + A = R11; + D = R12; + E = R1; + } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) { + A = R12; + D = R11; + E = R1; + } else { + return std::nullopt; + } + } } // Bail if RHS was a icmp that can't be decomposed into an equality. if (!ICmpInst::isEquality(PredR)) return std::nullopt; - // Look for ANDs on the right side of the RHS icmp. - if (!Ok) { - if (!match(R2, m_And(m_Value(R11), m_Value(R12)))) { - R11 = R2; - R12 = Constant::getAllOnesValue(R2->getType()); - } - - if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) { - A = R11; - D = R12; - E = R1; - Ok = true; - } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) { - A = R12; - D = R11; - E = R1; - Ok = true; - } else { - return std::nullopt; - } - - assert(Ok && "Failed to find AND on the right side of the RHS icmp."); - } - if (L11 == A) { B = L12; C = L2; @@ -333,8 +343,8 @@ static std::optional> getMaskedTypeForICmpPair( /// (icmp (A & 12) != 0) & (icmp (A & 15) == 8) -> (icmp (A & 15) == 8). /// Also used for logical and/or, must be poison safe. static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed( - ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, Value *A, Value *B, Value *D, - Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR, + Value *LHS, Value *RHS, bool IsAnd, Value *A, Value *B, Value *D, Value *E, + ICmpInst::Predicate PredL, ICmpInst::Predicate PredR, InstCombiner::BuilderTy &Builder) { // We are given the canonical form: // (icmp ne (A & B), 0) & (icmp eq (A & D), E). @@ -457,7 +467,8 @@ static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed( // (icmp ne (A & 15), 0) & (icmp eq (A & 15), 8) -> (icmp eq (A & 15), 8). if (IsSuperSetOrEqual(BCst, DCst)) { // We can't guarantee that samesign hold after this fold. - RHS->setSameSign(false); + if (auto *ICmp = dyn_cast(RHS)) + ICmp->setSameSign(false); return RHS; } // Otherwise, B is a subset of D. If B and E have a common bit set, @@ -466,7 +477,8 @@ static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed( assert(IsSubSetOrEqual(BCst, DCst) && "Precondition due to above code"); if ((*BCst & ECst) != 0) { // We can't guarantee that samesign hold after this fold. - RHS->setSameSign(false); + if (auto *ICmp = dyn_cast(RHS)) + ICmp->setSameSign(false); return RHS; } // Otherwise, LHS and RHS contradict and the whole expression becomes false @@ -481,8 +493,8 @@ static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed( /// aren't of the common mask pattern type. /// Also used for logical and/or, must be poison safe. static Value *foldLogOpOfMaskedICmpsAsymmetric( - ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, Value *A, Value *B, Value *C, - Value *D, Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR, + Value *LHS, Value *RHS, bool IsAnd, Value *A, Value *B, Value *C, Value *D, + Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR, unsigned LHSMask, unsigned RHSMask, InstCombiner::BuilderTy &Builder) { assert(ICmpInst::isEquality(PredL) && ICmpInst::isEquality(PredR) && "Expected equality predicates for masked type of icmps."); @@ -511,12 +523,12 @@ static Value *foldLogOpOfMaskedICmpsAsymmetric( /// Try to fold (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E) /// into a single (icmp(A & X) ==/!= Y). -static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, +static Value *foldLogOpOfMaskedICmps(Value *LHS, Value *RHS, bool IsAnd, bool IsLogical, InstCombiner::BuilderTy &Builder, const SimplifyQuery &Q) { Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr; - ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); + ICmpInst::Predicate PredL, PredR; std::optional> MaskPair = getMaskedTypeForICmpPair(A, B, C, D, E, LHS, RHS, PredL, PredR); if (!MaskPair) @@ -1066,8 +1078,7 @@ static Value *foldPowerOf2AndShiftedMask(ICmpInst *Cmp0, ICmpInst *Cmp1, if (!JoinedByAnd) return nullptr; Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr; - ICmpInst::Predicate CmpPred0 = Cmp0->getPredicate(), - CmpPred1 = Cmp1->getPredicate(); + ICmpInst::Predicate CmpPred0, CmpPred1; // Assuming P is a 2^n, getMaskedTypeForICmpPair will normalize (icmp X u< // 2^n) into (icmp (X & ~(2^n-1)) == 0) and (icmp X s> -1) into (icmp (X & // SignMask) == 0).