@@ -227,12 +227,10 @@ static bool checkOrderedReduction(RecurKind Kind, Instruction *ExactFPMathInst,
227
227
return true ;
228
228
}
229
229
230
- bool RecurrenceDescriptor::AddReductionVar (PHINode *Phi, RecurKind Kind,
231
- Loop *TheLoop, FastMathFlags FuncFMF,
232
- RecurrenceDescriptor &RedDes,
233
- DemandedBits *DB,
234
- AssumptionCache *AC,
235
- DominatorTree *DT) {
230
+ bool RecurrenceDescriptor::AddReductionVar (
231
+ PHINode *Phi, RecurKind Kind, Loop *TheLoop, FastMathFlags FuncFMF,
232
+ RecurrenceDescriptor &RedDes, DemandedBits *DB, AssumptionCache *AC,
233
+ DominatorTree *DT, ScalarEvolution *SE) {
236
234
if (Phi->getNumIncomingValues () != 2 )
237
235
return false ;
238
236
@@ -249,6 +247,12 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurKind Kind,
249
247
// This includes users of the reduction, variables (which form a cycle
250
248
// which ends in the phi node).
251
249
Instruction *ExitInstruction = nullptr ;
250
+
251
+ // Variable to keep last visited store instruction. By the end of the
252
+ // algorithm this variable will be either empty or having intermediate
253
+ // reduction value stored in invariant address.
254
+ StoreInst *IntermediateStore = nullptr ;
255
+
252
256
// Indicates that we found a reduction operation in our scan.
253
257
bool FoundReduxOp = false ;
254
258
@@ -314,13 +318,54 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurKind Kind,
314
318
// - By instructions outside of the loop (safe).
315
319
// * One value may have several outside users, but all outside
316
320
// uses must be of the same value.
321
+ // - By store instructions with a loop invariant address (safe with
322
+ // the following restrictions):
323
+ // * If there are several stores, all must have the same address.
324
+ // * Final value should be stored in that loop invariant address.
317
325
// - By an instruction that is not part of the reduction (not safe).
318
326
// This is either:
319
327
// * An instruction type other than PHI or the reduction operation.
320
328
// * A PHI in the header other than the initial PHI.
321
329
while (!Worklist.empty ()) {
322
330
Instruction *Cur = Worklist.pop_back_val ();
323
331
332
+ // Store instructions are allowed iff it is the store of the reduction
333
+ // value to the same loop invariant memory location.
334
+ if (auto *SI = dyn_cast<StoreInst>(Cur)) {
335
+ if (!SE) {
336
+ LLVM_DEBUG (dbgs () << " Store instructions are not processed without "
337
+ << " Scalar Evolution Analysis\n " );
338
+ return false ;
339
+ }
340
+
341
+ const SCEV *PtrScev = SE->getSCEV (SI->getPointerOperand ());
342
+ // Check it is the same address as previous stores
343
+ if (IntermediateStore) {
344
+ const SCEV *OtherScev =
345
+ SE->getSCEV (IntermediateStore->getPointerOperand ());
346
+
347
+ if (OtherScev != PtrScev) {
348
+ LLVM_DEBUG (dbgs () << " Storing reduction value to different addresses "
349
+ << " inside the loop: " << *SI->getPointerOperand ()
350
+ << " and "
351
+ << *IntermediateStore->getPointerOperand () << ' \n ' );
352
+ return false ;
353
+ }
354
+ }
355
+
356
+ // Check the pointer is loop invariant
357
+ if (!SE->isLoopInvariant (PtrScev, TheLoop)) {
358
+ LLVM_DEBUG (dbgs () << " Storing reduction value to non-uniform address "
359
+ << " inside the loop: " << *SI->getPointerOperand ()
360
+ << ' \n ' );
361
+ return false ;
362
+ }
363
+
364
+ // IntermediateStore is always the last store in the loop.
365
+ IntermediateStore = SI;
366
+ continue ;
367
+ }
368
+
324
369
// No Users.
325
370
// If the instruction has no users then this is a broken chain and can't be
326
371
// a reduction variable.
@@ -443,10 +488,17 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurKind Kind,
443
488
// reductions which are represented as a cmp followed by a select.
444
489
InstDesc IgnoredVal (false , nullptr );
445
490
if (VisitedInsts.insert (UI).second ) {
446
- if (isa<PHINode>(UI))
491
+ if (isa<PHINode>(UI)) {
447
492
PHIs.push_back (UI);
448
- else
493
+ } else {
494
+ StoreInst *SI = dyn_cast<StoreInst>(UI);
495
+ if (SI && SI->getPointerOperand () == Cur) {
496
+ // Reduction variable chain can only be stored somewhere but it
497
+ // can't be used as an address.
498
+ return false ;
499
+ }
449
500
NonPHIs.push_back (UI);
501
+ }
450
502
} else if (!isa<PHINode>(UI) &&
451
503
((!isa<FCmpInst>(UI) && !isa<ICmpInst>(UI) &&
452
504
!isa<SelectInst>(UI)) ||
@@ -474,6 +526,32 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurKind Kind,
474
526
if (isSelectCmpRecurrenceKind (Kind) && NumCmpSelectPatternInst != 1 )
475
527
return false ;
476
528
529
+ if (IntermediateStore) {
530
+ // Check that stored value goes to the phi node again. This way we make sure
531
+ // that the value stored in IntermediateStore is indeed the final reduction
532
+ // value.
533
+ if (!is_contained (Phi->operands (), IntermediateStore->getValueOperand ())) {
534
+ LLVM_DEBUG (dbgs () << " Not a final reduction value stored: "
535
+ << *IntermediateStore << ' \n ' );
536
+ return false ;
537
+ }
538
+
539
+ // If there is an exit instruction it's value should be stored in
540
+ // IntermediateStore
541
+ if (ExitInstruction &&
542
+ IntermediateStore->getValueOperand () != ExitInstruction) {
543
+ LLVM_DEBUG (dbgs () << " Last store Instruction of reduction value does not "
544
+ " store last calculated value of the reduction: "
545
+ << *IntermediateStore << ' \n ' );
546
+ return false ;
547
+ }
548
+
549
+ // If all uses are inside the loop (intermediate stores), then the
550
+ // reduction value after the loop will be the one used in the last store.
551
+ if (!ExitInstruction)
552
+ ExitInstruction = cast<Instruction>(IntermediateStore->getValueOperand ());
553
+ }
554
+
477
555
if (!FoundStartPHI || !FoundReduxOp || !ExitInstruction)
478
556
return false ;
479
557
@@ -535,9 +613,9 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurKind Kind,
535
613
// is saved as part of the RecurrenceDescriptor.
536
614
537
615
// Save the description of this reduction variable.
538
- RecurrenceDescriptor RD (RdxStart, ExitInstruction, Kind, FMF, ExactFPMathInst ,
539
- RecurrenceType, IsSigned, IsOrdered, CastInsts ,
540
- MinWidthCastToRecurrenceType);
616
+ RecurrenceDescriptor RD (RdxStart, ExitInstruction, IntermediateStore, Kind ,
617
+ FMF, ExactFPMathInst, RecurrenceType, IsSigned ,
618
+ IsOrdered, CastInsts, MinWidthCastToRecurrenceType);
541
619
RedDes = RD;
542
620
543
621
return true ;
@@ -761,7 +839,8 @@ bool RecurrenceDescriptor::hasMultipleUsesOf(
761
839
bool RecurrenceDescriptor::isReductionPHI (PHINode *Phi, Loop *TheLoop,
762
840
RecurrenceDescriptor &RedDes,
763
841
DemandedBits *DB, AssumptionCache *AC,
764
- DominatorTree *DT) {
842
+ DominatorTree *DT,
843
+ ScalarEvolution *SE) {
765
844
BasicBlock *Header = TheLoop->getHeader ();
766
845
Function &F = *Header->getParent ();
767
846
FastMathFlags FMF;
@@ -770,72 +849,85 @@ bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop,
770
849
FMF.setNoSignedZeros (
771
850
F.getFnAttribute (" no-signed-zeros-fp-math" ).getValueAsBool ());
772
851
773
- if (AddReductionVar (Phi, RecurKind::Add, TheLoop, FMF, RedDes, DB, AC, DT)) {
852
+ if (AddReductionVar (Phi, RecurKind::Add, TheLoop, FMF, RedDes, DB, AC, DT,
853
+ SE)) {
774
854
LLVM_DEBUG (dbgs () << " Found an ADD reduction PHI." << *Phi << " \n " );
775
855
return true ;
776
856
}
777
- if (AddReductionVar (Phi, RecurKind::Mul, TheLoop, FMF, RedDes, DB, AC, DT)) {
857
+ if (AddReductionVar (Phi, RecurKind::Mul, TheLoop, FMF, RedDes, DB, AC, DT,
858
+ SE)) {
778
859
LLVM_DEBUG (dbgs () << " Found a MUL reduction PHI." << *Phi << " \n " );
779
860
return true ;
780
861
}
781
- if (AddReductionVar (Phi, RecurKind::Or, TheLoop, FMF, RedDes, DB, AC, DT)) {
862
+ if (AddReductionVar (Phi, RecurKind::Or, TheLoop, FMF, RedDes, DB, AC, DT,
863
+ SE)) {
782
864
LLVM_DEBUG (dbgs () << " Found an OR reduction PHI." << *Phi << " \n " );
783
865
return true ;
784
866
}
785
- if (AddReductionVar (Phi, RecurKind::And, TheLoop, FMF, RedDes, DB, AC, DT)) {
867
+ if (AddReductionVar (Phi, RecurKind::And, TheLoop, FMF, RedDes, DB, AC, DT,
868
+ SE)) {
786
869
LLVM_DEBUG (dbgs () << " Found an AND reduction PHI." << *Phi << " \n " );
787
870
return true ;
788
871
}
789
- if (AddReductionVar (Phi, RecurKind::Xor, TheLoop, FMF, RedDes, DB, AC, DT)) {
872
+ if (AddReductionVar (Phi, RecurKind::Xor, TheLoop, FMF, RedDes, DB, AC, DT,
873
+ SE)) {
790
874
LLVM_DEBUG (dbgs () << " Found a XOR reduction PHI." << *Phi << " \n " );
791
875
return true ;
792
876
}
793
- if (AddReductionVar (Phi, RecurKind::SMax, TheLoop, FMF, RedDes, DB, AC, DT)) {
877
+ if (AddReductionVar (Phi, RecurKind::SMax, TheLoop, FMF, RedDes, DB, AC, DT,
878
+ SE)) {
794
879
LLVM_DEBUG (dbgs () << " Found a SMAX reduction PHI." << *Phi << " \n " );
795
880
return true ;
796
881
}
797
- if (AddReductionVar (Phi, RecurKind::SMin, TheLoop, FMF, RedDes, DB, AC, DT)) {
882
+ if (AddReductionVar (Phi, RecurKind::SMin, TheLoop, FMF, RedDes, DB, AC, DT,
883
+ SE)) {
798
884
LLVM_DEBUG (dbgs () << " Found a SMIN reduction PHI." << *Phi << " \n " );
799
885
return true ;
800
886
}
801
- if (AddReductionVar (Phi, RecurKind::UMax, TheLoop, FMF, RedDes, DB, AC, DT)) {
887
+ if (AddReductionVar (Phi, RecurKind::UMax, TheLoop, FMF, RedDes, DB, AC, DT,
888
+ SE)) {
802
889
LLVM_DEBUG (dbgs () << " Found a UMAX reduction PHI." << *Phi << " \n " );
803
890
return true ;
804
891
}
805
- if (AddReductionVar (Phi, RecurKind::UMin, TheLoop, FMF, RedDes, DB, AC, DT)) {
892
+ if (AddReductionVar (Phi, RecurKind::UMin, TheLoop, FMF, RedDes, DB, AC, DT,
893
+ SE)) {
806
894
LLVM_DEBUG (dbgs () << " Found a UMIN reduction PHI." << *Phi << " \n " );
807
895
return true ;
808
896
}
809
897
if (AddReductionVar (Phi, RecurKind::SelectICmp, TheLoop, FMF, RedDes, DB, AC,
810
- DT)) {
898
+ DT, SE )) {
811
899
LLVM_DEBUG (dbgs () << " Found an integer conditional select reduction PHI."
812
900
<< *Phi << " \n " );
813
901
return true ;
814
902
}
815
- if (AddReductionVar (Phi, RecurKind::FMul, TheLoop, FMF, RedDes, DB, AC, DT)) {
903
+ if (AddReductionVar (Phi, RecurKind::FMul, TheLoop, FMF, RedDes, DB, AC, DT,
904
+ SE)) {
816
905
LLVM_DEBUG (dbgs () << " Found an FMult reduction PHI." << *Phi << " \n " );
817
906
return true ;
818
907
}
819
- if (AddReductionVar (Phi, RecurKind::FAdd, TheLoop, FMF, RedDes, DB, AC, DT)) {
908
+ if (AddReductionVar (Phi, RecurKind::FAdd, TheLoop, FMF, RedDes, DB, AC, DT,
909
+ SE)) {
820
910
LLVM_DEBUG (dbgs () << " Found an FAdd reduction PHI." << *Phi << " \n " );
821
911
return true ;
822
912
}
823
- if (AddReductionVar (Phi, RecurKind::FMax, TheLoop, FMF, RedDes, DB, AC, DT)) {
913
+ if (AddReductionVar (Phi, RecurKind::FMax, TheLoop, FMF, RedDes, DB, AC, DT,
914
+ SE)) {
824
915
LLVM_DEBUG (dbgs () << " Found a float MAX reduction PHI." << *Phi << " \n " );
825
916
return true ;
826
917
}
827
- if (AddReductionVar (Phi, RecurKind::FMin, TheLoop, FMF, RedDes, DB, AC, DT)) {
918
+ if (AddReductionVar (Phi, RecurKind::FMin, TheLoop, FMF, RedDes, DB, AC, DT,
919
+ SE)) {
828
920
LLVM_DEBUG (dbgs () << " Found a float MIN reduction PHI." << *Phi << " \n " );
829
921
return true ;
830
922
}
831
923
if (AddReductionVar (Phi, RecurKind::SelectFCmp, TheLoop, FMF, RedDes, DB, AC,
832
- DT)) {
924
+ DT, SE )) {
833
925
LLVM_DEBUG (dbgs () << " Found a float conditional select reduction PHI."
834
926
<< " PHI." << *Phi << " \n " );
835
927
return true ;
836
928
}
837
- if (AddReductionVar (Phi, RecurKind::FMulAdd, TheLoop, FMF, RedDes, DB, AC,
838
- DT )) {
929
+ if (AddReductionVar (Phi, RecurKind::FMulAdd, TheLoop, FMF, RedDes, DB, AC, DT,
930
+ SE )) {
839
931
LLVM_DEBUG (dbgs () << " Found an FMulAdd reduction PHI." << *Phi << " \n " );
840
932
return true ;
841
933
}
0 commit comments