@@ -60,19 +60,34 @@ static cl::opt<bool> DisableAll("disable-loop-idiom-vectorize-all", cl::Hidden,
60
60
cl::init (false ),
61
61
cl::desc(" Disable Loop Idiom Vectorize Pass." ));
62
62
63
+ static cl::opt<LoopIdiomVectorizeStyle>
64
+ LITVecStyle (" loop-idiom-vectorize-style" , cl::Hidden,
65
+ cl::desc (" The vectorization style for loop idiom transform." ),
66
+ cl::values(clEnumValN(LoopIdiomVectorizeStyle::Masked, " masked" ,
67
+ " Use masked vector intrinsics" ),
68
+ clEnumValN(LoopIdiomVectorizeStyle::Predicated,
69
+ " predicated" , " Use VP intrinsics" )),
70
+ cl::init(LoopIdiomVectorizeStyle::Masked));
71
+
63
72
static cl::opt<bool >
64
73
DisableByteCmp (" disable-loop-idiom-vectorize-bytecmp" , cl::Hidden,
65
74
cl::init (false ),
66
75
cl::desc(" Proceed with Loop Idiom Vectorize Pass, but do "
67
76
" not convert byte-compare loop(s)." ));
68
77
78
+ static cl::opt<unsigned >
79
+ ByteCmpVF (" loop-idiom-vectorize-bytecmp-vf" , cl::Hidden,
80
+ cl::desc (" The vectorization factor for byte-compare patterns." ),
81
+ cl::init(16 ));
82
+
69
83
static cl::opt<bool >
70
84
VerifyLoops (" loop-idiom-vectorize-verify" , cl::Hidden, cl::init(false ),
71
85
cl::desc(" Verify loops generated Loop Idiom Vectorize Pass." ));
72
86
73
87
namespace {
74
-
75
88
class LoopIdiomVectorize {
89
+ LoopIdiomVectorizeStyle VectorizeStyle;
90
+ unsigned ByteCompareVF;
76
91
Loop *CurLoop = nullptr ;
77
92
DominatorTree *DT;
78
93
LoopInfo *LI;
@@ -87,10 +102,11 @@ class LoopIdiomVectorize {
87
102
BasicBlock *VectorLoopIncBlock = nullptr ;
88
103
89
104
public:
90
- explicit LoopIdiomVectorize (DominatorTree *DT, LoopInfo *LI,
91
- const TargetTransformInfo *TTI,
92
- const DataLayout *DL)
93
- : DT(DT), LI(LI), TTI(TTI), DL(DL) {}
105
+ LoopIdiomVectorize (LoopIdiomVectorizeStyle S, unsigned VF, DominatorTree *DT,
106
+ LoopInfo *LI, const TargetTransformInfo *TTI,
107
+ const DataLayout *DL)
108
+ : VectorizeStyle(S), ByteCompareVF(VF), DT(DT), LI(LI), TTI(TTI), DL(DL) {
109
+ }
94
110
95
111
bool run (Loop *L);
96
112
@@ -111,6 +127,10 @@ class LoopIdiomVectorize {
111
127
Value *createMaskedFindMismatch (IRBuilder<> &Builder, GetElementPtrInst *GEPA,
112
128
GetElementPtrInst *GEPB, Value *ExtStart,
113
129
Value *ExtEnd);
130
+ Value *createPredicatedFindMismatch (IRBuilder<> &Builder,
131
+ GetElementPtrInst *GEPA,
132
+ GetElementPtrInst *GEPB, Value *ExtStart,
133
+ Value *ExtEnd);
114
134
115
135
void transformByteCompare (GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
116
136
PHINode *IndPhi, Value *MaxLen, Instruction *Index,
@@ -128,8 +148,16 @@ PreservedAnalyses LoopIdiomVectorizePass::run(Loop &L, LoopAnalysisManager &AM,
128
148
129
149
const auto *DL = &L.getHeader ()->getModule ()->getDataLayout ();
130
150
131
- LoopIdiomVectorize LIT (&AR.DT , &AR.LI , &AR.TTI , DL);
132
- if (!LIT.run (&L))
151
+ LoopIdiomVectorizeStyle VecStyle = VectorizeStyle;
152
+ if (LITVecStyle.getNumOccurrences ())
153
+ VecStyle = LITVecStyle;
154
+
155
+ unsigned BCVF = ByteCompareVF;
156
+ if (ByteCmpVF.getNumOccurrences ())
157
+ BCVF = ByteCmpVF;
158
+
159
+ LoopIdiomVectorize LIV (VecStyle, BCVF, &AR.DT , &AR.LI , &AR.TTI , DL);
160
+ if (!LIV.run (&L))
133
161
return PreservedAnalyses::all ();
134
162
135
163
return PreservedAnalyses::none ();
@@ -362,14 +390,15 @@ Value *LoopIdiomVectorize::createMaskedFindMismatch(IRBuilder<> &Builder,
362
390
// Therefore, we know that we can use a 64-bit induction variable that
363
391
// starts from 0 -> ExtMaxLen and it will not overflow.
364
392
ScalableVectorType *PredVTy =
365
- ScalableVectorType::get (Builder.getInt1Ty (), 16 );
393
+ ScalableVectorType::get (Builder.getInt1Ty (), ByteCompareVF );
366
394
367
395
Value *InitialPred = Builder.CreateIntrinsic (
368
396
Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});
369
397
370
398
Value *VecLen = Builder.CreateIntrinsic (Intrinsic::vscale, {I64Type}, {});
371
- VecLen = Builder.CreateMul (VecLen, ConstantInt::get (I64Type, 16 ), " " ,
372
- /* HasNUW=*/ true , /* HasNSW=*/ true );
399
+ VecLen =
400
+ Builder.CreateMul (VecLen, ConstantInt::get (I64Type, ByteCompareVF), " " ,
401
+ /* HasNUW=*/ true , /* HasNSW=*/ true );
373
402
374
403
Value *PFalse = Builder.CreateVectorSplat (PredVTy->getElementCount (),
375
404
Builder.getInt1 (false ));
@@ -384,7 +413,8 @@ Value *LoopIdiomVectorize::createMaskedFindMismatch(IRBuilder<> &Builder,
384
413
LoopPred->addIncoming (InitialPred, VectorLoopPreheaderBlock);
385
414
PHINode *VectorIndexPhi = Builder.CreatePHI (I64Type, 2 , " mismatch_vec_index" );
386
415
VectorIndexPhi->addIncoming (ExtStart, VectorLoopPreheaderBlock);
387
- Type *VectorLoadType = ScalableVectorType::get (Builder.getInt8Ty (), 16 );
416
+ Type *VectorLoadType =
417
+ ScalableVectorType::get (Builder.getInt8Ty (), ByteCompareVF);
388
418
Value *Passthru = ConstantInt::getNullValue (VectorLoadType);
389
419
390
420
Value *VectorLhsGep =
@@ -445,6 +475,112 @@ Value *LoopIdiomVectorize::createMaskedFindMismatch(IRBuilder<> &Builder,
445
475
return Builder.CreateTrunc (VectorLoopRes64, ResType );
446
476
}
447
477
478
+ Value *LoopIdiomVectorize::createPredicatedFindMismatch (IRBuilder<> &Builder,
479
+ GetElementPtrInst *GEPA,
480
+ GetElementPtrInst *GEPB,
481
+ Value *ExtStart,
482
+ Value *ExtEnd) {
483
+ Type *I64Type = Builder.getInt64Ty ();
484
+ Type *I32Type = Builder.getInt32Ty ();
485
+ Type *ResType = I32Type;
486
+ Type *LoadType = Builder.getInt8Ty ();
487
+ Value *PtrA = GEPA->getPointerOperand ();
488
+ Value *PtrB = GEPB->getPointerOperand ();
489
+
490
+ // At this point we know two things must be true:
491
+ // 1. Start <= End
492
+ // 2. ExtMaxLen <= 4096 due to the page checks.
493
+ // Therefore, we know that we can use a 64-bit induction variable that
494
+ // starts from 0 -> ExtMaxLen and it will not overflow.
495
+ auto *JumpToVectorLoop = BranchInst::Create (VectorLoopStartBlock);
496
+ Builder.Insert (JumpToVectorLoop);
497
+
498
+ // Set up the first Vector loop block by creating the PHIs, doing the vector
499
+ // loads and comparing the vectors.
500
+ Builder.SetInsertPoint (VectorLoopStartBlock);
501
+ auto *VectorIndexPhi = Builder.CreatePHI (I64Type, 2 , " mismatch_vector_index" );
502
+ VectorIndexPhi->addIncoming (ExtStart, VectorLoopPreheaderBlock);
503
+
504
+ // Calculate AVL by subtracting the vector loop index from the trip count
505
+ Value *AVL = Builder.CreateSub (ExtEnd, VectorIndexPhi, " avl" , /* HasNUW=*/ true ,
506
+ /* HasNSW=*/ true );
507
+
508
+ auto *VectorLoadType = ScalableVectorType::get (LoadType, ByteCompareVF);
509
+ auto *VF = ConstantInt::get (
510
+ I32Type, VectorLoadType->getElementCount ().getKnownMinValue ());
511
+ auto *IsScalable = ConstantInt::getBool (
512
+ Builder.getContext (), VectorLoadType->getElementCount ().isScalable ());
513
+
514
+ Value *VL = Builder.CreateIntrinsic (Intrinsic::experimental_get_vector_length,
515
+ {I64Type}, {AVL, VF, IsScalable});
516
+ Value *GepOffset = VectorIndexPhi;
517
+
518
+ Value *VectorLhsGep = Builder.CreateGEP (LoadType, PtrA, GepOffset);
519
+ if (GEPA->isInBounds ())
520
+ cast<GetElementPtrInst>(VectorLhsGep)->setIsInBounds (true );
521
+ VectorType *TrueMaskTy =
522
+ VectorType::get (Builder.getInt1Ty (), VectorLoadType->getElementCount ());
523
+ Value *AllTrueMask = Constant::getAllOnesValue (TrueMaskTy);
524
+ Value *VectorLhsLoad = Builder.CreateIntrinsic (
525
+ Intrinsic::vp_load, {VectorLoadType, VectorLhsGep->getType ()},
526
+ {VectorLhsGep, AllTrueMask, VL}, nullptr , " lhs.load" );
527
+
528
+ Value *VectorRhsGep = Builder.CreateGEP (LoadType, PtrB, GepOffset);
529
+ if (GEPB->isInBounds ())
530
+ cast<GetElementPtrInst>(VectorRhsGep)->setIsInBounds (true );
531
+ Value *VectorRhsLoad = Builder.CreateIntrinsic (
532
+ Intrinsic::vp_load, {VectorLoadType, VectorLhsGep->getType ()},
533
+ {VectorRhsGep, AllTrueMask, VL}, nullptr , " rhs.load" );
534
+
535
+ StringRef PredicateStr = CmpInst::getPredicateName (CmpInst::ICMP_NE);
536
+ auto *PredicateMDS = MDString::get (VectorLhsLoad->getContext (), PredicateStr);
537
+ Value *Pred = MetadataAsValue::get (VectorLhsLoad->getContext (), PredicateMDS);
538
+ Value *VectorMatchCmp = Builder.CreateIntrinsic (
539
+ Intrinsic::vp_icmp, {VectorLhsLoad->getType ()},
540
+ {VectorLhsLoad, VectorRhsLoad, Pred, AllTrueMask, VL}, nullptr ,
541
+ " mismatch.cmp" );
542
+ Value *CTZ = Builder.CreateIntrinsic (
543
+ Intrinsic::vp_cttz_elts, {ResType , VectorMatchCmp->getType ()},
544
+ {VectorMatchCmp, /* ZeroIsPoison=*/ Builder.getInt1 (true ), AllTrueMask,
545
+ VL});
546
+ // RISC-V refines/lowers the poison returned by vp.cttz.elts to -1.
547
+ Value *MismatchFound =
548
+ Builder.CreateICmpSGE (CTZ, ConstantInt::get (ResType , 0 ));
549
+ auto *VectorEarlyExit = BranchInst::Create (VectorLoopMismatchBlock,
550
+ VectorLoopIncBlock, MismatchFound);
551
+ Builder.Insert (VectorEarlyExit);
552
+
553
+ // Increment the index counter and calculate the predicate for the next
554
+ // iteration of the loop. We branch back to the start of the loop if there
555
+ // is at least one active lane.
556
+ Builder.SetInsertPoint (VectorLoopIncBlock);
557
+ Value *VL64 = Builder.CreateZExt (VL, I64Type);
558
+ Value *NewVectorIndexPhi =
559
+ Builder.CreateAdd (VectorIndexPhi, VL64, " " ,
560
+ /* HasNUW=*/ true , /* HasNSW=*/ true );
561
+ VectorIndexPhi->addIncoming (NewVectorIndexPhi, VectorLoopIncBlock);
562
+ Value *ExitCond = Builder.CreateICmpNE (NewVectorIndexPhi, ExtEnd);
563
+ auto *VectorLoopBranchBack =
564
+ BranchInst::Create (VectorLoopStartBlock, EndBlock, ExitCond);
565
+ Builder.Insert (VectorLoopBranchBack);
566
+
567
+ // If we found a mismatch then we need to calculate which lane in the vector
568
+ // had a mismatch and add that on to the current loop index.
569
+ Builder.SetInsertPoint (VectorLoopMismatchBlock);
570
+
571
+ // Add LCSSA phis for CTZ and VectorIndexPhi.
572
+ auto *CTZLCSSAPhi = Builder.CreatePHI (CTZ->getType (), 1 , " ctz" );
573
+ CTZLCSSAPhi->addIncoming (CTZ, VectorLoopStartBlock);
574
+ auto *VectorIndexLCSSAPhi =
575
+ Builder.CreatePHI (VectorIndexPhi->getType (), 1 , " mismatch_vector_index" );
576
+ VectorIndexLCSSAPhi->addIncoming (VectorIndexPhi, VectorLoopStartBlock);
577
+
578
+ Value *CTZI64 = Builder.CreateZExt (CTZLCSSAPhi, I64Type);
579
+ Value *VectorLoopRes64 = Builder.CreateAdd (VectorIndexLCSSAPhi, CTZI64, " " ,
580
+ /* HasNUW=*/ true , /* HasNSW=*/ true );
581
+ return Builder.CreateTrunc (VectorLoopRes64, ResType );
582
+ }
583
+
448
584
Value *LoopIdiomVectorize::expandFindMismatch (
449
585
IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
450
586
GetElementPtrInst *GEPB, Instruction *Index, Value *Start, Value *MaxLen) {
@@ -604,8 +740,17 @@ Value *LoopIdiomVectorize::expandFindMismatch(
604
740
// processed in each iteration, etc.
605
741
Builder.SetInsertPoint (VectorLoopPreheaderBlock);
606
742
607
- Value *VectorLoopRes =
608
- createMaskedFindMismatch (Builder, GEPA, GEPB, ExtStart, ExtEnd);
743
+ Value *VectorLoopRes = nullptr ;
744
+ switch (VectorizeStyle) {
745
+ case LoopIdiomVectorizeStyle::Masked:
746
+ VectorLoopRes =
747
+ createMaskedFindMismatch (Builder, GEPA, GEPB, ExtStart, ExtEnd);
748
+ break ;
749
+ case LoopIdiomVectorizeStyle::Predicated:
750
+ VectorLoopRes =
751
+ createPredicatedFindMismatch (Builder, GEPA, GEPB, ExtStart, ExtEnd);
752
+ break ;
753
+ }
609
754
610
755
Builder.Insert (BranchInst::Create (EndBlock));
611
756
0 commit comments