Skip to content

Commit 8b55d34

Browse files
authored
[RISCV][LoopIdiomVectorize] Support VP intrinsics in LoopIdiomVectorize (#94082)
Teach LoopIdiomVectorize to use VP intrinsics to replace the byte compare loops. Right now only RISC-V uses LoopIdiomVectorize of this style.
1 parent a355c2d commit 8b55d34

File tree

7 files changed

+2674
-19
lines changed

7 files changed

+2674
-19
lines changed

llvm/include/llvm/Transforms/Vectorize/LoopIdiomVectorize.h

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,22 @@
1313
#include "llvm/Transforms/Scalar/LoopPassManager.h"
1414

1515
namespace llvm {
16-
struct LoopIdiomVectorizePass : PassInfoMixin<LoopIdiomVectorizePass> {
16+
enum class LoopIdiomVectorizeStyle { Masked, Predicated };
17+
18+
class LoopIdiomVectorizePass : public PassInfoMixin<LoopIdiomVectorizePass> {
19+
LoopIdiomVectorizeStyle VectorizeStyle = LoopIdiomVectorizeStyle::Masked;
20+
21+
// The VF used in vectorizing the byte compare pattern.
22+
unsigned ByteCompareVF = 16;
23+
24+
public:
25+
LoopIdiomVectorizePass() = default;
26+
explicit LoopIdiomVectorizePass(LoopIdiomVectorizeStyle S)
27+
: VectorizeStyle(S) {}
28+
29+
LoopIdiomVectorizePass(LoopIdiomVectorizeStyle S, unsigned BCVF)
30+
: VectorizeStyle(S), ByteCompareVF(BCVF) {}
31+
1732
PreservedAnalyses run(Loop &L, LoopAnalysisManager &AM,
1833
LoopStandardAnalysisResults &AR, LPMUpdater &U);
1934
};

llvm/lib/Target/RISCV/RISCVTargetMachine.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,12 @@
3333
#include "llvm/CodeGen/TargetPassConfig.h"
3434
#include "llvm/InitializePasses.h"
3535
#include "llvm/MC/TargetRegistry.h"
36+
#include "llvm/Passes/PassBuilder.h"
3637
#include "llvm/Support/FormattedStream.h"
3738
#include "llvm/Target/TargetOptions.h"
3839
#include "llvm/Transforms/IPO.h"
3940
#include "llvm/Transforms/Scalar.h"
41+
#include "llvm/Transforms/Vectorize/LoopIdiomVectorize.h"
4042
#include <optional>
4143
using namespace llvm;
4244

@@ -572,6 +574,13 @@ void RISCVPassConfig::addPostRegAlloc() {
572574
addPass(createRISCVRedundantCopyEliminationPass());
573575
}
574576

577+
void RISCVTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
578+
PB.registerLateLoopOptimizationsEPCallback([=](LoopPassManager &LPM,
579+
OptimizationLevel Level) {
580+
LPM.addPass(LoopIdiomVectorizePass(LoopIdiomVectorizeStyle::Predicated));
581+
});
582+
}
583+
575584
yaml::MachineFunctionInfo *
576585
RISCVTargetMachine::createDefaultFuncInfoYAML() const {
577586
return new yaml::RISCVMachineFunctionInfo();

llvm/lib/Target/RISCV/RISCVTargetMachine.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class RISCVTargetMachine : public LLVMTargetMachine {
5959
PerFunctionMIParsingState &PFS,
6060
SMDiagnostic &Error,
6161
SMRange &SourceRange) const override;
62+
void registerPassBuilderCallbacks(PassBuilder &PB) override;
6263
};
6364
} // namespace llvm
6465

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,8 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
397397
bool shouldFoldTerminatingConditionAfterLSR() const {
398398
return true;
399399
}
400+
401+
std::optional<unsigned> getMinPageSize() const { return 4096; }
400402
};
401403

402404
} // end namespace llvm

llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp

Lines changed: 160 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -59,19 +59,34 @@ static cl::opt<bool> DisableAll("disable-loop-idiom-vectorize-all", cl::Hidden,
5959
cl::init(false),
6060
cl::desc("Disable Loop Idiom Vectorize Pass."));
6161

62+
static cl::opt<LoopIdiomVectorizeStyle>
63+
LITVecStyle("loop-idiom-vectorize-style", cl::Hidden,
64+
cl::desc("The vectorization style for loop idiom transform."),
65+
cl::values(clEnumValN(LoopIdiomVectorizeStyle::Masked, "masked",
66+
"Use masked vector intrinsics"),
67+
clEnumValN(LoopIdiomVectorizeStyle::Predicated,
68+
"predicated", "Use VP intrinsics")),
69+
cl::init(LoopIdiomVectorizeStyle::Masked));
70+
6271
static cl::opt<bool>
6372
DisableByteCmp("disable-loop-idiom-vectorize-bytecmp", cl::Hidden,
6473
cl::init(false),
6574
cl::desc("Proceed with Loop Idiom Vectorize Pass, but do "
6675
"not convert byte-compare loop(s)."));
6776

77+
static cl::opt<unsigned>
78+
ByteCmpVF("loop-idiom-vectorize-bytecmp-vf", cl::Hidden,
79+
cl::desc("The vectorization factor for byte-compare patterns."),
80+
cl::init(16));
81+
6882
static cl::opt<bool>
6983
VerifyLoops("loop-idiom-vectorize-verify", cl::Hidden, cl::init(false),
7084
cl::desc("Verify loops generated Loop Idiom Vectorize Pass."));
7185

7286
namespace {
73-
7487
class LoopIdiomVectorize {
88+
LoopIdiomVectorizeStyle VectorizeStyle;
89+
unsigned ByteCompareVF;
7590
Loop *CurLoop = nullptr;
7691
DominatorTree *DT;
7792
LoopInfo *LI;
@@ -86,10 +101,11 @@ class LoopIdiomVectorize {
86101
BasicBlock *VectorLoopIncBlock = nullptr;
87102

88103
public:
89-
explicit LoopIdiomVectorize(DominatorTree *DT, LoopInfo *LI,
90-
const TargetTransformInfo *TTI,
91-
const DataLayout *DL)
92-
: DT(DT), LI(LI), TTI(TTI), DL(DL) {}
104+
LoopIdiomVectorize(LoopIdiomVectorizeStyle S, unsigned VF, DominatorTree *DT,
105+
LoopInfo *LI, const TargetTransformInfo *TTI,
106+
const DataLayout *DL)
107+
: VectorizeStyle(S), ByteCompareVF(VF), DT(DT), LI(LI), TTI(TTI), DL(DL) {
108+
}
93109

94110
bool run(Loop *L);
95111

@@ -111,6 +127,10 @@ class LoopIdiomVectorize {
111127
GetElementPtrInst *GEPA,
112128
GetElementPtrInst *GEPB, Value *ExtStart,
113129
Value *ExtEnd);
130+
Value *createPredicatedFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU,
131+
GetElementPtrInst *GEPA,
132+
GetElementPtrInst *GEPB, Value *ExtStart,
133+
Value *ExtEnd);
114134

115135
void transformByteCompare(GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
116136
PHINode *IndPhi, Value *MaxLen, Instruction *Index,
@@ -128,8 +148,16 @@ PreservedAnalyses LoopIdiomVectorizePass::run(Loop &L, LoopAnalysisManager &AM,
128148

129149
const auto *DL = &L.getHeader()->getDataLayout();
130150

131-
LoopIdiomVectorize LIT(&AR.DT, &AR.LI, &AR.TTI, DL);
132-
if (!LIT.run(&L))
151+
LoopIdiomVectorizeStyle VecStyle = VectorizeStyle;
152+
if (LITVecStyle.getNumOccurrences())
153+
VecStyle = LITVecStyle;
154+
155+
unsigned BCVF = ByteCompareVF;
156+
if (ByteCmpVF.getNumOccurrences())
157+
BCVF = ByteCmpVF;
158+
159+
LoopIdiomVectorize LIV(VecStyle, BCVF, &AR.DT, &AR.LI, &AR.TTI, DL);
160+
if (!LIV.run(&L))
133161
return PreservedAnalyses::all();
134162

135163
return PreservedAnalyses::none();
@@ -354,20 +382,16 @@ Value *LoopIdiomVectorize::createMaskedFindMismatch(
354382
Value *PtrA = GEPA->getPointerOperand();
355383
Value *PtrB = GEPB->getPointerOperand();
356384

357-
// At this point we know two things must be true:
358-
// 1. Start <= End
359-
// 2. ExtMaxLen <= MinPageSize due to the page checks.
360-
// Therefore, we know that we can use a 64-bit induction variable that
361-
// starts from 0 -> ExtMaxLen and it will not overflow.
362385
ScalableVectorType *PredVTy =
363-
ScalableVectorType::get(Builder.getInt1Ty(), 16);
386+
ScalableVectorType::get(Builder.getInt1Ty(), ByteCompareVF);
364387

365388
Value *InitialPred = Builder.CreateIntrinsic(
366389
Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});
367390

368391
Value *VecLen = Builder.CreateIntrinsic(Intrinsic::vscale, {I64Type}, {});
369-
VecLen = Builder.CreateMul(VecLen, ConstantInt::get(I64Type, 16), "",
370-
/*HasNUW=*/true, /*HasNSW=*/true);
392+
VecLen =
393+
Builder.CreateMul(VecLen, ConstantInt::get(I64Type, ByteCompareVF), "",
394+
/*HasNUW=*/true, /*HasNSW=*/true);
371395

372396
Value *PFalse = Builder.CreateVectorSplat(PredVTy->getElementCount(),
373397
Builder.getInt1(false));
@@ -385,7 +409,8 @@ Value *LoopIdiomVectorize::createMaskedFindMismatch(
385409
LoopPred->addIncoming(InitialPred, VectorLoopPreheaderBlock);
386410
PHINode *VectorIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_vec_index");
387411
VectorIndexPhi->addIncoming(ExtStart, VectorLoopPreheaderBlock);
388-
Type *VectorLoadType = ScalableVectorType::get(Builder.getInt8Ty(), 16);
412+
Type *VectorLoadType =
413+
ScalableVectorType::get(Builder.getInt8Ty(), ByteCompareVF);
389414
Value *Passthru = ConstantInt::getNullValue(VectorLoadType);
390415

391416
Value *VectorLhsGep =
@@ -454,6 +479,109 @@ Value *LoopIdiomVectorize::createMaskedFindMismatch(
454479
return Builder.CreateTrunc(VectorLoopRes64, ResType);
455480
}
456481

482+
Value *LoopIdiomVectorize::createPredicatedFindMismatch(
483+
IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
484+
GetElementPtrInst *GEPB, Value *ExtStart, Value *ExtEnd) {
485+
Type *I64Type = Builder.getInt64Ty();
486+
Type *I32Type = Builder.getInt32Ty();
487+
Type *ResType = I32Type;
488+
Type *LoadType = Builder.getInt8Ty();
489+
Value *PtrA = GEPA->getPointerOperand();
490+
Value *PtrB = GEPB->getPointerOperand();
491+
492+
auto *JumpToVectorLoop = BranchInst::Create(VectorLoopStartBlock);
493+
Builder.Insert(JumpToVectorLoop);
494+
495+
DTU.applyUpdates({{DominatorTree::Insert, VectorLoopPreheaderBlock,
496+
VectorLoopStartBlock}});
497+
498+
// Set up the first Vector loop block by creating the PHIs, doing the vector
499+
// loads and comparing the vectors.
500+
Builder.SetInsertPoint(VectorLoopStartBlock);
501+
auto *VectorIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_vector_index");
502+
VectorIndexPhi->addIncoming(ExtStart, VectorLoopPreheaderBlock);
503+
504+
// Calculate AVL by subtracting the vector loop index from the trip count
505+
Value *AVL = Builder.CreateSub(ExtEnd, VectorIndexPhi, "avl", /*HasNUW=*/true,
506+
/*HasNSW=*/true);
507+
508+
auto *VectorLoadType = ScalableVectorType::get(LoadType, ByteCompareVF);
509+
auto *VF = ConstantInt::get(I32Type, ByteCompareVF);
510+
511+
Value *VL = Builder.CreateIntrinsic(Intrinsic::experimental_get_vector_length,
512+
{I64Type}, {AVL, VF, Builder.getTrue()});
513+
Value *GepOffset = VectorIndexPhi;
514+
515+
Value *VectorLhsGep =
516+
Builder.CreateGEP(LoadType, PtrA, GepOffset, "", GEPA->isInBounds());
517+
VectorType *TrueMaskTy =
518+
VectorType::get(Builder.getInt1Ty(), VectorLoadType->getElementCount());
519+
Value *AllTrueMask = Constant::getAllOnesValue(TrueMaskTy);
520+
Value *VectorLhsLoad = Builder.CreateIntrinsic(
521+
Intrinsic::vp_load, {VectorLoadType, VectorLhsGep->getType()},
522+
{VectorLhsGep, AllTrueMask, VL}, nullptr, "lhs.load");
523+
524+
Value *VectorRhsGep =
525+
Builder.CreateGEP(LoadType, PtrB, GepOffset, "", GEPB->isInBounds());
526+
Value *VectorRhsLoad = Builder.CreateIntrinsic(
527+
Intrinsic::vp_load, {VectorLoadType, VectorLhsGep->getType()},
528+
{VectorRhsGep, AllTrueMask, VL}, nullptr, "rhs.load");
529+
530+
StringRef PredicateStr = CmpInst::getPredicateName(CmpInst::ICMP_NE);
531+
auto *PredicateMDS = MDString::get(VectorLhsLoad->getContext(), PredicateStr);
532+
Value *Pred = MetadataAsValue::get(VectorLhsLoad->getContext(), PredicateMDS);
533+
Value *VectorMatchCmp = Builder.CreateIntrinsic(
534+
Intrinsic::vp_icmp, {VectorLhsLoad->getType()},
535+
{VectorLhsLoad, VectorRhsLoad, Pred, AllTrueMask, VL}, nullptr,
536+
"mismatch.cmp");
537+
Value *CTZ = Builder.CreateIntrinsic(
538+
Intrinsic::vp_cttz_elts, {ResType, VectorMatchCmp->getType()},
539+
{VectorMatchCmp, /*ZeroIsPoison=*/Builder.getInt1(false), AllTrueMask,
540+
VL});
541+
Value *MismatchFound = Builder.CreateICmpNE(CTZ, VL);
542+
auto *VectorEarlyExit = BranchInst::Create(VectorLoopMismatchBlock,
543+
VectorLoopIncBlock, MismatchFound);
544+
Builder.Insert(VectorEarlyExit);
545+
546+
DTU.applyUpdates(
547+
{{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock},
548+
{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}});
549+
550+
// Increment the index counter and calculate the predicate for the next
551+
// iteration of the loop. We branch back to the start of the loop if there
552+
// is at least one active lane.
553+
Builder.SetInsertPoint(VectorLoopIncBlock);
554+
Value *VL64 = Builder.CreateZExt(VL, I64Type);
555+
Value *NewVectorIndexPhi =
556+
Builder.CreateAdd(VectorIndexPhi, VL64, "",
557+
/*HasNUW=*/true, /*HasNSW=*/true);
558+
VectorIndexPhi->addIncoming(NewVectorIndexPhi, VectorLoopIncBlock);
559+
Value *ExitCond = Builder.CreateICmpNE(NewVectorIndexPhi, ExtEnd);
560+
auto *VectorLoopBranchBack =
561+
BranchInst::Create(VectorLoopStartBlock, EndBlock, ExitCond);
562+
Builder.Insert(VectorLoopBranchBack);
563+
564+
DTU.applyUpdates(
565+
{{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock},
566+
{DominatorTree::Insert, VectorLoopIncBlock, EndBlock}});
567+
568+
// If we found a mismatch then we need to calculate which lane in the vector
569+
// had a mismatch and add that on to the current loop index.
570+
Builder.SetInsertPoint(VectorLoopMismatchBlock);
571+
572+
// Add LCSSA phis for CTZ and VectorIndexPhi.
573+
auto *CTZLCSSAPhi = Builder.CreatePHI(CTZ->getType(), 1, "ctz");
574+
CTZLCSSAPhi->addIncoming(CTZ, VectorLoopStartBlock);
575+
auto *VectorIndexLCSSAPhi =
576+
Builder.CreatePHI(VectorIndexPhi->getType(), 1, "mismatch_vector_index");
577+
VectorIndexLCSSAPhi->addIncoming(VectorIndexPhi, VectorLoopStartBlock);
578+
579+
Value *CTZI64 = Builder.CreateZExt(CTZLCSSAPhi, I64Type);
580+
Value *VectorLoopRes64 = Builder.CreateAdd(VectorIndexLCSSAPhi, CTZI64, "",
581+
/*HasNUW=*/true, /*HasNSW=*/true);
582+
return Builder.CreateTrunc(VectorLoopRes64, ResType);
583+
}
584+
457585
Value *LoopIdiomVectorize::expandFindMismatch(
458586
IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
459587
GetElementPtrInst *GEPB, Instruction *Index, Value *Start, Value *MaxLen) {
@@ -613,8 +741,22 @@ Value *LoopIdiomVectorize::expandFindMismatch(
613741
// processed in each iteration, etc.
614742
Builder.SetInsertPoint(VectorLoopPreheaderBlock);
615743

616-
Value *VectorLoopRes =
617-
createMaskedFindMismatch(Builder, DTU, GEPA, GEPB, ExtStart, ExtEnd);
744+
// At this point we know two things must be true:
745+
// 1. Start <= End
746+
// 2. ExtMaxLen <= MinPageSize due to the page checks.
747+
// Therefore, we know that we can use a 64-bit induction variable that
748+
// starts from 0 -> ExtMaxLen and it will not overflow.
749+
Value *VectorLoopRes = nullptr;
750+
switch (VectorizeStyle) {
751+
case LoopIdiomVectorizeStyle::Masked:
752+
VectorLoopRes =
753+
createMaskedFindMismatch(Builder, DTU, GEPA, GEPB, ExtStart, ExtEnd);
754+
break;
755+
case LoopIdiomVectorizeStyle::Predicated:
756+
VectorLoopRes = createPredicatedFindMismatch(Builder, DTU, GEPA, GEPB,
757+
ExtStart, ExtEnd);
758+
break;
759+
}
618760

619761
Builder.Insert(BranchInst::Create(EndBlock));
620762

0 commit comments

Comments
 (0)