-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[SLP][REVEC] Make ShuffleCostEstimator and ShuffleInstructionBuilder support vector instructions. #99499
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SLP][REVEC] Make ShuffleCostEstimator and ShuffleInstructionBuilder support vector instructions. #99499
Conversation
@llvm/pr-subscribers-llvm-transforms Author: Han-Kuan Chen (HanKuanChen) ChangesThis PR will try to make ShuffleCostEstimator and ShuffleInstructionBuilder can vectorize vector instructions. In addition, when REVEC is enabled, CreateInsertVector and CreateExtractVector are used because the scalar type may be a FixedVectorType. Full diff: https://github.com/llvm/llvm-project/pull/99499.diff 2 Files Affected:
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index d8c3bae06e932..aefec86d332fe 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -253,6 +253,21 @@ static FixedVectorType *getWidenedType(Type *ScalarTy, unsigned VF) {
VF * getNumElements(ScalarTy));
}
+static void transformScalarShuffleIndiciesToVector(unsigned VecTyNumElements,
+ SmallVectorImpl<int> &Mask) {
+ // The ShuffleBuilder implementation use shufflevector to splat an "element".
+ // But the element have different meaning for SLP (scalar) and REVEC
+ // (vector). We need to expand Mask into masks which shufflevector can use
+ // directly.
+ SmallVector<int> NewMask(Mask.size() * VecTyNumElements);
+ for (size_t I = 0, E = Mask.size(); I != E; ++I)
+ for (unsigned J = 0; J != VecTyNumElements; ++J)
+ NewMask[I * VecTyNumElements + J] = Mask[I] == PoisonMaskElem
+ ? PoisonMaskElem
+ : Mask[I] * VecTyNumElements + J;
+ Mask.swap(NewMask);
+}
+
/// \returns True if the value is a constant (but not globals/constant
/// expressions).
static bool isConstant(Value *V) {
@@ -8800,6 +8815,8 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
// Shuffle single vector.
ExtraCost += GetValueMinBWAffectedCost(V1);
CommonVF = cast<FixedVectorType>(V1->getType())->getNumElements();
+ if (auto *VecTy = dyn_cast<FixedVectorType>(ScalarTy))
+ CommonVF /= VecTy->getNumElements();
assert(
all_of(Mask,
[=](int Idx) { return Idx < static_cast<int>(CommonVF); }) &&
@@ -8807,6 +8824,8 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
} else if (V1 && !V2) {
// Shuffle vector and tree node.
unsigned VF = cast<FixedVectorType>(V1->getType())->getNumElements();
+ if (auto *VecTy = dyn_cast<FixedVectorType>(ScalarTy))
+ VF /= VecTy->getNumElements();
const TreeEntry *E2 = P2.get<const TreeEntry *>();
CommonVF = std::max(VF, E2->getVectorFactor());
assert(all_of(Mask,
@@ -8833,6 +8852,8 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
} else if (!V1 && V2) {
// Shuffle vector and tree node.
unsigned VF = cast<FixedVectorType>(V2->getType())->getNumElements();
+ if (auto *VecTy = dyn_cast<FixedVectorType>(ScalarTy))
+ VF /= VecTy->getNumElements();
const TreeEntry *E1 = P1.get<const TreeEntry *>();
CommonVF = std::max(VF, E1->getVectorFactor());
assert(all_of(Mask,
@@ -8863,6 +8884,8 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
unsigned VF = cast<FixedVectorType>(V1->getType())->getNumElements();
CommonVF =
std::max(VF, cast<FixedVectorType>(V2->getType())->getNumElements());
+ if (auto *VecTy = dyn_cast<FixedVectorType>(ScalarTy))
+ CommonVF /= VecTy->getNumElements();
assert(all_of(Mask,
[=](int Idx) {
return Idx < 2 * static_cast<int>(CommonVF);
@@ -8880,6 +8903,9 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
V2 = getAllOnesValue(*R.DL, getWidenedType(ScalarTy, CommonVF));
}
}
+ if (auto *VecTy = dyn_cast<FixedVectorType>(ScalarTy))
+ transformScalarShuffleIndiciesToVector(VecTy->getNumElements(),
+ CommonMask);
InVectors.front() =
Constant::getNullValue(getWidenedType(ScalarTy, CommonMask.size()));
if (InVectors.size() == 2)
@@ -9098,6 +9124,8 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
assert(!InVectors.empty() && !CommonMask.empty() &&
"Expected only tree entries from extracts/reused buildvectors.");
unsigned VF = cast<FixedVectorType>(V1->getType())->getNumElements();
+ if (auto *VecTy = dyn_cast<FixedVectorType>(ScalarTy))
+ VF /= VecTy->getNumElements();
if (InVectors.size() == 2) {
Cost += createShuffle(InVectors.front(), InVectors.back(), CommonMask);
transformMaskAfterShuffle(CommonMask, CommonMask);
@@ -9125,18 +9153,27 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
if (MaskVF != 0)
VF = std::min(VF, MaskVF);
for (Value *V : VL.take_front(VF)) {
- if (isa<UndefValue>(V)) {
- Vals.push_back(cast<Constant>(V));
- continue;
+ Type *Ty = V->getType();
+ Type *ScalarTy = Ty->getScalarType();
+ unsigned VNumElements = getNumElements(Ty);
+ for (unsigned I = 0; I != VNumElements; ++I) {
+ if (isa<PoisonValue>(V)) {
+ Vals.push_back(PoisonValue::get(ScalarTy));
+ continue;
+ }
+ if (isa<UndefValue>(V)) {
+ Vals.push_back(UndefValue::get(ScalarTy));
+ continue;
+ }
+ Vals.push_back(Constant::getNullValue(ScalarTy));
}
- Vals.push_back(Constant::getNullValue(V->getType()));
}
return ConstantVector::get(Vals);
}
return ConstantVector::getSplat(
ElementCount::getFixed(
cast<FixedVectorType>(Root->getType())->getNumElements()),
- getAllOnesValue(*R.DL, ScalarTy));
+ getAllOnesValue(*R.DL, ScalarTy->getScalarType()));
}
InstructionCost createFreeze(InstructionCost Cost) { return Cost; }
/// Finalize emission of the shuffles.
@@ -11618,7 +11655,8 @@ Value *BoUpSLP::gather(ArrayRef<Value *> VL, Value *Root, Type *ScalarTy) {
Type *Ty) {
Value *Scalar = V;
if (Scalar->getType() != Ty) {
- assert(Scalar->getType()->isIntegerTy() && Ty->isIntegerTy() &&
+ assert(Scalar->getType()->getScalarType()->isIntegerTy() &&
+ Ty->getScalarType()->isIntegerTy() &&
"Expected integer types only.");
Value *V = Scalar;
if (auto *CI = dyn_cast<CastInst>(Scalar);
@@ -11632,10 +11670,20 @@ Value *BoUpSLP::gather(ArrayRef<Value *> VL, Value *Root, Type *ScalarTy) {
V, Ty, !isKnownNonNegative(Scalar, SimplifyQuery(*DL)));
}
- Vec = Builder.CreateInsertElement(Vec, Scalar, Builder.getInt32(Pos));
- auto *InsElt = dyn_cast<InsertElementInst>(Vec);
- if (!InsElt)
- return Vec;
+ Instruction *InsElt;
+ if (auto *VecTy = dyn_cast<FixedVectorType>(Scalar->getType())) {
+ Vec = InsElt = Builder.CreateInsertVector(
+ Vec->getType(), Vec, V,
+ Builder.getInt64(Pos * VecTy->getNumElements()));
+ auto *II = dyn_cast<IntrinsicInst>(InsElt);
+ if (!(II && II->getIntrinsicID() == Intrinsic::vector_insert))
+ return Vec;
+ } else {
+ Vec = Builder.CreateInsertElement(Vec, Scalar, Builder.getInt32(Pos));
+ InsElt = dyn_cast<InsertElementInst>(Vec);
+ if (!InsElt)
+ return Vec;
+ }
GatherShuffleExtractSeq.insert(InsElt);
CSEBlocks.insert(InsElt->getParent());
// Add to our 'need-to-extract' list.
@@ -12123,6 +12171,8 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
int VF = CommonMask.size();
if (auto *FTy = dyn_cast<FixedVectorType>(V1->getType()))
VF = FTy->getNumElements();
+ if (auto *VecTy = dyn_cast<FixedVectorType>(ScalarTy))
+ VF /= VecTy->getNumElements();
for (unsigned Idx = 0, Sz = CommonMask.size(); Idx < Sz; ++Idx)
if (Mask[Idx] != PoisonMaskElem && CommonMask[Idx] == PoisonMaskElem)
CommonMask[Idx] = Mask[Idx] + (It == InVectors.begin() ? 0 : VF);
@@ -12145,6 +12195,14 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
finalize(ArrayRef<int> ExtMask, unsigned VF = 0,
function_ref<void(Value *&, SmallVectorImpl<int> &)> Action = {}) {
IsFinalized = true;
+ SmallVector<int> NewExtMask(ExtMask);
+ if (auto *VecTy = dyn_cast<FixedVectorType>(ScalarTy)) {
+ transformScalarShuffleIndiciesToVector(VecTy->getNumElements(),
+ CommonMask);
+ transformScalarShuffleIndiciesToVector(VecTy->getNumElements(),
+ NewExtMask);
+ ExtMask = NewExtMask;
+ }
if (Action) {
Value *Vec = InVectors.front();
if (InVectors.size() == 2) {
@@ -13906,7 +13964,17 @@ Value *BoUpSLP::vectorizeTree(
CloneGEP->takeName(GEP);
Ex = CloneGEP;
} else {
- Ex = Builder.CreateExtractElement(Vec, Lane);
+ if (auto *VecTy = dyn_cast<FixedVectorType>(Scalar->getType())) {
+ unsigned VecTyNumElements = VecTy->getNumElements();
+ // When REVEC is enabled, we need to extract a vector.
+ // Note: The element size of Scalar may be different from the
+ // element size of Vec.
+ Ex = Builder.CreateExtractVector(
+ FixedVectorType::get(Vec->getType()->getScalarType(),
+ VecTyNumElements),
+ Vec, Builder.getInt64(ExternalUse.Lane * VecTyNumElements));
+ } else
+ Ex = Builder.CreateExtractElement(Vec, Lane);
}
// If necessary, sign-extend or zero-extend ScalarRoot
// to the larger type.
diff --git a/llvm/test/Transforms/SLPVectorizer/revec.ll b/llvm/test/Transforms/SLPVectorizer/revec.ll
index c2dc6d0ab73b7..84426ce6e96bf 100644
--- a/llvm/test/Transforms/SLPVectorizer/revec.ll
+++ b/llvm/test/Transforms/SLPVectorizer/revec.ll
@@ -58,3 +58,39 @@ entry:
store <8 x i16> %4, ptr %5, align 2
ret void
}
+
+define void @test3(ptr %in, ptr %out) {
+; CHECK-LABEL: @test3(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[TMP0:%.*]] = load <8 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT: [[TMP1:%.*]] = call <16 x float> @llvm.vector.insert.v16f32.v8f32(<16 x float> poison, <8 x float> poison, i64 8)
+; CHECK-NEXT: [[TMP2:%.*]] = call <16 x float> @llvm.vector.insert.v16f32.v8f32(<16 x float> [[TMP1]], <8 x float> [[TMP0]], i64 0)
+; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <16 x float> [[TMP2]], <16 x float> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT: [[TMP4:%.*]] = call <16 x float> @llvm.vector.insert.v16f32.v8f32(<16 x float> poison, <8 x float> zeroinitializer, i64 0)
+; CHECK-NEXT: [[TMP5:%.*]] = call <16 x float> @llvm.vector.insert.v16f32.v8f32(<16 x float> [[TMP4]], <8 x float> zeroinitializer, i64 8)
+; CHECK-NEXT: [[TMP6:%.*]] = fmul <16 x float> [[TMP3]], [[TMP5]]
+; CHECK-NEXT: [[TMP7:%.*]] = call <16 x float> @llvm.vector.insert.v16f32.v8f32(<16 x float> poison, <8 x float> poison, i64 0)
+; CHECK-NEXT: [[TMP8:%.*]] = call <16 x float> @llvm.vector.insert.v16f32.v8f32(<16 x float> [[TMP7]], <8 x float> zeroinitializer, i64 8)
+; CHECK-NEXT: [[TMP9:%.*]] = shufflevector <16 x float> [[TMP2]], <16 x float> [[TMP8]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 24, i32 25, i32 26, i32 27, i32 28, i32 29, i32 30, i32 31>
+; CHECK-NEXT: [[TMP10:%.*]] = fadd <16 x float> [[TMP9]], [[TMP6]]
+; CHECK-NEXT: [[TMP11:%.*]] = fcmp ogt <16 x float> [[TMP10]], [[TMP5]]
+; CHECK-NEXT: [[TMP12:%.*]] = getelementptr i1, ptr [[OUT:%.*]], i64 8
+; CHECK-NEXT: [[TMP13:%.*]] = call <8 x i1> @llvm.vector.extract.v8i1.v16i1(<16 x i1> [[TMP11]], i64 8)
+; CHECK-NEXT: store <8 x i1> [[TMP13]], ptr [[OUT]], align 1
+; CHECK-NEXT: [[TMP14:%.*]] = call <8 x i1> @llvm.vector.extract.v8i1.v16i1(<16 x i1> [[TMP11]], i64 0)
+; CHECK-NEXT: store <8 x i1> [[TMP14]], ptr [[TMP12]], align 1
+; CHECK-NEXT: ret void
+;
+entry:
+ %0 = load <8 x float>, ptr %in, align 4
+ %1 = fmul <8 x float> %0, zeroinitializer
+ %2 = fmul <8 x float> %0, zeroinitializer
+ %3 = fadd <8 x float> zeroinitializer, %1
+ %4 = fadd <8 x float> %0, %2
+ %5 = fcmp ogt <8 x float> %3, zeroinitializer
+ %6 = fcmp ogt <8 x float> %4, zeroinitializer
+ %7 = getelementptr i1, ptr %out, i64 8
+ store <8 x i1> %5, ptr %out, align 1
+ store <8 x i1> %6, ptr %7, align 1
+ ret void
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please, slpit this serie of patches into separate patches
Type *Ty = V->getType(); | ||
Type *ScalarTy = Ty->getScalarType(); | ||
unsigned VNumElements = getNumElements(Ty); | ||
for (unsigned I = 0; I != VNumElements; ++I) { | ||
if (isa<PoisonValue>(V)) { | ||
Vals.push_back(PoisonValue::get(ScalarTy)); | ||
continue; | ||
} | ||
if (isa<UndefValue>(V)) { | ||
Vals.push_back(UndefValue::get(ScalarTy)); | ||
continue; | ||
} | ||
Vals.push_back(Constant::getNullValue(ScalarTy)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better to keep the original code for scalar type
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix. Add the expand logic in another loop.
auto *II = dyn_cast<IntrinsicInst>(InsElt); | ||
if (!(II && II->getIntrinsicID() == Intrinsic::vector_insert)) | ||
return Vec; | ||
} else { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need for else after return
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The else is for FixedVectorType
. Or are you saying that it always returns Intrinsic::vector_insert
?
3e3942d
to
40306f1
Compare
instructions. When REVEC is enabled, we need to expand vector types into scalar types.
ShuffleCostEstimator.
ShuffleInstructionBuilder::add support vector instructions.
instructions. The VF is relative to the number of elements in ScalarTy instead of the size of mask.
40306f1
to
c8750b6
Compare
SmallVector<Constant *> NewVals; | ||
NewVals.reserve(VL.size() * VecTyNumElements); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SmallVector<Constant *> NewVals(VL.size() * VecTyNumElements, PoisonValue::get());
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- It should be VF because
for (Value *V : VL.take_front(VF))
- I use nullptr here.
c8750b6
to
ccd6f2b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG
Uh oh!
There was an error while loading. Please reload this page.