Skip to content

Commit 39ac95d

Browse files
committed
[LoopInterchange] Fix overflow in cost calculation
If the iteration count is really large, e.g. UINT_MAX, then the cost calculation can overflows and trigger an assert. So saturate the cost to INT_MAX if this is the case (the cost value is kept in a signed integer). This fixes #104761
1 parent 60f3e67 commit 39ac95d

File tree

2 files changed

+53
-2
lines changed

2 files changed

+53
-2
lines changed

llvm/lib/Analysis/LoopCacheAnalysis.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,8 @@ CacheCostTy IndexedReference::computeRefCost(const Loop &L,
328328
const SCEV *TripCount =
329329
computeTripCount(*AR->getLoop(), *Sizes.back(), SE);
330330
Type *WiderType = SE.getWiderType(RefCost->getType(), TripCount->getType());
331+
// For the multiplication result to fit, request a type twice as wide.
332+
WiderType = WiderType->getExtendedType();
331333
RefCost = SE.getMulExpr(SE.getNoopOrZeroExtend(RefCost, WiderType),
332334
SE.getNoopOrZeroExtend(TripCount, WiderType));
333335
}
@@ -338,8 +340,12 @@ CacheCostTy IndexedReference::computeRefCost(const Loop &L,
338340
assert(RefCost && "Expecting a valid RefCost");
339341

340342
// Attempt to fold RefCost into a constant.
343+
// CacheCostTy is a signed integer, but the tripcount value can be large
344+
// and may not fit, so saturate/limit the value to the maximum signed
345+
// integer value.
341346
if (auto ConstantCost = dyn_cast<SCEVConstant>(RefCost))
342-
return ConstantCost->getValue()->getZExtValue();
347+
return (CacheCostTy)ConstantCost->getValue()->getLimitedValue(
348+
std::numeric_limits<CacheCostTy>::max());
343349

344350
LLVM_DEBUG(dbgs().indent(4)
345351
<< "RefCost is not a constant! Setting to RefCost=InvalidCost "
@@ -712,7 +718,13 @@ CacheCost::computeLoopCacheCost(const Loop &L,
712718
CacheCostTy LoopCost = 0;
713719
for (const ReferenceGroupTy &RG : RefGroups) {
714720
CacheCostTy RefGroupCost = computeRefGroupCacheCost(RG, L);
715-
LoopCost += RefGroupCost * TripCountsProduct;
721+
722+
// Saturate the cost to INT MAX if the value can overflow.
723+
if (RefGroupCost >
724+
(std::numeric_limits<CacheCostTy>::max() / TripCountsProduct))
725+
LoopCost = std::numeric_limits<CacheCostTy>::max();
726+
else
727+
LoopCost += RefGroupCost * TripCountsProduct;
716728
}
717729

718730
LLVM_DEBUG(dbgs().indent(2) << "Loop '" << L.getName()
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
; RUN: opt < %s -passes='print<loop-cache-cost>' -disable-output 2>&1 | FileCheck %s
2+
3+
; For a loop with a very large iteration count, make sure the cost
4+
; calculation does not overflow:
5+
;
6+
; void a(int b) {
7+
; for (int c;; c += b)
8+
; for (long d = 0; d < -3ULL; d += 2ULL)
9+
; A[c][d][d] = 0;
10+
; }
11+
12+
; CHECK: Loop 'for.cond' has cost = 9223372036854775807
13+
; CHECK: Loop 'for.body' has cost = 9223372036854775807
14+
15+
target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128-Fn32"
16+
17+
@A = local_unnamed_addr global [11 x [11 x [11 x i32]]] zeroinitializer, align 16
18+
19+
define void @foo(i32 noundef %b) {
20+
entry:
21+
%0 = sext i32 %b to i64
22+
br label %for.cond
23+
24+
for.cond:
25+
%indvars.iv = phi i64 [ %indvars.iv.next, %for.cond.cleanup ], [ 0, %entry ]
26+
br label %for.body
27+
28+
for.cond.cleanup:
29+
%indvars.iv.next = add nsw i64 %indvars.iv, %0
30+
br label %for.cond
31+
32+
for.body:
33+
%d.010 = phi i64 [ 0, %for.cond ], [ %add, %for.body ]
34+
%arrayidx3 = getelementptr inbounds [11 x [11 x [11 x i32]]], ptr @A, i64 0, i64 %indvars.iv, i64 %d.010, i64 %d.010
35+
store i32 0, ptr %arrayidx3, align 4
36+
%add = add nuw i64 %d.010, 2
37+
%cmp = icmp ult i64 %d.010, -5
38+
br i1 %cmp, label %for.body, label %for.cond.cleanup
39+
}

0 commit comments

Comments
 (0)