Skip to content

Commit de5ff38

Browse files
authored
[LoopIdiomVectorize][NFC] Factoring out the part that handles vectorization strategy (#94682)
To pave the way for porting LIV to RISC-V, which uses VP intrinsics for vectors. NFC.
1 parent 0856064 commit de5ff38

File tree

1 file changed

+133
-107
lines changed

1 file changed

+133
-107
lines changed

llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp

Lines changed: 133 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,13 @@ class LoopIdiomVectorize {
7878
const TargetTransformInfo *TTI;
7979
const DataLayout *DL;
8080

81+
// Blocks that will be used for inserting vectorized code.
82+
BasicBlock *EndBlock = nullptr;
83+
BasicBlock *VectorLoopPreheaderBlock = nullptr;
84+
BasicBlock *VectorLoopStartBlock = nullptr;
85+
BasicBlock *VectorLoopMismatchBlock = nullptr;
86+
BasicBlock *VectorLoopIncBlock = nullptr;
87+
8188
public:
8289
explicit LoopIdiomVectorize(DominatorTree *DT, LoopInfo *LI,
8390
const TargetTransformInfo *TTI,
@@ -95,9 +102,16 @@ class LoopIdiomVectorize {
95102
SmallVectorImpl<BasicBlock *> &ExitBlocks);
96103

97104
bool recognizeByteCompare();
105+
98106
Value *expandFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU,
99107
GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
100108
Instruction *Index, Value *Start, Value *MaxLen);
109+
110+
Value *createMaskedFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU,
111+
GetElementPtrInst *GEPA,
112+
GetElementPtrInst *GEPB, Value *ExtStart,
113+
Value *ExtEnd);
114+
101115
void transformByteCompare(GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
102116
PHINode *IndPhi, Value *MaxLen, Instruction *Index,
103117
Value *Start, bool IncIdx, BasicBlock *FoundBB,
@@ -331,6 +345,115 @@ bool LoopIdiomVectorize::recognizeByteCompare() {
331345
return true;
332346
}
333347

348+
Value *LoopIdiomVectorize::createMaskedFindMismatch(
349+
IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
350+
GetElementPtrInst *GEPB, Value *ExtStart, Value *ExtEnd) {
351+
Type *I64Type = Builder.getInt64Ty();
352+
Type *ResType = Builder.getInt32Ty();
353+
Type *LoadType = Builder.getInt8Ty();
354+
Value *PtrA = GEPA->getPointerOperand();
355+
Value *PtrB = GEPB->getPointerOperand();
356+
357+
// At this point we know two things must be true:
358+
// 1. Start <= End
359+
// 2. ExtMaxLen <= MinPageSize due to the page checks.
360+
// Therefore, we know that we can use a 64-bit induction variable that
361+
// starts from 0 -> ExtMaxLen and it will not overflow.
362+
ScalableVectorType *PredVTy =
363+
ScalableVectorType::get(Builder.getInt1Ty(), 16);
364+
365+
Value *InitialPred = Builder.CreateIntrinsic(
366+
Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});
367+
368+
Value *VecLen = Builder.CreateIntrinsic(Intrinsic::vscale, {I64Type}, {});
369+
VecLen = Builder.CreateMul(VecLen, ConstantInt::get(I64Type, 16), "",
370+
/*HasNUW=*/true, /*HasNSW=*/true);
371+
372+
Value *PFalse = Builder.CreateVectorSplat(PredVTy->getElementCount(),
373+
Builder.getInt1(false));
374+
375+
BranchInst *JumpToVectorLoop = BranchInst::Create(VectorLoopStartBlock);
376+
Builder.Insert(JumpToVectorLoop);
377+
378+
DTU.applyUpdates({{DominatorTree::Insert, VectorLoopPreheaderBlock,
379+
VectorLoopStartBlock}});
380+
381+
// Set up the first vector loop block by creating the PHIs, doing the vector
382+
// loads and comparing the vectors.
383+
Builder.SetInsertPoint(VectorLoopStartBlock);
384+
PHINode *LoopPred = Builder.CreatePHI(PredVTy, 2, "mismatch_vec_loop_pred");
385+
LoopPred->addIncoming(InitialPred, VectorLoopPreheaderBlock);
386+
PHINode *VectorIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_vec_index");
387+
VectorIndexPhi->addIncoming(ExtStart, VectorLoopPreheaderBlock);
388+
Type *VectorLoadType = ScalableVectorType::get(Builder.getInt8Ty(), 16);
389+
Value *Passthru = ConstantInt::getNullValue(VectorLoadType);
390+
391+
Value *VectorLhsGep =
392+
Builder.CreateGEP(LoadType, PtrA, VectorIndexPhi, "", GEPA->isInBounds());
393+
Value *VectorLhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorLhsGep,
394+
Align(1), LoopPred, Passthru);
395+
396+
Value *VectorRhsGep =
397+
Builder.CreateGEP(LoadType, PtrB, VectorIndexPhi, "", GEPB->isInBounds());
398+
Value *VectorRhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorRhsGep,
399+
Align(1), LoopPred, Passthru);
400+
401+
Value *VectorMatchCmp = Builder.CreateICmpNE(VectorLhsLoad, VectorRhsLoad);
402+
VectorMatchCmp = Builder.CreateSelect(LoopPred, VectorMatchCmp, PFalse);
403+
Value *VectorMatchHasActiveLanes = Builder.CreateOrReduce(VectorMatchCmp);
404+
BranchInst *VectorEarlyExit = BranchInst::Create(
405+
VectorLoopMismatchBlock, VectorLoopIncBlock, VectorMatchHasActiveLanes);
406+
Builder.Insert(VectorEarlyExit);
407+
408+
DTU.applyUpdates(
409+
{{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock},
410+
{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}});
411+
412+
// Increment the index counter and calculate the predicate for the next
413+
// iteration of the loop. We branch back to the start of the loop if there
414+
// is at least one active lane.
415+
Builder.SetInsertPoint(VectorLoopIncBlock);
416+
Value *NewVectorIndexPhi =
417+
Builder.CreateAdd(VectorIndexPhi, VecLen, "",
418+
/*HasNUW=*/true, /*HasNSW=*/true);
419+
VectorIndexPhi->addIncoming(NewVectorIndexPhi, VectorLoopIncBlock);
420+
Value *NewPred =
421+
Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask,
422+
{PredVTy, I64Type}, {NewVectorIndexPhi, ExtEnd});
423+
LoopPred->addIncoming(NewPred, VectorLoopIncBlock);
424+
425+
Value *PredHasActiveLanes =
426+
Builder.CreateExtractElement(NewPred, uint64_t(0));
427+
BranchInst *VectorLoopBranchBack =
428+
BranchInst::Create(VectorLoopStartBlock, EndBlock, PredHasActiveLanes);
429+
Builder.Insert(VectorLoopBranchBack);
430+
431+
DTU.applyUpdates(
432+
{{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock},
433+
{DominatorTree::Insert, VectorLoopIncBlock, EndBlock}});
434+
435+
// If we found a mismatch then we need to calculate which lane in the vector
436+
// had a mismatch and add that on to the current loop index.
437+
Builder.SetInsertPoint(VectorLoopMismatchBlock);
438+
PHINode *FoundPred = Builder.CreatePHI(PredVTy, 1, "mismatch_vec_found_pred");
439+
FoundPred->addIncoming(VectorMatchCmp, VectorLoopStartBlock);
440+
PHINode *LastLoopPred =
441+
Builder.CreatePHI(PredVTy, 1, "mismatch_vec_last_loop_pred");
442+
LastLoopPred->addIncoming(LoopPred, VectorLoopStartBlock);
443+
PHINode *VectorFoundIndex =
444+
Builder.CreatePHI(I64Type, 1, "mismatch_vec_found_index");
445+
VectorFoundIndex->addIncoming(VectorIndexPhi, VectorLoopStartBlock);
446+
447+
Value *PredMatchCmp = Builder.CreateAnd(LastLoopPred, FoundPred);
448+
Value *Ctz = Builder.CreateIntrinsic(
449+
Intrinsic::experimental_cttz_elts, {ResType, PredMatchCmp->getType()},
450+
{PredMatchCmp, /*ZeroIsPoison=*/Builder.getInt1(true)});
451+
Ctz = Builder.CreateZExt(Ctz, I64Type);
452+
Value *VectorLoopRes64 = Builder.CreateAdd(VectorFoundIndex, Ctz, "",
453+
/*HasNUW=*/true, /*HasNSW=*/true);
454+
return Builder.CreateTrunc(VectorLoopRes64, ResType);
455+
}
456+
334457
Value *LoopIdiomVectorize::expandFindMismatch(
335458
IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
336459
GetElementPtrInst *GEPB, Instruction *Index, Value *Start, Value *MaxLen) {
@@ -345,8 +468,7 @@ Value *LoopIdiomVectorize::expandFindMismatch(
345468
Type *ResType = Builder.getInt32Ty();
346469

347470
// Split block in the original loop preheader.
348-
BasicBlock *EndBlock =
349-
SplitBlock(Preheader, PHBranch, DT, LI, nullptr, "mismatch_end");
471+
EndBlock = SplitBlock(Preheader, PHBranch, DT, LI, nullptr, "mismatch_end");
350472

351473
// Create the blocks that we're going to need:
352474
// 1. A block for checking the zero-extended length exceeds 0
@@ -370,17 +492,17 @@ Value *LoopIdiomVectorize::expandFindMismatch(
370492
BasicBlock *MemCheckBlock = BasicBlock::Create(
371493
Ctx, "mismatch_mem_check", EndBlock->getParent(), EndBlock);
372494

373-
BasicBlock *VectorLoopPreheaderBlock = BasicBlock::Create(
495+
VectorLoopPreheaderBlock = BasicBlock::Create(
374496
Ctx, "mismatch_vec_loop_preheader", EndBlock->getParent(), EndBlock);
375497

376-
BasicBlock *VectorLoopStartBlock = BasicBlock::Create(
377-
Ctx, "mismatch_vec_loop", EndBlock->getParent(), EndBlock);
498+
VectorLoopStartBlock = BasicBlock::Create(Ctx, "mismatch_vec_loop",
499+
EndBlock->getParent(), EndBlock);
378500

379-
BasicBlock *VectorLoopIncBlock = BasicBlock::Create(
380-
Ctx, "mismatch_vec_loop_inc", EndBlock->getParent(), EndBlock);
501+
VectorLoopIncBlock = BasicBlock::Create(Ctx, "mismatch_vec_loop_inc",
502+
EndBlock->getParent(), EndBlock);
381503

382-
BasicBlock *VectorLoopMismatchBlock = BasicBlock::Create(
383-
Ctx, "mismatch_vec_loop_found", EndBlock->getParent(), EndBlock);
504+
VectorLoopMismatchBlock = BasicBlock::Create(Ctx, "mismatch_vec_loop_found",
505+
EndBlock->getParent(), EndBlock);
384506

385507
BasicBlock *LoopPreHeaderBlock = BasicBlock::Create(
386508
Ctx, "mismatch_loop_pre", EndBlock->getParent(), EndBlock);
@@ -491,104 +613,8 @@ Value *LoopIdiomVectorize::expandFindMismatch(
491613
// processed in each iteration, etc.
492614
Builder.SetInsertPoint(VectorLoopPreheaderBlock);
493615

494-
// At this point we know two things must be true:
495-
// 1. Start <= End
496-
// 2. ExtMaxLen <= MinPageSize due to the page checks.
497-
// Therefore, we know that we can use a 64-bit induction variable that
498-
// starts from 0 -> ExtMaxLen and it will not overflow.
499-
ScalableVectorType *PredVTy =
500-
ScalableVectorType::get(Builder.getInt1Ty(), 16);
501-
502-
Value *InitialPred = Builder.CreateIntrinsic(
503-
Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});
504-
505-
Value *VecLen = Builder.CreateIntrinsic(Intrinsic::vscale, {I64Type}, {});
506-
VecLen = Builder.CreateMul(VecLen, ConstantInt::get(I64Type, 16), "",
507-
/*HasNUW=*/true, /*HasNSW=*/true);
508-
509-
Value *PFalse = Builder.CreateVectorSplat(PredVTy->getElementCount(),
510-
Builder.getInt1(false));
511-
512-
BranchInst *JumpToVectorLoop = BranchInst::Create(VectorLoopStartBlock);
513-
Builder.Insert(JumpToVectorLoop);
514-
515-
DTU.applyUpdates({{DominatorTree::Insert, VectorLoopPreheaderBlock,
516-
VectorLoopStartBlock}});
517-
518-
// Set up the first vector loop block by creating the PHIs, doing the vector
519-
// loads and comparing the vectors.
520-
Builder.SetInsertPoint(VectorLoopStartBlock);
521-
PHINode *LoopPred = Builder.CreatePHI(PredVTy, 2, "mismatch_vec_loop_pred");
522-
LoopPred->addIncoming(InitialPred, VectorLoopPreheaderBlock);
523-
PHINode *VectorIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_vec_index");
524-
VectorIndexPhi->addIncoming(ExtStart, VectorLoopPreheaderBlock);
525-
Type *VectorLoadType = ScalableVectorType::get(Builder.getInt8Ty(), 16);
526-
Value *Passthru = ConstantInt::getNullValue(VectorLoadType);
527-
528-
Value *VectorLhsGep =
529-
Builder.CreateGEP(LoadType, PtrA, VectorIndexPhi, "", GEPA->isInBounds());
530-
Value *VectorLhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorLhsGep,
531-
Align(1), LoopPred, Passthru);
532-
533-
Value *VectorRhsGep =
534-
Builder.CreateGEP(LoadType, PtrB, VectorIndexPhi, "", GEPB->isInBounds());
535-
Value *VectorRhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorRhsGep,
536-
Align(1), LoopPred, Passthru);
537-
538-
Value *VectorMatchCmp = Builder.CreateICmpNE(VectorLhsLoad, VectorRhsLoad);
539-
VectorMatchCmp = Builder.CreateSelect(LoopPred, VectorMatchCmp, PFalse);
540-
Value *VectorMatchHasActiveLanes = Builder.CreateOrReduce(VectorMatchCmp);
541-
BranchInst *VectorEarlyExit = BranchInst::Create(
542-
VectorLoopMismatchBlock, VectorLoopIncBlock, VectorMatchHasActiveLanes);
543-
Builder.Insert(VectorEarlyExit);
544-
545-
DTU.applyUpdates(
546-
{{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock},
547-
{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}});
548-
549-
// Increment the index counter and calculate the predicate for the next
550-
// iteration of the loop. We branch back to the start of the loop if there
551-
// is at least one active lane.
552-
Builder.SetInsertPoint(VectorLoopIncBlock);
553-
Value *NewVectorIndexPhi =
554-
Builder.CreateAdd(VectorIndexPhi, VecLen, "",
555-
/*HasNUW=*/true, /*HasNSW=*/true);
556-
VectorIndexPhi->addIncoming(NewVectorIndexPhi, VectorLoopIncBlock);
557-
Value *NewPred =
558-
Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask,
559-
{PredVTy, I64Type}, {NewVectorIndexPhi, ExtEnd});
560-
LoopPred->addIncoming(NewPred, VectorLoopIncBlock);
561-
562-
Value *PredHasActiveLanes =
563-
Builder.CreateExtractElement(NewPred, uint64_t(0));
564-
BranchInst *VectorLoopBranchBack =
565-
BranchInst::Create(VectorLoopStartBlock, EndBlock, PredHasActiveLanes);
566-
Builder.Insert(VectorLoopBranchBack);
567-
568-
DTU.applyUpdates(
569-
{{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock},
570-
{DominatorTree::Insert, VectorLoopIncBlock, EndBlock}});
571-
572-
// If we found a mismatch then we need to calculate which lane in the vector
573-
// had a mismatch and add that on to the current loop index.
574-
Builder.SetInsertPoint(VectorLoopMismatchBlock);
575-
PHINode *FoundPred = Builder.CreatePHI(PredVTy, 1, "mismatch_vec_found_pred");
576-
FoundPred->addIncoming(VectorMatchCmp, VectorLoopStartBlock);
577-
PHINode *LastLoopPred =
578-
Builder.CreatePHI(PredVTy, 1, "mismatch_vec_last_loop_pred");
579-
LastLoopPred->addIncoming(LoopPred, VectorLoopStartBlock);
580-
PHINode *VectorFoundIndex =
581-
Builder.CreatePHI(I64Type, 1, "mismatch_vec_found_index");
582-
VectorFoundIndex->addIncoming(VectorIndexPhi, VectorLoopStartBlock);
583-
584-
Value *PredMatchCmp = Builder.CreateAnd(LastLoopPred, FoundPred);
585-
Value *Ctz = Builder.CreateIntrinsic(
586-
Intrinsic::experimental_cttz_elts, {ResType, PredMatchCmp->getType()},
587-
{PredMatchCmp, /*ZeroIsPoison=*/Builder.getInt1(true)});
588-
Ctz = Builder.CreateZExt(Ctz, I64Type);
589-
Value *VectorLoopRes64 = Builder.CreateAdd(VectorFoundIndex, Ctz, "",
590-
/*HasNUW=*/true, /*HasNSW=*/true);
591-
Value *VectorLoopRes = Builder.CreateTrunc(VectorLoopRes64, ResType);
616+
Value *VectorLoopRes =
617+
createMaskedFindMismatch(Builder, DTU, GEPA, GEPB, ExtStart, ExtEnd);
592618

593619
Builder.Insert(BranchInst::Create(EndBlock));
594620

0 commit comments

Comments
 (0)