-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[InstCombine] Prepare foldLogOpOfMaskedICmps to handle trunc to i1. (NFC) #122179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-llvm-analysis @llvm/pr-subscribers-llvm-transforms Author: Andreas Jonson (andjo403) ChangesMaking the move of foldLogOpOfMaskedICmps before adding the handling of Full diff: https://github.com/llvm/llvm-project/pull/122179.diff 1 Files Affected:
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 8bfa3d0f6c5ea1..0aeb025ea44840 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -199,113 +199,132 @@ static bool decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pre
/// the right hand side as a pair.
/// LHS and RHS are the left hand side and the right hand side ICmps and PredL
/// and PredR are their predicates, respectively.
-static std::optional<std::pair<unsigned, unsigned>> getMaskedTypeForICmpPair(
- Value *&A, Value *&B, Value *&C, Value *&D, Value *&E, ICmpInst *LHS,
- ICmpInst *RHS, ICmpInst::Predicate &PredL, ICmpInst::Predicate &PredR) {
- // Don't allow pointers. Splat vectors are fine.
- if (!LHS->getOperand(0)->getType()->isIntOrIntVectorTy() ||
- !RHS->getOperand(0)->getType()->isIntOrIntVectorTy())
- return std::nullopt;
+static std::optional<std::pair<unsigned, unsigned>>
+getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, Value *&D, Value *&E,
+ Value *LHS, Value *RHS, ICmpInst::Predicate &PredL,
+ ICmpInst::Predicate &PredR) {
- // Here comes the tricky part:
- // LHS might be of the form L11 & L12 == X, X == L21 & L22,
- // and L11 & L12 == L21 & L22. The same goes for RHS.
- // Now we must find those components L** and R**, that are equal, so
- // that we can extract the parameters A, B, C, D, and E for the canonical
- // above.
- Value *L1 = LHS->getOperand(0);
- Value *L2 = LHS->getOperand(1);
- Value *L11, *L12, *L21, *L22;
- // Check whether the icmp can be decomposed into a bit test.
- if (decomposeBitTestICmp(L1, L2, PredL, L11, L12, L2)) {
- L21 = L22 = L1 = nullptr;
- } else {
- // Look for ANDs in the LHS icmp.
- if (!match(L1, m_And(m_Value(L11), m_Value(L12)))) {
- // Any icmp can be viewed as being trivially masked; if it allows us to
- // remove one, it's worth it.
- L11 = L1;
- L12 = Constant::getAllOnesValue(L1->getType());
- }
+ Value *L1, *L11, *L12, *L2, *L21, *L22;
+ if (auto *LHSCMP = dyn_cast<ICmpInst>(LHS)) {
+
+ // Don't allow pointers. Splat vectors are fine.
+ if (!LHSCMP->getOperand(0)->getType()->isIntOrIntVectorTy())
+ return std::nullopt;
+
+ PredL = LHSCMP->getPredicate();
+
+ // Here comes the tricky part:
+ // LHS might be of the form L11 & L12 == X, X == L21 & L22,
+ // and L11 & L12 == L21 & L22. The same goes for RHS.
+ // Now we must find those components L** and R**, that are equal, so
+ // that we can extract the parameters A, B, C, D, and E for the canonical
+ // above.
+ L1 = LHSCMP->getOperand(0);
+ L2 = LHSCMP->getOperand(1);
+ // Check whether the icmp can be decomposed into a bit test.
+ if (decomposeBitTestICmp(L1, L2, PredL, L11, L12, L2)) {
+ L21 = L22 = L1 = nullptr;
+ } else {
+ // Look for ANDs in the LHS icmp.
+ if (!match(L1, m_And(m_Value(L11), m_Value(L12)))) {
+ // Any icmp can be viewed as being trivially masked; if it allows us to
+ // remove one, it's worth it.
+ L11 = L1;
+ L12 = Constant::getAllOnesValue(L1->getType());
+ }
- if (!match(L2, m_And(m_Value(L21), m_Value(L22)))) {
- L21 = L2;
- L22 = Constant::getAllOnesValue(L2->getType());
+ if (!match(L2, m_And(m_Value(L21), m_Value(L22)))) {
+ L21 = L2;
+ L22 = Constant::getAllOnesValue(L2->getType());
+ }
}
- }
+ // Bail if LHS was a icmp that can't be decomposed into an equality.
+ if (!ICmpInst::isEquality(PredL))
+ return std::nullopt;
- // Bail if LHS was a icmp that can't be decomposed into an equality.
- if (!ICmpInst::isEquality(PredL))
+ } else {
return std::nullopt;
+ }
- Value *R1 = RHS->getOperand(0);
- Value *R2 = RHS->getOperand(1);
Value *R11, *R12;
- bool Ok = false;
- if (decomposeBitTestICmp(R1, R2, PredR, R11, R12, R2)) {
- if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
- A = R11;
- D = R12;
- } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
- A = R12;
- D = R11;
- } else {
+ if (auto *RHSCMP = dyn_cast<ICmpInst>(RHS)) {
+
+ // Don't allow pointers. Splat vectors are fine.
+ if (!RHSCMP->getOperand(0)->getType()->isIntOrIntVectorTy())
return std::nullopt;
- }
- E = R2;
- R1 = nullptr;
- Ok = true;
- } else {
- if (!match(R1, m_And(m_Value(R11), m_Value(R12)))) {
- // As before, model no mask as a trivial mask if it'll let us do an
- // optimization.
- R11 = R1;
- R12 = Constant::getAllOnesValue(R1->getType());
- }
- if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
- A = R11;
- D = R12;
- E = R2;
- Ok = true;
- } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
- A = R12;
- D = R11;
+ PredR = RHSCMP->getPredicate();
+
+ Value *R1 = RHSCMP->getOperand(0);
+ Value *R2 = RHSCMP->getOperand(1);
+ bool Ok = false;
+ if (decomposeBitTestICmp(R1, R2, PredR, R11, R12, R2)) {
+ if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
+ A = R11;
+ D = R12;
+ } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
+ A = R12;
+ D = R11;
+ } else {
+ return std::nullopt;
+ }
E = R2;
+ R1 = nullptr;
Ok = true;
- }
-
- // Avoid matching against the -1 value we created for unmasked operand.
- if (Ok && match(A, m_AllOnes()))
- Ok = false;
- }
+ } else {
+ if (!match(R1, m_And(m_Value(R11), m_Value(R12)))) {
+ // As before, model no mask as a trivial mask if it'll let us do an
+ // optimization.
+ R11 = R1;
+ R12 = Constant::getAllOnesValue(R1->getType());
+ }
- // Bail if RHS was a icmp that can't be decomposed into an equality.
- if (!ICmpInst::isEquality(PredR))
- return std::nullopt;
+ if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
+ A = R11;
+ D = R12;
+ E = R2;
+ Ok = true;
+ } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
+ A = R12;
+ D = R11;
+ E = R2;
+ Ok = true;
+ }
- // Look for ANDs on the right side of the RHS icmp.
- if (!Ok) {
- if (!match(R2, m_And(m_Value(R11), m_Value(R12)))) {
- R11 = R2;
- R12 = Constant::getAllOnesValue(R2->getType());
+ // Avoid matching against the -1 value we created for unmasked operand.
+ if (Ok && match(A, m_AllOnes()))
+ Ok = false;
}
- if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
- A = R11;
- D = R12;
- E = R1;
- Ok = true;
- } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
- A = R12;
- D = R11;
- E = R1;
- Ok = true;
- } else {
+ // Bail if RHS was a icmp that can't be decomposed into an equality.
+ if (!ICmpInst::isEquality(PredR))
return std::nullopt;
- }
- assert(Ok && "Failed to find AND on the right side of the RHS icmp.");
+ // Look for ANDs on the right side of the RHS icmp.
+ if (!Ok) {
+ if (!match(R2, m_And(m_Value(R11), m_Value(R12)))) {
+ R11 = R2;
+ R12 = Constant::getAllOnesValue(R2->getType());
+ }
+
+ if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
+ A = R11;
+ D = R12;
+ E = R1;
+ Ok = true;
+ } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
+ A = R12;
+ D = R11;
+ E = R1;
+ Ok = true;
+ } else {
+ return std::nullopt;
+ }
+
+ assert(Ok && "Failed to find AND on the right side of the RHS icmp.");
+ }
+ } else {
+ return std::nullopt;
}
if (L11 == A) {
@@ -334,8 +353,8 @@ static std::optional<std::pair<unsigned, unsigned>> getMaskedTypeForICmpPair(
/// (icmp (A & 12) != 0) & (icmp (A & 15) == 8) -> (icmp (A & 15) == 8).
/// Also used for logical and/or, must be poison safe.
static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
- ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, Value *A, Value *B, Value *D,
- Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
+ Value *LHS, Value *RHS, bool IsAnd, Value *A, Value *B, Value *D, Value *E,
+ ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
InstCombiner::BuilderTy &Builder) {
// We are given the canonical form:
// (icmp ne (A & B), 0) & (icmp eq (A & D), E).
@@ -458,7 +477,8 @@ static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
// (icmp ne (A & 15), 0) & (icmp eq (A & 15), 8) -> (icmp eq (A & 15), 8).
if (IsSuperSetOrEqual(BCst, DCst)) {
// We can't guarantee that samesign hold after this fold.
- RHS->setSameSign(false);
+ if (auto *ICmp = dyn_cast<ICmpInst>(RHS))
+ ICmp->setSameSign(false);
return RHS;
}
// Otherwise, B is a subset of D. If B and E have a common bit set,
@@ -467,7 +487,8 @@ static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
assert(IsSubSetOrEqual(BCst, DCst) && "Precondition due to above code");
if ((*BCst & ECst) != 0) {
// We can't guarantee that samesign hold after this fold.
- RHS->setSameSign(false);
+ if (auto *ICmp = dyn_cast<ICmpInst>(RHS))
+ ICmp->setSameSign(false);
return RHS;
}
// Otherwise, LHS and RHS contradict and the whole expression becomes false
@@ -482,8 +503,8 @@ static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
/// aren't of the common mask pattern type.
/// Also used for logical and/or, must be poison safe.
static Value *foldLogOpOfMaskedICmpsAsymmetric(
- ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, Value *A, Value *B, Value *C,
- Value *D, Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
+ Value *LHS, Value *RHS, bool IsAnd, Value *A, Value *B, Value *C, Value *D,
+ Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
unsigned LHSMask, unsigned RHSMask, InstCombiner::BuilderTy &Builder) {
assert(ICmpInst::isEquality(PredL) && ICmpInst::isEquality(PredR) &&
"Expected equality predicates for masked type of icmps.");
@@ -512,12 +533,12 @@ static Value *foldLogOpOfMaskedICmpsAsymmetric(
/// Try to fold (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E)
/// into a single (icmp(A & X) ==/!= Y).
-static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
+static Value *foldLogOpOfMaskedICmps(Value *LHS, Value *RHS, bool IsAnd,
bool IsLogical,
InstCombiner::BuilderTy &Builder,
const SimplifyQuery &Q) {
Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr;
- ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate();
+ ICmpInst::Predicate PredL, PredR;
std::optional<std::pair<unsigned, unsigned>> MaskPair =
getMaskedTypeForICmpPair(A, B, C, D, E, LHS, RHS, PredL, PredR);
if (!MaskPair)
@@ -1067,8 +1088,7 @@ static Value *foldPowerOf2AndShiftedMask(ICmpInst *Cmp0, ICmpInst *Cmp1,
if (!JoinedByAnd)
return nullptr;
Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr;
- ICmpInst::Predicate CmpPred0 = Cmp0->getPredicate(),
- CmpPred1 = Cmp1->getPredicate();
+ ICmpInst::Predicate CmpPred0, CmpPred1;
// Assuming P is a 2^n, getMaskedTypeForICmpPair will normalize (icmp X u<
// 2^n) into (icmp (X & ~(2^n-1)) == 0) and (icmp X s> -1) into (icmp (X &
// SignMask) == 0).
@@ -3325,12 +3345,6 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
}
}
- // handle (roughly):
- // (icmp ne (A & B), C) | (icmp ne (A & D), E)
- // (icmp eq (A & B), C) & (icmp eq (A & D), E)
- if (Value *V = foldLogOpOfMaskedICmps(LHS, RHS, IsAnd, IsLogical, Builder, Q))
- return V;
-
if (Value *V =
foldAndOrOfICmpEqConstantAndICmp(LHS, RHS, IsAnd, IsLogical, Builder))
return V;
@@ -3510,6 +3524,12 @@ Value *InstCombinerImpl::foldBooleanAndOr(Value *LHS, Value *RHS,
if (Value *Res = foldAndOrOfICmps(LHSCmp, RHSCmp, I, IsAnd, IsLogical))
return Res;
+ /// Try to fold (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E)
+ /// into a single (icmp(A & X) ==/!= Y).
+ if (Value *V = foldLogOpOfMaskedICmps(LHS, RHS, IsAnd, IsLogical, Builder,
+ SQ.getWithInstruction(&I)))
+ return V;
+
if (auto *LHSCmp = dyn_cast<FCmpInst>(LHS))
if (auto *RHSCmp = dyn_cast<FCmpInst>(RHS))
if (Value *Res = foldLogicOfFCmps(LHSCmp, RHSCmp, IsAnd, IsLogical))
|
Can you add some tests for these improvements and regressions? |
/// into a single (icmp(A & X) ==/!= Y). | ||
if (Value *V = foldLogOpOfMaskedICmps(LHS, RHS, IsAnd, IsLogical, Builder, | ||
SQ.getWithInstruction(&I))) | ||
return V; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe keep the original place of the call in this commit so this change can be NFC. Would simplify review.
8bf52d5
to
3c0c7b0
Compare
Changed to only do the preparation to support trunc but do not move the call to a place where any of the values can be trunc so this is NFC now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thank you!
|
||
std::optional<DecomposedBitTest> | ||
llvm::decomposeBitTest(Value *Cond, bool LookThruTrunc, bool AllowNonZeroC) { | ||
using namespace PatternMatch; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unused?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes the using is unused forgot to remove it when I have removed the trunc handling that I hade here. and will add back soon but have removed it now
@nikic do you have any more comments or is this good to go? |
b409b04
to
a15be5a
Compare
Making the move of foldLogOpOfMaskedICmps before adding the handling of
trunc to i1
in getMaskedTypeForICmpPair as there was some diffs in llvm-opt-benchmark due to the move. most of the diffs that I looked at was to the better but there was some regressions.