Skip to content

Commit edcbd4a

Browse files
[SLP][NFC]Extract a check for strided loads into separate function, NFC
Reviewers: hiraditya, RKSimon Reviewed By: RKSimon Pull Request: #134876
1 parent 02a708b commit edcbd4a

File tree

1 file changed

+69
-46
lines changed

1 file changed

+69
-46
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 69 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -5597,6 +5597,71 @@ static bool isMaskedLoadCompress(
55975597
return TotalVecCost < GatherCost;
55985598
}
55995599

5600+
/// Checks if strided loads can be generated out of \p VL loads with pointers \p
5601+
/// PointerOps:
5602+
/// 1. Target with strided load support is detected.
5603+
/// 2. The number of loads is greater than MinProfitableStridedLoads, or the
5604+
/// potential stride <= MaxProfitableLoadStride and the potential stride is
5605+
/// power-of-2 (to avoid perf regressions for the very small number of loads)
5606+
/// and max distance > number of loads, or potential stride is -1.
5607+
/// 3. The loads are ordered, or number of unordered loads <=
5608+
/// MaxProfitableUnorderedLoads, or loads are in reversed order. (this check is
5609+
/// to avoid extra costs for very expensive shuffles).
5610+
/// 4. Any pointer operand is an instruction with the users outside of the
5611+
/// current graph (for masked gathers extra extractelement instructions
5612+
/// might be required).
5613+
static bool isStridedLoad(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps,
5614+
ArrayRef<unsigned> Order,
5615+
const TargetTransformInfo &TTI, const DataLayout &DL,
5616+
ScalarEvolution &SE,
5617+
const bool IsAnyPointerUsedOutGraph, const int Diff) {
5618+
const unsigned Sz = VL.size();
5619+
const unsigned AbsoluteDiff = std::abs(Diff);
5620+
Type *ScalarTy = VL.front()->getType();
5621+
auto *VecTy = getWidenedType(ScalarTy, Sz);
5622+
if (IsAnyPointerUsedOutGraph ||
5623+
(AbsoluteDiff > Sz &&
5624+
(Sz > MinProfitableStridedLoads ||
5625+
(AbsoluteDiff <= MaxProfitableLoadStride * Sz &&
5626+
AbsoluteDiff % Sz == 0 && has_single_bit(AbsoluteDiff / Sz)))) ||
5627+
Diff == -(static_cast<int>(Sz) - 1)) {
5628+
int Stride = Diff / static_cast<int>(Sz - 1);
5629+
if (Diff != Stride * static_cast<int>(Sz - 1))
5630+
return false;
5631+
Align Alignment =
5632+
cast<LoadInst>(Order.empty() ? VL.front() : VL[Order.front()])
5633+
->getAlign();
5634+
if (!TTI.isLegalStridedLoadStore(VecTy, Alignment))
5635+
return false;
5636+
Value *Ptr0;
5637+
Value *PtrN;
5638+
if (Order.empty()) {
5639+
Ptr0 = PointerOps.front();
5640+
PtrN = PointerOps.back();
5641+
} else {
5642+
Ptr0 = PointerOps[Order.front()];
5643+
PtrN = PointerOps[Order.back()];
5644+
}
5645+
// Iterate through all pointers and check if all distances are
5646+
// unique multiple of Dist.
5647+
SmallSet<int, 4> Dists;
5648+
for (Value *Ptr : PointerOps) {
5649+
int Dist = 0;
5650+
if (Ptr == PtrN)
5651+
Dist = Diff;
5652+
else if (Ptr != Ptr0)
5653+
Dist = *getPointersDiff(ScalarTy, Ptr0, ScalarTy, Ptr, DL, SE);
5654+
// If the strides are not the same or repeated, we can't
5655+
// vectorize.
5656+
if (((Dist / Stride) * Stride) != Dist || !Dists.insert(Dist).second)
5657+
break;
5658+
}
5659+
if (Dists.size() == Sz)
5660+
return true;
5661+
}
5662+
return false;
5663+
}
5664+
56005665
BoUpSLP::LoadsState
56015666
BoUpSLP::canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0,
56025667
SmallVectorImpl<unsigned> &Order,
@@ -5670,59 +5735,17 @@ BoUpSLP::canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0,
56705735
return LoadsState::Vectorize;
56715736
// Simple check if not a strided access - clear order.
56725737
bool IsPossibleStrided = *Diff % (Sz - 1) == 0;
5673-
// Try to generate strided load node if:
5674-
// 1. Target with strided load support is detected.
5675-
// 2. The number of loads is greater than MinProfitableStridedLoads,
5676-
// or the potential stride <= MaxProfitableLoadStride and the
5677-
// potential stride is power-of-2 (to avoid perf regressions for the very
5678-
// small number of loads) and max distance > number of loads, or potential
5679-
// stride is -1.
5680-
// 3. The loads are ordered, or number of unordered loads <=
5681-
// MaxProfitableUnorderedLoads, or loads are in reversed order.
5682-
// (this check is to avoid extra costs for very expensive shuffles).
5683-
// 4. Any pointer operand is an instruction with the users outside of the
5684-
// current graph (for masked gathers extra extractelement instructions
5685-
// might be required).
5738+
// Try to generate strided load node.
56865739
auto IsAnyPointerUsedOutGraph =
56875740
IsPossibleStrided && any_of(PointerOps, [&](Value *V) {
56885741
return isa<Instruction>(V) && any_of(V->users(), [&](User *U) {
56895742
return !isVectorized(U) && !MustGather.contains(U);
56905743
});
56915744
});
5692-
const unsigned AbsoluteDiff = std::abs(*Diff);
56935745
if (IsPossibleStrided &&
5694-
(IsAnyPointerUsedOutGraph ||
5695-
(AbsoluteDiff > Sz &&
5696-
(Sz > MinProfitableStridedLoads ||
5697-
(AbsoluteDiff <= MaxProfitableLoadStride * Sz &&
5698-
AbsoluteDiff % Sz == 0 && has_single_bit(AbsoluteDiff / Sz)))) ||
5699-
*Diff == -(static_cast<int>(Sz) - 1))) {
5700-
int Stride = *Diff / static_cast<int>(Sz - 1);
5701-
if (*Diff == Stride * static_cast<int>(Sz - 1)) {
5702-
Align Alignment =
5703-
cast<LoadInst>(Order.empty() ? VL.front() : VL[Order.front()])
5704-
->getAlign();
5705-
if (TTI->isLegalStridedLoadStore(VecTy, Alignment)) {
5706-
// Iterate through all pointers and check if all distances are
5707-
// unique multiple of Dist.
5708-
SmallSet<int, 4> Dists;
5709-
for (Value *Ptr : PointerOps) {
5710-
int Dist = 0;
5711-
if (Ptr == PtrN)
5712-
Dist = *Diff;
5713-
else if (Ptr != Ptr0)
5714-
Dist = *getPointersDiff(ScalarTy, Ptr0, ScalarTy, Ptr, *DL, *SE);
5715-
// If the strides are not the same or repeated, we can't
5716-
// vectorize.
5717-
if (((Dist / Stride) * Stride) != Dist ||
5718-
!Dists.insert(Dist).second)
5719-
break;
5720-
}
5721-
if (Dists.size() == Sz)
5722-
return LoadsState::StridedVectorize;
5723-
}
5724-
}
5725-
}
5746+
isStridedLoad(VL, PointerOps, Order, *TTI, *DL, *SE,
5747+
IsAnyPointerUsedOutGraph, *Diff))
5748+
return LoadsState::StridedVectorize;
57265749
bool IsMasked;
57275750
unsigned InterleaveFactor;
57285751
SmallVector<int> CompressMask;

0 commit comments

Comments
 (0)