Skip to content

Commit 70501ed

Browse files
authored
[LoopVectorizer] Prune VFs based on plan register pressure (#132190)
This PR moves the register usage checking to after the plans are created, so that any recipes that optimise register usage (such as partial reductions) can be properly costed and not have their VF pruned unnecessarily. Depends on #137746
1 parent e8a3074 commit 70501ed

38 files changed

+416
-826
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 46 additions & 240 deletions
Original file line numberDiff line numberDiff line change
@@ -996,12 +996,15 @@ class LoopVectorizationCostModel {
996996
/// Holds the maximum number of concurrent live intervals in the loop.
997997
/// The key is ClassID of target-provided register class.
998998
SmallMapVector<unsigned, unsigned, 4> MaxLocalUsers;
999-
};
1000999

1001-
/// \return Returns information about the register usages of the loop for the
1002-
/// given vectorization factors.
1003-
SmallVector<RegisterUsage, 8>
1004-
calculateRegisterUsage(ArrayRef<ElementCount> VFs);
1000+
/// Check if any of the tracked live intervals exceeds the number of
1001+
/// available registers for the target.
1002+
bool exceedsMaxNumRegs(const TargetTransformInfo &TTI) const {
1003+
return any_of(MaxLocalUsers, [&TTI](auto &LU) {
1004+
return LU.second > TTI.getNumberOfRegisters(LU.first);
1005+
});
1006+
}
1007+
};
10051008

10061009
/// Collect values we want to ignore in the cost model.
10071010
void collectValuesToIgnore();
@@ -4013,29 +4016,8 @@ ElementCount LoopVectorizationCostModel::getMaximizedVFForTarget(
40134016
auto MaxVectorElementCountMaxBW = ElementCount::get(
40144017
llvm::bit_floor(WidestRegister.getKnownMinValue() / SmallestType),
40154018
ComputeScalableMaxVF);
4016-
MaxVectorElementCountMaxBW = MinVF(MaxVectorElementCountMaxBW, MaxSafeVF);
4017-
4018-
// Collect all viable vectorization factors larger than the default MaxVF
4019-
// (i.e. MaxVectorElementCount).
4020-
SmallVector<ElementCount, 8> VFs;
4021-
for (ElementCount VS = MaxVectorElementCount * 2;
4022-
ElementCount::isKnownLE(VS, MaxVectorElementCountMaxBW); VS *= 2)
4023-
VFs.push_back(VS);
4024-
4025-
// For each VF calculate its register usage.
4026-
auto RUs = calculateRegisterUsage(VFs);
4027-
4028-
// Select the largest VF which doesn't require more registers than existing
4029-
// ones.
4030-
for (int I = RUs.size() - 1; I >= 0; --I) {
4031-
const auto &MLU = RUs[I].MaxLocalUsers;
4032-
if (all_of(MLU, [&](decltype(MLU.front()) &LU) {
4033-
return LU.second <= TTI.getNumberOfRegisters(LU.first);
4034-
})) {
4035-
MaxVF = VFs[I];
4036-
break;
4037-
}
4038-
}
4019+
MaxVF = MinVF(MaxVectorElementCountMaxBW, MaxSafeVF);
4020+
40394021
if (ElementCount MinVF =
40404022
TTI.getMinimumVF(SmallestType, ComputeScalableMaxVF)) {
40414023
if (ElementCount::isKnownLT(MaxVF, MinVF)) {
@@ -4360,6 +4342,15 @@ static bool hasReplicatorRegion(VPlan &Plan) {
43604342
}
43614343

43624344
#ifndef NDEBUG
4345+
/// Estimate the register usage for \p Plan and vectorization factors in \p VFs
4346+
/// by calculating the highest number of values that are live at a single
4347+
/// location as a rough estimate. Returns the register usage for each VF in \p
4348+
/// VFs.
4349+
static SmallVector<LoopVectorizationCostModel::RegisterUsage, 8>
4350+
calculateRegisterUsage(VPlan &Plan, ArrayRef<ElementCount> VFs,
4351+
const TargetTransformInfo &TTI,
4352+
const SmallPtrSetImpl<const Value *> &ValuesToIgnore);
4353+
43634354
VectorizationFactor LoopVectorizationPlanner::selectVectorizationFactor() {
43644355
InstructionCost ExpectedCost = CM.expectedCost(ElementCount::getFixed(1));
43654356
LLVM_DEBUG(dbgs() << "LV: Scalar loop costs: " << ExpectedCost << ".\n");
@@ -4383,11 +4374,19 @@ VectorizationFactor LoopVectorizationPlanner::selectVectorizationFactor() {
43834374
}
43844375

43854376
for (auto &P : VPlans) {
4386-
for (ElementCount VF : P->vectorFactors()) {
4377+
ArrayRef<ElementCount> VFs(P->vectorFactors().begin(),
4378+
P->vectorFactors().end());
4379+
auto RUs = ::calculateRegisterUsage(*P, VFs, TTI, CM.ValuesToIgnore);
4380+
for (auto [VF, RU] : zip_equal(VFs, RUs)) {
43874381
// The cost for scalar VF=1 is already calculated, so ignore it.
43884382
if (VF.isScalar())
43894383
continue;
43904384

4385+
/// Don't consider the VF if it exceeds the number of registers for the
4386+
/// target.
4387+
if (RU.exceedsMaxNumRegs(TTI))
4388+
continue;
4389+
43914390
InstructionCost C = CM.expectedCost(VF);
43924391

43934392
// Add on other costs that are modelled in VPlan, but not in the legacy
@@ -4859,9 +4858,13 @@ calculateRegisterUsage(VPlan &Plan, ArrayRef<ElementCount> VFs,
48594858
isa<VPCanonicalIVPHIRecipe, VPReplicateRecipe, VPDerivedIVRecipe,
48604859
VPScalarIVStepsRecipe>(R) ||
48614860
(isa<VPInstruction>(R) &&
4862-
all_of(cast<VPSingleDefRecipe>(R)->users(), [&](VPUser *U) {
4863-
return cast<VPRecipeBase>(U)->usesScalars(R->getVPSingleValue());
4864-
}))) {
4861+
all_of(cast<VPSingleDefRecipe>(R)->users(),
4862+
[&](VPUser *U) {
4863+
return cast<VPRecipeBase>(U)->usesScalars(
4864+
R->getVPSingleValue());
4865+
})) ||
4866+
(isa<VPReductionPHIRecipe>(R) &&
4867+
(cast<VPReductionPHIRecipe>(R))->isInLoop())) {
48654868
unsigned ClassID = TTI.getRegisterClassForType(
48664869
false, TypeInfo.inferScalarType(R->getVPSingleValue()));
48674870
// FIXME: The target might use more than one register for the type
@@ -5234,213 +5237,6 @@ LoopVectorizationCostModel::selectInterleaveCount(VPlan &Plan, ElementCount VF,
52345237
return 1;
52355238
}
52365239

5237-
SmallVector<LoopVectorizationCostModel::RegisterUsage, 8>
5238-
LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<ElementCount> VFs) {
5239-
// This function calculates the register usage by measuring the highest number
5240-
// of values that are alive at a single location. Obviously, this is a very
5241-
// rough estimation. We scan the loop in a topological order in order and
5242-
// assign a number to each instruction. We use RPO to ensure that defs are
5243-
// met before their users. We assume that each instruction that has in-loop
5244-
// users starts an interval. We record every time that an in-loop value is
5245-
// used, so we have a list of the first and last occurrences of each
5246-
// instruction. Next, we transpose this data structure into a multi map that
5247-
// holds the list of intervals that *end* at a specific location. This multi
5248-
// map allows us to perform a linear search. We scan the instructions linearly
5249-
// and record each time that a new interval starts, by placing it in a set.
5250-
// If we find this value in the multi-map then we remove it from the set.
5251-
// The max register usage is the maximum size of the set.
5252-
// We also search for instructions that are defined outside the loop, but are
5253-
// used inside the loop. We need this number separately from the max-interval
5254-
// usage number because when we unroll, loop-invariant values do not take
5255-
// more registers.
5256-
LoopBlocksDFS DFS(TheLoop);
5257-
DFS.perform(LI);
5258-
5259-
RegisterUsage RU;
5260-
5261-
// Each 'key' in the map opens a new interval. The values
5262-
// of the map are the index of the 'last seen' usage of the
5263-
// instruction that is the key.
5264-
using IntervalMap = SmallDenseMap<Instruction *, unsigned, 16>;
5265-
5266-
// Maps instruction to its index.
5267-
SmallVector<Instruction *, 64> IdxToInstr;
5268-
// Marks the end of each interval.
5269-
IntervalMap EndPoint;
5270-
// Saves the list of instruction indices that are used in the loop.
5271-
SmallPtrSet<Instruction *, 8> Ends;
5272-
// Saves the list of values that are used in the loop but are defined outside
5273-
// the loop (not including non-instruction values such as arguments and
5274-
// constants).
5275-
SmallSetVector<Instruction *, 8> LoopInvariants;
5276-
5277-
for (BasicBlock *BB : make_range(DFS.beginRPO(), DFS.endRPO())) {
5278-
for (Instruction &I : BB->instructionsWithoutDebug()) {
5279-
IdxToInstr.push_back(&I);
5280-
5281-
// Save the end location of each USE.
5282-
for (Value *U : I.operands()) {
5283-
auto *Instr = dyn_cast<Instruction>(U);
5284-
5285-
// Ignore non-instruction values such as arguments, constants, etc.
5286-
// FIXME: Might need some motivation why these values are ignored. If
5287-
// for example an argument is used inside the loop it will increase the
5288-
// register pressure (so shouldn't we add it to LoopInvariants).
5289-
if (!Instr)
5290-
continue;
5291-
5292-
// If this instruction is outside the loop then record it and continue.
5293-
if (!TheLoop->contains(Instr)) {
5294-
LoopInvariants.insert(Instr);
5295-
continue;
5296-
}
5297-
5298-
// Overwrite previous end points.
5299-
EndPoint[Instr] = IdxToInstr.size();
5300-
Ends.insert(Instr);
5301-
}
5302-
}
5303-
}
5304-
5305-
// Saves the list of intervals that end with the index in 'key'.
5306-
using InstrList = SmallVector<Instruction *, 2>;
5307-
SmallDenseMap<unsigned, InstrList, 16> TransposeEnds;
5308-
5309-
// Transpose the EndPoints to a list of values that end at each index.
5310-
for (auto &Interval : EndPoint)
5311-
TransposeEnds[Interval.second].push_back(Interval.first);
5312-
5313-
SmallPtrSet<Instruction *, 8> OpenIntervals;
5314-
SmallVector<RegisterUsage, 8> RUs(VFs.size());
5315-
SmallVector<SmallMapVector<unsigned, unsigned, 4>, 8> MaxUsages(VFs.size());
5316-
5317-
LLVM_DEBUG(dbgs() << "LV(REG): Calculating max register usage:\n");
5318-
5319-
const auto &TTICapture = TTI;
5320-
auto GetRegUsage = [&TTICapture](Type *Ty, ElementCount VF) -> unsigned {
5321-
if (Ty->isTokenTy() || !VectorType::isValidElementType(Ty) ||
5322-
(VF.isScalable() &&
5323-
!TTICapture.isElementTypeLegalForScalableVector(Ty)))
5324-
return 0;
5325-
return TTICapture.getRegUsageForType(VectorType::get(Ty, VF));
5326-
};
5327-
5328-
collectInLoopReductions();
5329-
5330-
for (unsigned int Idx = 0, Sz = IdxToInstr.size(); Idx < Sz; ++Idx) {
5331-
Instruction *I = IdxToInstr[Idx];
5332-
5333-
// Remove all of the instructions that end at this location.
5334-
InstrList &List = TransposeEnds[Idx];
5335-
for (Instruction *ToRemove : List)
5336-
OpenIntervals.erase(ToRemove);
5337-
5338-
// Ignore instructions that are never used within the loop and do not have
5339-
// side-effects.
5340-
if (!Ends.count(I) && !I->mayHaveSideEffects())
5341-
continue;
5342-
5343-
// Skip ignored values.
5344-
if (ValuesToIgnore.count(I))
5345-
continue;
5346-
5347-
// For each VF find the maximum usage of registers.
5348-
for (unsigned J = 0, E = VFs.size(); J < E; ++J) {
5349-
// Count the number of registers used, per register class, given all open
5350-
// intervals.
5351-
// Note that elements in this SmallMapVector will be default constructed
5352-
// as 0. So we can use "RegUsage[ClassID] += n" in the code below even if
5353-
// there is no previous entry for ClassID.
5354-
SmallMapVector<unsigned, unsigned, 4> RegUsage;
5355-
5356-
if (VFs[J].isScalar()) {
5357-
for (auto *Inst : OpenIntervals) {
5358-
unsigned ClassID =
5359-
TTI.getRegisterClassForType(false, Inst->getType());
5360-
// FIXME: The target might use more than one register for the type
5361-
// even in the scalar case.
5362-
RegUsage[ClassID] += 1;
5363-
}
5364-
} else {
5365-
collectNonVectorizedAndSetWideningDecisions(VFs[J]);
5366-
for (auto *Inst : OpenIntervals) {
5367-
// Skip ignored values for VF > 1.
5368-
if (VecValuesToIgnore.count(Inst))
5369-
continue;
5370-
if (isScalarAfterVectorization(Inst, VFs[J])) {
5371-
unsigned ClassID =
5372-
TTI.getRegisterClassForType(false, Inst->getType());
5373-
// FIXME: The target might use more than one register for the type
5374-
// even in the scalar case.
5375-
RegUsage[ClassID] += 1;
5376-
} else {
5377-
unsigned ClassID =
5378-
TTI.getRegisterClassForType(true, Inst->getType());
5379-
RegUsage[ClassID] += GetRegUsage(Inst->getType(), VFs[J]);
5380-
}
5381-
}
5382-
}
5383-
5384-
for (const auto &Pair : RegUsage) {
5385-
auto &Entry = MaxUsages[J][Pair.first];
5386-
Entry = std::max(Entry, Pair.second);
5387-
}
5388-
}
5389-
5390-
LLVM_DEBUG(dbgs() << "LV(REG): At #" << Idx << " Interval # "
5391-
<< OpenIntervals.size() << '\n');
5392-
5393-
// Add the current instruction to the list of open intervals.
5394-
OpenIntervals.insert(I);
5395-
}
5396-
5397-
for (unsigned Idx = 0, End = VFs.size(); Idx < End; ++Idx) {
5398-
// Note that elements in this SmallMapVector will be default constructed
5399-
// as 0. So we can use "Invariant[ClassID] += n" in the code below even if
5400-
// there is no previous entry for ClassID.
5401-
SmallMapVector<unsigned, unsigned, 4> Invariant;
5402-
5403-
for (auto *Inst : LoopInvariants) {
5404-
// FIXME: The target might use more than one register for the type
5405-
// even in the scalar case.
5406-
bool IsScalar = all_of(Inst->users(), [&](User *U) {
5407-
auto *I = cast<Instruction>(U);
5408-
return TheLoop != LI->getLoopFor(I->getParent()) ||
5409-
isScalarAfterVectorization(I, VFs[Idx]);
5410-
});
5411-
5412-
ElementCount VF = IsScalar ? ElementCount::getFixed(1) : VFs[Idx];
5413-
unsigned ClassID =
5414-
TTI.getRegisterClassForType(VF.isVector(), Inst->getType());
5415-
Invariant[ClassID] += GetRegUsage(Inst->getType(), VF);
5416-
}
5417-
5418-
LLVM_DEBUG({
5419-
dbgs() << "LV(REG): VF = " << VFs[Idx] << '\n';
5420-
dbgs() << "LV(REG): Found max usage: " << MaxUsages[Idx].size()
5421-
<< " item\n";
5422-
for (const auto &pair : MaxUsages[Idx]) {
5423-
dbgs() << "LV(REG): RegisterClass: "
5424-
<< TTI.getRegisterClassName(pair.first) << ", " << pair.second
5425-
<< " registers\n";
5426-
}
5427-
dbgs() << "LV(REG): Found invariant usage: " << Invariant.size()
5428-
<< " item\n";
5429-
for (const auto &pair : Invariant) {
5430-
dbgs() << "LV(REG): RegisterClass: "
5431-
<< TTI.getRegisterClassName(pair.first) << ", " << pair.second
5432-
<< " registers\n";
5433-
}
5434-
});
5435-
5436-
RU.LoopInvariantRegs = Invariant;
5437-
RU.MaxLocalUsers = MaxUsages[Idx];
5438-
RUs[Idx] = RU;
5439-
}
5440-
5441-
return RUs;
5442-
}
5443-
54445240
bool LoopVectorizationCostModel::useEmulatedMaskMemRefHack(Instruction *I,
54455241
ElementCount VF) {
54465242
// TODO: Cost model for emulated masked load/store is completely
@@ -7621,7 +7417,10 @@ VectorizationFactor LoopVectorizationPlanner::computeBestVF() {
76217417
}
76227418

76237419
for (auto &P : VPlans) {
7624-
for (ElementCount VF : P->vectorFactors()) {
7420+
ArrayRef<ElementCount> VFs(P->vectorFactors().begin(),
7421+
P->vectorFactors().end());
7422+
auto RUs = ::calculateRegisterUsage(*P, VFs, TTI, CM.ValuesToIgnore);
7423+
for (auto [VF, RU] : zip_equal(VFs, RUs)) {
76257424
if (VF.isScalar())
76267425
continue;
76277426
if (!ForceVectorization && !willGenerateVectors(*P, VF, TTI)) {
@@ -7642,6 +7441,13 @@ VectorizationFactor LoopVectorizationPlanner::computeBestVF() {
76427441

76437442
InstructionCost Cost = cost(*P, VF);
76447443
VectorizationFactor CurrentFactor(VF, Cost, ScalarCost);
7444+
7445+
if (RU.exceedsMaxNumRegs(TTI)) {
7446+
LLVM_DEBUG(dbgs() << "LV(REG): Not considering vector loop of width "
7447+
<< VF << " because it uses too many registers\n");
7448+
continue;
7449+
}
7450+
76457451
if (isMoreProfitable(CurrentFactor, BestFactor, P->hasScalarTail()))
76467452
BestFactor = CurrentFactor;
76477453

0 commit comments

Comments
 (0)