@@ -505,6 +505,7 @@ class LoopTreeOptimization {
505
505
llvm::DenseMap<SILLoop *, std::unique_ptr<LoopNestSummary>>
506
506
LoopNestSummaryMap;
507
507
SmallVector<SILLoop *, 8 > BotUpWorkList;
508
+ InstSet toDelete;
508
509
SILLoopInfo *LoopInfo;
509
510
AliasAnalysis *AA;
510
511
SideEffectAnalysis *SEA;
@@ -525,6 +526,8 @@ class LoopTreeOptimization {
525
526
InstVector SinkDown;
526
527
527
528
// / Load and store instructions that we may be able to move out of the loop.
529
+ // / All loads and stores within a block must be in instruction order to
530
+ // / simplify replacement of values after SSA update.
528
531
InstVector LoadsAndStores;
529
532
530
533
// / All access paths of the \p LoadsAndStores instructions.
@@ -561,12 +564,22 @@ class LoopTreeOptimization {
561
564
// / Collect a set of instructions that can be hoisted
562
565
void analyzeCurrentLoop (std::unique_ptr<LoopNestSummary> &CurrSummary);
563
566
567
+ SingleValueInstruction *splitLoad (SILValue splitAddress,
568
+ ArrayRef<AccessPath::Index> remainingPath,
569
+ SILBuilder &builder,
570
+ SmallVectorImpl<LoadInst *> &Loads,
571
+ unsigned ldStIdx);
572
+
573
+ // / Given an \p accessPath that is only loaded and stored, split loads that
574
+ // / are wider than \p accessPath.
575
+ bool splitLoads (SmallVectorImpl<LoadInst *> &Loads, AccessPath accessPath,
576
+ SILValue storeAddr);
577
+
564
578
// / Optimize the current loop nest.
565
579
bool optimizeLoop (std::unique_ptr<LoopNestSummary> &CurrSummary);
566
580
567
- // / Move all loads and stores from/to \p access out of the \p loop.
568
- void hoistLoadsAndStores (AccessPath accessPath, SILLoop *loop,
569
- InstVector &toDelete);
581
+ // / Move all loads and stores from/to \p accessPath out of the \p loop.
582
+ void hoistLoadsAndStores (AccessPath accessPath, SILLoop *loop);
570
583
571
584
// / Move all loads and stores from all addresses in LoadAndStoreAddrs out of
572
585
// / the \p loop.
@@ -799,6 +812,8 @@ static bool analyzeBeginAccess(BeginAccessInst *BI,
799
812
// We *need* to discover all SideEffectInsts -
800
813
// even if the loop is otherwise skipped!
801
814
// This is because outer loops will depend on the inner loop's writes.
815
+ //
816
+ // This may split some loads into smaller loads.
802
817
void LoopTreeOptimization::analyzeCurrentLoop (
803
818
std::unique_ptr<LoopNestSummary> &CurrSummary) {
804
819
InstSet &sideEffects = CurrSummary->SideEffectInsts ;
@@ -915,15 +930,22 @@ void LoopTreeOptimization::analyzeCurrentLoop(
915
930
916
931
// Collect memory locations for which we can move all loads and stores out
917
932
// of the loop.
933
+ //
934
+ // Note: The Loads set and LoadsAndStores set may mutate during this loop.
918
935
for (StoreInst *SI : Stores) {
919
936
// Use AccessPathWithBase to recover a base address that can be used for
920
937
// newly inserted memory operations. If we instead teach hoistLoadsAndStores
921
938
// how to rematerialize global_addr, then we don't need this base.
922
939
auto access = AccessPathWithBase::compute (SI->getDest ());
923
- if (access.accessPath .isValid () && isLoopInvariant (access.base , Loop)) {
940
+ auto accessPath = access.accessPath ;
941
+ if (accessPath.isValid () && isLoopInvariant (access.base , Loop)) {
924
942
if (isOnlyLoadedAndStored (AA, sideEffects, Loads, Stores, SI->getDest (),
925
- access.accessPath )) {
926
- LoadAndStoreAddrs.insert (accessPath);
943
+ accessPath)) {
944
+ if (!LoadAndStoreAddrs.count (accessPath)) {
945
+ if (splitLoads (Loads, accessPath, SI->getDest ())) {
946
+ LoadAndStoreAddrs.insert (accessPath);
947
+ }
948
+ }
927
949
}
928
950
}
929
951
}
@@ -950,6 +972,172 @@ void LoopTreeOptimization::analyzeCurrentLoop(
950
972
}
951
973
}
952
974
975
+ // Recursively determine whether the innerAddress is a direct tuple or struct
976
+ // projection chain from outerPath. Populate \p reversePathIndices with the path
977
+ // difference.
978
+ static bool
979
+ computeInnerAccessPath (AccessPath::PathNode outerPath,
980
+ AccessPath::PathNode innerPath, SILValue innerAddress,
981
+ SmallVectorImpl<AccessPath::Index> &reversePathIndices) {
982
+ if (outerPath == innerPath)
983
+ return true ;
984
+
985
+ if (!isa<StructElementAddrInst>(innerAddress)
986
+ && !isa<TupleElementAddrInst>(innerAddress)) {
987
+ return false ;
988
+ }
989
+ assert (ProjectionIndex (innerAddress).Index
990
+ == innerPath.getIndex ().getSubObjectIndex ());
991
+
992
+ reversePathIndices.push_back (innerPath.getIndex ());
993
+ SILValue srcAddr = cast<SingleValueInstruction>(innerAddress)->getOperand (0 );
994
+ if (!computeInnerAccessPath (outerPath, innerPath.getParent (), srcAddr,
995
+ reversePathIndices)) {
996
+ return false ;
997
+ }
998
+ return true ;
999
+ }
1000
+
1001
+ // / Split a load from \p outerAddress recursively following remainingPath.
1002
+ // /
1003
+ // / Creates a load with identical \p accessPath and a set of
1004
+ // / non-overlapping loads. Add the new non-overlapping loads to HoistUp.
1005
+ // /
1006
+ // / \p ldstIdx is the index into LoadsAndStores of the original outer load.
1007
+ // /
1008
+ // / Return the aggregate produced by merging the loads.
1009
+ SingleValueInstruction *LoopTreeOptimization::splitLoad (
1010
+ SILValue splitAddress, ArrayRef<AccessPath::Index> remainingPath,
1011
+ SILBuilder &builder, SmallVectorImpl<LoadInst *> &Loads, unsigned ldstIdx) {
1012
+ auto loc = LoadsAndStores[ldstIdx]->getLoc ();
1013
+ // Recurse until we have a load that matches accessPath.
1014
+ if (remainingPath.empty ()) {
1015
+ // Create a load that matches the stored access path.
1016
+ LoadInst *load = builder.createLoad (loc, splitAddress,
1017
+ LoadOwnershipQualifier::Unqualified);
1018
+ Loads.push_back (load);
1019
+ // Replace the outer load in the list of loads and stores to hoist and
1020
+ // sink. LoadsAndStores must remain in instruction order.
1021
+ LoadsAndStores[ldstIdx] = load;
1022
+ LLVM_DEBUG (llvm::dbgs () << " Created load from stored path: " << *load);
1023
+ return load;
1024
+ }
1025
+ auto recordDisjointLoad = [&](LoadInst *newLoad) {
1026
+ Loads.push_back (newLoad);
1027
+ LoadsAndStores.insert (LoadsAndStores.begin () + ldstIdx + 1 , newLoad);
1028
+ };
1029
+ auto subIndex = remainingPath.back ().getSubObjectIndex ();
1030
+ SILType loadTy = splitAddress->getType ();
1031
+ if (CanTupleType tupleTy = loadTy.getAs <TupleType>()) {
1032
+ SmallVector<SILValue, 4 > elements;
1033
+ for (int tupleIdx : range (tupleTy->getNumElements ())) {
1034
+ auto *projection = builder.createTupleElementAddr (
1035
+ loc, splitAddress, tupleIdx, loadTy.getTupleElementType (tupleIdx));
1036
+ SILValue elementVal;
1037
+ if (tupleIdx == subIndex) {
1038
+ elementVal = splitLoad (projection, remainingPath.drop_back (), builder,
1039
+ Loads, ldstIdx);
1040
+ } else {
1041
+ elementVal = builder.createLoad (loc, projection,
1042
+ LoadOwnershipQualifier::Unqualified);
1043
+ recordDisjointLoad (cast<LoadInst>(elementVal));
1044
+ }
1045
+ elements.push_back (elementVal);
1046
+ }
1047
+ return builder.createTuple (loc, elements);
1048
+ }
1049
+ auto structTy = loadTy.getStructOrBoundGenericStruct ();
1050
+ assert (structTy && " tuple and struct elements are checked earlier" );
1051
+ auto &module = builder.getModule ();
1052
+ auto expansionContext = builder.getFunction ().getTypeExpansionContext ();
1053
+
1054
+ SmallVector<SILValue, 4 > elements;
1055
+ int fieldIdx = 0 ;
1056
+ for (auto *field : structTy->getStoredProperties ()) {
1057
+ SILType fieldTy = loadTy.getFieldType (field, module , expansionContext);
1058
+ auto *projection =
1059
+ builder.createStructElementAddr (loc, splitAddress, field, fieldTy);
1060
+ SILValue fieldVal;
1061
+ if (fieldIdx++ == subIndex)
1062
+ fieldVal = splitLoad (projection, remainingPath.drop_back (), builder,
1063
+ Loads, ldstIdx);
1064
+ else {
1065
+ fieldVal = builder.createLoad (loc, projection,
1066
+ LoadOwnershipQualifier::Unqualified);
1067
+ recordDisjointLoad (cast<LoadInst>(fieldVal));
1068
+ }
1069
+ elements.push_back (fieldVal);
1070
+ }
1071
+ return builder.createStruct (loc, loadTy.getObjectType (), elements);
1072
+ }
1073
+
1074
+ // / Find all loads that contain \p accessPath. Split them into a load with
1075
+ // / identical accessPath and a set of non-overlapping loads. Add the new
1076
+ // / non-overlapping loads to LoadsAndStores and HoistUp.
1077
+ // /
1078
+ // / TODO: The \p storeAddr parameter is only needed until we have an
1079
+ // / AliasAnalysis interface that handles AccessPath.
1080
+ bool LoopTreeOptimization::splitLoads (SmallVectorImpl<LoadInst *> &Loads,
1081
+ AccessPath accessPath,
1082
+ SILValue storeAddr) {
1083
+ // The Loads set may mutate during this loop, but we only want to visit the
1084
+ // original set.
1085
+ for (unsigned loadsIdx = 0 , endIdx = Loads.size (); loadsIdx != endIdx;
1086
+ ++loadsIdx) {
1087
+ auto *load = Loads[loadsIdx];
1088
+ if (toDelete.count (load))
1089
+ continue ;
1090
+
1091
+ if (!AA->mayReadFromMemory (load, storeAddr))
1092
+ continue ;
1093
+
1094
+ AccessPath loadAccessPath = AccessPath::compute (load->getOperand ());
1095
+ if (accessPath.contains (loadAccessPath))
1096
+ continue ;
1097
+
1098
+ assert (loadAccessPath.contains (accessPath));
1099
+ LLVM_DEBUG (llvm::dbgs () << " Overlaps with loop stores: " << *load);
1100
+ SmallVector<AccessPath::Index, 4 > reversePathIndices;
1101
+ if (!computeInnerAccessPath (loadAccessPath.getPathNode (),
1102
+ accessPath.getPathNode (), storeAddr,
1103
+ reversePathIndices)) {
1104
+ return false ;
1105
+ }
1106
+ // Found a load wider than the store to accessPath.
1107
+ //
1108
+ // SplitLoads is called for each unique access path in the loop that is
1109
+ // only loaded from and stored to and this loop takes time proportional to:
1110
+ // num-wide-loads x num-fields x num-loop-memops
1111
+ //
1112
+ // For each load wider than the store, it creates a new load for each field
1113
+ // in that type. Each new load is inserted in the LoadsAndStores vector. To
1114
+ // avoid super-linear behavior for large types (e.g. giant tuples), limit
1115
+ // growth of new loads to an arbitrary constant factor per access path.
1116
+ if (Loads.size () >= endIdx + 6 ) {
1117
+ LLVM_DEBUG (llvm::dbgs () << " ...Refusing to split more loads\n " );
1118
+ return false ;
1119
+ }
1120
+ LLVM_DEBUG (llvm::dbgs () << " ...Splitting load\n " );
1121
+
1122
+ unsigned ldstIdx = [this , load]() {
1123
+ auto ldstIter = llvm::find (LoadsAndStores, load);
1124
+ assert (ldstIter != LoadsAndStores.end () && " outerLoad missing" );
1125
+ return std::distance (LoadsAndStores.begin (), ldstIter);
1126
+ }();
1127
+
1128
+ SILBuilderWithScope builder (load);
1129
+
1130
+ SILValue aggregateVal = splitLoad (load->getOperand (), reversePathIndices,
1131
+ builder, Loads, ldstIdx);
1132
+
1133
+ load->replaceAllUsesWith (aggregateVal);
1134
+ auto iterAndInserted = toDelete.insert (load);
1135
+ (void )iterAndInserted;
1136
+ assert (iterAndInserted.second && " the same load should only be split once" );
1137
+ }
1138
+ return true ;
1139
+ }
1140
+
953
1141
bool LoopTreeOptimization::optimizeLoop (
954
1142
std::unique_ptr<LoopNestSummary> &CurrSummary) {
955
1143
auto *CurrentLoop = CurrSummary->Loop ;
@@ -964,6 +1152,8 @@ bool LoopTreeOptimization::optimizeLoop(
964
1152
currChanged |= sinkInstructions (CurrSummary, DomTree, LoopInfo, SinkDown);
965
1153
currChanged |=
966
1154
hoistSpecialInstruction (CurrSummary, DomTree, LoopInfo, SpecialHoist);
1155
+
1156
+ assert (toDelete.empty () && " only hostAllLoadsAndStores deletes" );
967
1157
return currChanged;
968
1158
}
969
1159
@@ -1089,8 +1279,8 @@ storesCommonlyDominateLoopExits(AccessPath accessPath,
1089
1279
return true ;
1090
1280
}
1091
1281
1092
- void LoopTreeOptimization::hoistLoadsAndStores (
1093
- AccessPath accessPath, SILLoop *loop, InstVector &toDelete ) {
1282
+ void LoopTreeOptimization::
1283
+ hoistLoadsAndStores ( AccessPath accessPath, SILLoop *loop) {
1094
1284
SmallVector<SILBasicBlock *, 4 > exitingAndLatchBlocks;
1095
1285
loop->getExitingAndLatchBlocks (exitingAndLatchBlocks);
1096
1286
@@ -1171,7 +1361,7 @@ void LoopTreeOptimization::hoistLoadsAndStores(
1171
1361
if (auto *SI = isStoreToAccess (I, accessPath)) {
1172
1362
LLVM_DEBUG (llvm::dbgs () << " Deleting reloaded store " << *SI);
1173
1363
currentVal = SI->getSrc ();
1174
- toDelete.push_back (SI);
1364
+ toDelete.insert (SI);
1175
1365
continue ;
1176
1366
}
1177
1367
auto loadWithAccess = isLoadWithinAccess (I, accessPath);
@@ -1190,7 +1380,7 @@ void LoopTreeOptimization::hoistLoadsAndStores(
1190
1380
LLVM_DEBUG (llvm::dbgs () << " Replacing stored load " << *load << " with "
1191
1381
<< projectedValue);
1192
1382
load->replaceAllUsesWith (projectedValue);
1193
- toDelete.push_back (load);
1383
+ toDelete.insert (load);
1194
1384
}
1195
1385
1196
1386
// Store back the value at all loop exits.
@@ -1215,17 +1405,20 @@ void LoopTreeOptimization::hoistLoadsAndStores(
1215
1405
}
1216
1406
1217
1407
bool LoopTreeOptimization::hoistAllLoadsAndStores (SILLoop *loop) {
1218
- InstVector toDelete;
1219
1408
for (AccessPath accessPath : LoadAndStoreAddrs) {
1220
- hoistLoadsAndStores (accessPath, loop, toDelete );
1409
+ hoistLoadsAndStores (accessPath, loop);
1221
1410
}
1222
1411
LoadsAndStores.clear ();
1223
1412
LoadAndStoreAddrs.clear ();
1224
1413
1414
+ if (toDelete.empty ())
1415
+ return false ;
1416
+
1225
1417
for (SILInstruction *I : toDelete) {
1226
- I-> eraseFromParent ( );
1418
+ recursivelyDeleteTriviallyDeadInstructions (I, /* force */ true );
1227
1419
}
1228
- return !toDelete.empty ();
1420
+ toDelete.clear ();
1421
+ return true ;
1229
1422
}
1230
1423
1231
1424
namespace {
0 commit comments