@@ -59,19 +59,34 @@ static cl::opt<bool> DisableAll("disable-loop-idiom-vectorize-all", cl::Hidden,
59
59
cl::init (false ),
60
60
cl::desc(" Disable Loop Idiom Vectorize Pass." ));
61
61
62
+ static cl::opt<LoopIdiomVectorizeStyle>
63
+ LITVecStyle (" loop-idiom-vectorize-style" , cl::Hidden,
64
+ cl::desc (" The vectorization style for loop idiom transform." ),
65
+ cl::values(clEnumValN(LoopIdiomVectorizeStyle::Masked, " masked" ,
66
+ " Use masked vector intrinsics" ),
67
+ clEnumValN(LoopIdiomVectorizeStyle::Predicated,
68
+ " predicated" , " Use VP intrinsics" )),
69
+ cl::init(LoopIdiomVectorizeStyle::Masked));
70
+
62
71
static cl::opt<bool >
63
72
DisableByteCmp (" disable-loop-idiom-vectorize-bytecmp" , cl::Hidden,
64
73
cl::init (false ),
65
74
cl::desc(" Proceed with Loop Idiom Vectorize Pass, but do "
66
75
" not convert byte-compare loop(s)." ));
67
76
77
+ static cl::opt<unsigned >
78
+ ByteCmpVF (" loop-idiom-vectorize-bytecmp-vf" , cl::Hidden,
79
+ cl::desc (" The vectorization factor for byte-compare patterns." ),
80
+ cl::init(16 ));
81
+
68
82
static cl::opt<bool >
69
83
VerifyLoops (" loop-idiom-vectorize-verify" , cl::Hidden, cl::init(false ),
70
84
cl::desc(" Verify loops generated Loop Idiom Vectorize Pass." ));
71
85
72
86
namespace {
73
-
74
87
class LoopIdiomVectorize {
88
+ LoopIdiomVectorizeStyle VectorizeStyle;
89
+ unsigned ByteCompareVF;
75
90
Loop *CurLoop = nullptr ;
76
91
DominatorTree *DT;
77
92
LoopInfo *LI;
@@ -86,10 +101,11 @@ class LoopIdiomVectorize {
86
101
BasicBlock *VectorLoopIncBlock = nullptr ;
87
102
88
103
public:
89
- explicit LoopIdiomVectorize (DominatorTree *DT, LoopInfo *LI,
90
- const TargetTransformInfo *TTI,
91
- const DataLayout *DL)
92
- : DT(DT), LI(LI), TTI(TTI), DL(DL) {}
104
+ LoopIdiomVectorize (LoopIdiomVectorizeStyle S, unsigned VF, DominatorTree *DT,
105
+ LoopInfo *LI, const TargetTransformInfo *TTI,
106
+ const DataLayout *DL)
107
+ : VectorizeStyle(S), ByteCompareVF(VF), DT(DT), LI(LI), TTI(TTI), DL(DL) {
108
+ }
93
109
94
110
bool run (Loop *L);
95
111
@@ -111,6 +127,10 @@ class LoopIdiomVectorize {
111
127
GetElementPtrInst *GEPA,
112
128
GetElementPtrInst *GEPB, Value *ExtStart,
113
129
Value *ExtEnd);
130
+ Value *createPredicatedFindMismatch (IRBuilder<> &Builder, DomTreeUpdater &DTU,
131
+ GetElementPtrInst *GEPA,
132
+ GetElementPtrInst *GEPB, Value *ExtStart,
133
+ Value *ExtEnd);
114
134
115
135
void transformByteCompare (GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
116
136
PHINode *IndPhi, Value *MaxLen, Instruction *Index,
@@ -128,8 +148,16 @@ PreservedAnalyses LoopIdiomVectorizePass::run(Loop &L, LoopAnalysisManager &AM,
128
148
129
149
const auto *DL = &L.getHeader ()->getDataLayout ();
130
150
131
- LoopIdiomVectorize LIT (&AR.DT , &AR.LI , &AR.TTI , DL);
132
- if (!LIT.run (&L))
151
+ LoopIdiomVectorizeStyle VecStyle = VectorizeStyle;
152
+ if (LITVecStyle.getNumOccurrences ())
153
+ VecStyle = LITVecStyle;
154
+
155
+ unsigned BCVF = ByteCompareVF;
156
+ if (ByteCmpVF.getNumOccurrences ())
157
+ BCVF = ByteCmpVF;
158
+
159
+ LoopIdiomVectorize LIV (VecStyle, BCVF, &AR.DT , &AR.LI , &AR.TTI , DL);
160
+ if (!LIV.run (&L))
133
161
return PreservedAnalyses::all ();
134
162
135
163
return PreservedAnalyses::none ();
@@ -354,20 +382,16 @@ Value *LoopIdiomVectorize::createMaskedFindMismatch(
354
382
Value *PtrA = GEPA->getPointerOperand ();
355
383
Value *PtrB = GEPB->getPointerOperand ();
356
384
357
- // At this point we know two things must be true:
358
- // 1. Start <= End
359
- // 2. ExtMaxLen <= MinPageSize due to the page checks.
360
- // Therefore, we know that we can use a 64-bit induction variable that
361
- // starts from 0 -> ExtMaxLen and it will not overflow.
362
385
ScalableVectorType *PredVTy =
363
- ScalableVectorType::get (Builder.getInt1Ty (), 16 );
386
+ ScalableVectorType::get (Builder.getInt1Ty (), ByteCompareVF );
364
387
365
388
Value *InitialPred = Builder.CreateIntrinsic (
366
389
Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});
367
390
368
391
Value *VecLen = Builder.CreateIntrinsic (Intrinsic::vscale, {I64Type}, {});
369
- VecLen = Builder.CreateMul (VecLen, ConstantInt::get (I64Type, 16 ), " " ,
370
- /* HasNUW=*/ true , /* HasNSW=*/ true );
392
+ VecLen =
393
+ Builder.CreateMul (VecLen, ConstantInt::get (I64Type, ByteCompareVF), " " ,
394
+ /* HasNUW=*/ true , /* HasNSW=*/ true );
371
395
372
396
Value *PFalse = Builder.CreateVectorSplat (PredVTy->getElementCount (),
373
397
Builder.getInt1 (false ));
@@ -385,7 +409,8 @@ Value *LoopIdiomVectorize::createMaskedFindMismatch(
385
409
LoopPred->addIncoming (InitialPred, VectorLoopPreheaderBlock);
386
410
PHINode *VectorIndexPhi = Builder.CreatePHI (I64Type, 2 , " mismatch_vec_index" );
387
411
VectorIndexPhi->addIncoming (ExtStart, VectorLoopPreheaderBlock);
388
- Type *VectorLoadType = ScalableVectorType::get (Builder.getInt8Ty (), 16 );
412
+ Type *VectorLoadType =
413
+ ScalableVectorType::get (Builder.getInt8Ty (), ByteCompareVF);
389
414
Value *Passthru = ConstantInt::getNullValue (VectorLoadType);
390
415
391
416
Value *VectorLhsGep =
@@ -454,6 +479,109 @@ Value *LoopIdiomVectorize::createMaskedFindMismatch(
454
479
return Builder.CreateTrunc (VectorLoopRes64, ResType );
455
480
}
456
481
482
+ Value *LoopIdiomVectorize::createPredicatedFindMismatch (
483
+ IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
484
+ GetElementPtrInst *GEPB, Value *ExtStart, Value *ExtEnd) {
485
+ Type *I64Type = Builder.getInt64Ty ();
486
+ Type *I32Type = Builder.getInt32Ty ();
487
+ Type *ResType = I32Type;
488
+ Type *LoadType = Builder.getInt8Ty ();
489
+ Value *PtrA = GEPA->getPointerOperand ();
490
+ Value *PtrB = GEPB->getPointerOperand ();
491
+
492
+ auto *JumpToVectorLoop = BranchInst::Create (VectorLoopStartBlock);
493
+ Builder.Insert (JumpToVectorLoop);
494
+
495
+ DTU.applyUpdates ({{DominatorTree::Insert, VectorLoopPreheaderBlock,
496
+ VectorLoopStartBlock}});
497
+
498
+ // Set up the first Vector loop block by creating the PHIs, doing the vector
499
+ // loads and comparing the vectors.
500
+ Builder.SetInsertPoint (VectorLoopStartBlock);
501
+ auto *VectorIndexPhi = Builder.CreatePHI (I64Type, 2 , " mismatch_vector_index" );
502
+ VectorIndexPhi->addIncoming (ExtStart, VectorLoopPreheaderBlock);
503
+
504
+ // Calculate AVL by subtracting the vector loop index from the trip count
505
+ Value *AVL = Builder.CreateSub (ExtEnd, VectorIndexPhi, " avl" , /* HasNUW=*/ true ,
506
+ /* HasNSW=*/ true );
507
+
508
+ auto *VectorLoadType = ScalableVectorType::get (LoadType, ByteCompareVF);
509
+ auto *VF = ConstantInt::get (I32Type, ByteCompareVF);
510
+
511
+ Value *VL = Builder.CreateIntrinsic (Intrinsic::experimental_get_vector_length,
512
+ {I64Type}, {AVL, VF, Builder.getTrue ()});
513
+ Value *GepOffset = VectorIndexPhi;
514
+
515
+ Value *VectorLhsGep =
516
+ Builder.CreateGEP (LoadType, PtrA, GepOffset, " " , GEPA->isInBounds ());
517
+ VectorType *TrueMaskTy =
518
+ VectorType::get (Builder.getInt1Ty (), VectorLoadType->getElementCount ());
519
+ Value *AllTrueMask = Constant::getAllOnesValue (TrueMaskTy);
520
+ Value *VectorLhsLoad = Builder.CreateIntrinsic (
521
+ Intrinsic::vp_load, {VectorLoadType, VectorLhsGep->getType ()},
522
+ {VectorLhsGep, AllTrueMask, VL}, nullptr , " lhs.load" );
523
+
524
+ Value *VectorRhsGep =
525
+ Builder.CreateGEP (LoadType, PtrB, GepOffset, " " , GEPB->isInBounds ());
526
+ Value *VectorRhsLoad = Builder.CreateIntrinsic (
527
+ Intrinsic::vp_load, {VectorLoadType, VectorLhsGep->getType ()},
528
+ {VectorRhsGep, AllTrueMask, VL}, nullptr , " rhs.load" );
529
+
530
+ StringRef PredicateStr = CmpInst::getPredicateName (CmpInst::ICMP_NE);
531
+ auto *PredicateMDS = MDString::get (VectorLhsLoad->getContext (), PredicateStr);
532
+ Value *Pred = MetadataAsValue::get (VectorLhsLoad->getContext (), PredicateMDS);
533
+ Value *VectorMatchCmp = Builder.CreateIntrinsic (
534
+ Intrinsic::vp_icmp, {VectorLhsLoad->getType ()},
535
+ {VectorLhsLoad, VectorRhsLoad, Pred, AllTrueMask, VL}, nullptr ,
536
+ " mismatch.cmp" );
537
+ Value *CTZ = Builder.CreateIntrinsic (
538
+ Intrinsic::vp_cttz_elts, {ResType , VectorMatchCmp->getType ()},
539
+ {VectorMatchCmp, /* ZeroIsPoison=*/ Builder.getInt1 (false ), AllTrueMask,
540
+ VL});
541
+ Value *MismatchFound = Builder.CreateICmpNE (CTZ, VL);
542
+ auto *VectorEarlyExit = BranchInst::Create (VectorLoopMismatchBlock,
543
+ VectorLoopIncBlock, MismatchFound);
544
+ Builder.Insert (VectorEarlyExit);
545
+
546
+ DTU.applyUpdates (
547
+ {{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock},
548
+ {DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}});
549
+
550
+ // Increment the index counter and calculate the predicate for the next
551
+ // iteration of the loop. We branch back to the start of the loop if there
552
+ // is at least one active lane.
553
+ Builder.SetInsertPoint (VectorLoopIncBlock);
554
+ Value *VL64 = Builder.CreateZExt (VL, I64Type);
555
+ Value *NewVectorIndexPhi =
556
+ Builder.CreateAdd (VectorIndexPhi, VL64, " " ,
557
+ /* HasNUW=*/ true , /* HasNSW=*/ true );
558
+ VectorIndexPhi->addIncoming (NewVectorIndexPhi, VectorLoopIncBlock);
559
+ Value *ExitCond = Builder.CreateICmpNE (NewVectorIndexPhi, ExtEnd);
560
+ auto *VectorLoopBranchBack =
561
+ BranchInst::Create (VectorLoopStartBlock, EndBlock, ExitCond);
562
+ Builder.Insert (VectorLoopBranchBack);
563
+
564
+ DTU.applyUpdates (
565
+ {{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock},
566
+ {DominatorTree::Insert, VectorLoopIncBlock, EndBlock}});
567
+
568
+ // If we found a mismatch then we need to calculate which lane in the vector
569
+ // had a mismatch and add that on to the current loop index.
570
+ Builder.SetInsertPoint (VectorLoopMismatchBlock);
571
+
572
+ // Add LCSSA phis for CTZ and VectorIndexPhi.
573
+ auto *CTZLCSSAPhi = Builder.CreatePHI (CTZ->getType (), 1 , " ctz" );
574
+ CTZLCSSAPhi->addIncoming (CTZ, VectorLoopStartBlock);
575
+ auto *VectorIndexLCSSAPhi =
576
+ Builder.CreatePHI (VectorIndexPhi->getType (), 1 , " mismatch_vector_index" );
577
+ VectorIndexLCSSAPhi->addIncoming (VectorIndexPhi, VectorLoopStartBlock);
578
+
579
+ Value *CTZI64 = Builder.CreateZExt (CTZLCSSAPhi, I64Type);
580
+ Value *VectorLoopRes64 = Builder.CreateAdd (VectorIndexLCSSAPhi, CTZI64, " " ,
581
+ /* HasNUW=*/ true , /* HasNSW=*/ true );
582
+ return Builder.CreateTrunc (VectorLoopRes64, ResType );
583
+ }
584
+
457
585
Value *LoopIdiomVectorize::expandFindMismatch (
458
586
IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
459
587
GetElementPtrInst *GEPB, Instruction *Index, Value *Start, Value *MaxLen) {
@@ -613,8 +741,22 @@ Value *LoopIdiomVectorize::expandFindMismatch(
613
741
// processed in each iteration, etc.
614
742
Builder.SetInsertPoint (VectorLoopPreheaderBlock);
615
743
616
- Value *VectorLoopRes =
617
- createMaskedFindMismatch (Builder, DTU, GEPA, GEPB, ExtStart, ExtEnd);
744
+ // At this point we know two things must be true:
745
+ // 1. Start <= End
746
+ // 2. ExtMaxLen <= MinPageSize due to the page checks.
747
+ // Therefore, we know that we can use a 64-bit induction variable that
748
+ // starts from 0 -> ExtMaxLen and it will not overflow.
749
+ Value *VectorLoopRes = nullptr ;
750
+ switch (VectorizeStyle) {
751
+ case LoopIdiomVectorizeStyle::Masked:
752
+ VectorLoopRes =
753
+ createMaskedFindMismatch (Builder, DTU, GEPA, GEPB, ExtStart, ExtEnd);
754
+ break ;
755
+ case LoopIdiomVectorizeStyle::Predicated:
756
+ VectorLoopRes = createPredicatedFindMismatch (Builder, DTU, GEPA, GEPB,
757
+ ExtStart, ExtEnd);
758
+ break ;
759
+ }
618
760
619
761
Builder.Insert (BranchInst::Create (EndBlock));
620
762
0 commit comments