diff --git a/llvm/include/llvm/Analysis/LoopCacheAnalysis.h b/llvm/include/llvm/Analysis/LoopCacheAnalysis.h index 4fd2485e39d6d..3e22487e5e349 100644 --- a/llvm/include/llvm/Analysis/LoopCacheAnalysis.h +++ b/llvm/include/llvm/Analysis/LoopCacheAnalysis.h @@ -16,6 +16,7 @@ #include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/IR/PassManager.h" +#include "llvm/Support/InstructionCost.h" #include namespace llvm { @@ -31,7 +32,7 @@ class ScalarEvolution; class SCEV; class TargetTransformInfo; -using CacheCostTy = int64_t; +using CacheCostTy = InstructionCost; using LoopVectorTy = SmallVector; /// Represents a memory reference as a base pointer and a set of indexing @@ -192,8 +193,6 @@ class CacheCost { using LoopCacheCostTy = std::pair; public: - static CacheCostTy constexpr InvalidCost = -1; - /// Construct a CacheCost object for the loop nest described by \p Loops. /// The optional parameter \p TRT can be used to specify the max. distance /// between array elements accessed in a loop so that the elements are diff --git a/llvm/lib/Analysis/LoopCacheAnalysis.cpp b/llvm/lib/Analysis/LoopCacheAnalysis.cpp index 7ca9f15ad5fca..2897b922f61e4 100644 --- a/llvm/lib/Analysis/LoopCacheAnalysis.cpp +++ b/llvm/lib/Analysis/LoopCacheAnalysis.cpp @@ -328,6 +328,8 @@ CacheCostTy IndexedReference::computeRefCost(const Loop &L, const SCEV *TripCount = computeTripCount(*AR->getLoop(), *Sizes.back(), SE); Type *WiderType = SE.getWiderType(RefCost->getType(), TripCount->getType()); + // For the multiplication result to fit, request a type twice as wide. + WiderType = WiderType->getExtendedType(); RefCost = SE.getMulExpr(SE.getNoopOrZeroExtend(RefCost, WiderType), SE.getNoopOrZeroExtend(TripCount, WiderType)); } @@ -338,14 +340,18 @@ CacheCostTy IndexedReference::computeRefCost(const Loop &L, assert(RefCost && "Expecting a valid RefCost"); // Attempt to fold RefCost into a constant. + // CacheCostTy is a signed integer, but the tripcount value can be large + // and may not fit, so saturate/limit the value to the maximum signed + // integer value. if (auto ConstantCost = dyn_cast(RefCost)) - return ConstantCost->getValue()->getZExtValue(); + return ConstantCost->getValue()->getLimitedValue( + std::numeric_limits::max()); LLVM_DEBUG(dbgs().indent(4) << "RefCost is not a constant! Setting to RefCost=InvalidCost " "(invalid value).\n"); - return CacheCost::InvalidCost; + return CacheCostTy::getInvalid(); } bool IndexedReference::tryDelinearizeFixedSize( @@ -696,7 +702,7 @@ CacheCostTy CacheCost::computeLoopCacheCost(const Loop &L, const ReferenceGroupsTy &RefGroups) const { if (!L.isLoopSimplifyForm()) - return InvalidCost; + return CacheCostTy::getInvalid(); LLVM_DEBUG(dbgs() << "Considering loop '" << L.getName() << "' as innermost loop.\n"); diff --git a/llvm/test/Analysis/LoopCacheAnalysis/interchange-refcost-overflow.ll b/llvm/test/Analysis/LoopCacheAnalysis/interchange-refcost-overflow.ll new file mode 100644 index 0000000000000..7b6529601da32 --- /dev/null +++ b/llvm/test/Analysis/LoopCacheAnalysis/interchange-refcost-overflow.ll @@ -0,0 +1,37 @@ +; RUN: opt < %s -passes='print' -disable-output 2>&1 | FileCheck %s + +; For a loop with a very large iteration count, make sure the cost +; calculation does not overflow: +; +; void a(int b) { +; for (int c;; c += b) +; for (long d = 0; d < -3ULL; d += 2ULL) +; A[c][d][d] = 0; +; } + +; CHECK: Loop 'outer.loop' has cost = 9223372036854775807 +; CHECK: Loop 'inner.loop' has cost = 9223372036854775807 + +@A = local_unnamed_addr global [11 x [11 x [11 x i32]]] zeroinitializer, align 16 + +define void @foo(i32 noundef %b) { +entry: + %0 = sext i32 %b to i64 + br label %outer.loop + +outer.loop: + %indvars.iv = phi i64 [ %indvars.iv.next, %outer.loop.cleanup ], [ 0, %entry ] + br label %inner.loop + +outer.loop.cleanup: + %indvars.iv.next = add nsw i64 %indvars.iv, %0 + br label %outer.loop + +inner.loop: + %inner.iv = phi i64 [ 0, %outer.loop ], [ %add, %inner.loop ] + %arrayidx3 = getelementptr inbounds [11 x [11 x [11 x i32]]], ptr @A, i64 0, i64 %indvars.iv, i64 %inner.iv, i64 %inner.iv + store i32 0, ptr %arrayidx3, align 4 + %add = add nuw i64 %inner.iv, 2 + %cmp = icmp ult i64 %inner.iv, -5 + br i1 %cmp, label %inner.loop, label %outer.loop.cleanup +}