Skip to content

Commit 4b703ed

Browse files
committed
fixup! Address some of the review comments
- Make interleaveLeafValues generic to any power-of-two factors - Use a simpler logics when visiting deinterleave trees - Only generate masked segmented load/store intrinsics and optimize cases with all-ones masks away later.
1 parent cbab9a5 commit 4b703ed

File tree

2 files changed

+53
-84
lines changed

2 files changed

+53
-84
lines changed

llvm/lib/CodeGen/InterleavedAccessPass.cpp

Lines changed: 35 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -262,31 +262,29 @@ static bool isReInterleaveMask(ShuffleVectorInst *SVI, unsigned &Factor,
262262
// getVectorInterleaveFactor / getVectorDeinterleaveFactor. But TLI
263263
// hooks (e.g. lowerInterleavedScalableLoad) expect ABCD, so we need
264264
// to reorder them by interleaving these values.
265-
static void interleaveLeafValues(SmallVectorImpl<Value *> &Leaves) {
266-
unsigned Factor = Leaves.size();
267-
assert(isPowerOf2_32(Factor) && Factor <= 8 && Factor > 1);
268-
269-
if (Factor == 2)
265+
static void interleaveLeafValues(MutableArrayRef<Value *> SubLeaves) {
266+
int NumLeaves = SubLeaves.size();
267+
if (NumLeaves == 2)
270268
return;
271269

272-
SmallVector<Value *, 8> Buffer;
273-
if (Factor == 4) {
274-
for (unsigned SrcIdx : {0, 2, 1, 3})
275-
Buffer.push_back(Leaves[SrcIdx]);
276-
} else {
277-
// Factor of 8.
278-
//
279-
// A E C G B F D H
280-
// |_| |_| |_| |_|
281-
// |___| |___|
282-
// |_______|
283-
// |
284-
// A B C D E F G H
285-
for (unsigned SrcIdx : {0, 4, 2, 6, 1, 5, 3, 7})
286-
Buffer.push_back(Leaves[SrcIdx]);
287-
}
270+
assert(isPowerOf2_32(NumLeaves) && NumLeaves > 1);
271+
272+
const int HalfLeaves = NumLeaves / 2;
273+
// Visit the sub-trees.
274+
interleaveLeafValues(SubLeaves.take_front(HalfLeaves));
275+
interleaveLeafValues(SubLeaves.drop_front(HalfLeaves));
288276

289-
llvm::copy(Buffer, Leaves.begin());
277+
SmallVector<Value *, 8> Buffer;
278+
// The step is alternating between +half and -half+1. We exit the
279+
// loop right before the last element because given the fact that
280+
// SubLeaves always has an even number of elements, the last element
281+
// will never be moved and the last to be visited. This simplifies
282+
// the exit condition.
283+
for (int i = 0; i < NumLeaves - 1;
284+
(i < HalfLeaves) ? i += HalfLeaves : i += (1 - HalfLeaves))
285+
Buffer.push_back(SubLeaves[i]);
286+
287+
llvm::copy(Buffer, SubLeaves.begin());
290288
}
291289

292290
static unsigned getVectorInterleaveFactor(IntrinsicInst *II,
@@ -353,7 +351,7 @@ static std::optional<Value *> getMask(Value *WideMask, unsigned Factor) {
353351
return std::nullopt;
354352
}
355353

356-
static unsigned getVectorDeInterleaveFactor(IntrinsicInst *II,
354+
static unsigned getVectorDeinterleaveFactor(IntrinsicInst *II,
357355
SmallVectorImpl<Value *> &Results) {
358356
using namespace PatternMatch;
359357
if (II->getIntrinsicID() != Intrinsic::vector_deinterleave2 ||
@@ -370,7 +368,7 @@ static unsigned getVectorDeInterleaveFactor(IntrinsicInst *II,
370368
Queue.erase(Queue.begin());
371369
assert(Current->hasNUses(2));
372370

373-
unsigned VisitedIdx = 0;
371+
ExtractValueInst *LHS = nullptr, *RHS = nullptr;
374372
for (User *Usr : Current->users()) {
375373
// We're playing safe here and matching only the expression
376374
// consisting of a perfectly balanced binary tree in which all
@@ -380,38 +378,26 @@ static unsigned getVectorDeInterleaveFactor(IntrinsicInst *II,
380378

381379
auto *EV = cast<ExtractValueInst>(Usr);
382380
ArrayRef<unsigned> Indices = EV->getIndices();
383-
if (Indices.size() != 1 || Indices[0] >= 2)
381+
if (Indices.size() != 1)
384382
return 0;
385383

386-
// The idea is that we don't want to have two extractvalue
387-
// on the same index. So we XOR (1 << index) onto VisitedIdx
388-
// such that if there is any duplication, VisitedIdx will be
389-
// zero.
390-
VisitedIdx ^= (1 << Indices[0]);
391-
if (!VisitedIdx)
384+
if (Indices[0] == 0 && !LHS)
385+
LHS = EV;
386+
else if (Indices[0] == 1 && !RHS)
387+
RHS = EV;
388+
else
392389
return 0;
393-
// We have a legal index. At this point we're either going
394-
// to continue the traversal or push the leaf values into Results.
395-
// But in either cases we need to follow the order imposed by
396-
// ExtractValue's indices and swap with the last element pushed
397-
// into Queue/Results if necessary (This is also one of the main
398-
// reasons using BFS instead of DFS here, btw).
399-
400-
// When VisitedIdx equals to 0b11, we're the last visted ExtractValue.
401-
// So if the current index is 0, we need to swap. Conversely, when
402-
// we're either the first visited ExtractValue or the last operand
403-
// in Queue/Results is of index 0, there is no need to swap.
404-
bool SwapWithLast = VisitedIdx == 0b11 && Indices[0] == 0;
390+
}
405391

392+
// We have legal indices. At this point we're either going
393+
// to continue the traversal or push the leaf values into Results.
394+
for (ExtractValueInst *EV : {LHS, RHS}) {
406395
// Continue the traversal.
407396
if (match(EV->user_back(),
408397
m_Intrinsic<Intrinsic::vector_deinterleave2>()) &&
409398
EV->user_back()->hasNUses(2)) {
410399
auto *EVUsr = cast<IntrinsicInst>(EV->user_back());
411-
if (SwapWithLast && !Queue.empty())
412-
Queue.insert(Queue.end() - 1, EVUsr);
413-
else
414-
Queue.push_back(EVUsr);
400+
Queue.push_back(EVUsr);
415401
continue;
416402
}
417403

@@ -421,10 +407,7 @@ static unsigned getVectorDeInterleaveFactor(IntrinsicInst *II,
421407
return 0;
422408

423409
// Save the leaf value.
424-
if (SwapWithLast && !Results.empty())
425-
Results.insert(Results.end() - 1, EV);
426-
else
427-
Results.push_back(EV);
410+
Results.push_back(EV);
428411

429412
++Factor;
430413
}
@@ -673,7 +656,7 @@ bool InterleavedAccessImpl::lowerDeinterleaveIntrinsic(
673656
IntrinsicInst *DI, SmallSetVector<Instruction *, 32> &DeadInsts) {
674657
if (auto *VPLoad = dyn_cast<VPIntrinsic>(DI->getOperand(0))) {
675658
SmallVector<Value *, 8> DeInterleaveResults;
676-
unsigned Factor = getVectorDeInterleaveFactor(DI, DeInterleaveResults);
659+
unsigned Factor = getVectorDeinterleaveFactor(DI, DeInterleaveResults);
677660
if (!Factor)
678661
return false;
679662

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 18 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -22400,11 +22400,6 @@ bool RISCVTargetLowering::lowerInterleavedScalableLoad(
2240022400
Intrinsic::riscv_vlseg6_mask, Intrinsic::riscv_vlseg7_mask,
2240122401
Intrinsic::riscv_vlseg8_mask,
2240222402
};
22403-
static const Intrinsic::ID IntrIds[] = {
22404-
Intrinsic::riscv_vlseg2, Intrinsic::riscv_vlseg3, Intrinsic::riscv_vlseg4,
22405-
Intrinsic::riscv_vlseg5, Intrinsic::riscv_vlseg6, Intrinsic::riscv_vlseg7,
22406-
Intrinsic::riscv_vlseg8,
22407-
};
2240822403

2240922404
unsigned SEW = DL.getTypeSizeInBits(VTy->getElementType());
2241022405
unsigned NumElts = VTy->getElementCount().getKnownMinValue();
@@ -22418,22 +22413,20 @@ bool RISCVTargetLowering::lowerInterleavedScalableLoad(
2241822413
SmallVector<Value *> Operands;
2241922414
Operands.append({PoisonVal, Load->getArgOperand(0)});
2242022415

22421-
Function *VlsegNFunc;
22422-
if (Mask) {
22423-
VlsegNFunc = Intrinsic::getOrInsertDeclaration(
22424-
Load->getModule(), IntrMaskIds[Factor - 2],
22425-
{VecTupTy, Mask->getType(), EVL->getType()});
22426-
Operands.push_back(Mask);
22427-
} else {
22428-
VlsegNFunc = Intrinsic::getOrInsertDeclaration(
22429-
Load->getModule(), IntrIds[Factor - 2], {VecTupTy, EVL->getType()});
22430-
}
22416+
if (!Mask)
22417+
Mask = ConstantVector::getSplat(VTy->getElementCount(),
22418+
ConstantInt::getTrue(Load->getContext()));
22419+
22420+
Function *VlsegNFunc = Intrinsic::getOrInsertDeclaration(
22421+
Load->getModule(), IntrMaskIds[Factor - 2],
22422+
{VecTupTy, Mask->getType(), EVL->getType()});
22423+
22424+
Operands.push_back(Mask);
2243122425

2243222426
Operands.push_back(EVL);
2243322427

2243422428
// Tail-policy
22435-
if (Mask)
22436-
Operands.push_back(ConstantInt::get(XLenTy, 1));
22429+
Operands.push_back(ConstantInt::get(XLenTy, RISCVII::TAIL_AGNOSTIC));
2243722430

2243822431
Operands.push_back(ConstantInt::get(XLenTy, Log2_64(SEW)));
2243922432

@@ -22565,11 +22558,6 @@ bool RISCVTargetLowering::lowerInterleavedScalableStore(
2256522558
Intrinsic::riscv_vsseg6_mask, Intrinsic::riscv_vsseg7_mask,
2256622559
Intrinsic::riscv_vsseg8_mask,
2256722560
};
22568-
static const Intrinsic::ID IntrIds[] = {
22569-
Intrinsic::riscv_vsseg2, Intrinsic::riscv_vsseg3, Intrinsic::riscv_vsseg4,
22570-
Intrinsic::riscv_vsseg5, Intrinsic::riscv_vsseg6, Intrinsic::riscv_vsseg7,
22571-
Intrinsic::riscv_vsseg8,
22572-
};
2257322561

2257422562
unsigned SEW = DL.getTypeSizeInBits(VTy->getElementType());
2257522563
unsigned NumElts = VTy->getElementCount().getKnownMinValue();
@@ -22590,17 +22578,15 @@ bool RISCVTargetLowering::lowerInterleavedScalableStore(
2259022578
Operands.push_back(StoredVal);
2259122579
Operands.push_back(Store->getArgOperand(1));
2259222580

22593-
Function *VssegNFunc;
22594-
if (Mask) {
22595-
VssegNFunc = Intrinsic::getOrInsertDeclaration(
22596-
Store->getModule(), IntrMaskIds[Factor - 2],
22597-
{VecTupTy, Mask->getType(), EVL->getType()});
22598-
Operands.push_back(Mask);
22599-
} else {
22600-
VssegNFunc = Intrinsic::getOrInsertDeclaration(
22601-
Store->getModule(), IntrIds[Factor - 2], {VecTupTy, EVL->getType()});
22602-
}
22581+
if (!Mask)
22582+
Mask = ConstantVector::getSplat(VTy->getElementCount(),
22583+
ConstantInt::getTrue(Store->getContext()));
22584+
22585+
Function *VssegNFunc = Intrinsic::getOrInsertDeclaration(
22586+
Store->getModule(), IntrMaskIds[Factor - 2],
22587+
{VecTupTy, Mask->getType(), EVL->getType()});
2260322588

22589+
Operands.push_back(Mask);
2260422590
Operands.push_back(EVL);
2260522591
Operands.push_back(ConstantInt::get(XLenTy, Log2_64(SEW)));
2260622592

0 commit comments

Comments
 (0)