Skip to content

[InstCombine] Prepare foldLogOpOfMaskedICmps to handle trunc to i1. (NFC) #122179

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions llvm/include/llvm/Analysis/CmpInstAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ namespace llvm {
bool LookThroughTrunc = true,
bool AllowNonZeroC = false);

/// Decompose an icmp into the form ((X & Mask) pred C) if
/// possible. Unless \p AllowNonZeroC is true, C will always be 0.
std::optional<DecomposedBitTest>
decomposeBitTest(Value *Cond, bool LookThroughTrunc = true,
bool AllowNonZeroC = false);

} // end namespace llvm

#endif
14 changes: 14 additions & 0 deletions llvm/lib/Analysis/CmpInstAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,17 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,

return Result;
}

std::optional<DecomposedBitTest>
llvm::decomposeBitTest(Value *Cond, bool LookThruTrunc, bool AllowNonZeroC) {
if (auto *ICmp = dyn_cast<ICmpInst>(Cond)) {
// Don't allow pointers. Splat vectors are fine.
if (!ICmp->getOperand(0)->getType()->isIntOrIntVectorTy())
return std::nullopt;
return decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1),
ICmp->getPredicate(), LookThruTrunc,
AllowNonZeroC);
}

return std::nullopt;
}
121 changes: 66 additions & 55 deletions llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,10 @@ static unsigned conjugateICmpMask(unsigned Mask) {
}

// Adapts the external decomposeBitTestICmp for local use.
static bool decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pred,
static bool decomposeBitTestICmp(Value *Cond, CmpInst::Predicate &Pred,
Value *&X, Value *&Y, Value *&Z) {
auto Res = llvm::decomposeBitTestICmp(
LHS, RHS, Pred, /*LookThroughTrunc=*/true, /*AllowNonZeroC=*/true);
auto Res = llvm::decomposeBitTest(Cond, /*LookThroughTrunc=*/true,
/*AllowNonZeroC=*/true);
if (!Res)
return false;

Expand All @@ -198,27 +198,34 @@ static bool decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pre
/// the right hand side as a pair.
/// LHS and RHS are the left hand side and the right hand side ICmps and PredL
/// and PredR are their predicates, respectively.
static std::optional<std::pair<unsigned, unsigned>> getMaskedTypeForICmpPair(
Value *&A, Value *&B, Value *&C, Value *&D, Value *&E, ICmpInst *LHS,
ICmpInst *RHS, ICmpInst::Predicate &PredL, ICmpInst::Predicate &PredR) {
// Don't allow pointers. Splat vectors are fine.
if (!LHS->getOperand(0)->getType()->isIntOrIntVectorTy() ||
!RHS->getOperand(0)->getType()->isIntOrIntVectorTy())
return std::nullopt;
static std::optional<std::pair<unsigned, unsigned>>
getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, Value *&D, Value *&E,
Value *LHS, Value *RHS, ICmpInst::Predicate &PredL,
ICmpInst::Predicate &PredR) {

// Here comes the tricky part:
// LHS might be of the form L11 & L12 == X, X == L21 & L22,
// and L11 & L12 == L21 & L22. The same goes for RHS.
// Now we must find those components L** and R**, that are equal, so
// that we can extract the parameters A, B, C, D, and E for the canonical
// above.
Value *L1 = LHS->getOperand(0);
Value *L2 = LHS->getOperand(1);
Value *L11, *L12, *L21, *L22;

// Check whether the icmp can be decomposed into a bit test.
if (decomposeBitTestICmp(L1, L2, PredL, L11, L12, L2)) {
Value *L1, *L11, *L12, *L2, *L21, *L22;
if (decomposeBitTestICmp(LHS, PredL, L11, L12, L2)) {
L21 = L22 = L1 = nullptr;
} else {
auto *LHSCMP = dyn_cast<ICmpInst>(LHS);
if (!LHSCMP)
return std::nullopt;

// Don't allow pointers. Splat vectors are fine.
if (!LHSCMP->getOperand(0)->getType()->isIntOrIntVectorTy())
return std::nullopt;

PredL = LHSCMP->getPredicate();
L1 = LHSCMP->getOperand(0);
L2 = LHSCMP->getOperand(1);
// Look for ANDs in the LHS icmp.
if (!match(L1, m_And(m_Value(L11), m_Value(L12)))) {
// Any icmp can be viewed as being trivially masked; if it allows us to
Expand All @@ -237,11 +244,8 @@ static std::optional<std::pair<unsigned, unsigned>> getMaskedTypeForICmpPair(
if (!ICmpInst::isEquality(PredL))
return std::nullopt;

Value *R1 = RHS->getOperand(0);
Value *R2 = RHS->getOperand(1);
Value *R11, *R12;
bool Ok = false;
if (decomposeBitTestICmp(R1, R2, PredR, R11, R12, R2)) {
Value *R11, *R12, *R2;
if (decomposeBitTestICmp(RHS, PredR, R11, R12, R2)) {
if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
A = R11;
D = R12;
Expand All @@ -252,9 +256,19 @@ static std::optional<std::pair<unsigned, unsigned>> getMaskedTypeForICmpPair(
return std::nullopt;
}
E = R2;
R1 = nullptr;
Ok = true;
} else {
auto *RHSCMP = dyn_cast<ICmpInst>(RHS);
if (!RHSCMP)
return std::nullopt;
// Don't allow pointers. Splat vectors are fine.
if (!RHSCMP->getOperand(0)->getType()->isIntOrIntVectorTy())
return std::nullopt;

PredR = RHSCMP->getPredicate();

Value *R1 = RHSCMP->getOperand(0);
R2 = RHSCMP->getOperand(1);
bool Ok = false;
if (!match(R1, m_And(m_Value(R11), m_Value(R12)))) {
// As before, model no mask as a trivial mask if it'll let us do an
// optimization.
Expand All @@ -277,36 +291,32 @@ static std::optional<std::pair<unsigned, unsigned>> getMaskedTypeForICmpPair(
// Avoid matching against the -1 value we created for unmasked operand.
if (Ok && match(A, m_AllOnes()))
Ok = false;

// Look for ANDs on the right side of the RHS icmp.
if (!Ok) {
if (!match(R2, m_And(m_Value(R11), m_Value(R12)))) {
R11 = R2;
R12 = Constant::getAllOnesValue(R2->getType());
}

if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
A = R11;
D = R12;
E = R1;
} else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
A = R12;
D = R11;
E = R1;
} else {
return std::nullopt;
}
}
}

// Bail if RHS was a icmp that can't be decomposed into an equality.
if (!ICmpInst::isEquality(PredR))
return std::nullopt;

// Look for ANDs on the right side of the RHS icmp.
if (!Ok) {
if (!match(R2, m_And(m_Value(R11), m_Value(R12)))) {
R11 = R2;
R12 = Constant::getAllOnesValue(R2->getType());
}

if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
A = R11;
D = R12;
E = R1;
Ok = true;
} else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
A = R12;
D = R11;
E = R1;
Ok = true;
} else {
return std::nullopt;
}

assert(Ok && "Failed to find AND on the right side of the RHS icmp.");
}

if (L11 == A) {
B = L12;
C = L2;
Expand All @@ -333,8 +343,8 @@ static std::optional<std::pair<unsigned, unsigned>> getMaskedTypeForICmpPair(
/// (icmp (A & 12) != 0) & (icmp (A & 15) == 8) -> (icmp (A & 15) == 8).
/// Also used for logical and/or, must be poison safe.
static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, Value *A, Value *B, Value *D,
Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
Value *LHS, Value *RHS, bool IsAnd, Value *A, Value *B, Value *D, Value *E,
ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
InstCombiner::BuilderTy &Builder) {
// We are given the canonical form:
// (icmp ne (A & B), 0) & (icmp eq (A & D), E).
Expand Down Expand Up @@ -457,7 +467,8 @@ static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
// (icmp ne (A & 15), 0) & (icmp eq (A & 15), 8) -> (icmp eq (A & 15), 8).
if (IsSuperSetOrEqual(BCst, DCst)) {
// We can't guarantee that samesign hold after this fold.
RHS->setSameSign(false);
if (auto *ICmp = dyn_cast<ICmpInst>(RHS))
ICmp->setSameSign(false);
return RHS;
}
// Otherwise, B is a subset of D. If B and E have a common bit set,
Expand All @@ -466,7 +477,8 @@ static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
assert(IsSubSetOrEqual(BCst, DCst) && "Precondition due to above code");
if ((*BCst & ECst) != 0) {
// We can't guarantee that samesign hold after this fold.
RHS->setSameSign(false);
if (auto *ICmp = dyn_cast<ICmpInst>(RHS))
ICmp->setSameSign(false);
return RHS;
}
// Otherwise, LHS and RHS contradict and the whole expression becomes false
Expand All @@ -481,8 +493,8 @@ static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
/// aren't of the common mask pattern type.
/// Also used for logical and/or, must be poison safe.
static Value *foldLogOpOfMaskedICmpsAsymmetric(
ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, Value *A, Value *B, Value *C,
Value *D, Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
Value *LHS, Value *RHS, bool IsAnd, Value *A, Value *B, Value *C, Value *D,
Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
unsigned LHSMask, unsigned RHSMask, InstCombiner::BuilderTy &Builder) {
assert(ICmpInst::isEquality(PredL) && ICmpInst::isEquality(PredR) &&
"Expected equality predicates for masked type of icmps.");
Expand Down Expand Up @@ -511,12 +523,12 @@ static Value *foldLogOpOfMaskedICmpsAsymmetric(

/// Try to fold (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E)
/// into a single (icmp(A & X) ==/!= Y).
static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
static Value *foldLogOpOfMaskedICmps(Value *LHS, Value *RHS, bool IsAnd,
bool IsLogical,
InstCombiner::BuilderTy &Builder,
const SimplifyQuery &Q) {
Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr;
ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate();
ICmpInst::Predicate PredL, PredR;
std::optional<std::pair<unsigned, unsigned>> MaskPair =
getMaskedTypeForICmpPair(A, B, C, D, E, LHS, RHS, PredL, PredR);
if (!MaskPair)
Expand Down Expand Up @@ -1066,8 +1078,7 @@ static Value *foldPowerOf2AndShiftedMask(ICmpInst *Cmp0, ICmpInst *Cmp1,
if (!JoinedByAnd)
return nullptr;
Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr;
ICmpInst::Predicate CmpPred0 = Cmp0->getPredicate(),
CmpPred1 = Cmp1->getPredicate();
ICmpInst::Predicate CmpPred0, CmpPred1;
// Assuming P is a 2^n, getMaskedTypeForICmpPair will normalize (icmp X u<
// 2^n) into (icmp (X & ~(2^n-1)) == 0) and (icmp X s> -1) into (icmp (X &
// SignMask) == 0).
Expand Down
Loading