Skip to content

Commit c9bf976

Browse files
committed
[SLP][REVEC] Make ShuffleCostEstimator::createShuffle support vector
instructions. The VF is relative to the number of elements in ScalarTy instead of the size of mask.
1 parent e21df02 commit c9bf976

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8703,6 +8703,7 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
87038703
Value *V1 = P1.dyn_cast<Value *>(), *V2 = P2.dyn_cast<Value *>();
87048704
unsigned CommonVF = Mask.size();
87058705
InstructionCost ExtraCost = 0;
8706+
unsigned ScalarTyNumElements = getNumElements(ScalarTy);
87068707
auto GetNodeMinBWAffectedCost = [&](const TreeEntry &E,
87078708
unsigned VF) -> InstructionCost {
87088709
if (E.isGather() && allConstant(E.Scalars))
@@ -8743,6 +8744,15 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
87438744
}
87448745
return TTI::TCC_Free;
87458746
};
8747+
auto GetVF = [&](Value *V) {
8748+
unsigned VNumElements =
8749+
cast<FixedVectorType>(V->getType())->getNumElements();
8750+
assert(VNumElements > ScalarTyNumElements &&
8751+
"the number of elements of V is not large enough");
8752+
assert(VNumElements % ScalarTyNumElements == 0 &&
8753+
"the number of elements of V is not a vectorized value");
8754+
return VNumElements / ScalarTyNumElements;
8755+
};
87468756
if (!V1 && !V2 && !P2.isNull()) {
87478757
// Shuffle 2 entry nodes.
87488758
const TreeEntry *E = P1.get<const TreeEntry *>();
@@ -8814,14 +8824,14 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
88148824
} else if (V1 && P2.isNull()) {
88158825
// Shuffle single vector.
88168826
ExtraCost += GetValueMinBWAffectedCost(V1);
8817-
CommonVF = cast<FixedVectorType>(V1->getType())->getNumElements();
8827+
CommonVF = GetVF(V1);
88188828
assert(
88198829
all_of(Mask,
88208830
[=](int Idx) { return Idx < static_cast<int>(CommonVF); }) &&
88218831
"All elements in mask must be less than CommonVF.");
88228832
} else if (V1 && !V2) {
88238833
// Shuffle vector and tree node.
8824-
unsigned VF = cast<FixedVectorType>(V1->getType())->getNumElements();
8834+
unsigned VF = GetVF(V1);
88258835
const TreeEntry *E2 = P2.get<const TreeEntry *>();
88268836
CommonVF = std::max(VF, E2->getVectorFactor());
88278837
assert(all_of(Mask,
@@ -8847,7 +8857,7 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
88478857
V2 = getAllOnesValue(*R.DL, getWidenedType(ScalarTy, CommonVF));
88488858
} else if (!V1 && V2) {
88498859
// Shuffle vector and tree node.
8850-
unsigned VF = cast<FixedVectorType>(V2->getType())->getNumElements();
8860+
unsigned VF = GetVF(V2);
88518861
const TreeEntry *E1 = P1.get<const TreeEntry *>();
88528862
CommonVF = std::max(VF, E1->getVectorFactor());
88538863
assert(all_of(Mask,
@@ -8875,9 +8885,8 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
88758885
V2 = getAllOnesValue(*R.DL, getWidenedType(ScalarTy, CommonVF));
88768886
} else {
88778887
assert(V1 && V2 && "Expected both vectors.");
8878-
unsigned VF = cast<FixedVectorType>(V1->getType())->getNumElements();
8879-
CommonVF =
8880-
std::max(VF, cast<FixedVectorType>(V2->getType())->getNumElements());
8888+
unsigned VF = GetVF(V1);
8889+
CommonVF = std::max(VF, GetVF(V2));
88818890
assert(all_of(Mask,
88828891
[=](int Idx) {
88838892
return Idx < 2 * static_cast<int>(CommonVF);
@@ -8895,6 +8904,9 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
88958904
V2 = getAllOnesValue(*R.DL, getWidenedType(ScalarTy, CommonVF));
88968905
}
88978906
}
8907+
if (auto *VecTy = dyn_cast<FixedVectorType>(ScalarTy))
8908+
transformScalarShuffleIndiciesToVector(VecTy->getNumElements(),
8909+
CommonMask);
88988910
InVectors.front() =
88998911
Constant::getNullValue(getWidenedType(ScalarTy, CommonMask.size()));
89008912
if (InVectors.size() == 2)

0 commit comments

Comments
 (0)