Skip to content

Commit 97743b8

Browse files
authored
[SLP][REVEC] Make ShuffleCostEstimator and ShuffleInstructionBuilder support vector instructions. (#99499)
1. When REVEC is enabled, we need to expand vector types into scalar types. 2. When REVEC is enabled, CreateInsertVector (and CreateExtractVector) is used because the scalar type may be a FixedVectorType. 3. Since the mask indices which are used by processBuildVector expect the source is scalar type, we need to transform the mask indices into a form which can be used when REVEC is enabled. The transform is only called when the mask is really used.
1 parent 13c61dd commit 97743b8

File tree

2 files changed

+148
-19
lines changed

2 files changed

+148
-19
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 112 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,21 @@ static FixedVectorType *getWidenedType(Type *ScalarTy, unsigned VF) {
253253
VF * getNumElements(ScalarTy));
254254
}
255255

256+
static void transformScalarShuffleIndiciesToVector(unsigned VecTyNumElements,
257+
SmallVectorImpl<int> &Mask) {
258+
// The ShuffleBuilder implementation use shufflevector to splat an "element".
259+
// But the element have different meaning for SLP (scalar) and REVEC
260+
// (vector). We need to expand Mask into masks which shufflevector can use
261+
// directly.
262+
SmallVector<int> NewMask(Mask.size() * VecTyNumElements);
263+
for (unsigned I : seq<unsigned>(Mask.size()))
264+
for (auto [J, MaskV] : enumerate(MutableArrayRef(NewMask).slice(
265+
I * VecTyNumElements, VecTyNumElements)))
266+
MaskV = Mask[I] == PoisonMaskElem ? PoisonMaskElem
267+
: Mask[I] * VecTyNumElements + J;
268+
Mask.swap(NewMask);
269+
}
270+
256271
/// \returns True if the value is a constant (but not globals/constant
257272
/// expressions).
258273
static bool isConstant(Value *V) {
@@ -7772,6 +7787,31 @@ namespace {
77727787
/// The base class for shuffle instruction emission and shuffle cost estimation.
77737788
class BaseShuffleAnalysis {
77747789
protected:
7790+
Type *ScalarTy = nullptr;
7791+
7792+
BaseShuffleAnalysis(Type *ScalarTy) : ScalarTy(ScalarTy) {}
7793+
7794+
/// V is expected to be a vectorized value.
7795+
/// When REVEC is disabled, there is no difference between VF and
7796+
/// VNumElements.
7797+
/// When REVEC is enabled, VF is VNumElements / ScalarTyNumElements.
7798+
/// e.g., if ScalarTy is <4 x Ty> and V1 is <8 x Ty>, 2 is returned instead
7799+
/// of 8.
7800+
unsigned getVF(Value *V) const {
7801+
assert(V && "V cannot be nullptr");
7802+
assert(isa<FixedVectorType>(V->getType()) &&
7803+
"V does not have FixedVectorType");
7804+
assert(ScalarTy && "ScalarTy cannot be nullptr");
7805+
unsigned ScalarTyNumElements = getNumElements(ScalarTy);
7806+
unsigned VNumElements =
7807+
cast<FixedVectorType>(V->getType())->getNumElements();
7808+
assert(VNumElements > ScalarTyNumElements &&
7809+
"the number of elements of V is not large enough");
7810+
assert(VNumElements % ScalarTyNumElements == 0 &&
7811+
"the number of elements of V is not a vectorized value");
7812+
return VNumElements / ScalarTyNumElements;
7813+
}
7814+
77757815
/// Checks if the mask is an identity mask.
77767816
/// \param IsStrict if is true the function returns false if mask size does
77777817
/// not match vector size.
@@ -8265,7 +8305,6 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
82658305
bool IsFinalized = false;
82668306
SmallVector<int> CommonMask;
82678307
SmallVector<PointerUnion<Value *, const TreeEntry *>, 2> InVectors;
8268-
Type *ScalarTy = nullptr;
82698308
const TargetTransformInfo &TTI;
82708309
InstructionCost Cost = 0;
82718310
SmallDenseSet<Value *> VectorizedVals;
@@ -8847,14 +8886,14 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
88478886
} else if (V1 && P2.isNull()) {
88488887
// Shuffle single vector.
88498888
ExtraCost += GetValueMinBWAffectedCost(V1);
8850-
CommonVF = cast<FixedVectorType>(V1->getType())->getNumElements();
8889+
CommonVF = getVF(V1);
88518890
assert(
88528891
all_of(Mask,
88538892
[=](int Idx) { return Idx < static_cast<int>(CommonVF); }) &&
88548893
"All elements in mask must be less than CommonVF.");
88558894
} else if (V1 && !V2) {
88568895
// Shuffle vector and tree node.
8857-
unsigned VF = cast<FixedVectorType>(V1->getType())->getNumElements();
8896+
unsigned VF = getVF(V1);
88588897
const TreeEntry *E2 = P2.get<const TreeEntry *>();
88598898
CommonVF = std::max(VF, E2->getVectorFactor());
88608899
assert(all_of(Mask,
@@ -8880,7 +8919,7 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
88808919
V2 = getAllOnesValue(*R.DL, getWidenedType(ScalarTy, CommonVF));
88818920
} else if (!V1 && V2) {
88828921
// Shuffle vector and tree node.
8883-
unsigned VF = cast<FixedVectorType>(V2->getType())->getNumElements();
8922+
unsigned VF = getVF(V2);
88848923
const TreeEntry *E1 = P1.get<const TreeEntry *>();
88858924
CommonVF = std::max(VF, E1->getVectorFactor());
88868925
assert(all_of(Mask,
@@ -8908,9 +8947,8 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
89088947
V2 = getAllOnesValue(*R.DL, getWidenedType(ScalarTy, CommonVF));
89098948
} else {
89108949
assert(V1 && V2 && "Expected both vectors.");
8911-
unsigned VF = cast<FixedVectorType>(V1->getType())->getNumElements();
8912-
CommonVF =
8913-
std::max(VF, cast<FixedVectorType>(V2->getType())->getNumElements());
8950+
unsigned VF = getVF(V1);
8951+
CommonVF = std::max(VF, getVF(V2));
89148952
assert(all_of(Mask,
89158953
[=](int Idx) {
89168954
return Idx < 2 * static_cast<int>(CommonVF);
@@ -8928,6 +8966,11 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
89288966
V2 = getAllOnesValue(*R.DL, getWidenedType(ScalarTy, CommonVF));
89298967
}
89308968
}
8969+
if (auto *VecTy = dyn_cast<FixedVectorType>(ScalarTy)) {
8970+
assert(SLPReVec && "FixedVectorType is not expected.");
8971+
transformScalarShuffleIndiciesToVector(VecTy->getNumElements(),
8972+
CommonMask);
8973+
}
89318974
InVectors.front() =
89328975
Constant::getNullValue(getWidenedType(ScalarTy, CommonMask.size()));
89338976
if (InVectors.size() == 2)
@@ -8940,7 +8983,7 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
89408983
ShuffleCostEstimator(Type *ScalarTy, TargetTransformInfo &TTI,
89418984
ArrayRef<Value *> VectorizedVals, BoUpSLP &R,
89428985
SmallPtrSetImpl<Value *> &CheckedExtracts)
8943-
: ScalarTy(ScalarTy), TTI(TTI),
8986+
: BaseShuffleAnalysis(ScalarTy), TTI(TTI),
89448987
VectorizedVals(VectorizedVals.begin(), VectorizedVals.end()), R(R),
89458988
CheckedExtracts(CheckedExtracts) {}
89468989
Value *adjustExtracts(const TreeEntry *E, MutableArrayRef<int> Mask,
@@ -9145,7 +9188,7 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
91459188
}
91469189
assert(!InVectors.empty() && !CommonMask.empty() &&
91479190
"Expected only tree entries from extracts/reused buildvectors.");
9148-
unsigned VF = cast<FixedVectorType>(V1->getType())->getNumElements();
9191+
unsigned VF = getVF(V1);
91499192
if (InVectors.size() == 2) {
91509193
Cost += createShuffle(InVectors.front(), InVectors.back(), CommonMask);
91519194
transformMaskAfterShuffle(CommonMask, CommonMask);
@@ -9179,12 +9222,32 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
91799222
}
91809223
Vals.push_back(Constant::getNullValue(V->getType()));
91819224
}
9225+
if (auto *VecTy = dyn_cast<FixedVectorType>(Vals.front()->getType())) {
9226+
assert(SLPReVec && "FixedVectorType is not expected.");
9227+
// When REVEC is enabled, we need to expand vector types into scalar
9228+
// types.
9229+
unsigned VecTyNumElements = VecTy->getNumElements();
9230+
SmallVector<Constant *> NewVals(VF * VecTyNumElements, nullptr);
9231+
for (auto [I, V] : enumerate(Vals)) {
9232+
Type *ScalarTy = V->getType()->getScalarType();
9233+
Constant *NewVal;
9234+
if (isa<PoisonValue>(V))
9235+
NewVal = PoisonValue::get(ScalarTy);
9236+
else if (isa<UndefValue>(V))
9237+
NewVal = UndefValue::get(ScalarTy);
9238+
else
9239+
NewVal = Constant::getNullValue(ScalarTy);
9240+
std::fill_n(NewVals.begin() + I * VecTyNumElements, VecTyNumElements,
9241+
NewVal);
9242+
}
9243+
Vals.swap(NewVals);
9244+
}
91829245
return ConstantVector::get(Vals);
91839246
}
91849247
return ConstantVector::getSplat(
91859248
ElementCount::getFixed(
91869249
cast<FixedVectorType>(Root->getType())->getNumElements()),
9187-
getAllOnesValue(*R.DL, ScalarTy));
9250+
getAllOnesValue(*R.DL, ScalarTy->getScalarType()));
91889251
}
91899252
InstructionCost createFreeze(InstructionCost Cost) { return Cost; }
91909253
/// Finalize emission of the shuffles.
@@ -11685,8 +11748,8 @@ Value *BoUpSLP::gather(ArrayRef<Value *> VL, Value *Root, Type *ScalarTy) {
1168511748
Type *Ty) {
1168611749
Value *Scalar = V;
1168711750
if (Scalar->getType() != Ty) {
11688-
assert(Scalar->getType()->isIntegerTy() && Ty->isIntegerTy() &&
11689-
"Expected integer types only.");
11751+
assert(Scalar->getType()->isIntOrIntVectorTy() &&
11752+
Ty->isIntOrIntVectorTy() && "Expected integer types only.");
1169011753
Value *V = Scalar;
1169111754
if (auto *CI = dyn_cast<CastInst>(Scalar);
1169211755
isa_and_nonnull<SExtInst, ZExtInst>(CI)) {
@@ -11699,10 +11762,21 @@ Value *BoUpSLP::gather(ArrayRef<Value *> VL, Value *Root, Type *ScalarTy) {
1169911762
V, Ty, !isKnownNonNegative(Scalar, SimplifyQuery(*DL)));
1170011763
}
1170111764

11702-
Vec = Builder.CreateInsertElement(Vec, Scalar, Builder.getInt32(Pos));
11703-
auto *InsElt = dyn_cast<InsertElementInst>(Vec);
11704-
if (!InsElt)
11705-
return Vec;
11765+
Instruction *InsElt;
11766+
if (auto *VecTy = dyn_cast<FixedVectorType>(Scalar->getType())) {
11767+
assert(SLPReVec && "FixedVectorType is not expected.");
11768+
Vec = InsElt = Builder.CreateInsertVector(
11769+
Vec->getType(), Vec, V,
11770+
Builder.getInt64(Pos * VecTy->getNumElements()));
11771+
auto *II = dyn_cast<IntrinsicInst>(InsElt);
11772+
if (!II || II->getIntrinsicID() != Intrinsic::vector_insert)
11773+
return Vec;
11774+
} else {
11775+
Vec = Builder.CreateInsertElement(Vec, Scalar, Builder.getInt32(Pos));
11776+
InsElt = dyn_cast<InsertElementInst>(Vec);
11777+
if (!InsElt)
11778+
return Vec;
11779+
}
1170611780
GatherShuffleExtractSeq.insert(InsElt);
1170711781
CSEBlocks.insert(InsElt->getParent());
1170811782
// Add to our 'need-to-extract' list.
@@ -11803,7 +11877,6 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
1180311877
/// resulting shuffle and the second operand sets to be the newly added
1180411878
/// operand. The \p CommonMask is transformed in the proper way after that.
1180511879
SmallVector<Value *, 2> InVectors;
11806-
Type *ScalarTy = nullptr;
1180711880
IRBuilderBase &Builder;
1180811881
BoUpSLP &R;
1180911882

@@ -11929,7 +12002,7 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
1192912002

1193012003
public:
1193112004
ShuffleInstructionBuilder(Type *ScalarTy, IRBuilderBase &Builder, BoUpSLP &R)
11932-
: ScalarTy(ScalarTy), Builder(Builder), R(R) {}
12005+
: BaseShuffleAnalysis(ScalarTy), Builder(Builder), R(R) {}
1193312006

1193412007
/// Adjusts extractelements after reusing them.
1193512008
Value *adjustExtracts(const TreeEntry *E, MutableArrayRef<int> Mask,
@@ -12186,7 +12259,7 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
1218612259
break;
1218712260
}
1218812261
}
12189-
int VF = cast<FixedVectorType>(V1->getType())->getNumElements();
12262+
int VF = getVF(V1);
1219012263
for (unsigned Idx = 0, Sz = CommonMask.size(); Idx < Sz; ++Idx)
1219112264
if (Mask[Idx] != PoisonMaskElem && CommonMask[Idx] == PoisonMaskElem)
1219212265
CommonMask[Idx] = Mask[Idx] + (It == InVectors.begin() ? 0 : VF);
@@ -12209,6 +12282,15 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
1220912282
finalize(ArrayRef<int> ExtMask, unsigned VF = 0,
1221012283
function_ref<void(Value *&, SmallVectorImpl<int> &)> Action = {}) {
1221112284
IsFinalized = true;
12285+
SmallVector<int> NewExtMask(ExtMask);
12286+
if (auto *VecTy = dyn_cast<FixedVectorType>(ScalarTy)) {
12287+
assert(SLPReVec && "FixedVectorType is not expected.");
12288+
transformScalarShuffleIndiciesToVector(VecTy->getNumElements(),
12289+
CommonMask);
12290+
transformScalarShuffleIndiciesToVector(VecTy->getNumElements(),
12291+
NewExtMask);
12292+
ExtMask = NewExtMask;
12293+
}
1221212294
if (Action) {
1221312295
Value *Vec = InVectors.front();
1221412296
if (InVectors.size() == 2) {
@@ -13992,6 +14074,17 @@ Value *BoUpSLP::vectorizeTree(
1399214074
if (GEP->hasName())
1399314075
CloneGEP->takeName(GEP);
1399414076
Ex = CloneGEP;
14077+
} else if (auto *VecTy =
14078+
dyn_cast<FixedVectorType>(Scalar->getType())) {
14079+
assert(SLPReVec && "FixedVectorType is not expected.");
14080+
unsigned VecTyNumElements = VecTy->getNumElements();
14081+
// When REVEC is enabled, we need to extract a vector.
14082+
// Note: The element size of Scalar may be different from the
14083+
// element size of Vec.
14084+
Ex = Builder.CreateExtractVector(
14085+
FixedVectorType::get(Vec->getType()->getScalarType(),
14086+
VecTyNumElements),
14087+
Vec, Builder.getInt64(ExternalUse.Lane * VecTyNumElements));
1399514088
} else {
1399614089
Ex = Builder.CreateExtractElement(Vec, Lane);
1399714090
}

llvm/test/Transforms/SLPVectorizer/revec.ll

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,39 @@ entry:
8888
store <4 x i32> %9, ptr %10, align 4
8989
ret void
9090
}
91+
92+
define void @test4(ptr %in, ptr %out) {
93+
; CHECK-LABEL: @test4(
94+
; CHECK-NEXT: entry:
95+
; CHECK-NEXT: [[TMP0:%.*]] = load <8 x float>, ptr [[IN:%.*]], align 4
96+
; CHECK-NEXT: [[TMP1:%.*]] = call <16 x float> @llvm.vector.insert.v16f32.v8f32(<16 x float> poison, <8 x float> poison, i64 8)
97+
; CHECK-NEXT: [[TMP2:%.*]] = call <16 x float> @llvm.vector.insert.v16f32.v8f32(<16 x float> [[TMP1]], <8 x float> [[TMP0]], i64 0)
98+
; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <16 x float> [[TMP2]], <16 x float> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
99+
; CHECK-NEXT: [[TMP4:%.*]] = call <16 x float> @llvm.vector.insert.v16f32.v8f32(<16 x float> poison, <8 x float> zeroinitializer, i64 0)
100+
; CHECK-NEXT: [[TMP5:%.*]] = call <16 x float> @llvm.vector.insert.v16f32.v8f32(<16 x float> [[TMP4]], <8 x float> zeroinitializer, i64 8)
101+
; CHECK-NEXT: [[TMP6:%.*]] = fmul <16 x float> [[TMP3]], [[TMP5]]
102+
; CHECK-NEXT: [[TMP7:%.*]] = call <16 x float> @llvm.vector.insert.v16f32.v8f32(<16 x float> poison, <8 x float> poison, i64 0)
103+
; CHECK-NEXT: [[TMP8:%.*]] = call <16 x float> @llvm.vector.insert.v16f32.v8f32(<16 x float> [[TMP7]], <8 x float> zeroinitializer, i64 8)
104+
; CHECK-NEXT: [[TMP9:%.*]] = shufflevector <16 x float> [[TMP2]], <16 x float> [[TMP8]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 24, i32 25, i32 26, i32 27, i32 28, i32 29, i32 30, i32 31>
105+
; CHECK-NEXT: [[TMP10:%.*]] = fadd <16 x float> [[TMP9]], [[TMP6]]
106+
; CHECK-NEXT: [[TMP11:%.*]] = fcmp ogt <16 x float> [[TMP10]], [[TMP5]]
107+
; CHECK-NEXT: [[TMP12:%.*]] = getelementptr i1, ptr [[OUT:%.*]], i64 8
108+
; CHECK-NEXT: [[TMP13:%.*]] = call <8 x i1> @llvm.vector.extract.v8i1.v16i1(<16 x i1> [[TMP11]], i64 8)
109+
; CHECK-NEXT: store <8 x i1> [[TMP13]], ptr [[OUT]], align 1
110+
; CHECK-NEXT: [[TMP14:%.*]] = call <8 x i1> @llvm.vector.extract.v8i1.v16i1(<16 x i1> [[TMP11]], i64 0)
111+
; CHECK-NEXT: store <8 x i1> [[TMP14]], ptr [[TMP12]], align 1
112+
; CHECK-NEXT: ret void
113+
;
114+
entry:
115+
%0 = load <8 x float>, ptr %in, align 4
116+
%1 = fmul <8 x float> %0, zeroinitializer
117+
%2 = fmul <8 x float> %0, zeroinitializer
118+
%3 = fadd <8 x float> zeroinitializer, %1
119+
%4 = fadd <8 x float> %0, %2
120+
%5 = fcmp ogt <8 x float> %3, zeroinitializer
121+
%6 = fcmp ogt <8 x float> %4, zeroinitializer
122+
%7 = getelementptr i1, ptr %out, i64 8
123+
store <8 x i1> %5, ptr %out, align 1
124+
store <8 x i1> %6, ptr %7, align 1
125+
ret void
126+
}

0 commit comments

Comments
 (0)