Skip to content

Commit 4e5e042

Browse files
committed
[LoopVectorize] Support reductions that store intermediary result
Adds ability to vectorize loops containing a store to a loop-invariant address as part of a reduction that isn't converted to SSA form due to lack of aliasing info. Runtime checks are generated to ensure the store does not alias any other accesses in the loop. Ordered fadd reductions are not yet supported. Differential Revision: https://reviews.llvm.org/D110235
1 parent c819dce commit 4e5e042

File tree

11 files changed

+509
-65
lines changed

11 files changed

+509
-65
lines changed

llvm/include/llvm/Analysis/IVDescriptors.h

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class Loop;
2929
class PredicatedScalarEvolution;
3030
class ScalarEvolution;
3131
class SCEV;
32+
class StoreInst;
3233

3334
/// These are the kinds of recurrences that we support.
3435
enum class RecurKind {
@@ -69,14 +70,14 @@ class RecurrenceDescriptor {
6970
public:
7071
RecurrenceDescriptor() = default;
7172

72-
RecurrenceDescriptor(Value *Start, Instruction *Exit, RecurKind K,
73-
FastMathFlags FMF, Instruction *ExactFP, Type *RT,
74-
bool Signed, bool Ordered,
73+
RecurrenceDescriptor(Value *Start, Instruction *Exit, StoreInst *Store,
74+
RecurKind K, FastMathFlags FMF, Instruction *ExactFP,
75+
Type *RT, bool Signed, bool Ordered,
7576
SmallPtrSetImpl<Instruction *> &CI,
7677
unsigned MinWidthCastToRecurTy)
77-
: StartValue(Start), LoopExitInstr(Exit), Kind(K), FMF(FMF),
78-
ExactFPMathInst(ExactFP), RecurrenceType(RT), IsSigned(Signed),
79-
IsOrdered(Ordered),
78+
: IntermediateStore(Store), StartValue(Start), LoopExitInstr(Exit),
79+
Kind(K), FMF(FMF), ExactFPMathInst(ExactFP), RecurrenceType(RT),
80+
IsSigned(Signed), IsOrdered(Ordered),
8081
MinWidthCastToRecurrenceType(MinWidthCastToRecurTy) {
8182
CastInsts.insert(CI.begin(), CI.end());
8283
}
@@ -163,22 +164,21 @@ class RecurrenceDescriptor {
163164
/// RecurrenceDescriptor. If either \p DB is non-null or \p AC and \p DT are
164165
/// non-null, the minimal bit width needed to compute the reduction will be
165166
/// computed.
166-
static bool AddReductionVar(PHINode *Phi, RecurKind Kind, Loop *TheLoop,
167-
FastMathFlags FuncFMF,
168-
RecurrenceDescriptor &RedDes,
169-
DemandedBits *DB = nullptr,
170-
AssumptionCache *AC = nullptr,
171-
DominatorTree *DT = nullptr);
167+
static bool
168+
AddReductionVar(PHINode *Phi, RecurKind Kind, Loop *TheLoop,
169+
FastMathFlags FuncFMF, RecurrenceDescriptor &RedDes,
170+
DemandedBits *DB = nullptr, AssumptionCache *AC = nullptr,
171+
DominatorTree *DT = nullptr, ScalarEvolution *SE = nullptr);
172172

173173
/// Returns true if Phi is a reduction in TheLoop. The RecurrenceDescriptor
174174
/// is returned in RedDes. If either \p DB is non-null or \p AC and \p DT are
175175
/// non-null, the minimal bit width needed to compute the reduction will be
176-
/// computed.
177-
static bool isReductionPHI(PHINode *Phi, Loop *TheLoop,
178-
RecurrenceDescriptor &RedDes,
179-
DemandedBits *DB = nullptr,
180-
AssumptionCache *AC = nullptr,
181-
DominatorTree *DT = nullptr);
176+
/// computed. If \p SE is non-null, store instructions to loop invariant
177+
/// addresses are processed.
178+
static bool
179+
isReductionPHI(PHINode *Phi, Loop *TheLoop, RecurrenceDescriptor &RedDes,
180+
DemandedBits *DB = nullptr, AssumptionCache *AC = nullptr,
181+
DominatorTree *DT = nullptr, ScalarEvolution *SE = nullptr);
182182

183183
/// Returns true if Phi is a first-order recurrence. A first-order recurrence
184184
/// is a non-reduction recurrence relation in which the value of the
@@ -270,6 +270,11 @@ class RecurrenceDescriptor {
270270
cast<IntrinsicInst>(I)->getIntrinsicID() == Intrinsic::fmuladd;
271271
}
272272

273+
/// Reductions may store temporary or final result to an invariant address.
274+
/// If there is such a store in the loop then, after successfull run of
275+
/// AddReductionVar method, this field will be assigned the last met store.
276+
StoreInst *IntermediateStore = nullptr;
277+
273278
private:
274279
// The starting value of the recurrence.
275280
// It does not have to be zero!

llvm/include/llvm/Analysis/LoopAccessAnalysis.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,11 @@ class LoopAccessInfo {
575575
return HasDependenceInvolvingLoopInvariantAddress;
576576
}
577577

578+
/// Return the list of stores to invariant addresses.
579+
const ArrayRef<StoreInst *> getStoresToInvariantAddresses() const {
580+
return StoresToInvariantAddresses;
581+
}
582+
578583
/// Used to add runtime SCEV checks. Simplifies SCEV expressions and converts
579584
/// them to a more usable form. All SCEV expressions during the analysis
580585
/// should be re-written (and therefore simplified) according to PSE.
@@ -634,6 +639,9 @@ class LoopAccessInfo {
634639
/// Indicator that there are non vectorizable stores to a uniform address.
635640
bool HasDependenceInvolvingLoopInvariantAddress = false;
636641

642+
/// List of stores to invariant addresses.
643+
SmallVector<StoreInst *> StoresToInvariantAddresses;
644+
637645
/// The diagnostics report generated for the analysis. E.g. why we
638646
/// couldn't analyze the loop.
639647
std::unique_ptr<OptimizationRemarkAnalysis> Report;

llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,14 @@ class LoopVectorizationLegality {
308308
/// Returns the widest induction type.
309309
Type *getWidestInductionType() { return WidestIndTy; }
310310

311+
/// Returns True if given store is a final invariant store of one of the
312+
/// reductions found in the loop.
313+
bool isInvariantStoreOfReduction(StoreInst *SI);
314+
315+
/// Returns True if given address is invariant and is used to store recurrent
316+
/// expression
317+
bool isInvariantAddressOfReduction(Value *V);
318+
311319
/// Returns True if V is a Phi node of an induction variable in this loop.
312320
bool isInductionPhi(const Value *V) const;
313321

llvm/lib/Analysis/IVDescriptors.cpp

Lines changed: 121 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -227,12 +227,10 @@ static bool checkOrderedReduction(RecurKind Kind, Instruction *ExactFPMathInst,
227227
return true;
228228
}
229229

230-
bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurKind Kind,
231-
Loop *TheLoop, FastMathFlags FuncFMF,
232-
RecurrenceDescriptor &RedDes,
233-
DemandedBits *DB,
234-
AssumptionCache *AC,
235-
DominatorTree *DT) {
230+
bool RecurrenceDescriptor::AddReductionVar(
231+
PHINode *Phi, RecurKind Kind, Loop *TheLoop, FastMathFlags FuncFMF,
232+
RecurrenceDescriptor &RedDes, DemandedBits *DB, AssumptionCache *AC,
233+
DominatorTree *DT, ScalarEvolution *SE) {
236234
if (Phi->getNumIncomingValues() != 2)
237235
return false;
238236

@@ -249,6 +247,12 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurKind Kind,
249247
// This includes users of the reduction, variables (which form a cycle
250248
// which ends in the phi node).
251249
Instruction *ExitInstruction = nullptr;
250+
251+
// Variable to keep last visited store instruction. By the end of the
252+
// algorithm this variable will be either empty or having intermediate
253+
// reduction value stored in invariant address.
254+
StoreInst *IntermediateStore = nullptr;
255+
252256
// Indicates that we found a reduction operation in our scan.
253257
bool FoundReduxOp = false;
254258

@@ -314,13 +318,54 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurKind Kind,
314318
// - By instructions outside of the loop (safe).
315319
// * One value may have several outside users, but all outside
316320
// uses must be of the same value.
321+
// - By store instructions with a loop invariant address (safe with
322+
// the following restrictions):
323+
// * If there are several stores, all must have the same address.
324+
// * Final value should be stored in that loop invariant address.
317325
// - By an instruction that is not part of the reduction (not safe).
318326
// This is either:
319327
// * An instruction type other than PHI or the reduction operation.
320328
// * A PHI in the header other than the initial PHI.
321329
while (!Worklist.empty()) {
322330
Instruction *Cur = Worklist.pop_back_val();
323331

332+
// Store instructions are allowed iff it is the store of the reduction
333+
// value to the same loop invariant memory location.
334+
if (auto *SI = dyn_cast<StoreInst>(Cur)) {
335+
if (!SE) {
336+
LLVM_DEBUG(dbgs() << "Store instructions are not processed without "
337+
<< "Scalar Evolution Analysis\n");
338+
return false;
339+
}
340+
341+
const SCEV *PtrScev = SE->getSCEV(SI->getPointerOperand());
342+
// Check it is the same address as previous stores
343+
if (IntermediateStore) {
344+
const SCEV *OtherScev =
345+
SE->getSCEV(IntermediateStore->getPointerOperand());
346+
347+
if (OtherScev != PtrScev) {
348+
LLVM_DEBUG(dbgs() << "Storing reduction value to different addresses "
349+
<< "inside the loop: " << *SI->getPointerOperand()
350+
<< " and "
351+
<< *IntermediateStore->getPointerOperand() << '\n');
352+
return false;
353+
}
354+
}
355+
356+
// Check the pointer is loop invariant
357+
if (!SE->isLoopInvariant(PtrScev, TheLoop)) {
358+
LLVM_DEBUG(dbgs() << "Storing reduction value to non-uniform address "
359+
<< "inside the loop: " << *SI->getPointerOperand()
360+
<< '\n');
361+
return false;
362+
}
363+
364+
// IntermediateStore is always the last store in the loop.
365+
IntermediateStore = SI;
366+
continue;
367+
}
368+
324369
// No Users.
325370
// If the instruction has no users then this is a broken chain and can't be
326371
// a reduction variable.
@@ -443,10 +488,17 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurKind Kind,
443488
// reductions which are represented as a cmp followed by a select.
444489
InstDesc IgnoredVal(false, nullptr);
445490
if (VisitedInsts.insert(UI).second) {
446-
if (isa<PHINode>(UI))
491+
if (isa<PHINode>(UI)) {
447492
PHIs.push_back(UI);
448-
else
493+
} else {
494+
StoreInst *SI = dyn_cast<StoreInst>(UI);
495+
if (SI && SI->getPointerOperand() == Cur) {
496+
// Reduction variable chain can only be stored somewhere but it
497+
// can't be used as an address.
498+
return false;
499+
}
449500
NonPHIs.push_back(UI);
501+
}
450502
} else if (!isa<PHINode>(UI) &&
451503
((!isa<FCmpInst>(UI) && !isa<ICmpInst>(UI) &&
452504
!isa<SelectInst>(UI)) ||
@@ -474,6 +526,32 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurKind Kind,
474526
if (isSelectCmpRecurrenceKind(Kind) && NumCmpSelectPatternInst != 1)
475527
return false;
476528

529+
if (IntermediateStore) {
530+
// Check that stored value goes to the phi node again. This way we make sure
531+
// that the value stored in IntermediateStore is indeed the final reduction
532+
// value.
533+
if (!is_contained(Phi->operands(), IntermediateStore->getValueOperand())) {
534+
LLVM_DEBUG(dbgs() << "Not a final reduction value stored: "
535+
<< *IntermediateStore << '\n');
536+
return false;
537+
}
538+
539+
// If there is an exit instruction it's value should be stored in
540+
// IntermediateStore
541+
if (ExitInstruction &&
542+
IntermediateStore->getValueOperand() != ExitInstruction) {
543+
LLVM_DEBUG(dbgs() << "Last store Instruction of reduction value does not "
544+
"store last calculated value of the reduction: "
545+
<< *IntermediateStore << '\n');
546+
return false;
547+
}
548+
549+
// If all uses are inside the loop (intermediate stores), then the
550+
// reduction value after the loop will be the one used in the last store.
551+
if (!ExitInstruction)
552+
ExitInstruction = cast<Instruction>(IntermediateStore->getValueOperand());
553+
}
554+
477555
if (!FoundStartPHI || !FoundReduxOp || !ExitInstruction)
478556
return false;
479557

@@ -535,9 +613,9 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurKind Kind,
535613
// is saved as part of the RecurrenceDescriptor.
536614

537615
// Save the description of this reduction variable.
538-
RecurrenceDescriptor RD(RdxStart, ExitInstruction, Kind, FMF, ExactFPMathInst,
539-
RecurrenceType, IsSigned, IsOrdered, CastInsts,
540-
MinWidthCastToRecurrenceType);
616+
RecurrenceDescriptor RD(RdxStart, ExitInstruction, IntermediateStore, Kind,
617+
FMF, ExactFPMathInst, RecurrenceType, IsSigned,
618+
IsOrdered, CastInsts, MinWidthCastToRecurrenceType);
541619
RedDes = RD;
542620

543621
return true;
@@ -761,7 +839,8 @@ bool RecurrenceDescriptor::hasMultipleUsesOf(
761839
bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop,
762840
RecurrenceDescriptor &RedDes,
763841
DemandedBits *DB, AssumptionCache *AC,
764-
DominatorTree *DT) {
842+
DominatorTree *DT,
843+
ScalarEvolution *SE) {
765844
BasicBlock *Header = TheLoop->getHeader();
766845
Function &F = *Header->getParent();
767846
FastMathFlags FMF;
@@ -770,72 +849,85 @@ bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop,
770849
FMF.setNoSignedZeros(
771850
F.getFnAttribute("no-signed-zeros-fp-math").getValueAsBool());
772851

773-
if (AddReductionVar(Phi, RecurKind::Add, TheLoop, FMF, RedDes, DB, AC, DT)) {
852+
if (AddReductionVar(Phi, RecurKind::Add, TheLoop, FMF, RedDes, DB, AC, DT,
853+
SE)) {
774854
LLVM_DEBUG(dbgs() << "Found an ADD reduction PHI." << *Phi << "\n");
775855
return true;
776856
}
777-
if (AddReductionVar(Phi, RecurKind::Mul, TheLoop, FMF, RedDes, DB, AC, DT)) {
857+
if (AddReductionVar(Phi, RecurKind::Mul, TheLoop, FMF, RedDes, DB, AC, DT,
858+
SE)) {
778859
LLVM_DEBUG(dbgs() << "Found a MUL reduction PHI." << *Phi << "\n");
779860
return true;
780861
}
781-
if (AddReductionVar(Phi, RecurKind::Or, TheLoop, FMF, RedDes, DB, AC, DT)) {
862+
if (AddReductionVar(Phi, RecurKind::Or, TheLoop, FMF, RedDes, DB, AC, DT,
863+
SE)) {
782864
LLVM_DEBUG(dbgs() << "Found an OR reduction PHI." << *Phi << "\n");
783865
return true;
784866
}
785-
if (AddReductionVar(Phi, RecurKind::And, TheLoop, FMF, RedDes, DB, AC, DT)) {
867+
if (AddReductionVar(Phi, RecurKind::And, TheLoop, FMF, RedDes, DB, AC, DT,
868+
SE)) {
786869
LLVM_DEBUG(dbgs() << "Found an AND reduction PHI." << *Phi << "\n");
787870
return true;
788871
}
789-
if (AddReductionVar(Phi, RecurKind::Xor, TheLoop, FMF, RedDes, DB, AC, DT)) {
872+
if (AddReductionVar(Phi, RecurKind::Xor, TheLoop, FMF, RedDes, DB, AC, DT,
873+
SE)) {
790874
LLVM_DEBUG(dbgs() << "Found a XOR reduction PHI." << *Phi << "\n");
791875
return true;
792876
}
793-
if (AddReductionVar(Phi, RecurKind::SMax, TheLoop, FMF, RedDes, DB, AC, DT)) {
877+
if (AddReductionVar(Phi, RecurKind::SMax, TheLoop, FMF, RedDes, DB, AC, DT,
878+
SE)) {
794879
LLVM_DEBUG(dbgs() << "Found a SMAX reduction PHI." << *Phi << "\n");
795880
return true;
796881
}
797-
if (AddReductionVar(Phi, RecurKind::SMin, TheLoop, FMF, RedDes, DB, AC, DT)) {
882+
if (AddReductionVar(Phi, RecurKind::SMin, TheLoop, FMF, RedDes, DB, AC, DT,
883+
SE)) {
798884
LLVM_DEBUG(dbgs() << "Found a SMIN reduction PHI." << *Phi << "\n");
799885
return true;
800886
}
801-
if (AddReductionVar(Phi, RecurKind::UMax, TheLoop, FMF, RedDes, DB, AC, DT)) {
887+
if (AddReductionVar(Phi, RecurKind::UMax, TheLoop, FMF, RedDes, DB, AC, DT,
888+
SE)) {
802889
LLVM_DEBUG(dbgs() << "Found a UMAX reduction PHI." << *Phi << "\n");
803890
return true;
804891
}
805-
if (AddReductionVar(Phi, RecurKind::UMin, TheLoop, FMF, RedDes, DB, AC, DT)) {
892+
if (AddReductionVar(Phi, RecurKind::UMin, TheLoop, FMF, RedDes, DB, AC, DT,
893+
SE)) {
806894
LLVM_DEBUG(dbgs() << "Found a UMIN reduction PHI." << *Phi << "\n");
807895
return true;
808896
}
809897
if (AddReductionVar(Phi, RecurKind::SelectICmp, TheLoop, FMF, RedDes, DB, AC,
810-
DT)) {
898+
DT, SE)) {
811899
LLVM_DEBUG(dbgs() << "Found an integer conditional select reduction PHI."
812900
<< *Phi << "\n");
813901
return true;
814902
}
815-
if (AddReductionVar(Phi, RecurKind::FMul, TheLoop, FMF, RedDes, DB, AC, DT)) {
903+
if (AddReductionVar(Phi, RecurKind::FMul, TheLoop, FMF, RedDes, DB, AC, DT,
904+
SE)) {
816905
LLVM_DEBUG(dbgs() << "Found an FMult reduction PHI." << *Phi << "\n");
817906
return true;
818907
}
819-
if (AddReductionVar(Phi, RecurKind::FAdd, TheLoop, FMF, RedDes, DB, AC, DT)) {
908+
if (AddReductionVar(Phi, RecurKind::FAdd, TheLoop, FMF, RedDes, DB, AC, DT,
909+
SE)) {
820910
LLVM_DEBUG(dbgs() << "Found an FAdd reduction PHI." << *Phi << "\n");
821911
return true;
822912
}
823-
if (AddReductionVar(Phi, RecurKind::FMax, TheLoop, FMF, RedDes, DB, AC, DT)) {
913+
if (AddReductionVar(Phi, RecurKind::FMax, TheLoop, FMF, RedDes, DB, AC, DT,
914+
SE)) {
824915
LLVM_DEBUG(dbgs() << "Found a float MAX reduction PHI." << *Phi << "\n");
825916
return true;
826917
}
827-
if (AddReductionVar(Phi, RecurKind::FMin, TheLoop, FMF, RedDes, DB, AC, DT)) {
918+
if (AddReductionVar(Phi, RecurKind::FMin, TheLoop, FMF, RedDes, DB, AC, DT,
919+
SE)) {
828920
LLVM_DEBUG(dbgs() << "Found a float MIN reduction PHI." << *Phi << "\n");
829921
return true;
830922
}
831923
if (AddReductionVar(Phi, RecurKind::SelectFCmp, TheLoop, FMF, RedDes, DB, AC,
832-
DT)) {
924+
DT, SE)) {
833925
LLVM_DEBUG(dbgs() << "Found a float conditional select reduction PHI."
834926
<< " PHI." << *Phi << "\n");
835927
return true;
836928
}
837-
if (AddReductionVar(Phi, RecurKind::FMulAdd, TheLoop, FMF, RedDes, DB, AC,
838-
DT)) {
929+
if (AddReductionVar(Phi, RecurKind::FMulAdd, TheLoop, FMF, RedDes, DB, AC, DT,
930+
SE)) {
839931
LLVM_DEBUG(dbgs() << "Found an FMulAdd reduction PHI." << *Phi << "\n");
840932
return true;
841933
}

llvm/lib/Analysis/LoopAccessAnalysis.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1993,9 +1993,12 @@ void LoopAccessInfo::analyzeLoop(AAResults *AA, LoopInfo *LI,
19931993
for (StoreInst *ST : Stores) {
19941994
Value *Ptr = ST->getPointerOperand();
19951995

1996-
if (isUniform(Ptr))
1996+
if (isUniform(Ptr)) {
1997+
// Record store instructions to loop invariant addresses
1998+
StoresToInvariantAddresses.push_back(ST);
19971999
HasDependenceInvolvingLoopInvariantAddress |=
19982000
!UniformStores.insert(Ptr).second;
2001+
}
19992002

20002003
// If we did *not* see this pointer before, insert it to the read-write
20012004
// list. At this phase it is only a 'write' list.

0 commit comments

Comments
 (0)