Skip to content

Commit 28a5e6b

Browse files
authored
[InstCombine] Remove over-generalization from computeKnownBitsFromCmp() (#72637)
For most practical purposes, the only KnownBits patterns we care about are those involving a constant comparison RHS and constant mask. However, the actual implementation is written in a very general way -- and of course, with basically no test coverage of those generalizations. This patch reduces the implementation to only handle cases with constant operands. The test changes are all in "make sure we don't crash" tests. The motivation for this change is an upcoming patch to handling dominating conditions in computeKnownBits(). Handling non-constant RHS would add significant additional compile-time overhead in that case, without any significant impact on optimization quality.
1 parent 936180a commit 28a5e6b

File tree

4 files changed

+75
-97
lines changed

4 files changed

+75
-97
lines changed

llvm/lib/Analysis/AssumptionCache.cpp

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -92,29 +92,19 @@ findAffectedValues(CallBase *CI, TargetTransformInfo *TTI,
9292
AddAffected(B);
9393

9494
if (Pred == ICmpInst::ICMP_EQ) {
95-
// For equality comparisons, we handle the case of bit inversion.
96-
auto AddAffectedFromEq = [&AddAffected](Value *V) {
97-
Value *A, *B;
98-
// (A & B) or (A | B) or (A ^ B).
99-
if (match(V, m_BitwiseLogic(m_Value(A), m_Value(B)))) {
100-
AddAffected(A);
101-
AddAffected(B);
102-
// (A << C) or (A >>_s C) or (A >>_u C) where C is some constant.
103-
} else if (match(V, m_Shift(m_Value(A), m_ConstantInt()))) {
104-
AddAffected(A);
105-
}
106-
};
107-
108-
AddAffectedFromEq(A);
109-
AddAffectedFromEq(B);
95+
if (match(B, m_ConstantInt())) {
96+
Value *X;
97+
// (X & C) or (X | C) or (X ^ C).
98+
// (X << C) or (X >>_s C) or (X >>_u C).
99+
if (match(A, m_BitwiseLogic(m_Value(X), m_ConstantInt())) ||
100+
match(A, m_Shift(m_Value(X), m_ConstantInt())))
101+
AddAffected(X);
102+
}
110103
} else if (Pred == ICmpInst::ICMP_NE) {
111-
Value *X, *Y;
112-
// Handle (a & b != 0). If a/b is a power of 2 we can use this
113-
// information.
114-
if (match(A, m_And(m_Value(X), m_Value(Y))) && match(B, m_Zero())) {
104+
Value *X;
105+
// Handle (X & pow2 != 0).
106+
if (match(A, m_And(m_Value(X), m_Power2())) && match(B, m_Zero()))
115107
AddAffected(X);
116-
AddAffected(Y);
117-
}
118108
} else if (Pred == ICmpInst::ICMP_ULT) {
119109
Value *X;
120110
// Handle (A + C1) u< C2, which is the canonical form of A > C3 && A < C4,

llvm/lib/Analysis/ValueTracking.cpp

Lines changed: 61 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -633,101 +633,89 @@ static bool isKnownNonZeroFromAssume(const Value *V, const SimplifyQuery &Q) {
633633
static void computeKnownBitsFromCmp(const Value *V, const ICmpInst *Cmp,
634634
KnownBits &Known, unsigned Depth,
635635
const SimplifyQuery &Q) {
636+
if (Cmp->getOperand(1)->getType()->isPointerTy()) {
637+
// Handle comparison of pointer to null explicitly, as it will not be
638+
// covered by the m_APInt() logic below.
639+
if (match(Cmp->getOperand(1), m_Zero())) {
640+
switch (Cmp->getPredicate()) {
641+
case ICmpInst::ICMP_EQ:
642+
Known.setAllZero();
643+
break;
644+
case ICmpInst::ICMP_SGE:
645+
case ICmpInst::ICMP_SGT:
646+
Known.makeNonNegative();
647+
break;
648+
case ICmpInst::ICMP_SLT:
649+
Known.makeNegative();
650+
break;
651+
default:
652+
break;
653+
}
654+
}
655+
return;
656+
}
657+
636658
unsigned BitWidth = Known.getBitWidth();
637-
// We are attempting to compute known bits for the operands of an assume.
638-
// Do not try to use other assumptions for those recursive calls because
639-
// that can lead to mutual recursion and a compile-time explosion.
640-
// An example of the mutual recursion: computeKnownBits can call
641-
// isKnownNonZero which calls computeKnownBitsFromAssume (this function)
642-
// and so on.
643-
SimplifyQuery QueryNoAC = Q;
644-
QueryNoAC.AC = nullptr;
645-
646-
// Note that ptrtoint may change the bitwidth.
647-
Value *A, *B;
648659
auto m_V =
649660
m_CombineOr(m_Specific(V), m_PtrToIntSameSize(Q.DL, m_Specific(V)));
650661

651662
CmpInst::Predicate Pred;
652-
uint64_t C;
663+
const APInt *Mask, *C;
664+
uint64_t ShAmt;
653665
switch (Cmp->getPredicate()) {
654666
case ICmpInst::ICMP_EQ:
655-
// assume(v = a)
656-
if (match(Cmp, m_c_ICmp(Pred, m_V, m_Value(A)))) {
657-
KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
658-
Known = Known.unionWith(RHSKnown);
659-
// assume(v & b = a)
660-
} else if (match(Cmp,
661-
m_c_ICmp(Pred, m_c_And(m_V, m_Value(B)), m_Value(A)))) {
662-
KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
663-
KnownBits MaskKnown = computeKnownBits(B, Depth + 1, QueryNoAC);
664-
665-
// For those bits in the mask that are known to be one, we can propagate
666-
// known bits from the RHS to V.
667-
Known.Zero |= RHSKnown.Zero & MaskKnown.One;
668-
Known.One |= RHSKnown.One & MaskKnown.One;
669-
// assume(v | b = a)
667+
// assume(V = C)
668+
if (match(Cmp, m_ICmp(Pred, m_V, m_APInt(C)))) {
669+
Known = Known.unionWith(KnownBits::makeConstant(*C));
670+
// assume(V & Mask = C)
670671
} else if (match(Cmp,
671-
m_c_ICmp(Pred, m_c_Or(m_V, m_Value(B)), m_Value(A)))) {
672-
KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
673-
KnownBits BKnown = computeKnownBits(B, Depth + 1, QueryNoAC);
674-
675-
// For those bits in B that are known to be zero, we can propagate known
676-
// bits from the RHS to V.
677-
Known.Zero |= RHSKnown.Zero & BKnown.Zero;
678-
Known.One |= RHSKnown.One & BKnown.Zero;
679-
// assume(v ^ b = a)
672+
m_ICmp(Pred, m_And(m_V, m_APInt(Mask)), m_APInt(C)))) {
673+
// For one bits in Mask, we can propagate bits from C to V.
674+
Known.Zero |= ~*C & *Mask;
675+
Known.One |= *C & *Mask;
676+
// assume(V | Mask = C)
677+
} else if (match(Cmp, m_ICmp(Pred, m_Or(m_V, m_APInt(Mask)), m_APInt(C)))) {
678+
// For zero bits in Mask, we can propagate bits from C to V.
679+
Known.Zero |= ~*C & ~*Mask;
680+
Known.One |= *C & ~*Mask;
681+
// assume(V ^ Mask = C)
680682
} else if (match(Cmp,
681-
m_c_ICmp(Pred, m_c_Xor(m_V, m_Value(B)), m_Value(A)))) {
682-
KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
683-
KnownBits BKnown = computeKnownBits(B, Depth + 1, QueryNoAC);
684-
685-
// For those bits in B that are known to be zero, we can propagate known
686-
// bits from the RHS to V. For those bits in B that are known to be one,
687-
// we can propagate inverted known bits from the RHS to V.
688-
Known.Zero |= RHSKnown.Zero & BKnown.Zero;
689-
Known.One |= RHSKnown.One & BKnown.Zero;
690-
Known.Zero |= RHSKnown.One & BKnown.One;
691-
Known.One |= RHSKnown.Zero & BKnown.One;
692-
// assume(v << c = a)
693-
} else if (match(Cmp, m_c_ICmp(Pred, m_Shl(m_V, m_ConstantInt(C)),
694-
m_Value(A))) &&
695-
C < BitWidth) {
696-
KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
697-
698-
// For those bits in RHS that are known, we can propagate them to known
699-
// bits in V shifted to the right by C.
700-
RHSKnown.Zero.lshrInPlace(C);
701-
RHSKnown.One.lshrInPlace(C);
683+
m_ICmp(Pred, m_Xor(m_V, m_APInt(Mask)), m_APInt(C)))) {
684+
// Equivalent to assume(V == Mask ^ C)
685+
Known = Known.unionWith(KnownBits::makeConstant(*C ^ *Mask));
686+
// assume(V << ShAmt = C)
687+
} else if (match(Cmp, m_ICmp(Pred, m_Shl(m_V, m_ConstantInt(ShAmt)),
688+
m_APInt(C))) &&
689+
ShAmt < BitWidth) {
690+
// For those bits in C that are known, we can propagate them to known
691+
// bits in V shifted to the right by ShAmt.
692+
KnownBits RHSKnown = KnownBits::makeConstant(*C);
693+
RHSKnown.Zero.lshrInPlace(ShAmt);
694+
RHSKnown.One.lshrInPlace(ShAmt);
702695
Known = Known.unionWith(RHSKnown);
703-
// assume(v >> c = a)
704-
} else if (match(Cmp, m_c_ICmp(Pred, m_Shr(m_V, m_ConstantInt(C)),
705-
m_Value(A))) &&
706-
C < BitWidth) {
707-
KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
696+
// assume(V >> ShAmt = C)
697+
} else if (match(Cmp, m_ICmp(Pred, m_Shr(m_V, m_ConstantInt(ShAmt)),
698+
m_APInt(C))) &&
699+
ShAmt < BitWidth) {
700+
KnownBits RHSKnown = KnownBits::makeConstant(*C);
708701
// For those bits in RHS that are known, we can propagate them to known
709702
// bits in V shifted to the right by C.
710-
Known.Zero |= RHSKnown.Zero << C;
711-
Known.One |= RHSKnown.One << C;
703+
Known.Zero |= RHSKnown.Zero << ShAmt;
704+
Known.One |= RHSKnown.One << ShAmt;
712705
}
713706
break;
714707
case ICmpInst::ICMP_NE: {
715-
// assume (v & b != 0) where b is a power of 2
708+
// assume (V & B != 0) where B is a power of 2
716709
const APInt *BPow2;
717-
if (match(Cmp, m_ICmp(Pred, m_c_And(m_V, m_Power2(BPow2)), m_Zero()))) {
710+
if (match(Cmp, m_ICmp(Pred, m_And(m_V, m_Power2(BPow2)), m_Zero())))
718711
Known.One |= *BPow2;
719-
}
720712
break;
721713
}
722714
default:
723715
const APInt *Offset = nullptr;
724716
if (match(Cmp, m_ICmp(Pred, m_CombineOr(m_V, m_Add(m_V, m_APInt(Offset))),
725-
m_Value(A)))) {
726-
KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
727-
ConstantRange RHSRange =
728-
ConstantRange::fromKnownBits(RHSKnown, Cmp->isSigned());
729-
ConstantRange LHSRange =
730-
ConstantRange::makeAllowedICmpRegion(Pred, RHSRange);
717+
m_APInt(C)))) {
718+
ConstantRange LHSRange = ConstantRange::makeAllowedICmpRegion(Pred, *C);
731719
if (Offset)
732720
LHSRange = LHSRange.sub(*Offset);
733721
Known = Known.unionWith(LHSRange.toKnownBits());

llvm/test/Transforms/InstCombine/assume.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ define i32 @bundle2(ptr %P) {
268268

269269
define i1 @nonnull1(ptr %a) {
270270
; CHECK-LABEL: @nonnull1(
271-
; CHECK-NEXT: [[LOAD:%.*]] = load ptr, ptr [[A:%.*]], align 8, !nonnull !6, !noundef !6
271+
; CHECK-NEXT: [[LOAD:%.*]] = load ptr, ptr [[A:%.*]], align 8, !nonnull [[META6:![0-9]+]], !noundef [[META6]]
272272
; CHECK-NEXT: tail call void @escape(ptr nonnull [[LOAD]])
273273
; CHECK-NEXT: ret i1 false
274274
;
@@ -386,7 +386,7 @@ define i1 @nonnull5(ptr %a) {
386386
define i32 @assumption_conflicts_with_known_bits(i32 %a, i32 %b) {
387387
; CHECK-LABEL: @assumption_conflicts_with_known_bits(
388388
; CHECK-NEXT: store i1 true, ptr poison, align 1
389-
; CHECK-NEXT: ret i32 poison
389+
; CHECK-NEXT: ret i32 1
390390
;
391391
%and1 = and i32 %b, 3
392392
%B1 = lshr i32 %and1, %and1

llvm/test/Transforms/InstCombine/zext-or-icmp.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ define i1 @PR51762(ptr %i, i32 %t0, i16 %t1, ptr %p, ptr %d, ptr %f, i32 %p2, i1
243243
; CHECK-NEXT: store i32 [[ADD]], ptr [[F]], align 4
244244
; CHECK-NEXT: [[REM18:%.*]] = srem i32 [[LOR_EXT]], [[ADD]]
245245
; CHECK-NEXT: [[CONV19:%.*]] = zext nneg i32 [[REM18]] to i64
246-
; CHECK-NEXT: store i32 0, ptr [[D]], align 8
246+
; CHECK-NEXT: store i32 [[SROA38]], ptr [[D]], align 8
247247
; CHECK-NEXT: [[R:%.*]] = icmp ult i64 [[INSERT_INSERT41]], [[CONV19]]
248248
; CHECK-NEXT: call void @llvm.assume(i1 [[R]])
249249
; CHECK-NEXT: ret i1 [[R]]

0 commit comments

Comments
 (0)