@@ -253,6 +253,21 @@ static FixedVectorType *getWidenedType(Type *ScalarTy, unsigned VF) {
253
253
VF * getNumElements(ScalarTy));
254
254
}
255
255
256
+ static void transformScalarShuffleIndiciesToVector(unsigned VecTyNumElements,
257
+ SmallVectorImpl<int> &Mask) {
258
+ // The ShuffleBuilder implementation use shufflevector to splat an "element".
259
+ // But the element have different meaning for SLP (scalar) and REVEC
260
+ // (vector). We need to expand Mask into masks which shufflevector can use
261
+ // directly.
262
+ SmallVector<int> NewMask(Mask.size() * VecTyNumElements);
263
+ for (unsigned I : seq<unsigned>(Mask.size()))
264
+ for (auto [J, MaskV] : enumerate(MutableArrayRef(NewMask).slice(
265
+ I * VecTyNumElements, VecTyNumElements)))
266
+ MaskV = Mask[I] == PoisonMaskElem ? PoisonMaskElem
267
+ : Mask[I] * VecTyNumElements + J;
268
+ Mask.swap(NewMask);
269
+ }
270
+
256
271
/// \returns True if the value is a constant (but not globals/constant
257
272
/// expressions).
258
273
static bool isConstant(Value *V) {
@@ -7772,6 +7787,31 @@ namespace {
7772
7787
/// The base class for shuffle instruction emission and shuffle cost estimation.
7773
7788
class BaseShuffleAnalysis {
7774
7789
protected:
7790
+ Type *ScalarTy = nullptr;
7791
+
7792
+ BaseShuffleAnalysis(Type *ScalarTy) : ScalarTy(ScalarTy) {}
7793
+
7794
+ /// V is expected to be a vectorized value.
7795
+ /// When REVEC is disabled, there is no difference between VF and
7796
+ /// VNumElements.
7797
+ /// When REVEC is enabled, VF is VNumElements / ScalarTyNumElements.
7798
+ /// e.g., if ScalarTy is <4 x Ty> and V1 is <8 x Ty>, 2 is returned instead
7799
+ /// of 8.
7800
+ unsigned getVF(Value *V) const {
7801
+ assert(V && "V cannot be nullptr");
7802
+ assert(isa<FixedVectorType>(V->getType()) &&
7803
+ "V does not have FixedVectorType");
7804
+ assert(ScalarTy && "ScalarTy cannot be nullptr");
7805
+ unsigned ScalarTyNumElements = getNumElements(ScalarTy);
7806
+ unsigned VNumElements =
7807
+ cast<FixedVectorType>(V->getType())->getNumElements();
7808
+ assert(VNumElements > ScalarTyNumElements &&
7809
+ "the number of elements of V is not large enough");
7810
+ assert(VNumElements % ScalarTyNumElements == 0 &&
7811
+ "the number of elements of V is not a vectorized value");
7812
+ return VNumElements / ScalarTyNumElements;
7813
+ }
7814
+
7775
7815
/// Checks if the mask is an identity mask.
7776
7816
/// \param IsStrict if is true the function returns false if mask size does
7777
7817
/// not match vector size.
@@ -8265,7 +8305,6 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
8265
8305
bool IsFinalized = false;
8266
8306
SmallVector<int> CommonMask;
8267
8307
SmallVector<PointerUnion<Value *, const TreeEntry *>, 2> InVectors;
8268
- Type *ScalarTy = nullptr;
8269
8308
const TargetTransformInfo &TTI;
8270
8309
InstructionCost Cost = 0;
8271
8310
SmallDenseSet<Value *> VectorizedVals;
@@ -8847,14 +8886,14 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
8847
8886
} else if (V1 && P2.isNull()) {
8848
8887
// Shuffle single vector.
8849
8888
ExtraCost += GetValueMinBWAffectedCost(V1);
8850
- CommonVF = cast<FixedVectorType> (V1->getType())->getNumElements( );
8889
+ CommonVF = getVF (V1);
8851
8890
assert(
8852
8891
all_of(Mask,
8853
8892
[=](int Idx) { return Idx < static_cast<int>(CommonVF); }) &&
8854
8893
"All elements in mask must be less than CommonVF.");
8855
8894
} else if (V1 && !V2) {
8856
8895
// Shuffle vector and tree node.
8857
- unsigned VF = cast<FixedVectorType> (V1->getType())->getNumElements( );
8896
+ unsigned VF = getVF (V1);
8858
8897
const TreeEntry *E2 = P2.get<const TreeEntry *>();
8859
8898
CommonVF = std::max(VF, E2->getVectorFactor());
8860
8899
assert(all_of(Mask,
@@ -8880,7 +8919,7 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
8880
8919
V2 = getAllOnesValue(*R.DL, getWidenedType(ScalarTy, CommonVF));
8881
8920
} else if (!V1 && V2) {
8882
8921
// Shuffle vector and tree node.
8883
- unsigned VF = cast<FixedVectorType> (V2->getType())->getNumElements( );
8922
+ unsigned VF = getVF (V2);
8884
8923
const TreeEntry *E1 = P1.get<const TreeEntry *>();
8885
8924
CommonVF = std::max(VF, E1->getVectorFactor());
8886
8925
assert(all_of(Mask,
@@ -8908,9 +8947,8 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
8908
8947
V2 = getAllOnesValue(*R.DL, getWidenedType(ScalarTy, CommonVF));
8909
8948
} else {
8910
8949
assert(V1 && V2 && "Expected both vectors.");
8911
- unsigned VF = cast<FixedVectorType>(V1->getType())->getNumElements();
8912
- CommonVF =
8913
- std::max(VF, cast<FixedVectorType>(V2->getType())->getNumElements());
8950
+ unsigned VF = getVF(V1);
8951
+ CommonVF = std::max(VF, getVF(V2));
8914
8952
assert(all_of(Mask,
8915
8953
[=](int Idx) {
8916
8954
return Idx < 2 * static_cast<int>(CommonVF);
@@ -8928,6 +8966,11 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
8928
8966
V2 = getAllOnesValue(*R.DL, getWidenedType(ScalarTy, CommonVF));
8929
8967
}
8930
8968
}
8969
+ if (auto *VecTy = dyn_cast<FixedVectorType>(ScalarTy)) {
8970
+ assert(SLPReVec && "FixedVectorType is not expected.");
8971
+ transformScalarShuffleIndiciesToVector(VecTy->getNumElements(),
8972
+ CommonMask);
8973
+ }
8931
8974
InVectors.front() =
8932
8975
Constant::getNullValue(getWidenedType(ScalarTy, CommonMask.size()));
8933
8976
if (InVectors.size() == 2)
@@ -8940,7 +8983,7 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
8940
8983
ShuffleCostEstimator(Type *ScalarTy, TargetTransformInfo &TTI,
8941
8984
ArrayRef<Value *> VectorizedVals, BoUpSLP &R,
8942
8985
SmallPtrSetImpl<Value *> &CheckedExtracts)
8943
- : ScalarTy (ScalarTy), TTI(TTI),
8986
+ : BaseShuffleAnalysis (ScalarTy), TTI(TTI),
8944
8987
VectorizedVals(VectorizedVals.begin(), VectorizedVals.end()), R(R),
8945
8988
CheckedExtracts(CheckedExtracts) {}
8946
8989
Value *adjustExtracts(const TreeEntry *E, MutableArrayRef<int> Mask,
@@ -9145,7 +9188,7 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
9145
9188
}
9146
9189
assert(!InVectors.empty() && !CommonMask.empty() &&
9147
9190
"Expected only tree entries from extracts/reused buildvectors.");
9148
- unsigned VF = cast<FixedVectorType> (V1->getType())->getNumElements( );
9191
+ unsigned VF = getVF (V1);
9149
9192
if (InVectors.size() == 2) {
9150
9193
Cost += createShuffle(InVectors.front(), InVectors.back(), CommonMask);
9151
9194
transformMaskAfterShuffle(CommonMask, CommonMask);
@@ -9179,12 +9222,32 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
9179
9222
}
9180
9223
Vals.push_back(Constant::getNullValue(V->getType()));
9181
9224
}
9225
+ if (auto *VecTy = dyn_cast<FixedVectorType>(Vals.front()->getType())) {
9226
+ assert(SLPReVec && "FixedVectorType is not expected.");
9227
+ // When REVEC is enabled, we need to expand vector types into scalar
9228
+ // types.
9229
+ unsigned VecTyNumElements = VecTy->getNumElements();
9230
+ SmallVector<Constant *> NewVals(VF * VecTyNumElements, nullptr);
9231
+ for (auto [I, V] : enumerate(Vals)) {
9232
+ Type *ScalarTy = V->getType()->getScalarType();
9233
+ Constant *NewVal;
9234
+ if (isa<PoisonValue>(V))
9235
+ NewVal = PoisonValue::get(ScalarTy);
9236
+ else if (isa<UndefValue>(V))
9237
+ NewVal = UndefValue::get(ScalarTy);
9238
+ else
9239
+ NewVal = Constant::getNullValue(ScalarTy);
9240
+ std::fill_n(NewVals.begin() + I * VecTyNumElements, VecTyNumElements,
9241
+ NewVal);
9242
+ }
9243
+ Vals.swap(NewVals);
9244
+ }
9182
9245
return ConstantVector::get(Vals);
9183
9246
}
9184
9247
return ConstantVector::getSplat(
9185
9248
ElementCount::getFixed(
9186
9249
cast<FixedVectorType>(Root->getType())->getNumElements()),
9187
- getAllOnesValue(*R.DL, ScalarTy));
9250
+ getAllOnesValue(*R.DL, ScalarTy->getScalarType() ));
9188
9251
}
9189
9252
InstructionCost createFreeze(InstructionCost Cost) { return Cost; }
9190
9253
/// Finalize emission of the shuffles.
@@ -11685,8 +11748,8 @@ Value *BoUpSLP::gather(ArrayRef<Value *> VL, Value *Root, Type *ScalarTy) {
11685
11748
Type *Ty) {
11686
11749
Value *Scalar = V;
11687
11750
if (Scalar->getType() != Ty) {
11688
- assert(Scalar->getType()->isIntegerTy() && Ty->isIntegerTy () &&
11689
- "Expected integer types only.");
11751
+ assert(Scalar->getType()->isIntOrIntVectorTy () &&
11752
+ Ty->isIntOrIntVectorTy() && "Expected integer types only.");
11690
11753
Value *V = Scalar;
11691
11754
if (auto *CI = dyn_cast<CastInst>(Scalar);
11692
11755
isa_and_nonnull<SExtInst, ZExtInst>(CI)) {
@@ -11699,10 +11762,21 @@ Value *BoUpSLP::gather(ArrayRef<Value *> VL, Value *Root, Type *ScalarTy) {
11699
11762
V, Ty, !isKnownNonNegative(Scalar, SimplifyQuery(*DL)));
11700
11763
}
11701
11764
11702
- Vec = Builder.CreateInsertElement(Vec, Scalar, Builder.getInt32(Pos));
11703
- auto *InsElt = dyn_cast<InsertElementInst>(Vec);
11704
- if (!InsElt)
11705
- return Vec;
11765
+ Instruction *InsElt;
11766
+ if (auto *VecTy = dyn_cast<FixedVectorType>(Scalar->getType())) {
11767
+ assert(SLPReVec && "FixedVectorType is not expected.");
11768
+ Vec = InsElt = Builder.CreateInsertVector(
11769
+ Vec->getType(), Vec, V,
11770
+ Builder.getInt64(Pos * VecTy->getNumElements()));
11771
+ auto *II = dyn_cast<IntrinsicInst>(InsElt);
11772
+ if (!II || II->getIntrinsicID() != Intrinsic::vector_insert)
11773
+ return Vec;
11774
+ } else {
11775
+ Vec = Builder.CreateInsertElement(Vec, Scalar, Builder.getInt32(Pos));
11776
+ InsElt = dyn_cast<InsertElementInst>(Vec);
11777
+ if (!InsElt)
11778
+ return Vec;
11779
+ }
11706
11780
GatherShuffleExtractSeq.insert(InsElt);
11707
11781
CSEBlocks.insert(InsElt->getParent());
11708
11782
// Add to our 'need-to-extract' list.
@@ -11803,7 +11877,6 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
11803
11877
/// resulting shuffle and the second operand sets to be the newly added
11804
11878
/// operand. The \p CommonMask is transformed in the proper way after that.
11805
11879
SmallVector<Value *, 2> InVectors;
11806
- Type *ScalarTy = nullptr;
11807
11880
IRBuilderBase &Builder;
11808
11881
BoUpSLP &R;
11809
11882
@@ -11929,7 +12002,7 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
11929
12002
11930
12003
public:
11931
12004
ShuffleInstructionBuilder(Type *ScalarTy, IRBuilderBase &Builder, BoUpSLP &R)
11932
- : ScalarTy (ScalarTy), Builder(Builder), R(R) {}
12005
+ : BaseShuffleAnalysis (ScalarTy), Builder(Builder), R(R) {}
11933
12006
11934
12007
/// Adjusts extractelements after reusing them.
11935
12008
Value *adjustExtracts(const TreeEntry *E, MutableArrayRef<int> Mask,
@@ -12186,7 +12259,7 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
12186
12259
break;
12187
12260
}
12188
12261
}
12189
- int VF = cast<FixedVectorType> (V1->getType())->getNumElements( );
12262
+ int VF = getVF (V1);
12190
12263
for (unsigned Idx = 0, Sz = CommonMask.size(); Idx < Sz; ++Idx)
12191
12264
if (Mask[Idx] != PoisonMaskElem && CommonMask[Idx] == PoisonMaskElem)
12192
12265
CommonMask[Idx] = Mask[Idx] + (It == InVectors.begin() ? 0 : VF);
@@ -12209,6 +12282,15 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
12209
12282
finalize(ArrayRef<int> ExtMask, unsigned VF = 0,
12210
12283
function_ref<void(Value *&, SmallVectorImpl<int> &)> Action = {}) {
12211
12284
IsFinalized = true;
12285
+ SmallVector<int> NewExtMask(ExtMask);
12286
+ if (auto *VecTy = dyn_cast<FixedVectorType>(ScalarTy)) {
12287
+ assert(SLPReVec && "FixedVectorType is not expected.");
12288
+ transformScalarShuffleIndiciesToVector(VecTy->getNumElements(),
12289
+ CommonMask);
12290
+ transformScalarShuffleIndiciesToVector(VecTy->getNumElements(),
12291
+ NewExtMask);
12292
+ ExtMask = NewExtMask;
12293
+ }
12212
12294
if (Action) {
12213
12295
Value *Vec = InVectors.front();
12214
12296
if (InVectors.size() == 2) {
@@ -13992,6 +14074,17 @@ Value *BoUpSLP::vectorizeTree(
13992
14074
if (GEP->hasName())
13993
14075
CloneGEP->takeName(GEP);
13994
14076
Ex = CloneGEP;
14077
+ } else if (auto *VecTy =
14078
+ dyn_cast<FixedVectorType>(Scalar->getType())) {
14079
+ assert(SLPReVec && "FixedVectorType is not expected.");
14080
+ unsigned VecTyNumElements = VecTy->getNumElements();
14081
+ // When REVEC is enabled, we need to extract a vector.
14082
+ // Note: The element size of Scalar may be different from the
14083
+ // element size of Vec.
14084
+ Ex = Builder.CreateExtractVector(
14085
+ FixedVectorType::get(Vec->getType()->getScalarType(),
14086
+ VecTyNumElements),
14087
+ Vec, Builder.getInt64(ExternalUse.Lane * VecTyNumElements));
13995
14088
} else {
13996
14089
Ex = Builder.CreateExtractElement(Vec, Lane);
13997
14090
}
0 commit comments