Skip to content

Commit 8982786

Browse files
committed
[SLP][NFC]Make canVectorizeLoads member of BoUpSLP class, NFC.
1 parent 13a78fd commit 8982786

File tree

1 file changed

+43
-35
lines changed

1 file changed

+43
-35
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 43 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -980,6 +980,14 @@ class BoUpSLP {
980980
class ShuffleInstructionBuilder;
981981

982982
public:
983+
/// Tracks the state we can represent the loads in the given sequence.
984+
enum class LoadsState {
985+
Gather,
986+
Vectorize,
987+
ScatterVectorize,
988+
StridedVectorize
989+
};
990+
983991
using ValueList = SmallVector<Value *, 8>;
984992
using InstrList = SmallVector<Instruction *, 16>;
985993
using ValueSet = SmallPtrSet<Value *, 16>;
@@ -1184,6 +1192,19 @@ class BoUpSLP {
11841192
/// may not be necessary.
11851193
bool isLoadCombineCandidate() const;
11861194

1195+
/// Checks if the given array of loads can be represented as a vectorized,
1196+
/// scatter or just simple gather.
1197+
/// \param VL list of loads.
1198+
/// \param VL0 main load value.
1199+
/// \param Order returned order of load instructions.
1200+
/// \param PointerOps returned list of pointer operands.
1201+
/// \param TryRecursiveCheck used to check if long masked gather can be
1202+
/// represented as a serie of loads/insert subvector, if profitable.
1203+
LoadsState canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0,
1204+
SmallVectorImpl<unsigned> &Order,
1205+
SmallVectorImpl<Value *> &PointerOps,
1206+
bool TryRecursiveCheck = true) const;
1207+
11871208
OptimizationRemarkEmitter *getORE() { return ORE; }
11881209

11891210
/// This structure holds any data we need about the edges being traversed
@@ -3957,11 +3978,6 @@ BoUpSLP::findReusedOrderedScalars(const BoUpSLP::TreeEntry &TE) {
39573978
return std::move(CurrentOrder);
39583979
}
39593980

3960-
namespace {
3961-
/// Tracks the state we can represent the loads in the given sequence.
3962-
enum class LoadsState { Gather, Vectorize, ScatterVectorize, StridedVectorize };
3963-
} // anonymous namespace
3964-
39653981
static bool arePointersCompatible(Value *Ptr1, Value *Ptr2,
39663982
const TargetLibraryInfo &TLI,
39673983
bool CompareOpcodes = true) {
@@ -3998,16 +4014,9 @@ static bool isReverseOrder(ArrayRef<unsigned> Order) {
39984014
});
39994015
}
40004016

4001-
/// Checks if the given array of loads can be represented as a vectorized,
4002-
/// scatter or just simple gather.
4003-
static LoadsState canVectorizeLoads(const BoUpSLP &R, ArrayRef<Value *> VL,
4004-
const Value *VL0,
4005-
const TargetTransformInfo &TTI,
4006-
const DataLayout &DL, ScalarEvolution &SE,
4007-
LoopInfo &LI, const TargetLibraryInfo &TLI,
4008-
SmallVectorImpl<unsigned> &Order,
4009-
SmallVectorImpl<Value *> &PointerOps,
4010-
bool TryRecursiveCheck = true) {
4017+
BoUpSLP::LoadsState BoUpSLP::canVectorizeLoads(
4018+
ArrayRef<Value *> VL, const Value *VL0, SmallVectorImpl<unsigned> &Order,
4019+
SmallVectorImpl<Value *> &PointerOps, bool TryRecursiveCheck) const {
40114020
// Check that a vectorized load would load the same memory as a scalar
40124021
// load. For example, we don't want to vectorize loads that are smaller
40134022
// than 8-bit. Even though we have a packed struct {<i2, i2, i2, i2>} LLVM
@@ -4016,7 +4025,7 @@ static LoadsState canVectorizeLoads(const BoUpSLP &R, ArrayRef<Value *> VL,
40164025
// unvectorized version.
40174026
Type *ScalarTy = VL0->getType();
40184027

4019-
if (DL.getTypeSizeInBits(ScalarTy) != DL.getTypeAllocSizeInBits(ScalarTy))
4028+
if (DL->getTypeSizeInBits(ScalarTy) != DL->getTypeAllocSizeInBits(ScalarTy))
40204029
return LoadsState::Gather;
40214030

40224031
// Make sure all loads in the bundle are simple - we can't vectorize
@@ -4036,9 +4045,9 @@ static LoadsState canVectorizeLoads(const BoUpSLP &R, ArrayRef<Value *> VL,
40364045
Order.clear();
40374046
auto *VecTy = FixedVectorType::get(ScalarTy, Sz);
40384047
// Check the order of pointer operands or that all pointers are the same.
4039-
bool IsSorted = sortPtrAccesses(PointerOps, ScalarTy, DL, SE, Order);
4048+
bool IsSorted = sortPtrAccesses(PointerOps, ScalarTy, *DL, *SE, Order);
40404049
if (IsSorted || all_of(PointerOps, [&](Value *P) {
4041-
return arePointersCompatible(P, PointerOps.front(), TLI);
4050+
return arePointersCompatible(P, PointerOps.front(), *TLI);
40424051
})) {
40434052
if (IsSorted) {
40444053
Value *Ptr0;
@@ -4051,7 +4060,7 @@ static LoadsState canVectorizeLoads(const BoUpSLP &R, ArrayRef<Value *> VL,
40514060
PtrN = PointerOps[Order.back()];
40524061
}
40534062
std::optional<int> Diff =
4054-
getPointersDiff(ScalarTy, Ptr0, ScalarTy, PtrN, DL, SE);
4063+
getPointersDiff(ScalarTy, Ptr0, ScalarTy, PtrN, *DL, *SE);
40554064
// Check that the sorted loads are consecutive.
40564065
if (static_cast<unsigned>(*Diff) == Sz - 1)
40574066
return LoadsState::Vectorize;
@@ -4078,7 +4087,7 @@ static LoadsState canVectorizeLoads(const BoUpSLP &R, ArrayRef<Value *> VL,
40784087
Align Alignment =
40794088
cast<LoadInst>(Order.empty() ? VL.front() : VL[Order.front()])
40804089
->getAlign();
4081-
if (TTI.isLegalStridedLoadStore(VecTy, Alignment)) {
4090+
if (TTI->isLegalStridedLoadStore(VecTy, Alignment)) {
40824091
// Iterate through all pointers and check if all distances are
40834092
// unique multiple of Dist.
40844093
SmallSet<int, 4> Dists;
@@ -4087,7 +4096,8 @@ static LoadsState canVectorizeLoads(const BoUpSLP &R, ArrayRef<Value *> VL,
40874096
if (Ptr == PtrN)
40884097
Dist = *Diff;
40894098
else if (Ptr != Ptr0)
4090-
Dist = *getPointersDiff(ScalarTy, Ptr0, ScalarTy, Ptr, DL, SE);
4099+
Dist =
4100+
*getPointersDiff(ScalarTy, Ptr0, ScalarTy, Ptr, *DL, *SE);
40914101
// If the strides are not the same or repeated, we can't
40924102
// vectorize.
40934103
if (((Dist / Stride) * Stride) != Dist ||
@@ -4100,11 +4110,11 @@ static LoadsState canVectorizeLoads(const BoUpSLP &R, ArrayRef<Value *> VL,
41004110
}
41014111
}
41024112
}
4103-
auto CheckForShuffledLoads = [&](Align CommonAlignment) {
4104-
unsigned Sz = DL.getTypeSizeInBits(ScalarTy);
4105-
unsigned MinVF = R.getMinVF(Sz);
4113+
auto CheckForShuffledLoads = [&, &TTI = *TTI](Align CommonAlignment) {
4114+
unsigned Sz = DL->getTypeSizeInBits(ScalarTy);
4115+
unsigned MinVF = getMinVF(Sz);
41064116
unsigned MaxVF = std::max<unsigned>(bit_floor(VL.size() / 2), MinVF);
4107-
MaxVF = std::min(R.getMaximumVF(Sz, Instruction::Load), MaxVF);
4117+
MaxVF = std::min(getMaximumVF(Sz, Instruction::Load), MaxVF);
41084118
for (unsigned VF = MaxVF; VF >= MinVF; VF /= 2) {
41094119
unsigned VectorizedCnt = 0;
41104120
SmallVector<LoadsState> States;
@@ -4114,8 +4124,8 @@ static LoadsState canVectorizeLoads(const BoUpSLP &R, ArrayRef<Value *> VL,
41144124
SmallVector<unsigned> Order;
41154125
SmallVector<Value *> PointerOps;
41164126
LoadsState LS =
4117-
canVectorizeLoads(R, Slice, Slice.front(), TTI, DL, SE, LI, TLI,
4118-
Order, PointerOps, /*TryRecursiveCheck=*/false);
4127+
canVectorizeLoads(Slice, Slice.front(), Order, PointerOps,
4128+
/*TryRecursiveCheck=*/false);
41194129
// Check that the sorted loads are consecutive.
41204130
if (LS == LoadsState::Gather)
41214131
break;
@@ -4175,7 +4185,7 @@ static LoadsState canVectorizeLoads(const BoUpSLP &R, ArrayRef<Value *> VL,
41754185
// TODO: need to improve analysis of the pointers, if not all of them are
41764186
// GEPs or have > 2 operands, we end up with a gather node, which just
41774187
// increases the cost.
4178-
Loop *L = LI.getLoopFor(cast<LoadInst>(VL0)->getParent());
4188+
Loop *L = LI->getLoopFor(cast<LoadInst>(VL0)->getParent());
41794189
bool ProfitableGatherPointers =
41804190
L && Sz > 2 && count_if(PointerOps, [L](Value *V) {
41814191
return L->isLoopInvariant(V);
@@ -4187,8 +4197,8 @@ static LoadsState canVectorizeLoads(const BoUpSLP &R, ArrayRef<Value *> VL,
41874197
isa<Constant, Instruction>(GEP->getOperand(1)));
41884198
})) {
41894199
Align CommonAlignment = computeCommonAlignment<LoadInst>(VL);
4190-
if (TTI.isLegalMaskedGather(VecTy, CommonAlignment) &&
4191-
!TTI.forceScalarizeMaskedGather(VecTy, CommonAlignment)) {
4200+
if (TTI->isLegalMaskedGather(VecTy, CommonAlignment) &&
4201+
!TTI->forceScalarizeMaskedGather(VecTy, CommonAlignment)) {
41924202
// Check if potential masked gather can be represented as series
41934203
// of loads + insertsubvectors.
41944204
if (TryRecursiveCheck && CheckForShuffledLoads(CommonAlignment)) {
@@ -5635,8 +5645,7 @@ BoUpSLP::TreeEntry::EntryState BoUpSLP::getScalarsVectorizationState(
56355645
// treats loading/storing it as an i8 struct. If we vectorize loads/stores
56365646
// from such a struct, we read/write packed bits disagreeing with the
56375647
// unvectorized version.
5638-
switch (canVectorizeLoads(*this, VL, VL0, *TTI, *DL, *SE, *LI, *TLI,
5639-
CurrentOrder, PointerOps)) {
5648+
switch (canVectorizeLoads(VL, VL0, CurrentOrder, PointerOps)) {
56405649
case LoadsState::Vectorize:
56415650
return TreeEntry::Vectorize;
56425651
case LoadsState::ScatterVectorize:
@@ -7416,9 +7425,8 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
74167425
!VectorizedLoads.count(Slice.back()) && allSameBlock(Slice)) {
74177426
SmallVector<Value *> PointerOps;
74187427
OrdersType CurrentOrder;
7419-
LoadsState LS =
7420-
canVectorizeLoads(R, Slice, Slice.front(), TTI, *R.DL, *R.SE,
7421-
*R.LI, *R.TLI, CurrentOrder, PointerOps);
7428+
LoadsState LS = R.canVectorizeLoads(Slice, Slice.front(),
7429+
CurrentOrder, PointerOps);
74227430
switch (LS) {
74237431
case LoadsState::Vectorize:
74247432
case LoadsState::ScatterVectorize:

0 commit comments

Comments
 (0)