Skip to content

Commit b5a7d3b

Browse files
authored
[SLP][REVEC] Make Instruction::Select support vector instructions. (#100507)
1 parent 7ab6433 commit b5a7d3b

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9754,6 +9754,23 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
97549754
}
97559755
VecCost = std::min(VecCost, IntrinsicCost);
97569756
}
9757+
if (auto *SI = dyn_cast<SelectInst>(VL0)) {
9758+
auto *CondType =
9759+
getWidenedType(SI->getCondition()->getType(), VL.size());
9760+
unsigned CondNumElements = CondType->getNumElements();
9761+
unsigned VecTyNumElements = getNumElements(VecTy);
9762+
assert(VecTyNumElements >= CondNumElements &&
9763+
VecTyNumElements % CondNumElements == 0 &&
9764+
"Cannot vectorize Instruction::Select");
9765+
if (CondNumElements != VecTyNumElements) {
9766+
// When the return type is i1 but the source is fixed vector type, we
9767+
// need to duplicate the condition value.
9768+
VecCost += TTI->getShuffleCost(
9769+
TTI::SK_PermuteSingleSrc, CondType,
9770+
createReplicatedMask(VecTyNumElements / CondNumElements,
9771+
CondNumElements));
9772+
}
9773+
}
97579774
return VecCost + CommonCost;
97589775
};
97599776
return GetCostDiff(GetScalarCost, GetVectorCost);
@@ -13237,6 +13254,22 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
1323713254
False = Builder.CreateIntCast(False, VecTy, GetOperandSignedness(2));
1323813255
}
1323913256

13257+
unsigned CondNumElements = getNumElements(Cond->getType());
13258+
unsigned TrueNumElements = getNumElements(True->getType());
13259+
assert(TrueNumElements >= CondNumElements &&
13260+
TrueNumElements % CondNumElements == 0 &&
13261+
"Cannot vectorize Instruction::Select");
13262+
assert(TrueNumElements == getNumElements(False->getType()) &&
13263+
"Cannot vectorize Instruction::Select");
13264+
if (CondNumElements != TrueNumElements) {
13265+
// When the return type is i1 but the source is fixed vector type, we
13266+
// need to duplicate the condition value.
13267+
Cond = Builder.CreateShuffleVector(
13268+
Cond, createReplicatedMask(TrueNumElements / CondNumElements,
13269+
CondNumElements));
13270+
}
13271+
assert(getNumElements(Cond->getType()) == TrueNumElements &&
13272+
"Cannot vectorize Instruction::Select");
1324013273
Value *V = Builder.CreateSelect(Cond, True, False);
1324113274
V = FinalShuffle(V, E, VecTy);
1324213275

llvm/test/Transforms/SLPVectorizer/revec.ll

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,33 @@ entry:
5858
store <8 x i16> %4, ptr %5, align 2
5959
ret void
6060
}
61+
62+
define void @test3(ptr %x, ptr %y, ptr %z) {
63+
; CHECK-LABEL: @test3(
64+
; CHECK-NEXT: entry:
65+
; CHECK-NEXT: [[TMP0:%.*]] = insertelement <2 x ptr> poison, ptr [[X:%.*]], i32 0
66+
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x ptr> [[TMP0]], ptr [[Y:%.*]], i32 1
67+
; CHECK-NEXT: [[TMP2:%.*]] = icmp eq <2 x ptr> [[TMP1]], zeroinitializer
68+
; CHECK-NEXT: [[TMP3:%.*]] = load <8 x i32>, ptr [[X]], align 4
69+
; CHECK-NEXT: [[TMP4:%.*]] = load <8 x i32>, ptr [[Y]], align 4
70+
; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <2 x i1> [[TMP2]], <2 x i1> poison, <8 x i32> <i32 0, i32 0, i32 0, i32 0, i32 1, i32 1, i32 1, i32 1>
71+
; CHECK-NEXT: [[TMP6:%.*]] = select <8 x i1> [[TMP5]], <8 x i32> [[TMP3]], <8 x i32> [[TMP4]]
72+
; CHECK-NEXT: store <8 x i32> [[TMP6]], ptr [[Z:%.*]], align 4
73+
; CHECK-NEXT: ret void
74+
;
75+
entry:
76+
%0 = getelementptr inbounds i32, ptr %x, i64 4
77+
%1 = getelementptr inbounds i32, ptr %y, i64 4
78+
%2 = load <4 x i32>, ptr %x, align 4
79+
%3 = load <4 x i32>, ptr %0, align 4
80+
%4 = load <4 x i32>, ptr %y, align 4
81+
%5 = load <4 x i32>, ptr %1, align 4
82+
%6 = icmp eq ptr %x, null
83+
%7 = icmp eq ptr %y, null
84+
%8 = select i1 %6, <4 x i32> %2, <4 x i32> %4
85+
%9 = select i1 %7, <4 x i32> %3, <4 x i32> %5
86+
%10 = getelementptr inbounds i32, ptr %z, i64 4
87+
store <4 x i32> %8, ptr %z, align 4
88+
store <4 x i32> %9, ptr %10, align 4
89+
ret void
90+
}

0 commit comments

Comments
 (0)