@@ -79,6 +79,13 @@ class LoopIdiomVectorize {
7979 const TargetTransformInfo *TTI;
8080 const DataLayout *DL;
8181
82+ // Blocks that will be used for inserting vectorized code.
83+ BasicBlock *EndBlock = nullptr ;
84+ BasicBlock *VectorLoopPreheaderBlock = nullptr ;
85+ BasicBlock *VectorLoopStartBlock = nullptr ;
86+ BasicBlock *VectorLoopMismatchBlock = nullptr ;
87+ BasicBlock *VectorLoopIncBlock = nullptr ;
88+
8289public:
8390 explicit LoopIdiomVectorize (DominatorTree *DT, LoopInfo *LI,
8491 const TargetTransformInfo *TTI,
@@ -96,9 +103,15 @@ class LoopIdiomVectorize {
96103 SmallVectorImpl<BasicBlock *> &ExitBlocks);
97104
98105 bool recognizeByteCompare ();
106+
99107 Value *expandFindMismatch (IRBuilder<> &Builder, DomTreeUpdater &DTU,
100108 GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
101109 Instruction *Index, Value *Start, Value *MaxLen);
110+
111+ Value *createMaskedFindMismatch (IRBuilder<> &Builder, GetElementPtrInst *GEPA,
112+ GetElementPtrInst *GEPB, Value *ExtStart,
113+ Value *ExtEnd);
114+
102115 void transformByteCompare (GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
103116 PHINode *IndPhi, Value *MaxLen, Instruction *Index,
104117 Value *Start, bool IncIdx, BasicBlock *FoundBB,
@@ -332,6 +345,106 @@ bool LoopIdiomVectorize::recognizeByteCompare() {
332345 return true ;
333346}
334347
348+ Value *LoopIdiomVectorize::createMaskedFindMismatch (IRBuilder<> &Builder,
349+ GetElementPtrInst *GEPA,
350+ GetElementPtrInst *GEPB,
351+ Value *ExtStart,
352+ Value *ExtEnd) {
353+ Type *I64Type = Builder.getInt64Ty ();
354+ Type *ResType = Builder.getInt32Ty ();
355+ Type *LoadType = Builder.getInt8Ty ();
356+ Value *PtrA = GEPA->getPointerOperand ();
357+ Value *PtrB = GEPB->getPointerOperand ();
358+
359+ // At this point we know two things must be true:
360+ // 1. Start <= End
361+ // 2. ExtMaxLen <= MinPageSize due to the page checks.
362+ // Therefore, we know that we can use a 64-bit induction variable that
363+ // starts from 0 -> ExtMaxLen and it will not overflow.
364+ ScalableVectorType *PredVTy =
365+ ScalableVectorType::get (Builder.getInt1Ty (), 16 );
366+
367+ Value *InitialPred = Builder.CreateIntrinsic (
368+ Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});
369+
370+ Value *VecLen = Builder.CreateIntrinsic (Intrinsic::vscale, {I64Type}, {});
371+ VecLen = Builder.CreateMul (VecLen, ConstantInt::get (I64Type, 16 ), " " ,
372+ /* HasNUW=*/ true , /* HasNSW=*/ true );
373+
374+ Value *PFalse = Builder.CreateVectorSplat (PredVTy->getElementCount (),
375+ Builder.getInt1 (false ));
376+
377+ BranchInst *JumpToVectorLoop = BranchInst::Create (VectorLoopStartBlock);
378+ Builder.Insert (JumpToVectorLoop);
379+
380+ // Set up the first vector loop block by creating the PHIs, doing the vector
381+ // loads and comparing the vectors.
382+ Builder.SetInsertPoint (VectorLoopStartBlock);
383+ PHINode *LoopPred = Builder.CreatePHI (PredVTy, 2 , " mismatch_vec_loop_pred" );
384+ LoopPred->addIncoming (InitialPred, VectorLoopPreheaderBlock);
385+ PHINode *VectorIndexPhi = Builder.CreatePHI (I64Type, 2 , " mismatch_vec_index" );
386+ VectorIndexPhi->addIncoming (ExtStart, VectorLoopPreheaderBlock);
387+ Type *VectorLoadType = ScalableVectorType::get (Builder.getInt8Ty (), 16 );
388+ Value *Passthru = ConstantInt::getNullValue (VectorLoadType);
389+
390+ Value *VectorLhsGep =
391+ Builder.CreateGEP (LoadType, PtrA, VectorIndexPhi, " " , GEPA->isInBounds ());
392+ Value *VectorLhsLoad = Builder.CreateMaskedLoad (VectorLoadType, VectorLhsGep,
393+ Align (1 ), LoopPred, Passthru);
394+
395+ Value *VectorRhsGep =
396+ Builder.CreateGEP (LoadType, PtrB, VectorIndexPhi, " " , GEPB->isInBounds ());
397+ Value *VectorRhsLoad = Builder.CreateMaskedLoad (VectorLoadType, VectorRhsGep,
398+ Align (1 ), LoopPred, Passthru);
399+
400+ Value *VectorMatchCmp = Builder.CreateICmpNE (VectorLhsLoad, VectorRhsLoad);
401+ VectorMatchCmp = Builder.CreateSelect (LoopPred, VectorMatchCmp, PFalse);
402+ Value *VectorMatchHasActiveLanes = Builder.CreateOrReduce (VectorMatchCmp);
403+ BranchInst *VectorEarlyExit = BranchInst::Create (
404+ VectorLoopMismatchBlock, VectorLoopIncBlock, VectorMatchHasActiveLanes);
405+ Builder.Insert (VectorEarlyExit);
406+
407+ // Increment the index counter and calculate the predicate for the next
408+ // iteration of the loop. We branch back to the start of the loop if there
409+ // is at least one active lane.
410+ Builder.SetInsertPoint (VectorLoopIncBlock);
411+ Value *NewVectorIndexPhi =
412+ Builder.CreateAdd (VectorIndexPhi, VecLen, " " ,
413+ /* HasNUW=*/ true , /* HasNSW=*/ true );
414+ VectorIndexPhi->addIncoming (NewVectorIndexPhi, VectorLoopIncBlock);
415+ Value *NewPred =
416+ Builder.CreateIntrinsic (Intrinsic::get_active_lane_mask,
417+ {PredVTy, I64Type}, {NewVectorIndexPhi, ExtEnd});
418+ LoopPred->addIncoming (NewPred, VectorLoopIncBlock);
419+
420+ Value *PredHasActiveLanes =
421+ Builder.CreateExtractElement (NewPred, uint64_t (0 ));
422+ BranchInst *VectorLoopBranchBack =
423+ BranchInst::Create (VectorLoopStartBlock, EndBlock, PredHasActiveLanes);
424+ Builder.Insert (VectorLoopBranchBack);
425+
426+ // If we found a mismatch then we need to calculate which lane in the vector
427+ // had a mismatch and add that on to the current loop index.
428+ Builder.SetInsertPoint (VectorLoopMismatchBlock);
429+ PHINode *FoundPred = Builder.CreatePHI (PredVTy, 1 , " mismatch_vec_found_pred" );
430+ FoundPred->addIncoming (VectorMatchCmp, VectorLoopStartBlock);
431+ PHINode *LastLoopPred =
432+ Builder.CreatePHI (PredVTy, 1 , " mismatch_vec_last_loop_pred" );
433+ LastLoopPred->addIncoming (LoopPred, VectorLoopStartBlock);
434+ PHINode *VectorFoundIndex =
435+ Builder.CreatePHI (I64Type, 1 , " mismatch_vec_found_index" );
436+ VectorFoundIndex->addIncoming (VectorIndexPhi, VectorLoopStartBlock);
437+
438+ Value *PredMatchCmp = Builder.CreateAnd (LastLoopPred, FoundPred);
439+ Value *Ctz = Builder.CreateIntrinsic (
440+ Intrinsic::experimental_cttz_elts, {ResType, PredMatchCmp->getType ()},
441+ {PredMatchCmp, /* ZeroIsPoison=*/ Builder.getInt1 (true )});
442+ Ctz = Builder.CreateZExt (Ctz, I64Type);
443+ Value *VectorLoopRes64 = Builder.CreateAdd (VectorFoundIndex, Ctz, " " ,
444+ /* HasNUW=*/ true , /* HasNSW=*/ true );
445+ return Builder.CreateTrunc (VectorLoopRes64, ResType);
446+ }
447+
335448Value *LoopIdiomVectorize::expandFindMismatch (
336449 IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
337450 GetElementPtrInst *GEPB, Instruction *Index, Value *Start, Value *MaxLen) {
@@ -346,8 +459,7 @@ Value *LoopIdiomVectorize::expandFindMismatch(
346459 Type *ResType = Builder.getInt32Ty ();
347460
348461 // Split block in the original loop preheader.
349- BasicBlock *EndBlock =
350- SplitBlock (Preheader, PHBranch, DT, LI, nullptr , " mismatch_end" );
462+ EndBlock = SplitBlock (Preheader, PHBranch, DT, LI, nullptr , " mismatch_end" );
351463
352464 // Create the blocks that we're going to need:
353465 // 1. A block for checking the zero-extended length exceeds 0
@@ -371,17 +483,17 @@ Value *LoopIdiomVectorize::expandFindMismatch(
371483 BasicBlock *MemCheckBlock = BasicBlock::Create (
372484 Ctx, " mismatch_mem_check" , EndBlock->getParent (), EndBlock);
373485
374- BasicBlock * VectorLoopPreheaderBlock = BasicBlock::Create (
486+ VectorLoopPreheaderBlock = BasicBlock::Create (
375487 Ctx, " mismatch_vec_loop_preheader" , EndBlock->getParent (), EndBlock);
376488
377- BasicBlock * VectorLoopStartBlock = BasicBlock::Create (
378- Ctx, " mismatch_vec_loop " , EndBlock->getParent (), EndBlock);
489+ VectorLoopStartBlock = BasicBlock::Create (Ctx, " mismatch_vec_loop " ,
490+ EndBlock->getParent (), EndBlock);
379491
380- BasicBlock * VectorLoopIncBlock = BasicBlock::Create (
381- Ctx, " mismatch_vec_loop_inc " , EndBlock->getParent (), EndBlock);
492+ VectorLoopIncBlock = BasicBlock::Create (Ctx, " mismatch_vec_loop_inc " ,
493+ EndBlock->getParent (), EndBlock);
382494
383- BasicBlock * VectorLoopMismatchBlock = BasicBlock::Create (
384- Ctx, " mismatch_vec_loop_found " , EndBlock->getParent (), EndBlock);
495+ VectorLoopMismatchBlock = BasicBlock::Create (Ctx, " mismatch_vec_loop_found " ,
496+ EndBlock->getParent (), EndBlock);
385497
386498 BasicBlock *LoopPreHeaderBlock = BasicBlock::Create (
387499 Ctx, " mismatch_loop_pre" , EndBlock->getParent (), EndBlock);
@@ -492,93 +604,8 @@ Value *LoopIdiomVectorize::expandFindMismatch(
492604 // processed in each iteration, etc.
493605 Builder.SetInsertPoint (VectorLoopPreheaderBlock);
494606
495- // At this point we know two things must be true:
496- // 1. Start <= End
497- // 2. ExtMaxLen <= MinPageSize due to the page checks.
498- // Therefore, we know that we can use a 64-bit induction variable that
499- // starts from 0 -> ExtMaxLen and it will not overflow.
500- ScalableVectorType *PredVTy =
501- ScalableVectorType::get (Builder.getInt1Ty (), 16 );
502-
503- Value *InitialPred = Builder.CreateIntrinsic (
504- Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});
505-
506- Value *VecLen = Builder.CreateIntrinsic (Intrinsic::vscale, {I64Type}, {});
507- VecLen = Builder.CreateMul (VecLen, ConstantInt::get (I64Type, 16 ), " " ,
508- /* HasNUW=*/ true , /* HasNSW=*/ true );
509-
510- Value *PFalse = Builder.CreateVectorSplat (PredVTy->getElementCount (),
511- Builder.getInt1 (false ));
512-
513- BranchInst *JumpToVectorLoop = BranchInst::Create (VectorLoopStartBlock);
514- Builder.Insert (JumpToVectorLoop);
515-
516- // Set up the first vector loop block by creating the PHIs, doing the vector
517- // loads and comparing the vectors.
518- Builder.SetInsertPoint (VectorLoopStartBlock);
519- PHINode *LoopPred = Builder.CreatePHI (PredVTy, 2 , " mismatch_vec_loop_pred" );
520- LoopPred->addIncoming (InitialPred, VectorLoopPreheaderBlock);
521- PHINode *VectorIndexPhi = Builder.CreatePHI (I64Type, 2 , " mismatch_vec_index" );
522- VectorIndexPhi->addIncoming (ExtStart, VectorLoopPreheaderBlock);
523- Type *VectorLoadType = ScalableVectorType::get (Builder.getInt8Ty (), 16 );
524- Value *Passthru = ConstantInt::getNullValue (VectorLoadType);
525-
526- Value *VectorLhsGep =
527- Builder.CreateGEP (LoadType, PtrA, VectorIndexPhi, " " , GEPA->isInBounds ());
528- Value *VectorLhsLoad = Builder.CreateMaskedLoad (VectorLoadType, VectorLhsGep,
529- Align (1 ), LoopPred, Passthru);
530-
531- Value *VectorRhsGep =
532- Builder.CreateGEP (LoadType, PtrB, VectorIndexPhi, " " , GEPB->isInBounds ());
533- Value *VectorRhsLoad = Builder.CreateMaskedLoad (VectorLoadType, VectorRhsGep,
534- Align (1 ), LoopPred, Passthru);
535-
536- Value *VectorMatchCmp = Builder.CreateICmpNE (VectorLhsLoad, VectorRhsLoad);
537- VectorMatchCmp = Builder.CreateSelect (LoopPred, VectorMatchCmp, PFalse);
538- Value *VectorMatchHasActiveLanes = Builder.CreateOrReduce (VectorMatchCmp);
539- BranchInst *VectorEarlyExit = BranchInst::Create (
540- VectorLoopMismatchBlock, VectorLoopIncBlock, VectorMatchHasActiveLanes);
541- Builder.Insert (VectorEarlyExit);
542-
543- // Increment the index counter and calculate the predicate for the next
544- // iteration of the loop. We branch back to the start of the loop if there
545- // is at least one active lane.
546- Builder.SetInsertPoint (VectorLoopIncBlock);
547- Value *NewVectorIndexPhi =
548- Builder.CreateAdd (VectorIndexPhi, VecLen, " " ,
549- /* HasNUW=*/ true , /* HasNSW=*/ true );
550- VectorIndexPhi->addIncoming (NewVectorIndexPhi, VectorLoopIncBlock);
551- Value *NewPred =
552- Builder.CreateIntrinsic (Intrinsic::get_active_lane_mask,
553- {PredVTy, I64Type}, {NewVectorIndexPhi, ExtEnd});
554- LoopPred->addIncoming (NewPred, VectorLoopIncBlock);
555-
556- Value *PredHasActiveLanes =
557- Builder.CreateExtractElement (NewPred, uint64_t (0 ));
558- BranchInst *VectorLoopBranchBack =
559- BranchInst::Create (VectorLoopStartBlock, EndBlock, PredHasActiveLanes);
560- Builder.Insert (VectorLoopBranchBack);
561-
562- // If we found a mismatch then we need to calculate which lane in the vector
563- // had a mismatch and add that on to the current loop index.
564- Builder.SetInsertPoint (VectorLoopMismatchBlock);
565- PHINode *FoundPred = Builder.CreatePHI (PredVTy, 1 , " mismatch_vec_found_pred" );
566- FoundPred->addIncoming (VectorMatchCmp, VectorLoopStartBlock);
567- PHINode *LastLoopPred =
568- Builder.CreatePHI (PredVTy, 1 , " mismatch_vec_last_loop_pred" );
569- LastLoopPred->addIncoming (LoopPred, VectorLoopStartBlock);
570- PHINode *VectorFoundIndex =
571- Builder.CreatePHI (I64Type, 1 , " mismatch_vec_found_index" );
572- VectorFoundIndex->addIncoming (VectorIndexPhi, VectorLoopStartBlock);
573-
574- Value *PredMatchCmp = Builder.CreateAnd (LastLoopPred, FoundPred);
575- Value *Ctz = Builder.CreateIntrinsic (
576- Intrinsic::experimental_cttz_elts, {ResType, PredMatchCmp->getType ()},
577- {PredMatchCmp, /* ZeroIsPoison=*/ Builder.getInt1 (true )});
578- Ctz = Builder.CreateZExt (Ctz, I64Type);
579- Value *VectorLoopRes64 = Builder.CreateAdd (VectorFoundIndex, Ctz, " " ,
580- /* HasNUW=*/ true , /* HasNSW=*/ true );
581- Value *VectorLoopRes = Builder.CreateTrunc (VectorLoopRes64, ResType);
607+ Value *VectorLoopRes =
608+ createMaskedFindMismatch (Builder, GEPA, GEPB, ExtStart, ExtEnd);
582609
583610 Builder.Insert (BranchInst::Create (EndBlock));
584611
0 commit comments