Skip to content

Commit a15be5a

Browse files
committed
[InstCombine] Prepare foldLogOpOfMaskedICmps to handle trunc to i1. (NFC)
1 parent 07a1847 commit a15be5a

File tree

3 files changed

+86
-55
lines changed

3 files changed

+86
-55
lines changed

llvm/include/llvm/Analysis/CmpInstAnalysis.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,12 @@ namespace llvm {
108108
bool LookThroughTrunc = true,
109109
bool AllowNonZeroC = false);
110110

111+
/// Decompose an icmp into the form ((X & Mask) pred C) if
112+
/// possible. Unless \p AllowNonZeroC is true, C will always be 0.
113+
std::optional<DecomposedBitTest>
114+
decomposeBitTest(Value *Cond, bool LookThroughTrunc = true,
115+
bool AllowNonZeroC = false);
116+
111117
} // end namespace llvm
112118

113119
#endif

llvm/lib/Analysis/CmpInstAnalysis.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,17 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
165165

166166
return Result;
167167
}
168+
169+
std::optional<DecomposedBitTest>
170+
llvm::decomposeBitTest(Value *Cond, bool LookThruTrunc, bool AllowNonZeroC) {
171+
if (auto *ICmp = dyn_cast<ICmpInst>(Cond)) {
172+
// Don't allow pointers. Splat vectors are fine.
173+
if (!ICmp->getOperand(0)->getType()->isIntOrIntVectorTy())
174+
return std::nullopt;
175+
return decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1),
176+
ICmp->getPredicate(), LookThruTrunc,
177+
AllowNonZeroC);
178+
}
179+
180+
return std::nullopt;
181+
}

llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp

Lines changed: 66 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -179,10 +179,10 @@ static unsigned conjugateICmpMask(unsigned Mask) {
179179
}
180180

181181
// Adapts the external decomposeBitTestICmp for local use.
182-
static bool decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pred,
182+
static bool decomposeBitTestICmp(Value *Cond, CmpInst::Predicate &Pred,
183183
Value *&X, Value *&Y, Value *&Z) {
184-
auto Res = llvm::decomposeBitTestICmp(
185-
LHS, RHS, Pred, /*LookThroughTrunc=*/true, /*AllowNonZeroC=*/true);
184+
auto Res = llvm::decomposeBitTest(Cond, /*LookThroughTrunc=*/true,
185+
/*AllowNonZeroC=*/true);
186186
if (!Res)
187187
return false;
188188

@@ -198,27 +198,34 @@ static bool decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pre
198198
/// the right hand side as a pair.
199199
/// LHS and RHS are the left hand side and the right hand side ICmps and PredL
200200
/// and PredR are their predicates, respectively.
201-
static std::optional<std::pair<unsigned, unsigned>> getMaskedTypeForICmpPair(
202-
Value *&A, Value *&B, Value *&C, Value *&D, Value *&E, ICmpInst *LHS,
203-
ICmpInst *RHS, ICmpInst::Predicate &PredL, ICmpInst::Predicate &PredR) {
204-
// Don't allow pointers. Splat vectors are fine.
205-
if (!LHS->getOperand(0)->getType()->isIntOrIntVectorTy() ||
206-
!RHS->getOperand(0)->getType()->isIntOrIntVectorTy())
207-
return std::nullopt;
201+
static std::optional<std::pair<unsigned, unsigned>>
202+
getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, Value *&D, Value *&E,
203+
Value *LHS, Value *RHS, ICmpInst::Predicate &PredL,
204+
ICmpInst::Predicate &PredR) {
208205

209206
// Here comes the tricky part:
210207
// LHS might be of the form L11 & L12 == X, X == L21 & L22,
211208
// and L11 & L12 == L21 & L22. The same goes for RHS.
212209
// Now we must find those components L** and R**, that are equal, so
213210
// that we can extract the parameters A, B, C, D, and E for the canonical
214211
// above.
215-
Value *L1 = LHS->getOperand(0);
216-
Value *L2 = LHS->getOperand(1);
217-
Value *L11, *L12, *L21, *L22;
212+
218213
// Check whether the icmp can be decomposed into a bit test.
219-
if (decomposeBitTestICmp(L1, L2, PredL, L11, L12, L2)) {
214+
Value *L1, *L11, *L12, *L2, *L21, *L22;
215+
if (decomposeBitTestICmp(LHS, PredL, L11, L12, L2)) {
220216
L21 = L22 = L1 = nullptr;
221217
} else {
218+
auto *LHSCMP = dyn_cast<ICmpInst>(LHS);
219+
if (!LHSCMP)
220+
return std::nullopt;
221+
222+
// Don't allow pointers. Splat vectors are fine.
223+
if (!LHSCMP->getOperand(0)->getType()->isIntOrIntVectorTy())
224+
return std::nullopt;
225+
226+
PredL = LHSCMP->getPredicate();
227+
L1 = LHSCMP->getOperand(0);
228+
L2 = LHSCMP->getOperand(1);
222229
// Look for ANDs in the LHS icmp.
223230
if (!match(L1, m_And(m_Value(L11), m_Value(L12)))) {
224231
// Any icmp can be viewed as being trivially masked; if it allows us to
@@ -237,11 +244,8 @@ static std::optional<std::pair<unsigned, unsigned>> getMaskedTypeForICmpPair(
237244
if (!ICmpInst::isEquality(PredL))
238245
return std::nullopt;
239246

240-
Value *R1 = RHS->getOperand(0);
241-
Value *R2 = RHS->getOperand(1);
242-
Value *R11, *R12;
243-
bool Ok = false;
244-
if (decomposeBitTestICmp(R1, R2, PredR, R11, R12, R2)) {
247+
Value *R11, *R12, *R2;
248+
if (decomposeBitTestICmp(RHS, PredR, R11, R12, R2)) {
245249
if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
246250
A = R11;
247251
D = R12;
@@ -252,9 +256,19 @@ static std::optional<std::pair<unsigned, unsigned>> getMaskedTypeForICmpPair(
252256
return std::nullopt;
253257
}
254258
E = R2;
255-
R1 = nullptr;
256-
Ok = true;
257259
} else {
260+
auto *RHSCMP = dyn_cast<ICmpInst>(RHS);
261+
if (!RHSCMP)
262+
return std::nullopt;
263+
// Don't allow pointers. Splat vectors are fine.
264+
if (!RHSCMP->getOperand(0)->getType()->isIntOrIntVectorTy())
265+
return std::nullopt;
266+
267+
PredR = RHSCMP->getPredicate();
268+
269+
Value *R1 = RHSCMP->getOperand(0);
270+
R2 = RHSCMP->getOperand(1);
271+
bool Ok = false;
258272
if (!match(R1, m_And(m_Value(R11), m_Value(R12)))) {
259273
// As before, model no mask as a trivial mask if it'll let us do an
260274
// optimization.
@@ -277,36 +291,32 @@ static std::optional<std::pair<unsigned, unsigned>> getMaskedTypeForICmpPair(
277291
// Avoid matching against the -1 value we created for unmasked operand.
278292
if (Ok && match(A, m_AllOnes()))
279293
Ok = false;
294+
295+
// Look for ANDs on the right side of the RHS icmp.
296+
if (!Ok) {
297+
if (!match(R2, m_And(m_Value(R11), m_Value(R12)))) {
298+
R11 = R2;
299+
R12 = Constant::getAllOnesValue(R2->getType());
300+
}
301+
302+
if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
303+
A = R11;
304+
D = R12;
305+
E = R1;
306+
} else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
307+
A = R12;
308+
D = R11;
309+
E = R1;
310+
} else {
311+
return std::nullopt;
312+
}
313+
}
280314
}
281315

282316
// Bail if RHS was a icmp that can't be decomposed into an equality.
283317
if (!ICmpInst::isEquality(PredR))
284318
return std::nullopt;
285319

286-
// Look for ANDs on the right side of the RHS icmp.
287-
if (!Ok) {
288-
if (!match(R2, m_And(m_Value(R11), m_Value(R12)))) {
289-
R11 = R2;
290-
R12 = Constant::getAllOnesValue(R2->getType());
291-
}
292-
293-
if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
294-
A = R11;
295-
D = R12;
296-
E = R1;
297-
Ok = true;
298-
} else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
299-
A = R12;
300-
D = R11;
301-
E = R1;
302-
Ok = true;
303-
} else {
304-
return std::nullopt;
305-
}
306-
307-
assert(Ok && "Failed to find AND on the right side of the RHS icmp.");
308-
}
309-
310320
if (L11 == A) {
311321
B = L12;
312322
C = L2;
@@ -333,8 +343,8 @@ static std::optional<std::pair<unsigned, unsigned>> getMaskedTypeForICmpPair(
333343
/// (icmp (A & 12) != 0) & (icmp (A & 15) == 8) -> (icmp (A & 15) == 8).
334344
/// Also used for logical and/or, must be poison safe.
335345
static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
336-
ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, Value *A, Value *B, Value *D,
337-
Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
346+
Value *LHS, Value *RHS, bool IsAnd, Value *A, Value *B, Value *D, Value *E,
347+
ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
338348
InstCombiner::BuilderTy &Builder) {
339349
// We are given the canonical form:
340350
// (icmp ne (A & B), 0) & (icmp eq (A & D), E).
@@ -457,7 +467,8 @@ static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
457467
// (icmp ne (A & 15), 0) & (icmp eq (A & 15), 8) -> (icmp eq (A & 15), 8).
458468
if (IsSuperSetOrEqual(BCst, DCst)) {
459469
// We can't guarantee that samesign hold after this fold.
460-
RHS->setSameSign(false);
470+
if (auto *ICmp = dyn_cast<ICmpInst>(RHS))
471+
ICmp->setSameSign(false);
461472
return RHS;
462473
}
463474
// Otherwise, B is a subset of D. If B and E have a common bit set,
@@ -466,7 +477,8 @@ static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
466477
assert(IsSubSetOrEqual(BCst, DCst) && "Precondition due to above code");
467478
if ((*BCst & ECst) != 0) {
468479
// We can't guarantee that samesign hold after this fold.
469-
RHS->setSameSign(false);
480+
if (auto *ICmp = dyn_cast<ICmpInst>(RHS))
481+
ICmp->setSameSign(false);
470482
return RHS;
471483
}
472484
// Otherwise, LHS and RHS contradict and the whole expression becomes false
@@ -481,8 +493,8 @@ static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
481493
/// aren't of the common mask pattern type.
482494
/// Also used for logical and/or, must be poison safe.
483495
static Value *foldLogOpOfMaskedICmpsAsymmetric(
484-
ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, Value *A, Value *B, Value *C,
485-
Value *D, Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
496+
Value *LHS, Value *RHS, bool IsAnd, Value *A, Value *B, Value *C, Value *D,
497+
Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
486498
unsigned LHSMask, unsigned RHSMask, InstCombiner::BuilderTy &Builder) {
487499
assert(ICmpInst::isEquality(PredL) && ICmpInst::isEquality(PredR) &&
488500
"Expected equality predicates for masked type of icmps.");
@@ -511,12 +523,12 @@ static Value *foldLogOpOfMaskedICmpsAsymmetric(
511523

512524
/// Try to fold (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E)
513525
/// into a single (icmp(A & X) ==/!= Y).
514-
static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
526+
static Value *foldLogOpOfMaskedICmps(Value *LHS, Value *RHS, bool IsAnd,
515527
bool IsLogical,
516528
InstCombiner::BuilderTy &Builder,
517529
const SimplifyQuery &Q) {
518530
Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr;
519-
ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate();
531+
ICmpInst::Predicate PredL, PredR;
520532
std::optional<std::pair<unsigned, unsigned>> MaskPair =
521533
getMaskedTypeForICmpPair(A, B, C, D, E, LHS, RHS, PredL, PredR);
522534
if (!MaskPair)
@@ -1066,8 +1078,7 @@ static Value *foldPowerOf2AndShiftedMask(ICmpInst *Cmp0, ICmpInst *Cmp1,
10661078
if (!JoinedByAnd)
10671079
return nullptr;
10681080
Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr;
1069-
ICmpInst::Predicate CmpPred0 = Cmp0->getPredicate(),
1070-
CmpPred1 = Cmp1->getPredicate();
1081+
ICmpInst::Predicate CmpPred0, CmpPred1;
10711082
// Assuming P is a 2^n, getMaskedTypeForICmpPair will normalize (icmp X u<
10721083
// 2^n) into (icmp (X & ~(2^n-1)) == 0) and (icmp X s> -1) into (icmp (X &
10731084
// SignMask) == 0).

0 commit comments

Comments
 (0)