@@ -79,6 +79,13 @@ class LoopIdiomVectorize {
79
79
const TargetTransformInfo *TTI;
80
80
const DataLayout *DL;
81
81
82
+ // Blocks that will be used for inserting vectorized code.
83
+ BasicBlock *EndBlock = nullptr ;
84
+ BasicBlock *VectorLoopPreheaderBlock = nullptr ;
85
+ BasicBlock *VectorLoopStartBlock = nullptr ;
86
+ BasicBlock *VectorLoopMismatchBlock = nullptr ;
87
+ BasicBlock *VectorLoopIncBlock = nullptr ;
88
+
82
89
public:
83
90
explicit LoopIdiomVectorize (DominatorTree *DT, LoopInfo *LI,
84
91
const TargetTransformInfo *TTI,
@@ -96,9 +103,15 @@ class LoopIdiomVectorize {
96
103
SmallVectorImpl<BasicBlock *> &ExitBlocks);
97
104
98
105
bool recognizeByteCompare ();
106
+
99
107
Value *expandFindMismatch (IRBuilder<> &Builder, DomTreeUpdater &DTU,
100
108
GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
101
109
Instruction *Index, Value *Start, Value *MaxLen);
110
+
111
+ Value *createMaskedFindMismatch (IRBuilder<> &Builder, GetElementPtrInst *GEPA,
112
+ GetElementPtrInst *GEPB, Value *ExtStart,
113
+ Value *ExtEnd);
114
+
102
115
void transformByteCompare (GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
103
116
PHINode *IndPhi, Value *MaxLen, Instruction *Index,
104
117
Value *Start, bool IncIdx, BasicBlock *FoundBB,
@@ -332,6 +345,106 @@ bool LoopIdiomVectorize::recognizeByteCompare() {
332
345
return true ;
333
346
}
334
347
348
+ Value *LoopIdiomVectorize::createMaskedFindMismatch (IRBuilder<> &Builder,
349
+ GetElementPtrInst *GEPA,
350
+ GetElementPtrInst *GEPB,
351
+ Value *ExtStart,
352
+ Value *ExtEnd) {
353
+ Type *I64Type = Builder.getInt64Ty ();
354
+ Type *ResType = Builder.getInt32Ty ();
355
+ Type *LoadType = Builder.getInt8Ty ();
356
+ Value *PtrA = GEPA->getPointerOperand ();
357
+ Value *PtrB = GEPB->getPointerOperand ();
358
+
359
+ // At this point we know two things must be true:
360
+ // 1. Start <= End
361
+ // 2. ExtMaxLen <= MinPageSize due to the page checks.
362
+ // Therefore, we know that we can use a 64-bit induction variable that
363
+ // starts from 0 -> ExtMaxLen and it will not overflow.
364
+ ScalableVectorType *PredVTy =
365
+ ScalableVectorType::get (Builder.getInt1Ty (), 16 );
366
+
367
+ Value *InitialPred = Builder.CreateIntrinsic (
368
+ Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});
369
+
370
+ Value *VecLen = Builder.CreateIntrinsic (Intrinsic::vscale, {I64Type}, {});
371
+ VecLen = Builder.CreateMul (VecLen, ConstantInt::get (I64Type, 16 ), " " ,
372
+ /* HasNUW=*/ true , /* HasNSW=*/ true );
373
+
374
+ Value *PFalse = Builder.CreateVectorSplat (PredVTy->getElementCount (),
375
+ Builder.getInt1 (false ));
376
+
377
+ BranchInst *JumpToVectorLoop = BranchInst::Create (VectorLoopStartBlock);
378
+ Builder.Insert (JumpToVectorLoop);
379
+
380
+ // Set up the first vector loop block by creating the PHIs, doing the vector
381
+ // loads and comparing the vectors.
382
+ Builder.SetInsertPoint (VectorLoopStartBlock);
383
+ PHINode *LoopPred = Builder.CreatePHI (PredVTy, 2 , " mismatch_vec_loop_pred" );
384
+ LoopPred->addIncoming (InitialPred, VectorLoopPreheaderBlock);
385
+ PHINode *VectorIndexPhi = Builder.CreatePHI (I64Type, 2 , " mismatch_vec_index" );
386
+ VectorIndexPhi->addIncoming (ExtStart, VectorLoopPreheaderBlock);
387
+ Type *VectorLoadType = ScalableVectorType::get (Builder.getInt8Ty (), 16 );
388
+ Value *Passthru = ConstantInt::getNullValue (VectorLoadType);
389
+
390
+ Value *VectorLhsGep =
391
+ Builder.CreateGEP (LoadType, PtrA, VectorIndexPhi, " " , GEPA->isInBounds ());
392
+ Value *VectorLhsLoad = Builder.CreateMaskedLoad (VectorLoadType, VectorLhsGep,
393
+ Align (1 ), LoopPred, Passthru);
394
+
395
+ Value *VectorRhsGep =
396
+ Builder.CreateGEP (LoadType, PtrB, VectorIndexPhi, " " , GEPB->isInBounds ());
397
+ Value *VectorRhsLoad = Builder.CreateMaskedLoad (VectorLoadType, VectorRhsGep,
398
+ Align (1 ), LoopPred, Passthru);
399
+
400
+ Value *VectorMatchCmp = Builder.CreateICmpNE (VectorLhsLoad, VectorRhsLoad);
401
+ VectorMatchCmp = Builder.CreateSelect (LoopPred, VectorMatchCmp, PFalse);
402
+ Value *VectorMatchHasActiveLanes = Builder.CreateOrReduce (VectorMatchCmp);
403
+ BranchInst *VectorEarlyExit = BranchInst::Create (
404
+ VectorLoopMismatchBlock, VectorLoopIncBlock, VectorMatchHasActiveLanes);
405
+ Builder.Insert (VectorEarlyExit);
406
+
407
+ // Increment the index counter and calculate the predicate for the next
408
+ // iteration of the loop. We branch back to the start of the loop if there
409
+ // is at least one active lane.
410
+ Builder.SetInsertPoint (VectorLoopIncBlock);
411
+ Value *NewVectorIndexPhi =
412
+ Builder.CreateAdd (VectorIndexPhi, VecLen, " " ,
413
+ /* HasNUW=*/ true , /* HasNSW=*/ true );
414
+ VectorIndexPhi->addIncoming (NewVectorIndexPhi, VectorLoopIncBlock);
415
+ Value *NewPred =
416
+ Builder.CreateIntrinsic (Intrinsic::get_active_lane_mask,
417
+ {PredVTy, I64Type}, {NewVectorIndexPhi, ExtEnd});
418
+ LoopPred->addIncoming (NewPred, VectorLoopIncBlock);
419
+
420
+ Value *PredHasActiveLanes =
421
+ Builder.CreateExtractElement (NewPred, uint64_t (0 ));
422
+ BranchInst *VectorLoopBranchBack =
423
+ BranchInst::Create (VectorLoopStartBlock, EndBlock, PredHasActiveLanes);
424
+ Builder.Insert (VectorLoopBranchBack);
425
+
426
+ // If we found a mismatch then we need to calculate which lane in the vector
427
+ // had a mismatch and add that on to the current loop index.
428
+ Builder.SetInsertPoint (VectorLoopMismatchBlock);
429
+ PHINode *FoundPred = Builder.CreatePHI (PredVTy, 1 , " mismatch_vec_found_pred" );
430
+ FoundPred->addIncoming (VectorMatchCmp, VectorLoopStartBlock);
431
+ PHINode *LastLoopPred =
432
+ Builder.CreatePHI (PredVTy, 1 , " mismatch_vec_last_loop_pred" );
433
+ LastLoopPred->addIncoming (LoopPred, VectorLoopStartBlock);
434
+ PHINode *VectorFoundIndex =
435
+ Builder.CreatePHI (I64Type, 1 , " mismatch_vec_found_index" );
436
+ VectorFoundIndex->addIncoming (VectorIndexPhi, VectorLoopStartBlock);
437
+
438
+ Value *PredMatchCmp = Builder.CreateAnd (LastLoopPred, FoundPred);
439
+ Value *Ctz = Builder.CreateIntrinsic (
440
+ Intrinsic::experimental_cttz_elts, {ResType , PredMatchCmp->getType ()},
441
+ {PredMatchCmp, /* ZeroIsPoison=*/ Builder.getInt1 (true )});
442
+ Ctz = Builder.CreateZExt (Ctz, I64Type);
443
+ Value *VectorLoopRes64 = Builder.CreateAdd (VectorFoundIndex, Ctz, " " ,
444
+ /* HasNUW=*/ true , /* HasNSW=*/ true );
445
+ return Builder.CreateTrunc (VectorLoopRes64, ResType );
446
+ }
447
+
335
448
Value *LoopIdiomVectorize::expandFindMismatch (
336
449
IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
337
450
GetElementPtrInst *GEPB, Instruction *Index, Value *Start, Value *MaxLen) {
@@ -346,8 +459,7 @@ Value *LoopIdiomVectorize::expandFindMismatch(
346
459
Type *ResType = Builder.getInt32Ty ();
347
460
348
461
// Split block in the original loop preheader.
349
- BasicBlock *EndBlock =
350
- SplitBlock (Preheader, PHBranch, DT, LI, nullptr , " mismatch_end" );
462
+ EndBlock = SplitBlock (Preheader, PHBranch, DT, LI, nullptr , " mismatch_end" );
351
463
352
464
// Create the blocks that we're going to need:
353
465
// 1. A block for checking the zero-extended length exceeds 0
@@ -371,17 +483,17 @@ Value *LoopIdiomVectorize::expandFindMismatch(
371
483
BasicBlock *MemCheckBlock = BasicBlock::Create (
372
484
Ctx, " mismatch_mem_check" , EndBlock->getParent (), EndBlock);
373
485
374
- BasicBlock * VectorLoopPreheaderBlock = BasicBlock::Create (
486
+ VectorLoopPreheaderBlock = BasicBlock::Create (
375
487
Ctx, " mismatch_vec_loop_preheader" , EndBlock->getParent (), EndBlock);
376
488
377
- BasicBlock * VectorLoopStartBlock = BasicBlock::Create (
378
- Ctx, " mismatch_vec_loop " , EndBlock->getParent (), EndBlock);
489
+ VectorLoopStartBlock = BasicBlock::Create (Ctx, " mismatch_vec_loop " ,
490
+ EndBlock->getParent (), EndBlock);
379
491
380
- BasicBlock * VectorLoopIncBlock = BasicBlock::Create (
381
- Ctx, " mismatch_vec_loop_inc " , EndBlock->getParent (), EndBlock);
492
+ VectorLoopIncBlock = BasicBlock::Create (Ctx, " mismatch_vec_loop_inc " ,
493
+ EndBlock->getParent (), EndBlock);
382
494
383
- BasicBlock * VectorLoopMismatchBlock = BasicBlock::Create (
384
- Ctx, " mismatch_vec_loop_found " , EndBlock->getParent (), EndBlock);
495
+ VectorLoopMismatchBlock = BasicBlock::Create (Ctx, " mismatch_vec_loop_found " ,
496
+ EndBlock->getParent (), EndBlock);
385
497
386
498
BasicBlock *LoopPreHeaderBlock = BasicBlock::Create (
387
499
Ctx, " mismatch_loop_pre" , EndBlock->getParent (), EndBlock);
@@ -492,93 +604,8 @@ Value *LoopIdiomVectorize::expandFindMismatch(
492
604
// processed in each iteration, etc.
493
605
Builder.SetInsertPoint (VectorLoopPreheaderBlock);
494
606
495
- // At this point we know two things must be true:
496
- // 1. Start <= End
497
- // 2. ExtMaxLen <= MinPageSize due to the page checks.
498
- // Therefore, we know that we can use a 64-bit induction variable that
499
- // starts from 0 -> ExtMaxLen and it will not overflow.
500
- ScalableVectorType *PredVTy =
501
- ScalableVectorType::get (Builder.getInt1Ty (), 16 );
502
-
503
- Value *InitialPred = Builder.CreateIntrinsic (
504
- Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});
505
-
506
- Value *VecLen = Builder.CreateIntrinsic (Intrinsic::vscale, {I64Type}, {});
507
- VecLen = Builder.CreateMul (VecLen, ConstantInt::get (I64Type, 16 ), " " ,
508
- /* HasNUW=*/ true , /* HasNSW=*/ true );
509
-
510
- Value *PFalse = Builder.CreateVectorSplat (PredVTy->getElementCount (),
511
- Builder.getInt1 (false ));
512
-
513
- BranchInst *JumpToVectorLoop = BranchInst::Create (VectorLoopStartBlock);
514
- Builder.Insert (JumpToVectorLoop);
515
-
516
- // Set up the first vector loop block by creating the PHIs, doing the vector
517
- // loads and comparing the vectors.
518
- Builder.SetInsertPoint (VectorLoopStartBlock);
519
- PHINode *LoopPred = Builder.CreatePHI (PredVTy, 2 , " mismatch_vec_loop_pred" );
520
- LoopPred->addIncoming (InitialPred, VectorLoopPreheaderBlock);
521
- PHINode *VectorIndexPhi = Builder.CreatePHI (I64Type, 2 , " mismatch_vec_index" );
522
- VectorIndexPhi->addIncoming (ExtStart, VectorLoopPreheaderBlock);
523
- Type *VectorLoadType = ScalableVectorType::get (Builder.getInt8Ty (), 16 );
524
- Value *Passthru = ConstantInt::getNullValue (VectorLoadType);
525
-
526
- Value *VectorLhsGep =
527
- Builder.CreateGEP (LoadType, PtrA, VectorIndexPhi, " " , GEPA->isInBounds ());
528
- Value *VectorLhsLoad = Builder.CreateMaskedLoad (VectorLoadType, VectorLhsGep,
529
- Align (1 ), LoopPred, Passthru);
530
-
531
- Value *VectorRhsGep =
532
- Builder.CreateGEP (LoadType, PtrB, VectorIndexPhi, " " , GEPB->isInBounds ());
533
- Value *VectorRhsLoad = Builder.CreateMaskedLoad (VectorLoadType, VectorRhsGep,
534
- Align (1 ), LoopPred, Passthru);
535
-
536
- Value *VectorMatchCmp = Builder.CreateICmpNE (VectorLhsLoad, VectorRhsLoad);
537
- VectorMatchCmp = Builder.CreateSelect (LoopPred, VectorMatchCmp, PFalse);
538
- Value *VectorMatchHasActiveLanes = Builder.CreateOrReduce (VectorMatchCmp);
539
- BranchInst *VectorEarlyExit = BranchInst::Create (
540
- VectorLoopMismatchBlock, VectorLoopIncBlock, VectorMatchHasActiveLanes);
541
- Builder.Insert (VectorEarlyExit);
542
-
543
- // Increment the index counter and calculate the predicate for the next
544
- // iteration of the loop. We branch back to the start of the loop if there
545
- // is at least one active lane.
546
- Builder.SetInsertPoint (VectorLoopIncBlock);
547
- Value *NewVectorIndexPhi =
548
- Builder.CreateAdd (VectorIndexPhi, VecLen, " " ,
549
- /* HasNUW=*/ true , /* HasNSW=*/ true );
550
- VectorIndexPhi->addIncoming (NewVectorIndexPhi, VectorLoopIncBlock);
551
- Value *NewPred =
552
- Builder.CreateIntrinsic (Intrinsic::get_active_lane_mask,
553
- {PredVTy, I64Type}, {NewVectorIndexPhi, ExtEnd});
554
- LoopPred->addIncoming (NewPred, VectorLoopIncBlock);
555
-
556
- Value *PredHasActiveLanes =
557
- Builder.CreateExtractElement (NewPred, uint64_t (0 ));
558
- BranchInst *VectorLoopBranchBack =
559
- BranchInst::Create (VectorLoopStartBlock, EndBlock, PredHasActiveLanes);
560
- Builder.Insert (VectorLoopBranchBack);
561
-
562
- // If we found a mismatch then we need to calculate which lane in the vector
563
- // had a mismatch and add that on to the current loop index.
564
- Builder.SetInsertPoint (VectorLoopMismatchBlock);
565
- PHINode *FoundPred = Builder.CreatePHI (PredVTy, 1 , " mismatch_vec_found_pred" );
566
- FoundPred->addIncoming (VectorMatchCmp, VectorLoopStartBlock);
567
- PHINode *LastLoopPred =
568
- Builder.CreatePHI (PredVTy, 1 , " mismatch_vec_last_loop_pred" );
569
- LastLoopPred->addIncoming (LoopPred, VectorLoopStartBlock);
570
- PHINode *VectorFoundIndex =
571
- Builder.CreatePHI (I64Type, 1 , " mismatch_vec_found_index" );
572
- VectorFoundIndex->addIncoming (VectorIndexPhi, VectorLoopStartBlock);
573
-
574
- Value *PredMatchCmp = Builder.CreateAnd (LastLoopPred, FoundPred);
575
- Value *Ctz = Builder.CreateIntrinsic (
576
- Intrinsic::experimental_cttz_elts, {ResType , PredMatchCmp->getType ()},
577
- {PredMatchCmp, /* ZeroIsPoison=*/ Builder.getInt1 (true )});
578
- Ctz = Builder.CreateZExt (Ctz, I64Type);
579
- Value *VectorLoopRes64 = Builder.CreateAdd (VectorFoundIndex, Ctz, " " ,
580
- /* HasNUW=*/ true , /* HasNSW=*/ true );
581
- Value *VectorLoopRes = Builder.CreateTrunc (VectorLoopRes64, ResType );
607
+ Value *VectorLoopRes =
608
+ createMaskedFindMismatch (Builder, GEPA, GEPB, ExtStart, ExtEnd);
582
609
583
610
Builder.Insert (BranchInst::Create (EndBlock));
584
611
0 commit comments