@@ -8876,14 +8876,14 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
8876
8876
} else if (V1 && P2.isNull()) {
8877
8877
// Shuffle single vector.
8878
8878
ExtraCost += GetValueMinBWAffectedCost(V1);
8879
- CommonVF = cast<FixedVectorType> (V1->getType())->getNumElements( );
8879
+ CommonVF = getVF (V1);
8880
8880
assert(
8881
8881
all_of(Mask,
8882
8882
[=](int Idx) { return Idx < static_cast<int>(CommonVF); }) &&
8883
8883
"All elements in mask must be less than CommonVF.");
8884
8884
} else if (V1 && !V2) {
8885
8885
// Shuffle vector and tree node.
8886
- unsigned VF = cast<FixedVectorType> (V1->getType())->getNumElements( );
8886
+ unsigned VF = getVF (V1);
8887
8887
const TreeEntry *E2 = P2.get<const TreeEntry *>();
8888
8888
CommonVF = std::max(VF, E2->getVectorFactor());
8889
8889
assert(all_of(Mask,
@@ -8909,7 +8909,7 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
8909
8909
V2 = getAllOnesValue(*R.DL, getWidenedType(ScalarTy, CommonVF));
8910
8910
} else if (!V1 && V2) {
8911
8911
// Shuffle vector and tree node.
8912
- unsigned VF = cast<FixedVectorType> (V2->getType())->getNumElements( );
8912
+ unsigned VF = getVF (V2);
8913
8913
const TreeEntry *E1 = P1.get<const TreeEntry *>();
8914
8914
CommonVF = std::max(VF, E1->getVectorFactor());
8915
8915
assert(all_of(Mask,
@@ -8937,9 +8937,8 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
8937
8937
V2 = getAllOnesValue(*R.DL, getWidenedType(ScalarTy, CommonVF));
8938
8938
} else {
8939
8939
assert(V1 && V2 && "Expected both vectors.");
8940
- unsigned VF = cast<FixedVectorType>(V1->getType())->getNumElements();
8941
- CommonVF =
8942
- std::max(VF, cast<FixedVectorType>(V2->getType())->getNumElements());
8940
+ unsigned VF = getVF(V1);
8941
+ CommonVF = std::max(VF, getVF(V2));
8943
8942
assert(all_of(Mask,
8944
8943
[=](int Idx) {
8945
8944
return Idx < 2 * static_cast<int>(CommonVF);
@@ -8957,6 +8956,9 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
8957
8956
V2 = getAllOnesValue(*R.DL, getWidenedType(ScalarTy, CommonVF));
8958
8957
}
8959
8958
}
8959
+ if (auto *VecTy = dyn_cast<FixedVectorType>(ScalarTy))
8960
+ transformScalarShuffleIndiciesToVector(VecTy->getNumElements(),
8961
+ CommonMask);
8960
8962
InVectors.front() =
8961
8963
Constant::getNullValue(getWidenedType(ScalarTy, CommonMask.size()));
8962
8964
if (InVectors.size() == 2)
0 commit comments