Skip to content

Commit 0be0a01

Browse files
committed
[LoopIdiomVectorize][NFC] Factoring out the part that handles vectorization strategy
To pave the way for porting LIV to RISC-V, which uses VP intrinsics for vectors. NFC.
1 parent 4db9d37 commit 0be0a01

File tree

1 file changed

+123
-96
lines changed

1 file changed

+123
-96
lines changed

llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp

Lines changed: 123 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,13 @@ class LoopIdiomVectorize {
7979
const TargetTransformInfo *TTI;
8080
const DataLayout *DL;
8181

82+
// Blocks that will be used for inserting vectorized code.
83+
BasicBlock *EndBlock = nullptr;
84+
BasicBlock *VectorLoopPreheaderBlock = nullptr;
85+
BasicBlock *VectorLoopStartBlock = nullptr;
86+
BasicBlock *VectorLoopMismatchBlock = nullptr;
87+
BasicBlock *VectorLoopIncBlock = nullptr;
88+
8289
public:
8390
explicit LoopIdiomVectorize(DominatorTree *DT, LoopInfo *LI,
8491
const TargetTransformInfo *TTI,
@@ -96,9 +103,15 @@ class LoopIdiomVectorize {
96103
SmallVectorImpl<BasicBlock *> &ExitBlocks);
97104

98105
bool recognizeByteCompare();
106+
99107
Value *expandFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU,
100108
GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
101109
Instruction *Index, Value *Start, Value *MaxLen);
110+
111+
Value *createMaskedFindMismatch(IRBuilder<> &Builder, GetElementPtrInst *GEPA,
112+
GetElementPtrInst *GEPB, Value *ExtStart,
113+
Value *ExtEnd);
114+
102115
void transformByteCompare(GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
103116
PHINode *IndPhi, Value *MaxLen, Instruction *Index,
104117
Value *Start, bool IncIdx, BasicBlock *FoundBB,
@@ -332,6 +345,106 @@ bool LoopIdiomVectorize::recognizeByteCompare() {
332345
return true;
333346
}
334347

348+
Value *LoopIdiomVectorize::createMaskedFindMismatch(IRBuilder<> &Builder,
349+
GetElementPtrInst *GEPA,
350+
GetElementPtrInst *GEPB,
351+
Value *ExtStart,
352+
Value *ExtEnd) {
353+
Type *I64Type = Builder.getInt64Ty();
354+
Type *ResType = Builder.getInt32Ty();
355+
Type *LoadType = Builder.getInt8Ty();
356+
Value *PtrA = GEPA->getPointerOperand();
357+
Value *PtrB = GEPB->getPointerOperand();
358+
359+
// At this point we know two things must be true:
360+
// 1. Start <= End
361+
// 2. ExtMaxLen <= MinPageSize due to the page checks.
362+
// Therefore, we know that we can use a 64-bit induction variable that
363+
// starts from 0 -> ExtMaxLen and it will not overflow.
364+
ScalableVectorType *PredVTy =
365+
ScalableVectorType::get(Builder.getInt1Ty(), 16);
366+
367+
Value *InitialPred = Builder.CreateIntrinsic(
368+
Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});
369+
370+
Value *VecLen = Builder.CreateIntrinsic(Intrinsic::vscale, {I64Type}, {});
371+
VecLen = Builder.CreateMul(VecLen, ConstantInt::get(I64Type, 16), "",
372+
/*HasNUW=*/true, /*HasNSW=*/true);
373+
374+
Value *PFalse = Builder.CreateVectorSplat(PredVTy->getElementCount(),
375+
Builder.getInt1(false));
376+
377+
BranchInst *JumpToVectorLoop = BranchInst::Create(VectorLoopStartBlock);
378+
Builder.Insert(JumpToVectorLoop);
379+
380+
// Set up the first vector loop block by creating the PHIs, doing the vector
381+
// loads and comparing the vectors.
382+
Builder.SetInsertPoint(VectorLoopStartBlock);
383+
PHINode *LoopPred = Builder.CreatePHI(PredVTy, 2, "mismatch_vec_loop_pred");
384+
LoopPred->addIncoming(InitialPred, VectorLoopPreheaderBlock);
385+
PHINode *VectorIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_vec_index");
386+
VectorIndexPhi->addIncoming(ExtStart, VectorLoopPreheaderBlock);
387+
Type *VectorLoadType = ScalableVectorType::get(Builder.getInt8Ty(), 16);
388+
Value *Passthru = ConstantInt::getNullValue(VectorLoadType);
389+
390+
Value *VectorLhsGep =
391+
Builder.CreateGEP(LoadType, PtrA, VectorIndexPhi, "", GEPA->isInBounds());
392+
Value *VectorLhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorLhsGep,
393+
Align(1), LoopPred, Passthru);
394+
395+
Value *VectorRhsGep =
396+
Builder.CreateGEP(LoadType, PtrB, VectorIndexPhi, "", GEPB->isInBounds());
397+
Value *VectorRhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorRhsGep,
398+
Align(1), LoopPred, Passthru);
399+
400+
Value *VectorMatchCmp = Builder.CreateICmpNE(VectorLhsLoad, VectorRhsLoad);
401+
VectorMatchCmp = Builder.CreateSelect(LoopPred, VectorMatchCmp, PFalse);
402+
Value *VectorMatchHasActiveLanes = Builder.CreateOrReduce(VectorMatchCmp);
403+
BranchInst *VectorEarlyExit = BranchInst::Create(
404+
VectorLoopMismatchBlock, VectorLoopIncBlock, VectorMatchHasActiveLanes);
405+
Builder.Insert(VectorEarlyExit);
406+
407+
// Increment the index counter and calculate the predicate for the next
408+
// iteration of the loop. We branch back to the start of the loop if there
409+
// is at least one active lane.
410+
Builder.SetInsertPoint(VectorLoopIncBlock);
411+
Value *NewVectorIndexPhi =
412+
Builder.CreateAdd(VectorIndexPhi, VecLen, "",
413+
/*HasNUW=*/true, /*HasNSW=*/true);
414+
VectorIndexPhi->addIncoming(NewVectorIndexPhi, VectorLoopIncBlock);
415+
Value *NewPred =
416+
Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask,
417+
{PredVTy, I64Type}, {NewVectorIndexPhi, ExtEnd});
418+
LoopPred->addIncoming(NewPred, VectorLoopIncBlock);
419+
420+
Value *PredHasActiveLanes =
421+
Builder.CreateExtractElement(NewPred, uint64_t(0));
422+
BranchInst *VectorLoopBranchBack =
423+
BranchInst::Create(VectorLoopStartBlock, EndBlock, PredHasActiveLanes);
424+
Builder.Insert(VectorLoopBranchBack);
425+
426+
// If we found a mismatch then we need to calculate which lane in the vector
427+
// had a mismatch and add that on to the current loop index.
428+
Builder.SetInsertPoint(VectorLoopMismatchBlock);
429+
PHINode *FoundPred = Builder.CreatePHI(PredVTy, 1, "mismatch_vec_found_pred");
430+
FoundPred->addIncoming(VectorMatchCmp, VectorLoopStartBlock);
431+
PHINode *LastLoopPred =
432+
Builder.CreatePHI(PredVTy, 1, "mismatch_vec_last_loop_pred");
433+
LastLoopPred->addIncoming(LoopPred, VectorLoopStartBlock);
434+
PHINode *VectorFoundIndex =
435+
Builder.CreatePHI(I64Type, 1, "mismatch_vec_found_index");
436+
VectorFoundIndex->addIncoming(VectorIndexPhi, VectorLoopStartBlock);
437+
438+
Value *PredMatchCmp = Builder.CreateAnd(LastLoopPred, FoundPred);
439+
Value *Ctz = Builder.CreateIntrinsic(
440+
Intrinsic::experimental_cttz_elts, {ResType, PredMatchCmp->getType()},
441+
{PredMatchCmp, /*ZeroIsPoison=*/Builder.getInt1(true)});
442+
Ctz = Builder.CreateZExt(Ctz, I64Type);
443+
Value *VectorLoopRes64 = Builder.CreateAdd(VectorFoundIndex, Ctz, "",
444+
/*HasNUW=*/true, /*HasNSW=*/true);
445+
return Builder.CreateTrunc(VectorLoopRes64, ResType);
446+
}
447+
335448
Value *LoopIdiomVectorize::expandFindMismatch(
336449
IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
337450
GetElementPtrInst *GEPB, Instruction *Index, Value *Start, Value *MaxLen) {
@@ -346,8 +459,7 @@ Value *LoopIdiomVectorize::expandFindMismatch(
346459
Type *ResType = Builder.getInt32Ty();
347460

348461
// Split block in the original loop preheader.
349-
BasicBlock *EndBlock =
350-
SplitBlock(Preheader, PHBranch, DT, LI, nullptr, "mismatch_end");
462+
EndBlock = SplitBlock(Preheader, PHBranch, DT, LI, nullptr, "mismatch_end");
351463

352464
// Create the blocks that we're going to need:
353465
// 1. A block for checking the zero-extended length exceeds 0
@@ -371,17 +483,17 @@ Value *LoopIdiomVectorize::expandFindMismatch(
371483
BasicBlock *MemCheckBlock = BasicBlock::Create(
372484
Ctx, "mismatch_mem_check", EndBlock->getParent(), EndBlock);
373485

374-
BasicBlock *VectorLoopPreheaderBlock = BasicBlock::Create(
486+
VectorLoopPreheaderBlock = BasicBlock::Create(
375487
Ctx, "mismatch_vec_loop_preheader", EndBlock->getParent(), EndBlock);
376488

377-
BasicBlock *VectorLoopStartBlock = BasicBlock::Create(
378-
Ctx, "mismatch_vec_loop", EndBlock->getParent(), EndBlock);
489+
VectorLoopStartBlock = BasicBlock::Create(Ctx, "mismatch_vec_loop",
490+
EndBlock->getParent(), EndBlock);
379491

380-
BasicBlock *VectorLoopIncBlock = BasicBlock::Create(
381-
Ctx, "mismatch_vec_loop_inc", EndBlock->getParent(), EndBlock);
492+
VectorLoopIncBlock = BasicBlock::Create(Ctx, "mismatch_vec_loop_inc",
493+
EndBlock->getParent(), EndBlock);
382494

383-
BasicBlock *VectorLoopMismatchBlock = BasicBlock::Create(
384-
Ctx, "mismatch_vec_loop_found", EndBlock->getParent(), EndBlock);
495+
VectorLoopMismatchBlock = BasicBlock::Create(Ctx, "mismatch_vec_loop_found",
496+
EndBlock->getParent(), EndBlock);
385497

386498
BasicBlock *LoopPreHeaderBlock = BasicBlock::Create(
387499
Ctx, "mismatch_loop_pre", EndBlock->getParent(), EndBlock);
@@ -492,93 +604,8 @@ Value *LoopIdiomVectorize::expandFindMismatch(
492604
// processed in each iteration, etc.
493605
Builder.SetInsertPoint(VectorLoopPreheaderBlock);
494606

495-
// At this point we know two things must be true:
496-
// 1. Start <= End
497-
// 2. ExtMaxLen <= MinPageSize due to the page checks.
498-
// Therefore, we know that we can use a 64-bit induction variable that
499-
// starts from 0 -> ExtMaxLen and it will not overflow.
500-
ScalableVectorType *PredVTy =
501-
ScalableVectorType::get(Builder.getInt1Ty(), 16);
502-
503-
Value *InitialPred = Builder.CreateIntrinsic(
504-
Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});
505-
506-
Value *VecLen = Builder.CreateIntrinsic(Intrinsic::vscale, {I64Type}, {});
507-
VecLen = Builder.CreateMul(VecLen, ConstantInt::get(I64Type, 16), "",
508-
/*HasNUW=*/true, /*HasNSW=*/true);
509-
510-
Value *PFalse = Builder.CreateVectorSplat(PredVTy->getElementCount(),
511-
Builder.getInt1(false));
512-
513-
BranchInst *JumpToVectorLoop = BranchInst::Create(VectorLoopStartBlock);
514-
Builder.Insert(JumpToVectorLoop);
515-
516-
// Set up the first vector loop block by creating the PHIs, doing the vector
517-
// loads and comparing the vectors.
518-
Builder.SetInsertPoint(VectorLoopStartBlock);
519-
PHINode *LoopPred = Builder.CreatePHI(PredVTy, 2, "mismatch_vec_loop_pred");
520-
LoopPred->addIncoming(InitialPred, VectorLoopPreheaderBlock);
521-
PHINode *VectorIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_vec_index");
522-
VectorIndexPhi->addIncoming(ExtStart, VectorLoopPreheaderBlock);
523-
Type *VectorLoadType = ScalableVectorType::get(Builder.getInt8Ty(), 16);
524-
Value *Passthru = ConstantInt::getNullValue(VectorLoadType);
525-
526-
Value *VectorLhsGep =
527-
Builder.CreateGEP(LoadType, PtrA, VectorIndexPhi, "", GEPA->isInBounds());
528-
Value *VectorLhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorLhsGep,
529-
Align(1), LoopPred, Passthru);
530-
531-
Value *VectorRhsGep =
532-
Builder.CreateGEP(LoadType, PtrB, VectorIndexPhi, "", GEPB->isInBounds());
533-
Value *VectorRhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorRhsGep,
534-
Align(1), LoopPred, Passthru);
535-
536-
Value *VectorMatchCmp = Builder.CreateICmpNE(VectorLhsLoad, VectorRhsLoad);
537-
VectorMatchCmp = Builder.CreateSelect(LoopPred, VectorMatchCmp, PFalse);
538-
Value *VectorMatchHasActiveLanes = Builder.CreateOrReduce(VectorMatchCmp);
539-
BranchInst *VectorEarlyExit = BranchInst::Create(
540-
VectorLoopMismatchBlock, VectorLoopIncBlock, VectorMatchHasActiveLanes);
541-
Builder.Insert(VectorEarlyExit);
542-
543-
// Increment the index counter and calculate the predicate for the next
544-
// iteration of the loop. We branch back to the start of the loop if there
545-
// is at least one active lane.
546-
Builder.SetInsertPoint(VectorLoopIncBlock);
547-
Value *NewVectorIndexPhi =
548-
Builder.CreateAdd(VectorIndexPhi, VecLen, "",
549-
/*HasNUW=*/true, /*HasNSW=*/true);
550-
VectorIndexPhi->addIncoming(NewVectorIndexPhi, VectorLoopIncBlock);
551-
Value *NewPred =
552-
Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask,
553-
{PredVTy, I64Type}, {NewVectorIndexPhi, ExtEnd});
554-
LoopPred->addIncoming(NewPred, VectorLoopIncBlock);
555-
556-
Value *PredHasActiveLanes =
557-
Builder.CreateExtractElement(NewPred, uint64_t(0));
558-
BranchInst *VectorLoopBranchBack =
559-
BranchInst::Create(VectorLoopStartBlock, EndBlock, PredHasActiveLanes);
560-
Builder.Insert(VectorLoopBranchBack);
561-
562-
// If we found a mismatch then we need to calculate which lane in the vector
563-
// had a mismatch and add that on to the current loop index.
564-
Builder.SetInsertPoint(VectorLoopMismatchBlock);
565-
PHINode *FoundPred = Builder.CreatePHI(PredVTy, 1, "mismatch_vec_found_pred");
566-
FoundPred->addIncoming(VectorMatchCmp, VectorLoopStartBlock);
567-
PHINode *LastLoopPred =
568-
Builder.CreatePHI(PredVTy, 1, "mismatch_vec_last_loop_pred");
569-
LastLoopPred->addIncoming(LoopPred, VectorLoopStartBlock);
570-
PHINode *VectorFoundIndex =
571-
Builder.CreatePHI(I64Type, 1, "mismatch_vec_found_index");
572-
VectorFoundIndex->addIncoming(VectorIndexPhi, VectorLoopStartBlock);
573-
574-
Value *PredMatchCmp = Builder.CreateAnd(LastLoopPred, FoundPred);
575-
Value *Ctz = Builder.CreateIntrinsic(
576-
Intrinsic::experimental_cttz_elts, {ResType, PredMatchCmp->getType()},
577-
{PredMatchCmp, /*ZeroIsPoison=*/Builder.getInt1(true)});
578-
Ctz = Builder.CreateZExt(Ctz, I64Type);
579-
Value *VectorLoopRes64 = Builder.CreateAdd(VectorFoundIndex, Ctz, "",
580-
/*HasNUW=*/true, /*HasNSW=*/true);
581-
Value *VectorLoopRes = Builder.CreateTrunc(VectorLoopRes64, ResType);
607+
Value *VectorLoopRes =
608+
createMaskedFindMismatch(Builder, GEPA, GEPB, ExtStart, ExtEnd);
582609

583610
Builder.Insert(BranchInst::Create(EndBlock));
584611

0 commit comments

Comments
 (0)