Skip to content

Commit 958b507

Browse files
committed
[IA] Generalize the support for power-of-two (de)interleave intrinsics
Previously, AArch64 used pattern matching to support llvm.vector.(de)interleave of 2 and 4; RISC-V only supported (de)interleave of 2. This patch consolidates the logics in these two targets by factoring out the common factor calculations into the InterleaveAccess Pass.
1 parent 6518b12 commit 958b507

File tree

11 files changed

+816
-359
lines changed

11 files changed

+816
-359
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3157,12 +3157,11 @@ class TargetLoweringBase {
31573157
/// llvm.vector.deinterleave2
31583158
///
31593159
/// \p DI is the deinterleave intrinsic.
3160-
/// \p LI is the accompanying load instruction
3161-
/// \p DeadInsts is a reference to a vector that keeps track of dead
3162-
/// instruction during transformations.
3163-
virtual bool lowerDeinterleaveIntrinsicToLoad(
3164-
IntrinsicInst *DI, LoadInst *LI,
3165-
SmallVectorImpl<Instruction *> &DeadInsts) const {
3160+
/// \p LI is the accompanying load instruction.
3161+
/// \p DeinterleaveValues contains the deinterleaved values.
3162+
virtual bool
3163+
lowerDeinterleaveIntrinsicToLoad(IntrinsicInst *DI, LoadInst *LI,
3164+
ArrayRef<Value *> DeinterleaveValues) const {
31663165
return false;
31673166
}
31683167

@@ -3172,11 +3171,10 @@ class TargetLoweringBase {
31723171
///
31733172
/// \p II is the interleave intrinsic.
31743173
/// \p SI is the accompanying store instruction
3175-
/// \p DeadInsts is a reference to a vector that keeps track of dead
3176-
/// instruction during transformations.
3177-
virtual bool lowerInterleaveIntrinsicToStore(
3178-
IntrinsicInst *II, StoreInst *SI,
3179-
SmallVectorImpl<Instruction *> &DeadInsts) const {
3174+
/// \p InterleaveValues contains the interleaved values.
3175+
virtual bool
3176+
lowerInterleaveIntrinsicToStore(IntrinsicInst *II, StoreInst *SI,
3177+
ArrayRef<Value *> InterleaveValues) const {
31803178
return false;
31813179
}
31823180

llvm/lib/CodeGen/InterleavedAccessPass.cpp

Lines changed: 174 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
#include "llvm/IR/Instruction.h"
6161
#include "llvm/IR/Instructions.h"
6262
#include "llvm/IR/IntrinsicInst.h"
63+
#include "llvm/IR/PatternMatch.h"
6364
#include "llvm/InitializePasses.h"
6465
#include "llvm/Pass.h"
6566
#include "llvm/Support/Casting.h"
@@ -478,23 +479,184 @@ bool InterleavedAccessImpl::lowerInterleavedStore(
478479
return true;
479480
}
480481

482+
// For an (de)interleave tree like this:
483+
//
484+
// A C B D
485+
// |___| |___|
486+
// |_____|
487+
// |
488+
// A B C D
489+
//
490+
// We will get ABCD at the end while the leaf operands/results
491+
// are ACBD, which are also what we initially collected in
492+
// getVectorInterleaveFactor / getVectorDeinterleaveFactor. But TLI
493+
// hooks (e.g. lowerDeinterleaveIntrinsicToLoad) expect ABCD, so we need
494+
// to reorder them by interleaving these values.
495+
static void interleaveLeafValues(MutableArrayRef<Value *> SubLeaves) {
496+
int NumLeaves = SubLeaves.size();
497+
if (NumLeaves == 2)
498+
return;
499+
500+
assert(isPowerOf2_32(NumLeaves) && NumLeaves > 1);
501+
502+
const int HalfLeaves = NumLeaves / 2;
503+
// Visit the sub-trees.
504+
interleaveLeafValues(SubLeaves.take_front(HalfLeaves));
505+
interleaveLeafValues(SubLeaves.drop_front(HalfLeaves));
506+
507+
SmallVector<Value *, 8> Buffer;
508+
// The step is alternating between +half and -half+1. We exit the
509+
// loop right before the last element because given the fact that
510+
// SubLeaves always has an even number of elements, the last element
511+
// will never be moved and the last to be visited. This simplifies
512+
// the exit condition.
513+
for (int i = 0; i < NumLeaves - 1;
514+
(i < HalfLeaves) ? i += HalfLeaves : i += (1 - HalfLeaves))
515+
Buffer.push_back(SubLeaves[i]);
516+
517+
llvm::copy(Buffer, SubLeaves.begin());
518+
}
519+
520+
static bool
521+
getVectorInterleaveFactor(IntrinsicInst *II, SmallVectorImpl<Value *> &Operands,
522+
SmallVectorImpl<Instruction *> &DeadInsts) {
523+
if (II->getIntrinsicID() != Intrinsic::vector_interleave2)
524+
return false;
525+
526+
// Visit with BFS
527+
SmallVector<IntrinsicInst *, 8> Queue;
528+
Queue.push_back(II);
529+
while (!Queue.empty()) {
530+
IntrinsicInst *Current = Queue.front();
531+
Queue.erase(Queue.begin());
532+
533+
// All the intermediate intrinsics will be deleted.
534+
DeadInsts.push_back(Current);
535+
536+
for (unsigned I = 0; I < 2; ++I) {
537+
Value *Op = Current->getOperand(I);
538+
if (auto *OpII = dyn_cast<IntrinsicInst>(Op))
539+
if (OpII->getIntrinsicID() == Intrinsic::vector_interleave2) {
540+
Queue.push_back(OpII);
541+
continue;
542+
}
543+
544+
// If this is not a perfectly balanced tree, the leaf
545+
// result types would be different.
546+
if (!Operands.empty() && Op->getType() != Operands.back()->getType())
547+
return false;
548+
549+
Operands.push_back(Op);
550+
}
551+
}
552+
553+
const unsigned Factor = Operands.size();
554+
// Currently we only recognize power-of-two factors.
555+
// FIXME: should we assert here instead?
556+
if (Factor <= 1 || !isPowerOf2_32(Factor))
557+
return false;
558+
559+
interleaveLeafValues(Operands);
560+
return true;
561+
}
562+
563+
static bool
564+
getVectorDeinterleaveFactor(IntrinsicInst *II,
565+
SmallVectorImpl<Value *> &Results,
566+
SmallVectorImpl<Instruction *> &DeadInsts) {
567+
using namespace PatternMatch;
568+
if (II->getIntrinsicID() != Intrinsic::vector_deinterleave2 ||
569+
!II->hasNUses(2))
570+
return false;
571+
572+
// Visit with BFS
573+
SmallVector<IntrinsicInst *, 8> Queue;
574+
Queue.push_back(II);
575+
while (!Queue.empty()) {
576+
IntrinsicInst *Current = Queue.front();
577+
Queue.erase(Queue.begin());
578+
assert(Current->hasNUses(2));
579+
580+
// All the intermediate intrinsics will be deleted from the bottom-up.
581+
DeadInsts.insert(DeadInsts.begin(), Current);
582+
583+
ExtractValueInst *LHS = nullptr, *RHS = nullptr;
584+
for (User *Usr : Current->users()) {
585+
if (!isa<ExtractValueInst>(Usr))
586+
return 0;
587+
588+
auto *EV = cast<ExtractValueInst>(Usr);
589+
// Intermediate ExtractValue instructions will also be deleted.
590+
DeadInsts.insert(DeadInsts.begin(), EV);
591+
ArrayRef<unsigned> Indices = EV->getIndices();
592+
if (Indices.size() != 1)
593+
return false;
594+
595+
if (Indices[0] == 0 && !LHS)
596+
LHS = EV;
597+
else if (Indices[0] == 1 && !RHS)
598+
RHS = EV;
599+
else
600+
return false;
601+
}
602+
603+
// We have legal indices. At this point we're either going
604+
// to continue the traversal or push the leaf values into Results.
605+
for (ExtractValueInst *EV : {LHS, RHS}) {
606+
// Continue the traversal. We're playing safe here and matching only the
607+
// expression consisting of a perfectly balanced binary tree in which all
608+
// intermediate values are only used once.
609+
if (EV->hasOneUse() &&
610+
match(EV->user_back(),
611+
m_Intrinsic<Intrinsic::vector_deinterleave2>()) &&
612+
EV->user_back()->hasNUses(2)) {
613+
auto *EVUsr = cast<IntrinsicInst>(EV->user_back());
614+
Queue.push_back(EVUsr);
615+
continue;
616+
}
617+
618+
// If this is not a perfectly balanced tree, the leaf
619+
// result types would be different.
620+
if (!Results.empty() && EV->getType() != Results.back()->getType())
621+
return false;
622+
623+
// Save the leaf value.
624+
Results.push_back(EV);
625+
}
626+
}
627+
628+
const unsigned Factor = Results.size();
629+
// Currently we only recognize power-of-two factors.
630+
// FIXME: should we assert here instead?
631+
if (Factor <= 1 || !isPowerOf2_32(Factor))
632+
return 0;
633+
634+
interleaveLeafValues(Results);
635+
return true;
636+
}
637+
481638
bool InterleavedAccessImpl::lowerDeinterleaveIntrinsic(
482639
IntrinsicInst *DI, SmallSetVector<Instruction *, 32> &DeadInsts) {
483640
LoadInst *LI = dyn_cast<LoadInst>(DI->getOperand(0));
484641

485642
if (!LI || !LI->hasOneUse() || !LI->isSimple())
486643
return false;
487644

488-
LLVM_DEBUG(dbgs() << "IA: Found a deinterleave intrinsic: " << *DI << "\n");
645+
SmallVector<Value *, 8> DeinterleaveValues;
646+
SmallVector<Instruction *, 8> DeinterleaveDeadInsts;
647+
if (!getVectorDeinterleaveFactor(DI, DeinterleaveValues,
648+
DeinterleaveDeadInsts))
649+
return false;
650+
651+
LLVM_DEBUG(dbgs() << "IA: Found a deinterleave intrinsic: " << *DI
652+
<< " with factor = " << DeinterleaveValues.size() << "\n");
489653

490654
// Try and match this with target specific intrinsics.
491-
SmallVector<Instruction *, 4> DeinterleaveDeadInsts;
492-
if (!TLI->lowerDeinterleaveIntrinsicToLoad(DI, LI, DeinterleaveDeadInsts))
655+
if (!TLI->lowerDeinterleaveIntrinsicToLoad(DI, LI, DeinterleaveValues))
493656
return false;
494657

495658
DeadInsts.insert(DeinterleaveDeadInsts.begin(), DeinterleaveDeadInsts.end());
496659
// We now have a target-specific load, so delete the old one.
497-
DeadInsts.insert(DI);
498660
DeadInsts.insert(LI);
499661
return true;
500662
}
@@ -509,16 +671,20 @@ bool InterleavedAccessImpl::lowerInterleaveIntrinsic(
509671
if (!SI || !SI->isSimple())
510672
return false;
511673

512-
LLVM_DEBUG(dbgs() << "IA: Found an interleave intrinsic: " << *II << "\n");
674+
SmallVector<Value *, 8> InterleaveValues;
675+
SmallVector<Instruction *, 8> InterleaveDeadInsts;
676+
if (!getVectorInterleaveFactor(II, InterleaveValues, InterleaveDeadInsts))
677+
return false;
678+
679+
LLVM_DEBUG(dbgs() << "IA: Found an interleave intrinsic: " << *II
680+
<< " with factor = " << InterleaveValues.size() << "\n");
513681

514-
SmallVector<Instruction *, 4> InterleaveDeadInsts;
515682
// Try and match this with target specific intrinsics.
516-
if (!TLI->lowerInterleaveIntrinsicToStore(II, SI, InterleaveDeadInsts))
683+
if (!TLI->lowerInterleaveIntrinsicToStore(II, SI, InterleaveValues))
517684
return false;
518685

519686
// We now have a target-specific store, so delete the old one.
520687
DeadInsts.insert(SI);
521-
DeadInsts.insert(II);
522688
DeadInsts.insert(InterleaveDeadInsts.begin(), InterleaveDeadInsts.end());
523689
return true;
524690
}

0 commit comments

Comments
 (0)