Skip to content

Commit b7286db

Browse files
authored
Reland "[LoopVectorize] Add support for reverse loops in isDereferenceableAndAlignedInLoop #96752" (#123616)
The last attempt failed a sanitiser build because we were creating a reference to a null Predicates pointer in isDereferenceableAndAlignedInLoop. This was exposed by the unit test IsDerefReadOnlyLoop in unittests/Analysis/LoadsTest.cpp. I fixed this by falling back on getConstantMaxBackedgeTakenCount if Predicates is null - see line 316 in llvm/lib/Analysis/Loads.cpp. There are no other changes.
1 parent ddbfe6f commit b7286db

File tree

5 files changed

+191
-267
lines changed

5 files changed

+191
-267
lines changed

llvm/include/llvm/Analysis/LoopAccessAnalysis.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -853,6 +853,25 @@ bool sortPtrAccesses(ArrayRef<Value *> VL, Type *ElemTy, const DataLayout &DL,
853853
bool isConsecutiveAccess(Value *A, Value *B, const DataLayout &DL,
854854
ScalarEvolution &SE, bool CheckType = true);
855855

856+
/// Calculate Start and End points of memory access.
857+
/// Let's assume A is the first access and B is a memory access on N-th loop
858+
/// iteration. Then B is calculated as:
859+
/// B = A + Step*N .
860+
/// Step value may be positive or negative.
861+
/// N is a calculated back-edge taken count:
862+
/// N = (TripCount > 0) ? RoundDown(TripCount -1 , VF) : 0
863+
/// Start and End points are calculated in the following way:
864+
/// Start = UMIN(A, B) ; End = UMAX(A, B) + SizeOfElt,
865+
/// where SizeOfElt is the size of single memory access in bytes.
866+
///
867+
/// There is no conflict when the intervals are disjoint:
868+
/// NoConflict = (P2.Start >= P1.End) || (P1.Start >= P2.End)
869+
std::pair<const SCEV *, const SCEV *> getStartAndEndForAccess(
870+
const Loop *Lp, const SCEV *PtrExpr, Type *AccessTy, const SCEV *MaxBECount,
871+
ScalarEvolution *SE,
872+
DenseMap<std::pair<const SCEV *, Type *>,
873+
std::pair<const SCEV *, const SCEV *>> *PointerBounds);
874+
856875
class LoopAccessInfoManager {
857876
/// The cache.
858877
DenseMap<Loop *, std::unique_ptr<LoopAccessInfo>> LoopAccessInfoMap;

llvm/lib/Analysis/Loads.cpp

Lines changed: 60 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "llvm/Analysis/Loads.h"
1414
#include "llvm/Analysis/AliasAnalysis.h"
1515
#include "llvm/Analysis/AssumeBundleQueries.h"
16+
#include "llvm/Analysis/LoopAccessAnalysis.h"
1617
#include "llvm/Analysis/LoopInfo.h"
1718
#include "llvm/Analysis/MemoryBuiltins.h"
1819
#include "llvm/Analysis/MemoryLocation.h"
@@ -277,84 +278,90 @@ static bool AreEquivalentAddressValues(const Value *A, const Value *B) {
277278
bool llvm::isDereferenceableAndAlignedInLoop(
278279
LoadInst *LI, Loop *L, ScalarEvolution &SE, DominatorTree &DT,
279280
AssumptionCache *AC, SmallVectorImpl<const SCEVPredicate *> *Predicates) {
281+
const Align Alignment = LI->getAlign();
280282
auto &DL = LI->getDataLayout();
281283
Value *Ptr = LI->getPointerOperand();
282-
283284
APInt EltSize(DL.getIndexTypeSizeInBits(Ptr->getType()),
284285
DL.getTypeStoreSize(LI->getType()).getFixedValue());
285-
const Align Alignment = LI->getAlign();
286-
287-
Instruction *HeaderFirstNonPHI = &*L->getHeader()->getFirstNonPHIIt();
288286

289287
// If given a uniform (i.e. non-varying) address, see if we can prove the
290288
// access is safe within the loop w/o needing predication.
291289
if (L->isLoopInvariant(Ptr))
292-
return isDereferenceableAndAlignedPointer(Ptr, Alignment, EltSize, DL,
293-
HeaderFirstNonPHI, AC, &DT);
290+
return isDereferenceableAndAlignedPointer(
291+
Ptr, Alignment, EltSize, DL, &*L->getHeader()->getFirstNonPHIIt(), AC,
292+
&DT);
293+
294+
const SCEV *PtrScev = SE.getSCEV(Ptr);
295+
auto *AddRec = dyn_cast<SCEVAddRecExpr>(PtrScev);
294296

295-
// Otherwise, check to see if we have a repeating access pattern where we can
296-
// prove that all accesses are well aligned and dereferenceable.
297-
auto *AddRec = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(Ptr));
297+
// Check to see if we have a repeating access pattern and it's possible
298+
// to prove all accesses are well aligned.
298299
if (!AddRec || AddRec->getLoop() != L || !AddRec->isAffine())
299300
return false;
300-
auto* Step = dyn_cast<SCEVConstant>(AddRec->getStepRecurrence(SE));
301+
302+
auto *Step = dyn_cast<SCEVConstant>(AddRec->getStepRecurrence(SE));
301303
if (!Step)
302304
return false;
303305

304-
auto TC = SE.getSmallConstantMaxTripCount(L, Predicates);
305-
if (!TC)
306+
// For the moment, restrict ourselves to the case where the access size is a
307+
// multiple of the requested alignment and the base is aligned.
308+
// TODO: generalize if a case found which warrants
309+
if (EltSize.urem(Alignment.value()) != 0)
306310
return false;
307311

308312
// TODO: Handle overlapping accesses.
309-
// We should be computing AccessSize as (TC - 1) * Step + EltSize.
310-
if (EltSize.sgt(Step->getAPInt()))
313+
if (EltSize.ugt(Step->getAPInt().abs()))
314+
return false;
315+
316+
const SCEV *MaxBECount =
317+
Predicates ? SE.getPredicatedConstantMaxBackedgeTakenCount(L, *Predicates)
318+
: SE.getConstantMaxBackedgeTakenCount(L);
319+
if (isa<SCEVCouldNotCompute>(MaxBECount))
320+
return false;
321+
322+
const auto &[AccessStart, AccessEnd] = getStartAndEndForAccess(
323+
L, PtrScev, LI->getType(), MaxBECount, &SE, nullptr);
324+
if (isa<SCEVCouldNotCompute>(AccessStart) ||
325+
isa<SCEVCouldNotCompute>(AccessEnd))
311326
return false;
312327

313-
// Compute the total access size for access patterns with unit stride and
314-
// patterns with gaps. For patterns with unit stride, Step and EltSize are the
315-
// same.
316-
// For patterns with gaps (i.e. non unit stride), we are
317-
// accessing EltSize bytes at every Step.
318-
APInt AccessSize = TC * Step->getAPInt();
328+
// Try to get the access size.
329+
const SCEV *PtrDiff = SE.getMinusSCEV(AccessEnd, AccessStart);
330+
APInt MaxPtrDiff = SE.getUnsignedRangeMax(PtrDiff);
319331

320-
assert(SE.isLoopInvariant(AddRec->getStart(), L) &&
321-
"implied by addrec definition");
322332
Value *Base = nullptr;
323-
if (auto *StartS = dyn_cast<SCEVUnknown>(AddRec->getStart())) {
324-
Base = StartS->getValue();
325-
} else if (auto *StartS = dyn_cast<SCEVAddExpr>(AddRec->getStart())) {
326-
// Handle (NewBase + offset) as start value.
327-
const auto *Offset = dyn_cast<SCEVConstant>(StartS->getOperand(0));
328-
const auto *NewBase = dyn_cast<SCEVUnknown>(StartS->getOperand(1));
329-
if (StartS->getNumOperands() == 2 && Offset && NewBase) {
330-
// The following code below assumes the offset is unsigned, but GEP
331-
// offsets are treated as signed so we can end up with a signed value
332-
// here too. For example, suppose the initial PHI value is (i8 255),
333-
// the offset will be treated as (i8 -1) and sign-extended to (i64 -1).
334-
if (Offset->getAPInt().isNegative())
335-
return false;
333+
APInt AccessSize;
334+
if (const SCEVUnknown *NewBase = dyn_cast<SCEVUnknown>(AccessStart)) {
335+
Base = NewBase->getValue();
336+
AccessSize = MaxPtrDiff;
337+
} else if (auto *MinAdd = dyn_cast<SCEVAddExpr>(AccessStart)) {
338+
if (MinAdd->getNumOperands() != 2)
339+
return false;
336340

337-
// For the moment, restrict ourselves to the case where the offset is a
338-
// multiple of the requested alignment and the base is aligned.
339-
// TODO: generalize if a case found which warrants
340-
if (Offset->getAPInt().urem(Alignment.value()) != 0)
341-
return false;
342-
Base = NewBase->getValue();
343-
bool Overflow = false;
344-
AccessSize = AccessSize.uadd_ov(Offset->getAPInt(), Overflow);
345-
if (Overflow)
346-
return false;
347-
}
348-
}
341+
const auto *Offset = dyn_cast<SCEVConstant>(MinAdd->getOperand(0));
342+
const auto *NewBase = dyn_cast<SCEVUnknown>(MinAdd->getOperand(1));
343+
if (!Offset || !NewBase)
344+
return false;
349345

350-
if (!Base)
351-
return false;
346+
// The following code below assumes the offset is unsigned, but GEP
347+
// offsets are treated as signed so we can end up with a signed value
348+
// here too. For example, suppose the initial PHI value is (i8 255),
349+
// the offset will be treated as (i8 -1) and sign-extended to (i64 -1).
350+
if (Offset->getAPInt().isNegative())
351+
return false;
352352

353-
// For the moment, restrict ourselves to the case where the access size is a
354-
// multiple of the requested alignment and the base is aligned.
355-
// TODO: generalize if a case found which warrants
356-
if (EltSize.urem(Alignment.value()) != 0)
353+
// For the moment, restrict ourselves to the case where the offset is a
354+
// multiple of the requested alignment and the base is aligned.
355+
// TODO: generalize if a case found which warrants
356+
if (Offset->getAPInt().urem(Alignment.value()) != 0)
357+
return false;
358+
359+
AccessSize = MaxPtrDiff + Offset->getAPInt();
360+
Base = NewBase->getValue();
361+
} else
357362
return false;
363+
364+
Instruction *HeaderFirstNonPHI = L->getHeader()->getFirstNonPHI();
358365
return isDereferenceableAndAlignedPointer(Base, Alignment, AccessSize, DL,
359366
HeaderFirstNonPHI, AC, &DT);
360367
}

llvm/lib/Analysis/LoopAccessAnalysis.cpp

Lines changed: 26 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -190,42 +190,29 @@ RuntimeCheckingPtrGroup::RuntimeCheckingPtrGroup(
190190
Members.push_back(Index);
191191
}
192192

193-
/// Calculate Start and End points of memory access.
194-
/// Let's assume A is the first access and B is a memory access on N-th loop
195-
/// iteration. Then B is calculated as:
196-
/// B = A + Step*N .
197-
/// Step value may be positive or negative.
198-
/// N is a calculated back-edge taken count:
199-
/// N = (TripCount > 0) ? RoundDown(TripCount -1 , VF) : 0
200-
/// Start and End points are calculated in the following way:
201-
/// Start = UMIN(A, B) ; End = UMAX(A, B) + SizeOfElt,
202-
/// where SizeOfElt is the size of single memory access in bytes.
203-
///
204-
/// There is no conflict when the intervals are disjoint:
205-
/// NoConflict = (P2.Start >= P1.End) || (P1.Start >= P2.End)
206-
static std::pair<const SCEV *, const SCEV *> getStartAndEndForAccess(
207-
const Loop *Lp, const SCEV *PtrExpr, Type *AccessTy,
208-
PredicatedScalarEvolution &PSE,
193+
std::pair<const SCEV *, const SCEV *> llvm::getStartAndEndForAccess(
194+
const Loop *Lp, const SCEV *PtrExpr, Type *AccessTy, const SCEV *MaxBECount,
195+
ScalarEvolution *SE,
209196
DenseMap<std::pair<const SCEV *, Type *>,
210-
std::pair<const SCEV *, const SCEV *>> &PointerBounds) {
211-
ScalarEvolution *SE = PSE.getSE();
212-
213-
auto [Iter, Ins] = PointerBounds.insert(
214-
{{PtrExpr, AccessTy},
215-
{SE->getCouldNotCompute(), SE->getCouldNotCompute()}});
216-
if (!Ins)
217-
return Iter->second;
197+
std::pair<const SCEV *, const SCEV *>> *PointerBounds) {
198+
std::pair<const SCEV *, const SCEV *> *PtrBoundsPair;
199+
if (PointerBounds) {
200+
auto [Iter, Ins] = PointerBounds->insert(
201+
{{PtrExpr, AccessTy},
202+
{SE->getCouldNotCompute(), SE->getCouldNotCompute()}});
203+
if (!Ins)
204+
return Iter->second;
205+
PtrBoundsPair = &Iter->second;
206+
}
218207

219208
const SCEV *ScStart;
220209
const SCEV *ScEnd;
221210

222211
if (SE->isLoopInvariant(PtrExpr, Lp)) {
223212
ScStart = ScEnd = PtrExpr;
224213
} else if (auto *AR = dyn_cast<SCEVAddRecExpr>(PtrExpr)) {
225-
const SCEV *Ex = PSE.getSymbolicMaxBackedgeTakenCount();
226-
227214
ScStart = AR->getStart();
228-
ScEnd = AR->evaluateAtIteration(Ex, *SE);
215+
ScEnd = AR->evaluateAtIteration(MaxBECount, *SE);
229216
const SCEV *Step = AR->getStepRecurrence(*SE);
230217

231218
// For expressions with negative step, the upper bound is ScStart and the
@@ -244,16 +231,18 @@ static std::pair<const SCEV *, const SCEV *> getStartAndEndForAccess(
244231
return {SE->getCouldNotCompute(), SE->getCouldNotCompute()};
245232

246233
assert(SE->isLoopInvariant(ScStart, Lp) && "ScStart needs to be invariant");
247-
assert(SE->isLoopInvariant(ScEnd, Lp)&& "ScEnd needs to be invariant");
234+
assert(SE->isLoopInvariant(ScEnd, Lp) && "ScEnd needs to be invariant");
248235

249236
// Add the size of the pointed element to ScEnd.
250237
auto &DL = Lp->getHeader()->getDataLayout();
251238
Type *IdxTy = DL.getIndexType(PtrExpr->getType());
252239
const SCEV *EltSizeSCEV = SE->getStoreSizeOfExpr(IdxTy, AccessTy);
253240
ScEnd = SE->getAddExpr(ScEnd, EltSizeSCEV);
254241

255-
Iter->second = {ScStart, ScEnd};
256-
return Iter->second;
242+
std::pair<const SCEV *, const SCEV *> Res = {ScStart, ScEnd};
243+
if (PointerBounds)
244+
*PtrBoundsPair = Res;
245+
return Res;
257246
}
258247

259248
/// Calculate Start and End points of memory access using
@@ -263,8 +252,9 @@ void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, const SCEV *PtrExpr,
263252
unsigned DepSetId, unsigned ASId,
264253
PredicatedScalarEvolution &PSE,
265254
bool NeedsFreeze) {
255+
const SCEV *MaxBECount = PSE.getSymbolicMaxBackedgeTakenCount();
266256
const auto &[ScStart, ScEnd] = getStartAndEndForAccess(
267-
Lp, PtrExpr, AccessTy, PSE, DC.getPointerBounds());
257+
Lp, PtrExpr, AccessTy, MaxBECount, PSE.getSE(), &DC.getPointerBounds());
268258
assert(!isa<SCEVCouldNotCompute>(ScStart) &&
269259
!isa<SCEVCouldNotCompute>(ScEnd) &&
270260
"must be able to compute both start and end expressions");
@@ -1938,10 +1928,11 @@ MemoryDepChecker::getDependenceDistanceStrideAndSize(
19381928
// required for correctness.
19391929
if (SE.isLoopInvariant(Src, InnermostLoop) ||
19401930
SE.isLoopInvariant(Sink, InnermostLoop)) {
1941-
const auto &[SrcStart_, SrcEnd_] =
1942-
getStartAndEndForAccess(InnermostLoop, Src, ATy, PSE, PointerBounds);
1943-
const auto &[SinkStart_, SinkEnd_] =
1944-
getStartAndEndForAccess(InnermostLoop, Sink, BTy, PSE, PointerBounds);
1931+
const SCEV *MaxBECount = PSE.getSymbolicMaxBackedgeTakenCount();
1932+
const auto &[SrcStart_, SrcEnd_] = getStartAndEndForAccess(
1933+
InnermostLoop, Src, ATy, MaxBECount, PSE.getSE(), &PointerBounds);
1934+
const auto &[SinkStart_, SinkEnd_] = getStartAndEndForAccess(
1935+
InnermostLoop, Sink, BTy, MaxBECount, PSE.getSE(), &PointerBounds);
19451936
if (!isa<SCEVCouldNotCompute>(SrcStart_) &&
19461937
!isa<SCEVCouldNotCompute>(SrcEnd_) &&
19471938
!isa<SCEVCouldNotCompute>(SinkStart_) &&

0 commit comments

Comments
 (0)