Skip to content

[InstCombine] Prepare foldLogOpOfMaskedICmps to handle trunc to i1. (NFC) #122179

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 15, 2025

Conversation

andjo403
Copy link
Contributor

@andjo403 andjo403 commented Jan 8, 2025

Making the move of foldLogOpOfMaskedICmps before adding the handling of trunc to i1 in getMaskedTypeForICmpPair as there was some diffs in llvm-opt-benchmark due to the move. most of the diffs that I looked at was to the better but there was some regressions.

@llvmbot
Copy link
Member

llvmbot commented Jan 8, 2025

@llvm/pr-subscribers-llvm-analysis

@llvm/pr-subscribers-llvm-transforms

Author: Andreas Jonson (andjo403)

Changes

Making the move of foldLogOpOfMaskedICmps before adding the handling of trunc to i1 in getMaskedTypeForICmpPair as there was some diffs in llvm-opt-benchmark due to the move.


Full diff: https://github.com/llvm/llvm-project/pull/122179.diff

1 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp (+126-106)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 8bfa3d0f6c5ea1..0aeb025ea44840 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -199,113 +199,132 @@ static bool decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pre
 /// the right hand side as a pair.
 /// LHS and RHS are the left hand side and the right hand side ICmps and PredL
 /// and PredR are their predicates, respectively.
-static std::optional<std::pair<unsigned, unsigned>> getMaskedTypeForICmpPair(
-    Value *&A, Value *&B, Value *&C, Value *&D, Value *&E, ICmpInst *LHS,
-    ICmpInst *RHS, ICmpInst::Predicate &PredL, ICmpInst::Predicate &PredR) {
-  // Don't allow pointers. Splat vectors are fine.
-  if (!LHS->getOperand(0)->getType()->isIntOrIntVectorTy() ||
-      !RHS->getOperand(0)->getType()->isIntOrIntVectorTy())
-    return std::nullopt;
+static std::optional<std::pair<unsigned, unsigned>>
+getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, Value *&D, Value *&E,
+                         Value *LHS, Value *RHS, ICmpInst::Predicate &PredL,
+                         ICmpInst::Predicate &PredR) {
 
-  // Here comes the tricky part:
-  // LHS might be of the form L11 & L12 == X, X == L21 & L22,
-  // and L11 & L12 == L21 & L22. The same goes for RHS.
-  // Now we must find those components L** and R**, that are equal, so
-  // that we can extract the parameters A, B, C, D, and E for the canonical
-  // above.
-  Value *L1 = LHS->getOperand(0);
-  Value *L2 = LHS->getOperand(1);
-  Value *L11, *L12, *L21, *L22;
-  // Check whether the icmp can be decomposed into a bit test.
-  if (decomposeBitTestICmp(L1, L2, PredL, L11, L12, L2)) {
-    L21 = L22 = L1 = nullptr;
-  } else {
-    // Look for ANDs in the LHS icmp.
-    if (!match(L1, m_And(m_Value(L11), m_Value(L12)))) {
-      // Any icmp can be viewed as being trivially masked; if it allows us to
-      // remove one, it's worth it.
-      L11 = L1;
-      L12 = Constant::getAllOnesValue(L1->getType());
-    }
+  Value *L1, *L11, *L12, *L2, *L21, *L22;
+  if (auto *LHSCMP = dyn_cast<ICmpInst>(LHS)) {
+
+    // Don't allow pointers. Splat vectors are fine.
+    if (!LHSCMP->getOperand(0)->getType()->isIntOrIntVectorTy())
+      return std::nullopt;
+
+    PredL = LHSCMP->getPredicate();
+
+    // Here comes the tricky part:
+    // LHS might be of the form L11 & L12 == X, X == L21 & L22,
+    // and L11 & L12 == L21 & L22. The same goes for RHS.
+    // Now we must find those components L** and R**, that are equal, so
+    // that we can extract the parameters A, B, C, D, and E for the canonical
+    // above.
+    L1 = LHSCMP->getOperand(0);
+    L2 = LHSCMP->getOperand(1);
+    // Check whether the icmp can be decomposed into a bit test.
+    if (decomposeBitTestICmp(L1, L2, PredL, L11, L12, L2)) {
+      L21 = L22 = L1 = nullptr;
+    } else {
+      // Look for ANDs in the LHS icmp.
+      if (!match(L1, m_And(m_Value(L11), m_Value(L12)))) {
+        // Any icmp can be viewed as being trivially masked; if it allows us to
+        // remove one, it's worth it.
+        L11 = L1;
+        L12 = Constant::getAllOnesValue(L1->getType());
+      }
 
-    if (!match(L2, m_And(m_Value(L21), m_Value(L22)))) {
-      L21 = L2;
-      L22 = Constant::getAllOnesValue(L2->getType());
+      if (!match(L2, m_And(m_Value(L21), m_Value(L22)))) {
+        L21 = L2;
+        L22 = Constant::getAllOnesValue(L2->getType());
+      }
     }
-  }
+    // Bail if LHS was a icmp that can't be decomposed into an equality.
+    if (!ICmpInst::isEquality(PredL))
+      return std::nullopt;
 
-  // Bail if LHS was a icmp that can't be decomposed into an equality.
-  if (!ICmpInst::isEquality(PredL))
+  } else {
     return std::nullopt;
+  }
 
-  Value *R1 = RHS->getOperand(0);
-  Value *R2 = RHS->getOperand(1);
   Value *R11, *R12;
-  bool Ok = false;
-  if (decomposeBitTestICmp(R1, R2, PredR, R11, R12, R2)) {
-    if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
-      A = R11;
-      D = R12;
-    } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
-      A = R12;
-      D = R11;
-    } else {
+  if (auto *RHSCMP = dyn_cast<ICmpInst>(RHS)) {
+
+    // Don't allow pointers. Splat vectors are fine.
+    if (!RHSCMP->getOperand(0)->getType()->isIntOrIntVectorTy())
       return std::nullopt;
-    }
-    E = R2;
-    R1 = nullptr;
-    Ok = true;
-  } else {
-    if (!match(R1, m_And(m_Value(R11), m_Value(R12)))) {
-      // As before, model no mask as a trivial mask if it'll let us do an
-      // optimization.
-      R11 = R1;
-      R12 = Constant::getAllOnesValue(R1->getType());
-    }
 
-    if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
-      A = R11;
-      D = R12;
-      E = R2;
-      Ok = true;
-    } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
-      A = R12;
-      D = R11;
+    PredR = RHSCMP->getPredicate();
+
+    Value *R1 = RHSCMP->getOperand(0);
+    Value *R2 = RHSCMP->getOperand(1);
+    bool Ok = false;
+    if (decomposeBitTestICmp(R1, R2, PredR, R11, R12, R2)) {
+      if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
+        A = R11;
+        D = R12;
+      } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
+        A = R12;
+        D = R11;
+      } else {
+        return std::nullopt;
+      }
       E = R2;
+      R1 = nullptr;
       Ok = true;
-    }
-
-    // Avoid matching against the -1 value we created for unmasked operand.
-    if (Ok && match(A, m_AllOnes()))
-      Ok = false;
-  }
+    } else {
+      if (!match(R1, m_And(m_Value(R11), m_Value(R12)))) {
+        // As before, model no mask as a trivial mask if it'll let us do an
+        // optimization.
+        R11 = R1;
+        R12 = Constant::getAllOnesValue(R1->getType());
+      }
 
-  // Bail if RHS was a icmp that can't be decomposed into an equality.
-  if (!ICmpInst::isEquality(PredR))
-    return std::nullopt;
+      if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
+        A = R11;
+        D = R12;
+        E = R2;
+        Ok = true;
+      } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
+        A = R12;
+        D = R11;
+        E = R2;
+        Ok = true;
+      }
 
-  // Look for ANDs on the right side of the RHS icmp.
-  if (!Ok) {
-    if (!match(R2, m_And(m_Value(R11), m_Value(R12)))) {
-      R11 = R2;
-      R12 = Constant::getAllOnesValue(R2->getType());
+      // Avoid matching against the -1 value we created for unmasked operand.
+      if (Ok && match(A, m_AllOnes()))
+        Ok = false;
     }
 
-    if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
-      A = R11;
-      D = R12;
-      E = R1;
-      Ok = true;
-    } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
-      A = R12;
-      D = R11;
-      E = R1;
-      Ok = true;
-    } else {
+    // Bail if RHS was a icmp that can't be decomposed into an equality.
+    if (!ICmpInst::isEquality(PredR))
       return std::nullopt;
-    }
 
-    assert(Ok && "Failed to find AND on the right side of the RHS icmp.");
+    // Look for ANDs on the right side of the RHS icmp.
+    if (!Ok) {
+      if (!match(R2, m_And(m_Value(R11), m_Value(R12)))) {
+        R11 = R2;
+        R12 = Constant::getAllOnesValue(R2->getType());
+      }
+
+      if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
+        A = R11;
+        D = R12;
+        E = R1;
+        Ok = true;
+      } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
+        A = R12;
+        D = R11;
+        E = R1;
+        Ok = true;
+      } else {
+        return std::nullopt;
+      }
+
+      assert(Ok && "Failed to find AND on the right side of the RHS icmp.");
+    }
+  } else {
+    return std::nullopt;
   }
 
   if (L11 == A) {
@@ -334,8 +353,8 @@ static std::optional<std::pair<unsigned, unsigned>> getMaskedTypeForICmpPair(
 /// (icmp (A & 12) != 0) & (icmp (A & 15) == 8) -> (icmp (A & 15) == 8).
 /// Also used for logical and/or, must be poison safe.
 static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
-    ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, Value *A, Value *B, Value *D,
-    Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
+    Value *LHS, Value *RHS, bool IsAnd, Value *A, Value *B, Value *D, Value *E,
+    ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
     InstCombiner::BuilderTy &Builder) {
   // We are given the canonical form:
   //   (icmp ne (A & B), 0) & (icmp eq (A & D), E).
@@ -458,7 +477,8 @@ static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
   // (icmp ne (A & 15), 0) & (icmp eq (A & 15), 8) -> (icmp eq (A & 15), 8).
   if (IsSuperSetOrEqual(BCst, DCst)) {
     // We can't guarantee that samesign hold after this fold.
-    RHS->setSameSign(false);
+    if (auto *ICmp = dyn_cast<ICmpInst>(RHS))
+      ICmp->setSameSign(false);
     return RHS;
   }
   // Otherwise, B is a subset of D. If B and E have a common bit set,
@@ -467,7 +487,8 @@ static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
   assert(IsSubSetOrEqual(BCst, DCst) && "Precondition due to above code");
   if ((*BCst & ECst) != 0) {
     // We can't guarantee that samesign hold after this fold.
-    RHS->setSameSign(false);
+    if (auto *ICmp = dyn_cast<ICmpInst>(RHS))
+      ICmp->setSameSign(false);
     return RHS;
   }
   // Otherwise, LHS and RHS contradict and the whole expression becomes false
@@ -482,8 +503,8 @@ static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
 /// aren't of the common mask pattern type.
 /// Also used for logical and/or, must be poison safe.
 static Value *foldLogOpOfMaskedICmpsAsymmetric(
-    ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, Value *A, Value *B, Value *C,
-    Value *D, Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
+    Value *LHS, Value *RHS, bool IsAnd, Value *A, Value *B, Value *C, Value *D,
+    Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
     unsigned LHSMask, unsigned RHSMask, InstCombiner::BuilderTy &Builder) {
   assert(ICmpInst::isEquality(PredL) && ICmpInst::isEquality(PredR) &&
          "Expected equality predicates for masked type of icmps.");
@@ -512,12 +533,12 @@ static Value *foldLogOpOfMaskedICmpsAsymmetric(
 
 /// Try to fold (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E)
 /// into a single (icmp(A & X) ==/!= Y).
-static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
+static Value *foldLogOpOfMaskedICmps(Value *LHS, Value *RHS, bool IsAnd,
                                      bool IsLogical,
                                      InstCombiner::BuilderTy &Builder,
                                      const SimplifyQuery &Q) {
   Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr;
-  ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate();
+  ICmpInst::Predicate PredL, PredR;
   std::optional<std::pair<unsigned, unsigned>> MaskPair =
       getMaskedTypeForICmpPair(A, B, C, D, E, LHS, RHS, PredL, PredR);
   if (!MaskPair)
@@ -1067,8 +1088,7 @@ static Value *foldPowerOf2AndShiftedMask(ICmpInst *Cmp0, ICmpInst *Cmp1,
   if (!JoinedByAnd)
     return nullptr;
   Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr;
-  ICmpInst::Predicate CmpPred0 = Cmp0->getPredicate(),
-                      CmpPred1 = Cmp1->getPredicate();
+  ICmpInst::Predicate CmpPred0, CmpPred1;
   // Assuming P is a 2^n, getMaskedTypeForICmpPair will normalize (icmp X u<
   // 2^n) into (icmp (X & ~(2^n-1)) == 0) and (icmp X s> -1) into (icmp (X &
   // SignMask) == 0).
@@ -3325,12 +3345,6 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
     }
   }
 
-  // handle (roughly):
-  // (icmp ne (A & B), C) | (icmp ne (A & D), E)
-  // (icmp eq (A & B), C) & (icmp eq (A & D), E)
-  if (Value *V = foldLogOpOfMaskedICmps(LHS, RHS, IsAnd, IsLogical, Builder, Q))
-    return V;
-
   if (Value *V =
           foldAndOrOfICmpEqConstantAndICmp(LHS, RHS, IsAnd, IsLogical, Builder))
     return V;
@@ -3510,6 +3524,12 @@ Value *InstCombinerImpl::foldBooleanAndOr(Value *LHS, Value *RHS,
       if (Value *Res = foldAndOrOfICmps(LHSCmp, RHSCmp, I, IsAnd, IsLogical))
         return Res;
 
+  /// Try to fold (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E)
+  /// into a single (icmp(A & X) ==/!= Y).
+  if (Value *V = foldLogOpOfMaskedICmps(LHS, RHS, IsAnd, IsLogical, Builder,
+                                        SQ.getWithInstruction(&I)))
+    return V;
+
   if (auto *LHSCmp = dyn_cast<FCmpInst>(LHS))
     if (auto *RHSCmp = dyn_cast<FCmpInst>(RHS))
       if (Value *Res = foldLogicOfFCmps(LHSCmp, RHSCmp, IsAnd, IsLogical))

@dtcxzyw
Copy link
Member

dtcxzyw commented Jan 9, 2025

most of the diffs that I looked at was to the better but there was some regressions.

Can you add some tests for these improvements and regressions?

/// into a single (icmp(A & X) ==/!= Y).
if (Value *V = foldLogOpOfMaskedICmps(LHS, RHS, IsAnd, IsLogical, Builder,
SQ.getWithInstruction(&I)))
return V;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe keep the original place of the call in this commit so this change can be NFC. Would simplify review.

@andjo403 andjo403 force-pushed the moveFoldLogOpOfMaskedICmps branch from 8bf52d5 to 3c0c7b0 Compare January 9, 2025 22:02
@andjo403 andjo403 changed the title [InstCombine] Move foldLogOpOfMaskedICmps to make it possible to handle trunc to i1. [InstCombine] Prepare foldLogOpOfMaskedICmps to handle trunc to i1. (NFC) Jan 9, 2025
@andjo403
Copy link
Contributor Author

andjo403 commented Jan 9, 2025

Changed to only do the preparation to support trunc but do not move the call to a place where any of the values can be trunc so this is NFC now.
Also added decomposeBitTest that only takes a value where the trunc or other bit test can be matched in.

Copy link
Member

@dtcxzyw dtcxzyw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you!


std::optional<DecomposedBitTest>
llvm::decomposeBitTest(Value *Cond, bool LookThruTrunc, bool AllowNonZeroC) {
using namespace PatternMatch;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes the using is unused forgot to remove it when I have removed the trunc handling that I hade here. and will add back soon but have removed it now

@andjo403
Copy link
Contributor Author

@nikic do you have any more comments or is this good to go?

@andjo403 andjo403 force-pushed the moveFoldLogOpOfMaskedICmps branch from b409b04 to a15be5a Compare January 15, 2025 16:44
@andjo403 andjo403 merged commit 06499f3 into llvm:main Jan 15, 2025
5 of 7 checks passed
@andjo403 andjo403 deleted the moveFoldLogOpOfMaskedICmps branch January 15, 2025 17:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants