-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[RISCV][LoopIdiomVectorize] Support VP intrinsics in LoopIdiomVectorize #94082
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1e15813
c3700ca
3babd98
49d491d
f4fc8b8
d4fb4c0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -59,19 +59,34 @@ static cl::opt<bool> DisableAll("disable-loop-idiom-vectorize-all", cl::Hidden, | |
cl::init(false), | ||
cl::desc("Disable Loop Idiom Vectorize Pass.")); | ||
|
||
static cl::opt<LoopIdiomVectorizeStyle> | ||
LITVecStyle("loop-idiom-vectorize-style", cl::Hidden, | ||
cl::desc("The vectorization style for loop idiom transform."), | ||
cl::values(clEnumValN(LoopIdiomVectorizeStyle::Masked, "masked", | ||
"Use masked vector intrinsics"), | ||
clEnumValN(LoopIdiomVectorizeStyle::Predicated, | ||
"predicated", "Use VP intrinsics")), | ||
cl::init(LoopIdiomVectorizeStyle::Masked)); | ||
|
||
static cl::opt<bool> | ||
DisableByteCmp("disable-loop-idiom-vectorize-bytecmp", cl::Hidden, | ||
cl::init(false), | ||
cl::desc("Proceed with Loop Idiom Vectorize Pass, but do " | ||
"not convert byte-compare loop(s).")); | ||
|
||
static cl::opt<unsigned> | ||
ByteCmpVF("loop-idiom-vectorize-bytecmp-vf", cl::Hidden, | ||
cl::desc("The vectorization factor for byte-compare patterns."), | ||
cl::init(16)); | ||
|
||
static cl::opt<bool> | ||
VerifyLoops("loop-idiom-vectorize-verify", cl::Hidden, cl::init(false), | ||
cl::desc("Verify loops generated Loop Idiom Vectorize Pass.")); | ||
|
||
namespace { | ||
|
||
class LoopIdiomVectorize { | ||
LoopIdiomVectorizeStyle VectorizeStyle; | ||
unsigned ByteCompareVF; | ||
Loop *CurLoop = nullptr; | ||
DominatorTree *DT; | ||
LoopInfo *LI; | ||
|
@@ -86,10 +101,11 @@ class LoopIdiomVectorize { | |
BasicBlock *VectorLoopIncBlock = nullptr; | ||
|
||
public: | ||
explicit LoopIdiomVectorize(DominatorTree *DT, LoopInfo *LI, | ||
const TargetTransformInfo *TTI, | ||
const DataLayout *DL) | ||
: DT(DT), LI(LI), TTI(TTI), DL(DL) {} | ||
LoopIdiomVectorize(LoopIdiomVectorizeStyle S, unsigned VF, DominatorTree *DT, | ||
LoopInfo *LI, const TargetTransformInfo *TTI, | ||
const DataLayout *DL) | ||
: VectorizeStyle(S), ByteCompareVF(VF), DT(DT), LI(LI), TTI(TTI), DL(DL) { | ||
} | ||
|
||
bool run(Loop *L); | ||
|
||
|
@@ -111,6 +127,10 @@ class LoopIdiomVectorize { | |
GetElementPtrInst *GEPA, | ||
GetElementPtrInst *GEPB, Value *ExtStart, | ||
Value *ExtEnd); | ||
Value *createPredicatedFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU, | ||
GetElementPtrInst *GEPA, | ||
GetElementPtrInst *GEPB, Value *ExtStart, | ||
Value *ExtEnd); | ||
|
||
void transformByteCompare(GetElementPtrInst *GEPA, GetElementPtrInst *GEPB, | ||
PHINode *IndPhi, Value *MaxLen, Instruction *Index, | ||
|
@@ -128,8 +148,16 @@ PreservedAnalyses LoopIdiomVectorizePass::run(Loop &L, LoopAnalysisManager &AM, | |
|
||
const auto *DL = &L.getHeader()->getDataLayout(); | ||
|
||
LoopIdiomVectorize LIT(&AR.DT, &AR.LI, &AR.TTI, DL); | ||
if (!LIT.run(&L)) | ||
LoopIdiomVectorizeStyle VecStyle = VectorizeStyle; | ||
if (LITVecStyle.getNumOccurrences()) | ||
VecStyle = LITVecStyle; | ||
|
||
unsigned BCVF = ByteCompareVF; | ||
if (ByteCmpVF.getNumOccurrences()) | ||
BCVF = ByteCmpVF; | ||
|
||
LoopIdiomVectorize LIV(VecStyle, BCVF, &AR.DT, &AR.LI, &AR.TTI, DL); | ||
if (!LIV.run(&L)) | ||
return PreservedAnalyses::all(); | ||
|
||
return PreservedAnalyses::none(); | ||
|
@@ -354,20 +382,16 @@ Value *LoopIdiomVectorize::createMaskedFindMismatch( | |
Value *PtrA = GEPA->getPointerOperand(); | ||
Value *PtrB = GEPB->getPointerOperand(); | ||
|
||
// At this point we know two things must be true: | ||
// 1. Start <= End | ||
// 2. ExtMaxLen <= MinPageSize due to the page checks. | ||
// Therefore, we know that we can use a 64-bit induction variable that | ||
// starts from 0 -> ExtMaxLen and it will not overflow. | ||
ScalableVectorType *PredVTy = | ||
ScalableVectorType::get(Builder.getInt1Ty(), 16); | ||
ScalableVectorType::get(Builder.getInt1Ty(), ByteCompareVF); | ||
|
||
Value *InitialPred = Builder.CreateIntrinsic( | ||
Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd}); | ||
|
||
Value *VecLen = Builder.CreateIntrinsic(Intrinsic::vscale, {I64Type}, {}); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See IRBuilders, CreateElementCount There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This part of the patch is an extraction of the existing code into a function. We shouldn't make any changes here in this patch. |
||
VecLen = Builder.CreateMul(VecLen, ConstantInt::get(I64Type, 16), "", | ||
/*HasNUW=*/true, /*HasNSW=*/true); | ||
VecLen = | ||
Builder.CreateMul(VecLen, ConstantInt::get(I64Type, ByteCompareVF), "", | ||
/*HasNUW=*/true, /*HasNSW=*/true); | ||
|
||
Value *PFalse = Builder.CreateVectorSplat(PredVTy->getElementCount(), | ||
Builder.getInt1(false)); | ||
|
@@ -385,7 +409,8 @@ Value *LoopIdiomVectorize::createMaskedFindMismatch( | |
LoopPred->addIncoming(InitialPred, VectorLoopPreheaderBlock); | ||
PHINode *VectorIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_vec_index"); | ||
VectorIndexPhi->addIncoming(ExtStart, VectorLoopPreheaderBlock); | ||
Type *VectorLoadType = ScalableVectorType::get(Builder.getInt8Ty(), 16); | ||
Type *VectorLoadType = | ||
ScalableVectorType::get(Builder.getInt8Ty(), ByteCompareVF); | ||
Value *Passthru = ConstantInt::getNullValue(VectorLoadType); | ||
|
||
Value *VectorLhsGep = | ||
|
@@ -454,6 +479,109 @@ Value *LoopIdiomVectorize::createMaskedFindMismatch( | |
return Builder.CreateTrunc(VectorLoopRes64, ResType); | ||
} | ||
|
||
Value *LoopIdiomVectorize::createPredicatedFindMismatch( | ||
IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA, | ||
GetElementPtrInst *GEPB, Value *ExtStart, Value *ExtEnd) { | ||
Type *I64Type = Builder.getInt64Ty(); | ||
Type *I32Type = Builder.getInt32Ty(); | ||
Type *ResType = I32Type; | ||
Type *LoadType = Builder.getInt8Ty(); | ||
Value *PtrA = GEPA->getPointerOperand(); | ||
Value *PtrB = GEPB->getPointerOperand(); | ||
|
||
auto *JumpToVectorLoop = BranchInst::Create(VectorLoopStartBlock); | ||
Builder.Insert(JumpToVectorLoop); | ||
|
||
DTU.applyUpdates({{DominatorTree::Insert, VectorLoopPreheaderBlock, | ||
VectorLoopStartBlock}}); | ||
|
||
// Set up the first Vector loop block by creating the PHIs, doing the vector | ||
// loads and comparing the vectors. | ||
Builder.SetInsertPoint(VectorLoopStartBlock); | ||
auto *VectorIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_vector_index"); | ||
VectorIndexPhi->addIncoming(ExtStart, VectorLoopPreheaderBlock); | ||
|
||
// Calculate AVL by subtracting the vector loop index from the trip count | ||
Value *AVL = Builder.CreateSub(ExtEnd, VectorIndexPhi, "avl", /*HasNUW=*/true, | ||
/*HasNSW=*/true); | ||
|
||
auto *VectorLoadType = ScalableVectorType::get(LoadType, ByteCompareVF); | ||
auto *VF = ConstantInt::get(I32Type, ByteCompareVF); | ||
|
||
Value *VL = Builder.CreateIntrinsic(Intrinsic::experimental_get_vector_length, | ||
{I64Type}, {AVL, VF, Builder.getTrue()}); | ||
Value *GepOffset = VectorIndexPhi; | ||
|
||
Value *VectorLhsGep = | ||
Builder.CreateGEP(LoadType, PtrA, GepOffset, "", GEPA->isInBounds()); | ||
VectorType *TrueMaskTy = | ||
VectorType::get(Builder.getInt1Ty(), VectorLoadType->getElementCount()); | ||
Value *AllTrueMask = Constant::getAllOnesValue(TrueMaskTy); | ||
Value *VectorLhsLoad = Builder.CreateIntrinsic( | ||
Intrinsic::vp_load, {VectorLoadType, VectorLhsGep->getType()}, | ||
{VectorLhsGep, AllTrueMask, VL}, nullptr, "lhs.load"); | ||
|
||
Value *VectorRhsGep = | ||
Builder.CreateGEP(LoadType, PtrB, GepOffset, "", GEPB->isInBounds()); | ||
Value *VectorRhsLoad = Builder.CreateIntrinsic( | ||
Intrinsic::vp_load, {VectorLoadType, VectorLhsGep->getType()}, | ||
{VectorRhsGep, AllTrueMask, VL}, nullptr, "rhs.load"); | ||
|
||
StringRef PredicateStr = CmpInst::getPredicateName(CmpInst::ICMP_NE); | ||
auto *PredicateMDS = MDString::get(VectorLhsLoad->getContext(), PredicateStr); | ||
Value *Pred = MetadataAsValue::get(VectorLhsLoad->getContext(), PredicateMDS); | ||
Value *VectorMatchCmp = Builder.CreateIntrinsic( | ||
Intrinsic::vp_icmp, {VectorLhsLoad->getType()}, | ||
{VectorLhsLoad, VectorRhsLoad, Pred, AllTrueMask, VL}, nullptr, | ||
"mismatch.cmp"); | ||
Value *CTZ = Builder.CreateIntrinsic( | ||
Intrinsic::vp_cttz_elts, {ResType, VectorMatchCmp->getType()}, | ||
{VectorMatchCmp, /*ZeroIsPoison=*/Builder.getInt1(false), AllTrueMask, | ||
VL}); | ||
Value *MismatchFound = Builder.CreateICmpNE(CTZ, VL); | ||
auto *VectorEarlyExit = BranchInst::Create(VectorLoopMismatchBlock, | ||
VectorLoopIncBlock, MismatchFound); | ||
Builder.Insert(VectorEarlyExit); | ||
|
||
DTU.applyUpdates( | ||
{{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock}, | ||
{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}}); | ||
|
||
// Increment the index counter and calculate the predicate for the next | ||
// iteration of the loop. We branch back to the start of the loop if there | ||
// is at least one active lane. | ||
Builder.SetInsertPoint(VectorLoopIncBlock); | ||
Value *VL64 = Builder.CreateZExt(VL, I64Type); | ||
Value *NewVectorIndexPhi = | ||
Builder.CreateAdd(VectorIndexPhi, VL64, "", | ||
/*HasNUW=*/true, /*HasNSW=*/true); | ||
VectorIndexPhi->addIncoming(NewVectorIndexPhi, VectorLoopIncBlock); | ||
Value *ExitCond = Builder.CreateICmpNE(NewVectorIndexPhi, ExtEnd); | ||
auto *VectorLoopBranchBack = | ||
BranchInst::Create(VectorLoopStartBlock, EndBlock, ExitCond); | ||
Builder.Insert(VectorLoopBranchBack); | ||
|
||
DTU.applyUpdates( | ||
{{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock}, | ||
{DominatorTree::Insert, VectorLoopIncBlock, EndBlock}}); | ||
|
||
// If we found a mismatch then we need to calculate which lane in the vector | ||
// had a mismatch and add that on to the current loop index. | ||
Builder.SetInsertPoint(VectorLoopMismatchBlock); | ||
|
||
// Add LCSSA phis for CTZ and VectorIndexPhi. | ||
auto *CTZLCSSAPhi = Builder.CreatePHI(CTZ->getType(), 1, "ctz"); | ||
CTZLCSSAPhi->addIncoming(CTZ, VectorLoopStartBlock); | ||
auto *VectorIndexLCSSAPhi = | ||
Builder.CreatePHI(VectorIndexPhi->getType(), 1, "mismatch_vector_index"); | ||
VectorIndexLCSSAPhi->addIncoming(VectorIndexPhi, VectorLoopStartBlock); | ||
|
||
Value *CTZI64 = Builder.CreateZExt(CTZLCSSAPhi, I64Type); | ||
Value *VectorLoopRes64 = Builder.CreateAdd(VectorIndexLCSSAPhi, CTZI64, "", | ||
/*HasNUW=*/true, /*HasNSW=*/true); | ||
return Builder.CreateTrunc(VectorLoopRes64, ResType); | ||
} | ||
|
||
Value *LoopIdiomVectorize::expandFindMismatch( | ||
IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA, | ||
GetElementPtrInst *GEPB, Instruction *Index, Value *Start, Value *MaxLen) { | ||
|
@@ -613,8 +741,22 @@ Value *LoopIdiomVectorize::expandFindMismatch( | |
// processed in each iteration, etc. | ||
Builder.SetInsertPoint(VectorLoopPreheaderBlock); | ||
|
||
Value *VectorLoopRes = | ||
createMaskedFindMismatch(Builder, DTU, GEPA, GEPB, ExtStart, ExtEnd); | ||
// At this point we know two things must be true: | ||
// 1. Start <= End | ||
// 2. ExtMaxLen <= MinPageSize due to the page checks. | ||
// Therefore, we know that we can use a 64-bit induction variable that | ||
// starts from 0 -> ExtMaxLen and it will not overflow. | ||
Value *VectorLoopRes = nullptr; | ||
switch (VectorizeStyle) { | ||
case LoopIdiomVectorizeStyle::Masked: | ||
VectorLoopRes = | ||
createMaskedFindMismatch(Builder, DTU, GEPA, GEPB, ExtStart, ExtEnd); | ||
break; | ||
case LoopIdiomVectorizeStyle::Predicated: | ||
VectorLoopRes = createPredicatedFindMismatch(Builder, DTU, GEPA, GEPB, | ||
ExtStart, ExtEnd); | ||
break; | ||
} | ||
|
||
Builder.Insert(BranchInst::Create(EndBlock)); | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: LoopIdiomVectorize is the only user of this TTI hook.