Skip to content

Commit d881bac

Browse files
authored
[BasicAA] Consider 'nneg' flag when comparing CastedValues (#94129)
Any of the `zext` bits in a `zext nneg` can be converted to `sext` but when checking if casts are compatible `BasicAA` fails to take into account `nneg`. This change adds tracking of `nneg` to the `CastedValue` struct and ensures that `sext` and `zext` bits are treated as interchangeable when either `CastedValue` has a `nneg`. When distributing casted values in `GetLinearExpression` we conservatively discard the `nneg` from the `CastedValue`, except in the case of `shl nsw`, where we know the sign has not changed to negative.
1 parent b9f1fdc commit d881bac

File tree

2 files changed

+224
-19
lines changed

2 files changed

+224
-19
lines changed

llvm/lib/Analysis/BasicAliasAnalysis.cpp

Lines changed: 43 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -268,43 +268,60 @@ struct CastedValue {
268268
unsigned ZExtBits = 0;
269269
unsigned SExtBits = 0;
270270
unsigned TruncBits = 0;
271+
/// Whether trunc(V) is non-negative.
272+
bool IsNonNegative = false;
271273

272274
explicit CastedValue(const Value *V) : V(V) {}
273275
explicit CastedValue(const Value *V, unsigned ZExtBits, unsigned SExtBits,
274-
unsigned TruncBits)
275-
: V(V), ZExtBits(ZExtBits), SExtBits(SExtBits), TruncBits(TruncBits) {}
276+
unsigned TruncBits, bool IsNonNegative)
277+
: V(V), ZExtBits(ZExtBits), SExtBits(SExtBits), TruncBits(TruncBits),
278+
IsNonNegative(IsNonNegative) {}
276279

277280
unsigned getBitWidth() const {
278281
return V->getType()->getPrimitiveSizeInBits() - TruncBits + ZExtBits +
279282
SExtBits;
280283
}
281284

282-
CastedValue withValue(const Value *NewV) const {
283-
return CastedValue(NewV, ZExtBits, SExtBits, TruncBits);
285+
CastedValue withValue(const Value *NewV, bool PreserveNonNeg) const {
286+
return CastedValue(NewV, ZExtBits, SExtBits, TruncBits,
287+
IsNonNegative && PreserveNonNeg);
284288
}
285289

286290
/// Replace V with zext(NewV)
287-
CastedValue withZExtOfValue(const Value *NewV) const {
291+
CastedValue withZExtOfValue(const Value *NewV, bool ZExtNonNegative) const {
288292
unsigned ExtendBy = V->getType()->getPrimitiveSizeInBits() -
289293
NewV->getType()->getPrimitiveSizeInBits();
290294
if (ExtendBy <= TruncBits)
291-
return CastedValue(NewV, ZExtBits, SExtBits, TruncBits - ExtendBy);
295+
// zext<nneg>(trunc(zext(NewV))) == zext<nneg>(trunc(NewV))
296+
// The nneg can be preserved on the outer zext here.
297+
return CastedValue(NewV, ZExtBits, SExtBits, TruncBits - ExtendBy,
298+
IsNonNegative);
292299

293300
// zext(sext(zext(NewV))) == zext(zext(zext(NewV)))
294301
ExtendBy -= TruncBits;
295-
return CastedValue(NewV, ZExtBits + SExtBits + ExtendBy, 0, 0);
302+
// zext<nneg>(zext(NewV)) == zext(NewV)
303+
// zext(zext<nneg>(NewV)) == zext<nneg>(NewV)
304+
// The nneg can be preserved from the inner zext here but must be dropped
305+
// from the outer.
306+
return CastedValue(NewV, ZExtBits + SExtBits + ExtendBy, 0, 0,
307+
ZExtNonNegative);
296308
}
297309

298310
/// Replace V with sext(NewV)
299311
CastedValue withSExtOfValue(const Value *NewV) const {
300312
unsigned ExtendBy = V->getType()->getPrimitiveSizeInBits() -
301313
NewV->getType()->getPrimitiveSizeInBits();
302314
if (ExtendBy <= TruncBits)
303-
return CastedValue(NewV, ZExtBits, SExtBits, TruncBits - ExtendBy);
315+
// zext<nneg>(trunc(sext(NewV))) == zext<nneg>(trunc(NewV))
316+
// The nneg can be preserved on the outer zext here
317+
return CastedValue(NewV, ZExtBits, SExtBits, TruncBits - ExtendBy,
318+
IsNonNegative);
304319

305320
// zext(sext(sext(NewV)))
306321
ExtendBy -= TruncBits;
307-
return CastedValue(NewV, ZExtBits, SExtBits + ExtendBy, 0);
322+
// zext<nneg>(sext(sext(NewV))) = zext<nneg>(sext(NewV))
323+
// The nneg can be preserved on the outer zext here
324+
return CastedValue(NewV, ZExtBits, SExtBits + ExtendBy, 0, IsNonNegative);
308325
}
309326

310327
APInt evaluateWith(APInt N) const {
@@ -333,8 +350,15 @@ struct CastedValue {
333350
}
334351

335352
bool hasSameCastsAs(const CastedValue &Other) const {
336-
return ZExtBits == Other.ZExtBits && SExtBits == Other.SExtBits &&
337-
TruncBits == Other.TruncBits;
353+
if (ZExtBits == Other.ZExtBits && SExtBits == Other.SExtBits &&
354+
TruncBits == Other.TruncBits)
355+
return true;
356+
// If either CastedValue has a nneg zext then the sext/zext bits are
357+
// interchangable for that value.
358+
if (IsNonNegative || Other.IsNonNegative)
359+
return (ZExtBits + SExtBits == Other.ZExtBits + Other.SExtBits &&
360+
TruncBits == Other.TruncBits);
361+
return false;
338362
}
339363
};
340364

@@ -410,21 +434,21 @@ static LinearExpression GetLinearExpression(
410434

411435
[[fallthrough]];
412436
case Instruction::Add: {
413-
E = GetLinearExpression(Val.withValue(BOp->getOperand(0)), DL,
437+
E = GetLinearExpression(Val.withValue(BOp->getOperand(0), false), DL,
414438
Depth + 1, AC, DT);
415439
E.Offset += RHS;
416440
E.IsNSW &= NSW;
417441
break;
418442
}
419443
case Instruction::Sub: {
420-
E = GetLinearExpression(Val.withValue(BOp->getOperand(0)), DL,
444+
E = GetLinearExpression(Val.withValue(BOp->getOperand(0), false), DL,
421445
Depth + 1, AC, DT);
422446
E.Offset -= RHS;
423447
E.IsNSW &= NSW;
424448
break;
425449
}
426450
case Instruction::Mul:
427-
E = GetLinearExpression(Val.withValue(BOp->getOperand(0)), DL,
451+
E = GetLinearExpression(Val.withValue(BOp->getOperand(0), false), DL,
428452
Depth + 1, AC, DT)
429453
.mul(RHS, NSW);
430454
break;
@@ -437,7 +461,7 @@ static LinearExpression GetLinearExpression(
437461
if (RHS.getLimitedValue() > Val.getBitWidth())
438462
return Val;
439463

440-
E = GetLinearExpression(Val.withValue(BOp->getOperand(0)), DL,
464+
E = GetLinearExpression(Val.withValue(BOp->getOperand(0), NSW), DL,
441465
Depth + 1, AC, DT);
442466
E.Offset <<= RHS.getLimitedValue();
443467
E.Scale <<= RHS.getLimitedValue();
@@ -448,10 +472,10 @@ static LinearExpression GetLinearExpression(
448472
}
449473
}
450474

451-
if (isa<ZExtInst>(Val.V))
475+
if (const auto *ZExt = dyn_cast<ZExtInst>(Val.V))
452476
return GetLinearExpression(
453-
Val.withZExtOfValue(cast<CastInst>(Val.V)->getOperand(0)),
454-
DL, Depth + 1, AC, DT);
477+
Val.withZExtOfValue(ZExt->getOperand(0), ZExt->hasNonNeg()), DL,
478+
Depth + 1, AC, DT);
455479

456480
if (isa<SExtInst>(Val.V))
457481
return GetLinearExpression(
@@ -673,7 +697,7 @@ BasicAAResult::DecomposeGEPExpression(const Value *V, const DataLayout &DL,
673697
unsigned SExtBits = IndexSize > Width ? IndexSize - Width : 0;
674698
unsigned TruncBits = IndexSize < Width ? Width - IndexSize : 0;
675699
LinearExpression LE = GetLinearExpression(
676-
CastedValue(Index, 0, SExtBits, TruncBits), DL, 0, AC, DT);
700+
CastedValue(Index, 0, SExtBits, TruncBits, false), DL, 0, AC, DT);
677701

678702
// Scale by the type size.
679703
unsigned TypeSize = AllocTypeSize.getFixedValue();
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
; RUN: opt < %s -aa-pipeline=basic-aa -passes=aa-eval -print-all-alias-modref-info -disable-output 2>&1 | FileCheck %s
2+
3+
;; Simple case: a zext nneg can be replaced with a sext. Make sure BasicAA
4+
;; understands that.
5+
define void @t1(i32 %a, i32 %b) {
6+
; CHECK-LABEL: Function: t1
7+
; CHECK: NoAlias: float* %gep1, float* %gep2
8+
9+
%1 = alloca [8 x float], align 4
10+
%or1 = or i32 %a, 1
11+
%2 = sext i32 %or1 to i64
12+
%gep1 = getelementptr inbounds float, ptr %1, i64 %2
13+
14+
%shl1 = shl i32 %b, 1
15+
%3 = zext nneg i32 %shl1 to i64
16+
%gep2 = getelementptr inbounds float, ptr %1, i64 %3
17+
18+
load float, ptr %gep1
19+
load float, ptr %gep2
20+
ret void
21+
}
22+
23+
;; A (zext nneg (sext V)) is equivalent to a (zext (sext V)) as long as the
24+
;; total number of zext+sext bits is the same for both.
25+
define void @t2(i8 %a, i8 %b) {
26+
; CHECK-LABEL: Function: t2
27+
; CHECK: NoAlias: float* %gep1, float* %gep2
28+
%1 = alloca [8 x float], align 4
29+
%or1 = or i8 %a, 1
30+
%2 = sext i8 %or1 to i32
31+
%3 = zext i32 %2 to i64
32+
%gep1 = getelementptr inbounds float, ptr %1, i64 %3
33+
34+
%shl1 = shl i8 %b, 1
35+
%4 = sext i8 %shl1 to i16
36+
%5 = zext nneg i16 %4 to i64
37+
%gep2 = getelementptr inbounds float, ptr %1, i64 %5
38+
39+
load float, ptr %gep1
40+
load float, ptr %gep2
41+
ret void
42+
}
43+
44+
;; Here the %a and %b are knowably non-equal. In this cases we can distribute
45+
;; the zext, preserving the nneg flag, through the shl because it has a nsw flag
46+
define void @t3(i8 %v) {
47+
; CHECK-LABEL: Function: t3
48+
; CHECK: NoAlias: <2 x float>* %gep1, <2 x float>* %gep2
49+
%a = or i8 %v, 1
50+
%b = and i8 %v, 2
51+
52+
%1 = alloca [8 x float], align 4
53+
%or1 = shl nuw nsw i8 %a, 1
54+
%2 = zext nneg i8 %or1 to i64
55+
%gep1 = getelementptr inbounds float, ptr %1, i64 %2
56+
57+
%m = mul nsw nuw i8 %b, 2
58+
%3 = sext i8 %m to i16
59+
%4 = zext i16 %3 to i64
60+
%gep2 = getelementptr inbounds float, ptr %1, i64 %4
61+
62+
load <2 x float>, ptr %gep1
63+
load <2 x float>, ptr %gep2
64+
ret void
65+
}
66+
67+
;; This is the same as above, but this time the shl does not have the nsw flag.
68+
;; the nneg cannot be kept on the zext.
69+
define void @t4(i8 %v) {
70+
; CHECK-LABEL: Function: t4
71+
; CHECK: MayAlias: <2 x float>* %gep1, <2 x float>* %gep2
72+
%a = or i8 %v, 1
73+
%b = and i8 %v, 2
74+
75+
%1 = alloca [8 x float], align 4
76+
%or1 = shl nuw i8 %a, 1
77+
%2 = zext nneg i8 %or1 to i64
78+
%gep1 = getelementptr inbounds float, ptr %1, i64 %2
79+
80+
%m = mul nsw nuw i8 %b, 2
81+
%3 = sext i8 %m to i16
82+
%4 = zext i16 %3 to i64
83+
%gep2 = getelementptr inbounds float, ptr %1, i64 %4
84+
85+
load <2 x float>, ptr %gep1
86+
load <2 x float>, ptr %gep2
87+
ret void
88+
}
89+
90+
;; Verify a zext nneg and a zext are understood as the same
91+
define void @t5(ptr %p, i16 %i) {
92+
; CHECK-LABEL: Function: t5
93+
; CHECK: NoAlias: i32* %pi, i32* %pi.next
94+
%i1 = zext nneg i16 %i to i32
95+
%pi = getelementptr i32, ptr %p, i32 %i1
96+
97+
%i.next = add i16 %i, 1
98+
%i.next2 = zext i16 %i.next to i32
99+
%pi.next = getelementptr i32, ptr %p, i32 %i.next2
100+
101+
load i32, ptr %pi
102+
load i32, ptr %pi.next
103+
ret void
104+
}
105+
106+
;; This is not very idiomatic, but still possible, verify the nneg is propagated
107+
;; outward. and that no alias is correctly identified.
108+
define void @t6(i8 %a) {
109+
; CHECK-LABEL: Function: t6
110+
; CHECK: NoAlias: float* %gep1, float* %gep2
111+
%1 = alloca [8 x float], align 4
112+
%a.add = add i8 %a, 1
113+
%2 = zext nneg i8 %a.add to i16
114+
%3 = sext i16 %2 to i32
115+
%4 = zext i32 %3 to i64
116+
%gep1 = getelementptr inbounds float, ptr %1, i64 %4
117+
118+
%5 = sext i8 %a to i64
119+
%gep2 = getelementptr inbounds float, ptr %1, i64 %5
120+
121+
load float, ptr %gep1
122+
load float, ptr %gep2
123+
ret void
124+
}
125+
126+
;; This is even less idiomatic, but still possible, verify the nneg is not
127+
;; propagated inward. and that may alias is correctly identified.
128+
define void @t7(i8 %a) {
129+
; CHECK-LABEL: Function: t7
130+
; CHECK: MayAlias: float* %gep1, float* %gep2
131+
%1 = alloca [8 x float], align 4
132+
%a.add = add i8 %a, 1
133+
%2 = zext i8 %a.add to i16
134+
%3 = sext i16 %2 to i32
135+
%4 = zext nneg i32 %3 to i64
136+
%gep1 = getelementptr inbounds float, ptr %1, i64 %4
137+
138+
%5 = sext i8 %a to i64
139+
%gep2 = getelementptr inbounds float, ptr %1, i64 %5
140+
141+
load float, ptr %gep1
142+
load float, ptr %gep2
143+
ret void
144+
}
145+
146+
;; Verify the nneg survives an implicit trunc of fewer bits then the zext.
147+
define void @t8(i8 %a) {
148+
; CHECK-LABEL: Function: t8
149+
; CHECK: NoAlias: float* %gep1, float* %gep2
150+
%1 = alloca [8 x float], align 4
151+
%a.add = add i8 %a, 1
152+
%2 = zext nneg i8 %a.add to i128
153+
%gep1 = getelementptr inbounds float, ptr %1, i128 %2
154+
155+
%3 = sext i8 %a to i64
156+
%gep2 = getelementptr inbounds float, ptr %1, i64 %3
157+
158+
load float, ptr %gep1
159+
load float, ptr %gep2
160+
ret void
161+
}
162+
163+
;; Ensure that the nneg is never propagated past this trunc and that these
164+
;; casted values are understood as non-equal.
165+
define void @t9(i8 %a) {
166+
; CHECK-LABEL: Function: t9
167+
; CHECK: MayAlias: float* %gep1, float* %gep2
168+
%1 = alloca [8 x float], align 4
169+
%a.add = add i8 %a, 1
170+
%2 = zext i8 %a.add to i16
171+
%3 = trunc i16 %2 to i1
172+
%4 = zext nneg i1 %3 to i64
173+
%gep1 = getelementptr inbounds float, ptr %1, i64 %4
174+
175+
%5 = sext i8 %a to i64
176+
%gep2 = getelementptr inbounds float, ptr %1, i64 %5
177+
178+
load float, ptr %gep1
179+
load float, ptr %gep2
180+
ret void
181+
}

0 commit comments

Comments
 (0)