Skip to content

Commit 47050e6

Browse files
committed
[RISCV][LoopIdiomVectorize] Support VP intrinsics in LoopIdiomVectorize
Teach LoopIdiomVectorize to use VP intrinsics to replace the byte compare loops. Right now only RISC-V uses LoopIdiomVectorize of this style.
1 parent 174d02e commit 47050e6

File tree

6 files changed

+1939
-14
lines changed

6 files changed

+1939
-14
lines changed

llvm/include/llvm/Transforms/Vectorize/LoopIdiomVectorize.h

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,22 @@
1313
#include "llvm/Transforms/Scalar/LoopPassManager.h"
1414

1515
namespace llvm {
16-
struct LoopIdiomVectorizePass : PassInfoMixin<LoopIdiomVectorizePass> {
16+
enum class LoopIdiomVectorizeStyle { Masked, Predicated };
17+
18+
class LoopIdiomVectorizePass : public PassInfoMixin<LoopIdiomVectorizePass> {
19+
LoopIdiomVectorizeStyle VectorizeStyle = LoopIdiomVectorizeStyle::Masked;
20+
21+
// The VF used in vectorizing the byte compare pattern.
22+
unsigned ByteCompareVF = 16;
23+
24+
public:
25+
LoopIdiomVectorizePass() = default;
26+
explicit LoopIdiomVectorizePass(LoopIdiomVectorizeStyle S)
27+
: VectorizeStyle(S) {}
28+
29+
LoopIdiomVectorizePass(LoopIdiomVectorizeStyle S, unsigned BCVF)
30+
: VectorizeStyle(S), ByteCompareVF(BCVF) {}
31+
1732
PreservedAnalyses run(Loop &L, LoopAnalysisManager &AM,
1833
LoopStandardAnalysisResults &AR, LPMUpdater &U);
1934
};

llvm/lib/Target/RISCV/RISCVTargetMachine.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,12 @@
3333
#include "llvm/CodeGen/TargetPassConfig.h"
3434
#include "llvm/InitializePasses.h"
3535
#include "llvm/MC/TargetRegistry.h"
36+
#include "llvm/Passes/PassBuilder.h"
3637
#include "llvm/Support/FormattedStream.h"
3738
#include "llvm/Target/TargetOptions.h"
3839
#include "llvm/Transforms/IPO.h"
3940
#include "llvm/Transforms/Scalar.h"
41+
#include "llvm/Transforms/Vectorize/LoopIdiomVectorize.h"
4042
#include <optional>
4143
using namespace llvm;
4244

@@ -573,6 +575,14 @@ void RISCVPassConfig::addPostRegAlloc() {
573575
addPass(createRISCVRedundantCopyEliminationPass());
574576
}
575577

578+
void RISCVTargetMachine::registerPassBuilderCallbacks(
579+
PassBuilder &PB, bool PopulateClassToPassNames) {
580+
PB.registerLateLoopOptimizationsEPCallback([=](LoopPassManager &LPM,
581+
OptimizationLevel Level) {
582+
LPM.addPass(LoopIdiomVectorizePass(LoopIdiomVectorizeStyle::Predicated));
583+
});
584+
}
585+
576586
yaml::MachineFunctionInfo *
577587
RISCVTargetMachine::createDefaultFuncInfoYAML() const {
578588
return new yaml::RISCVMachineFunctionInfo();

llvm/lib/Target/RISCV/RISCVTargetMachine.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ class RISCVTargetMachine : public LLVMTargetMachine {
5959
PerFunctionMIParsingState &PFS,
6060
SMDiagnostic &Error,
6161
SMRange &SourceRange) const override;
62+
void registerPassBuilderCallbacks(PassBuilder &PB,
63+
bool PopulateClassToPassNames) override;
6264
};
6365
} // namespace llvm
6466

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,8 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
397397
bool shouldFoldTerminatingConditionAfterLSR() const {
398398
return true;
399399
}
400+
401+
std::optional<unsigned> getMinPageSize() const { return 4096; }
400402
};
401403

402404
} // end namespace llvm

llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp

Lines changed: 158 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -60,19 +60,34 @@ static cl::opt<bool> DisableAll("disable-loop-idiom-vectorize-all", cl::Hidden,
6060
cl::init(false),
6161
cl::desc("Disable Loop Idiom Vectorize Pass."));
6262

63+
static cl::opt<LoopIdiomVectorizeStyle>
64+
LITVecStyle("loop-idiom-vectorize-style", cl::Hidden,
65+
cl::desc("The vectorization style for loop idiom transform."),
66+
cl::values(clEnumValN(LoopIdiomVectorizeStyle::Masked, "masked",
67+
"Use masked vector intrinsics"),
68+
clEnumValN(LoopIdiomVectorizeStyle::Predicated,
69+
"predicated", "Use VP intrinsics")),
70+
cl::init(LoopIdiomVectorizeStyle::Masked));
71+
6372
static cl::opt<bool>
6473
DisableByteCmp("disable-loop-idiom-vectorize-bytecmp", cl::Hidden,
6574
cl::init(false),
6675
cl::desc("Proceed with Loop Idiom Vectorize Pass, but do "
6776
"not convert byte-compare loop(s)."));
6877

78+
static cl::opt<unsigned>
79+
ByteCmpVF("loop-idiom-vectorize-bytecmp-vf", cl::Hidden,
80+
cl::desc("The vectorization factor for byte-compare patterns."),
81+
cl::init(16));
82+
6983
static cl::opt<bool>
7084
VerifyLoops("loop-idiom-vectorize-verify", cl::Hidden, cl::init(false),
7185
cl::desc("Verify loops generated Loop Idiom Vectorize Pass."));
7286

7387
namespace {
74-
7588
class LoopIdiomVectorize {
89+
LoopIdiomVectorizeStyle VectorizeStyle;
90+
unsigned ByteCompareVF;
7691
Loop *CurLoop = nullptr;
7792
DominatorTree *DT;
7893
LoopInfo *LI;
@@ -87,10 +102,11 @@ class LoopIdiomVectorize {
87102
BasicBlock *VectorLoopIncBlock = nullptr;
88103

89104
public:
90-
explicit LoopIdiomVectorize(DominatorTree *DT, LoopInfo *LI,
91-
const TargetTransformInfo *TTI,
92-
const DataLayout *DL)
93-
: DT(DT), LI(LI), TTI(TTI), DL(DL) {}
105+
LoopIdiomVectorize(LoopIdiomVectorizeStyle S, unsigned VF, DominatorTree *DT,
106+
LoopInfo *LI, const TargetTransformInfo *TTI,
107+
const DataLayout *DL)
108+
: VectorizeStyle(S), ByteCompareVF(VF), DT(DT), LI(LI), TTI(TTI), DL(DL) {
109+
}
94110

95111
bool run(Loop *L);
96112

@@ -111,6 +127,10 @@ class LoopIdiomVectorize {
111127
Value *createMaskedFindMismatch(IRBuilder<> &Builder, GetElementPtrInst *GEPA,
112128
GetElementPtrInst *GEPB, Value *ExtStart,
113129
Value *ExtEnd);
130+
Value *createPredicatedFindMismatch(IRBuilder<> &Builder,
131+
GetElementPtrInst *GEPA,
132+
GetElementPtrInst *GEPB, Value *ExtStart,
133+
Value *ExtEnd);
114134

115135
void transformByteCompare(GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
116136
PHINode *IndPhi, Value *MaxLen, Instruction *Index,
@@ -128,8 +148,16 @@ PreservedAnalyses LoopIdiomVectorizePass::run(Loop &L, LoopAnalysisManager &AM,
128148

129149
const auto *DL = &L.getHeader()->getModule()->getDataLayout();
130150

131-
LoopIdiomVectorize LIT(&AR.DT, &AR.LI, &AR.TTI, DL);
132-
if (!LIT.run(&L))
151+
LoopIdiomVectorizeStyle VecStyle = VectorizeStyle;
152+
if (LITVecStyle.getNumOccurrences())
153+
VecStyle = LITVecStyle;
154+
155+
unsigned BCVF = ByteCompareVF;
156+
if (ByteCmpVF.getNumOccurrences())
157+
BCVF = ByteCmpVF;
158+
159+
LoopIdiomVectorize LIV(VecStyle, BCVF, &AR.DT, &AR.LI, &AR.TTI, DL);
160+
if (!LIV.run(&L))
133161
return PreservedAnalyses::all();
134162

135163
return PreservedAnalyses::none();
@@ -362,14 +390,15 @@ Value *LoopIdiomVectorize::createMaskedFindMismatch(IRBuilder<> &Builder,
362390
// Therefore, we know that we can use a 64-bit induction variable that
363391
// starts from 0 -> ExtMaxLen and it will not overflow.
364392
ScalableVectorType *PredVTy =
365-
ScalableVectorType::get(Builder.getInt1Ty(), 16);
393+
ScalableVectorType::get(Builder.getInt1Ty(), ByteCompareVF);
366394

367395
Value *InitialPred = Builder.CreateIntrinsic(
368396
Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});
369397

370398
Value *VecLen = Builder.CreateIntrinsic(Intrinsic::vscale, {I64Type}, {});
371-
VecLen = Builder.CreateMul(VecLen, ConstantInt::get(I64Type, 16), "",
372-
/*HasNUW=*/true, /*HasNSW=*/true);
399+
VecLen =
400+
Builder.CreateMul(VecLen, ConstantInt::get(I64Type, ByteCompareVF), "",
401+
/*HasNUW=*/true, /*HasNSW=*/true);
373402

374403
Value *PFalse = Builder.CreateVectorSplat(PredVTy->getElementCount(),
375404
Builder.getInt1(false));
@@ -384,7 +413,8 @@ Value *LoopIdiomVectorize::createMaskedFindMismatch(IRBuilder<> &Builder,
384413
LoopPred->addIncoming(InitialPred, VectorLoopPreheaderBlock);
385414
PHINode *VectorIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_vec_index");
386415
VectorIndexPhi->addIncoming(ExtStart, VectorLoopPreheaderBlock);
387-
Type *VectorLoadType = ScalableVectorType::get(Builder.getInt8Ty(), 16);
416+
Type *VectorLoadType =
417+
ScalableVectorType::get(Builder.getInt8Ty(), ByteCompareVF);
388418
Value *Passthru = ConstantInt::getNullValue(VectorLoadType);
389419

390420
Value *VectorLhsGep =
@@ -445,6 +475,112 @@ Value *LoopIdiomVectorize::createMaskedFindMismatch(IRBuilder<> &Builder,
445475
return Builder.CreateTrunc(VectorLoopRes64, ResType);
446476
}
447477

478+
Value *LoopIdiomVectorize::createPredicatedFindMismatch(IRBuilder<> &Builder,
479+
GetElementPtrInst *GEPA,
480+
GetElementPtrInst *GEPB,
481+
Value *ExtStart,
482+
Value *ExtEnd) {
483+
Type *I64Type = Builder.getInt64Ty();
484+
Type *I32Type = Builder.getInt32Ty();
485+
Type *ResType = I32Type;
486+
Type *LoadType = Builder.getInt8Ty();
487+
Value *PtrA = GEPA->getPointerOperand();
488+
Value *PtrB = GEPB->getPointerOperand();
489+
490+
// At this point we know two things must be true:
491+
// 1. Start <= End
492+
// 2. ExtMaxLen <= 4096 due to the page checks.
493+
// Therefore, we know that we can use a 64-bit induction variable that
494+
// starts from 0 -> ExtMaxLen and it will not overflow.
495+
auto *JumpToVectorLoop = BranchInst::Create(VectorLoopStartBlock);
496+
Builder.Insert(JumpToVectorLoop);
497+
498+
// Set up the first Vector loop block by creating the PHIs, doing the vector
499+
// loads and comparing the vectors.
500+
Builder.SetInsertPoint(VectorLoopStartBlock);
501+
auto *VectorIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_vector_index");
502+
VectorIndexPhi->addIncoming(ExtStart, VectorLoopPreheaderBlock);
503+
504+
// Calculate AVL by subtracting the vector loop index from the trip count
505+
Value *AVL = Builder.CreateSub(ExtEnd, VectorIndexPhi, "avl", /*HasNUW=*/true,
506+
/*HasNSW=*/true);
507+
508+
auto *VectorLoadType = ScalableVectorType::get(LoadType, ByteCompareVF);
509+
auto *VF = ConstantInt::get(
510+
I32Type, VectorLoadType->getElementCount().getKnownMinValue());
511+
auto *IsScalable = ConstantInt::getBool(
512+
Builder.getContext(), VectorLoadType->getElementCount().isScalable());
513+
514+
Value *VL = Builder.CreateIntrinsic(Intrinsic::experimental_get_vector_length,
515+
{I64Type}, {AVL, VF, IsScalable});
516+
Value *GepOffset = VectorIndexPhi;
517+
518+
Value *VectorLhsGep = Builder.CreateGEP(LoadType, PtrA, GepOffset);
519+
if (GEPA->isInBounds())
520+
cast<GetElementPtrInst>(VectorLhsGep)->setIsInBounds(true);
521+
VectorType *TrueMaskTy =
522+
VectorType::get(Builder.getInt1Ty(), VectorLoadType->getElementCount());
523+
Value *AllTrueMask = Constant::getAllOnesValue(TrueMaskTy);
524+
Value *VectorLhsLoad = Builder.CreateIntrinsic(
525+
Intrinsic::vp_load, {VectorLoadType, VectorLhsGep->getType()},
526+
{VectorLhsGep, AllTrueMask, VL}, nullptr, "lhs.load");
527+
528+
Value *VectorRhsGep = Builder.CreateGEP(LoadType, PtrB, GepOffset);
529+
if (GEPB->isInBounds())
530+
cast<GetElementPtrInst>(VectorRhsGep)->setIsInBounds(true);
531+
Value *VectorRhsLoad = Builder.CreateIntrinsic(
532+
Intrinsic::vp_load, {VectorLoadType, VectorLhsGep->getType()},
533+
{VectorRhsGep, AllTrueMask, VL}, nullptr, "rhs.load");
534+
535+
StringRef PredicateStr = CmpInst::getPredicateName(CmpInst::ICMP_NE);
536+
auto *PredicateMDS = MDString::get(VectorLhsLoad->getContext(), PredicateStr);
537+
Value *Pred = MetadataAsValue::get(VectorLhsLoad->getContext(), PredicateMDS);
538+
Value *VectorMatchCmp = Builder.CreateIntrinsic(
539+
Intrinsic::vp_icmp, {VectorLhsLoad->getType()},
540+
{VectorLhsLoad, VectorRhsLoad, Pred, AllTrueMask, VL}, nullptr,
541+
"mismatch.cmp");
542+
Value *CTZ = Builder.CreateIntrinsic(
543+
Intrinsic::vp_cttz_elts, {ResType, VectorMatchCmp->getType()},
544+
{VectorMatchCmp, /*ZeroIsPoison=*/Builder.getInt1(true), AllTrueMask,
545+
VL});
546+
// RISC-V refines/lowers the poison returned by vp.cttz.elts to -1.
547+
Value *MismatchFound =
548+
Builder.CreateICmpSGE(CTZ, ConstantInt::get(ResType, 0));
549+
auto *VectorEarlyExit = BranchInst::Create(VectorLoopMismatchBlock,
550+
VectorLoopIncBlock, MismatchFound);
551+
Builder.Insert(VectorEarlyExit);
552+
553+
// Increment the index counter and calculate the predicate for the next
554+
// iteration of the loop. We branch back to the start of the loop if there
555+
// is at least one active lane.
556+
Builder.SetInsertPoint(VectorLoopIncBlock);
557+
Value *VL64 = Builder.CreateZExt(VL, I64Type);
558+
Value *NewVectorIndexPhi =
559+
Builder.CreateAdd(VectorIndexPhi, VL64, "",
560+
/*HasNUW=*/true, /*HasNSW=*/true);
561+
VectorIndexPhi->addIncoming(NewVectorIndexPhi, VectorLoopIncBlock);
562+
Value *ExitCond = Builder.CreateICmpNE(NewVectorIndexPhi, ExtEnd);
563+
auto *VectorLoopBranchBack =
564+
BranchInst::Create(VectorLoopStartBlock, EndBlock, ExitCond);
565+
Builder.Insert(VectorLoopBranchBack);
566+
567+
// If we found a mismatch then we need to calculate which lane in the vector
568+
// had a mismatch and add that on to the current loop index.
569+
Builder.SetInsertPoint(VectorLoopMismatchBlock);
570+
571+
// Add LCSSA phis for CTZ and VectorIndexPhi.
572+
auto *CTZLCSSAPhi = Builder.CreatePHI(CTZ->getType(), 1, "ctz");
573+
CTZLCSSAPhi->addIncoming(CTZ, VectorLoopStartBlock);
574+
auto *VectorIndexLCSSAPhi =
575+
Builder.CreatePHI(VectorIndexPhi->getType(), 1, "mismatch_vector_index");
576+
VectorIndexLCSSAPhi->addIncoming(VectorIndexPhi, VectorLoopStartBlock);
577+
578+
Value *CTZI64 = Builder.CreateZExt(CTZLCSSAPhi, I64Type);
579+
Value *VectorLoopRes64 = Builder.CreateAdd(VectorIndexLCSSAPhi, CTZI64, "",
580+
/*HasNUW=*/true, /*HasNSW=*/true);
581+
return Builder.CreateTrunc(VectorLoopRes64, ResType);
582+
}
583+
448584
Value *LoopIdiomVectorize::expandFindMismatch(
449585
IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
450586
GetElementPtrInst *GEPB, Instruction *Index, Value *Start, Value *MaxLen) {
@@ -604,8 +740,17 @@ Value *LoopIdiomVectorize::expandFindMismatch(
604740
// processed in each iteration, etc.
605741
Builder.SetInsertPoint(VectorLoopPreheaderBlock);
606742

607-
Value *VectorLoopRes =
608-
createMaskedFindMismatch(Builder, GEPA, GEPB, ExtStart, ExtEnd);
743+
Value *VectorLoopRes = nullptr;
744+
switch (VectorizeStyle) {
745+
case LoopIdiomVectorizeStyle::Masked:
746+
VectorLoopRes =
747+
createMaskedFindMismatch(Builder, GEPA, GEPB, ExtStart, ExtEnd);
748+
break;
749+
case LoopIdiomVectorizeStyle::Predicated:
750+
VectorLoopRes =
751+
createPredicatedFindMismatch(Builder, GEPA, GEPB, ExtStart, ExtEnd);
752+
break;
753+
}
609754

610755
Builder.Insert(BranchInst::Create(EndBlock));
611756

0 commit comments

Comments
 (0)