@@ -78,6 +78,13 @@ class LoopIdiomVectorize {
78
78
const TargetTransformInfo *TTI;
79
79
const DataLayout *DL;
80
80
81
+ // Blocks that will be used for inserting vectorized code.
82
+ BasicBlock *EndBlock = nullptr ;
83
+ BasicBlock *VectorLoopPreheaderBlock = nullptr ;
84
+ BasicBlock *VectorLoopStartBlock = nullptr ;
85
+ BasicBlock *VectorLoopMismatchBlock = nullptr ;
86
+ BasicBlock *VectorLoopIncBlock = nullptr ;
87
+
81
88
public:
82
89
explicit LoopIdiomVectorize (DominatorTree *DT, LoopInfo *LI,
83
90
const TargetTransformInfo *TTI,
@@ -95,9 +102,16 @@ class LoopIdiomVectorize {
95
102
SmallVectorImpl<BasicBlock *> &ExitBlocks);
96
103
97
104
bool recognizeByteCompare ();
105
+
98
106
Value *expandFindMismatch (IRBuilder<> &Builder, DomTreeUpdater &DTU,
99
107
GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
100
108
Instruction *Index, Value *Start, Value *MaxLen);
109
+
110
+ Value *createMaskedFindMismatch (IRBuilder<> &Builder, DomTreeUpdater &DTU,
111
+ GetElementPtrInst *GEPA,
112
+ GetElementPtrInst *GEPB, Value *ExtStart,
113
+ Value *ExtEnd);
114
+
101
115
void transformByteCompare (GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
102
116
PHINode *IndPhi, Value *MaxLen, Instruction *Index,
103
117
Value *Start, bool IncIdx, BasicBlock *FoundBB,
@@ -331,6 +345,115 @@ bool LoopIdiomVectorize::recognizeByteCompare() {
331
345
return true ;
332
346
}
333
347
348
+ Value *LoopIdiomVectorize::createMaskedFindMismatch (
349
+ IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
350
+ GetElementPtrInst *GEPB, Value *ExtStart, Value *ExtEnd) {
351
+ Type *I64Type = Builder.getInt64Ty ();
352
+ Type *ResType = Builder.getInt32Ty ();
353
+ Type *LoadType = Builder.getInt8Ty ();
354
+ Value *PtrA = GEPA->getPointerOperand ();
355
+ Value *PtrB = GEPB->getPointerOperand ();
356
+
357
+ // At this point we know two things must be true:
358
+ // 1. Start <= End
359
+ // 2. ExtMaxLen <= MinPageSize due to the page checks.
360
+ // Therefore, we know that we can use a 64-bit induction variable that
361
+ // starts from 0 -> ExtMaxLen and it will not overflow.
362
+ ScalableVectorType *PredVTy =
363
+ ScalableVectorType::get (Builder.getInt1Ty (), 16 );
364
+
365
+ Value *InitialPred = Builder.CreateIntrinsic (
366
+ Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});
367
+
368
+ Value *VecLen = Builder.CreateIntrinsic (Intrinsic::vscale, {I64Type}, {});
369
+ VecLen = Builder.CreateMul (VecLen, ConstantInt::get (I64Type, 16 ), " " ,
370
+ /* HasNUW=*/ true , /* HasNSW=*/ true );
371
+
372
+ Value *PFalse = Builder.CreateVectorSplat (PredVTy->getElementCount (),
373
+ Builder.getInt1 (false ));
374
+
375
+ BranchInst *JumpToVectorLoop = BranchInst::Create (VectorLoopStartBlock);
376
+ Builder.Insert (JumpToVectorLoop);
377
+
378
+ DTU.applyUpdates ({{DominatorTree::Insert, VectorLoopPreheaderBlock,
379
+ VectorLoopStartBlock}});
380
+
381
+ // Set up the first vector loop block by creating the PHIs, doing the vector
382
+ // loads and comparing the vectors.
383
+ Builder.SetInsertPoint (VectorLoopStartBlock);
384
+ PHINode *LoopPred = Builder.CreatePHI (PredVTy, 2 , " mismatch_vec_loop_pred" );
385
+ LoopPred->addIncoming (InitialPred, VectorLoopPreheaderBlock);
386
+ PHINode *VectorIndexPhi = Builder.CreatePHI (I64Type, 2 , " mismatch_vec_index" );
387
+ VectorIndexPhi->addIncoming (ExtStart, VectorLoopPreheaderBlock);
388
+ Type *VectorLoadType = ScalableVectorType::get (Builder.getInt8Ty (), 16 );
389
+ Value *Passthru = ConstantInt::getNullValue (VectorLoadType);
390
+
391
+ Value *VectorLhsGep =
392
+ Builder.CreateGEP (LoadType, PtrA, VectorIndexPhi, " " , GEPA->isInBounds ());
393
+ Value *VectorLhsLoad = Builder.CreateMaskedLoad (VectorLoadType, VectorLhsGep,
394
+ Align (1 ), LoopPred, Passthru);
395
+
396
+ Value *VectorRhsGep =
397
+ Builder.CreateGEP (LoadType, PtrB, VectorIndexPhi, " " , GEPB->isInBounds ());
398
+ Value *VectorRhsLoad = Builder.CreateMaskedLoad (VectorLoadType, VectorRhsGep,
399
+ Align (1 ), LoopPred, Passthru);
400
+
401
+ Value *VectorMatchCmp = Builder.CreateICmpNE (VectorLhsLoad, VectorRhsLoad);
402
+ VectorMatchCmp = Builder.CreateSelect (LoopPred, VectorMatchCmp, PFalse);
403
+ Value *VectorMatchHasActiveLanes = Builder.CreateOrReduce (VectorMatchCmp);
404
+ BranchInst *VectorEarlyExit = BranchInst::Create (
405
+ VectorLoopMismatchBlock, VectorLoopIncBlock, VectorMatchHasActiveLanes);
406
+ Builder.Insert (VectorEarlyExit);
407
+
408
+ DTU.applyUpdates (
409
+ {{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock},
410
+ {DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}});
411
+
412
+ // Increment the index counter and calculate the predicate for the next
413
+ // iteration of the loop. We branch back to the start of the loop if there
414
+ // is at least one active lane.
415
+ Builder.SetInsertPoint (VectorLoopIncBlock);
416
+ Value *NewVectorIndexPhi =
417
+ Builder.CreateAdd (VectorIndexPhi, VecLen, " " ,
418
+ /* HasNUW=*/ true , /* HasNSW=*/ true );
419
+ VectorIndexPhi->addIncoming (NewVectorIndexPhi, VectorLoopIncBlock);
420
+ Value *NewPred =
421
+ Builder.CreateIntrinsic (Intrinsic::get_active_lane_mask,
422
+ {PredVTy, I64Type}, {NewVectorIndexPhi, ExtEnd});
423
+ LoopPred->addIncoming (NewPred, VectorLoopIncBlock);
424
+
425
+ Value *PredHasActiveLanes =
426
+ Builder.CreateExtractElement (NewPred, uint64_t (0 ));
427
+ BranchInst *VectorLoopBranchBack =
428
+ BranchInst::Create (VectorLoopStartBlock, EndBlock, PredHasActiveLanes);
429
+ Builder.Insert (VectorLoopBranchBack);
430
+
431
+ DTU.applyUpdates (
432
+ {{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock},
433
+ {DominatorTree::Insert, VectorLoopIncBlock, EndBlock}});
434
+
435
+ // If we found a mismatch then we need to calculate which lane in the vector
436
+ // had a mismatch and add that on to the current loop index.
437
+ Builder.SetInsertPoint (VectorLoopMismatchBlock);
438
+ PHINode *FoundPred = Builder.CreatePHI (PredVTy, 1 , " mismatch_vec_found_pred" );
439
+ FoundPred->addIncoming (VectorMatchCmp, VectorLoopStartBlock);
440
+ PHINode *LastLoopPred =
441
+ Builder.CreatePHI (PredVTy, 1 , " mismatch_vec_last_loop_pred" );
442
+ LastLoopPred->addIncoming (LoopPred, VectorLoopStartBlock);
443
+ PHINode *VectorFoundIndex =
444
+ Builder.CreatePHI (I64Type, 1 , " mismatch_vec_found_index" );
445
+ VectorFoundIndex->addIncoming (VectorIndexPhi, VectorLoopStartBlock);
446
+
447
+ Value *PredMatchCmp = Builder.CreateAnd (LastLoopPred, FoundPred);
448
+ Value *Ctz = Builder.CreateIntrinsic (
449
+ Intrinsic::experimental_cttz_elts, {ResType , PredMatchCmp->getType ()},
450
+ {PredMatchCmp, /* ZeroIsPoison=*/ Builder.getInt1 (true )});
451
+ Ctz = Builder.CreateZExt (Ctz, I64Type);
452
+ Value *VectorLoopRes64 = Builder.CreateAdd (VectorFoundIndex, Ctz, " " ,
453
+ /* HasNUW=*/ true , /* HasNSW=*/ true );
454
+ return Builder.CreateTrunc (VectorLoopRes64, ResType );
455
+ }
456
+
334
457
Value *LoopIdiomVectorize::expandFindMismatch (
335
458
IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
336
459
GetElementPtrInst *GEPB, Instruction *Index, Value *Start, Value *MaxLen) {
@@ -345,8 +468,7 @@ Value *LoopIdiomVectorize::expandFindMismatch(
345
468
Type *ResType = Builder.getInt32Ty ();
346
469
347
470
// Split block in the original loop preheader.
348
- BasicBlock *EndBlock =
349
- SplitBlock (Preheader, PHBranch, DT, LI, nullptr , " mismatch_end" );
471
+ EndBlock = SplitBlock (Preheader, PHBranch, DT, LI, nullptr , " mismatch_end" );
350
472
351
473
// Create the blocks that we're going to need:
352
474
// 1. A block for checking the zero-extended length exceeds 0
@@ -370,17 +492,17 @@ Value *LoopIdiomVectorize::expandFindMismatch(
370
492
BasicBlock *MemCheckBlock = BasicBlock::Create (
371
493
Ctx, " mismatch_mem_check" , EndBlock->getParent (), EndBlock);
372
494
373
- BasicBlock * VectorLoopPreheaderBlock = BasicBlock::Create (
495
+ VectorLoopPreheaderBlock = BasicBlock::Create (
374
496
Ctx, " mismatch_vec_loop_preheader" , EndBlock->getParent (), EndBlock);
375
497
376
- BasicBlock * VectorLoopStartBlock = BasicBlock::Create (
377
- Ctx, " mismatch_vec_loop " , EndBlock->getParent (), EndBlock);
498
+ VectorLoopStartBlock = BasicBlock::Create (Ctx, " mismatch_vec_loop " ,
499
+ EndBlock->getParent (), EndBlock);
378
500
379
- BasicBlock * VectorLoopIncBlock = BasicBlock::Create (
380
- Ctx, " mismatch_vec_loop_inc " , EndBlock->getParent (), EndBlock);
501
+ VectorLoopIncBlock = BasicBlock::Create (Ctx, " mismatch_vec_loop_inc " ,
502
+ EndBlock->getParent (), EndBlock);
381
503
382
- BasicBlock * VectorLoopMismatchBlock = BasicBlock::Create (
383
- Ctx, " mismatch_vec_loop_found " , EndBlock->getParent (), EndBlock);
504
+ VectorLoopMismatchBlock = BasicBlock::Create (Ctx, " mismatch_vec_loop_found " ,
505
+ EndBlock->getParent (), EndBlock);
384
506
385
507
BasicBlock *LoopPreHeaderBlock = BasicBlock::Create (
386
508
Ctx, " mismatch_loop_pre" , EndBlock->getParent (), EndBlock);
@@ -491,104 +613,8 @@ Value *LoopIdiomVectorize::expandFindMismatch(
491
613
// processed in each iteration, etc.
492
614
Builder.SetInsertPoint (VectorLoopPreheaderBlock);
493
615
494
- // At this point we know two things must be true:
495
- // 1. Start <= End
496
- // 2. ExtMaxLen <= MinPageSize due to the page checks.
497
- // Therefore, we know that we can use a 64-bit induction variable that
498
- // starts from 0 -> ExtMaxLen and it will not overflow.
499
- ScalableVectorType *PredVTy =
500
- ScalableVectorType::get (Builder.getInt1Ty (), 16 );
501
-
502
- Value *InitialPred = Builder.CreateIntrinsic (
503
- Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});
504
-
505
- Value *VecLen = Builder.CreateIntrinsic (Intrinsic::vscale, {I64Type}, {});
506
- VecLen = Builder.CreateMul (VecLen, ConstantInt::get (I64Type, 16 ), " " ,
507
- /* HasNUW=*/ true , /* HasNSW=*/ true );
508
-
509
- Value *PFalse = Builder.CreateVectorSplat (PredVTy->getElementCount (),
510
- Builder.getInt1 (false ));
511
-
512
- BranchInst *JumpToVectorLoop = BranchInst::Create (VectorLoopStartBlock);
513
- Builder.Insert (JumpToVectorLoop);
514
-
515
- DTU.applyUpdates ({{DominatorTree::Insert, VectorLoopPreheaderBlock,
516
- VectorLoopStartBlock}});
517
-
518
- // Set up the first vector loop block by creating the PHIs, doing the vector
519
- // loads and comparing the vectors.
520
- Builder.SetInsertPoint (VectorLoopStartBlock);
521
- PHINode *LoopPred = Builder.CreatePHI (PredVTy, 2 , " mismatch_vec_loop_pred" );
522
- LoopPred->addIncoming (InitialPred, VectorLoopPreheaderBlock);
523
- PHINode *VectorIndexPhi = Builder.CreatePHI (I64Type, 2 , " mismatch_vec_index" );
524
- VectorIndexPhi->addIncoming (ExtStart, VectorLoopPreheaderBlock);
525
- Type *VectorLoadType = ScalableVectorType::get (Builder.getInt8Ty (), 16 );
526
- Value *Passthru = ConstantInt::getNullValue (VectorLoadType);
527
-
528
- Value *VectorLhsGep =
529
- Builder.CreateGEP (LoadType, PtrA, VectorIndexPhi, " " , GEPA->isInBounds ());
530
- Value *VectorLhsLoad = Builder.CreateMaskedLoad (VectorLoadType, VectorLhsGep,
531
- Align (1 ), LoopPred, Passthru);
532
-
533
- Value *VectorRhsGep =
534
- Builder.CreateGEP (LoadType, PtrB, VectorIndexPhi, " " , GEPB->isInBounds ());
535
- Value *VectorRhsLoad = Builder.CreateMaskedLoad (VectorLoadType, VectorRhsGep,
536
- Align (1 ), LoopPred, Passthru);
537
-
538
- Value *VectorMatchCmp = Builder.CreateICmpNE (VectorLhsLoad, VectorRhsLoad);
539
- VectorMatchCmp = Builder.CreateSelect (LoopPred, VectorMatchCmp, PFalse);
540
- Value *VectorMatchHasActiveLanes = Builder.CreateOrReduce (VectorMatchCmp);
541
- BranchInst *VectorEarlyExit = BranchInst::Create (
542
- VectorLoopMismatchBlock, VectorLoopIncBlock, VectorMatchHasActiveLanes);
543
- Builder.Insert (VectorEarlyExit);
544
-
545
- DTU.applyUpdates (
546
- {{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock},
547
- {DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}});
548
-
549
- // Increment the index counter and calculate the predicate for the next
550
- // iteration of the loop. We branch back to the start of the loop if there
551
- // is at least one active lane.
552
- Builder.SetInsertPoint (VectorLoopIncBlock);
553
- Value *NewVectorIndexPhi =
554
- Builder.CreateAdd (VectorIndexPhi, VecLen, " " ,
555
- /* HasNUW=*/ true , /* HasNSW=*/ true );
556
- VectorIndexPhi->addIncoming (NewVectorIndexPhi, VectorLoopIncBlock);
557
- Value *NewPred =
558
- Builder.CreateIntrinsic (Intrinsic::get_active_lane_mask,
559
- {PredVTy, I64Type}, {NewVectorIndexPhi, ExtEnd});
560
- LoopPred->addIncoming (NewPred, VectorLoopIncBlock);
561
-
562
- Value *PredHasActiveLanes =
563
- Builder.CreateExtractElement (NewPred, uint64_t (0 ));
564
- BranchInst *VectorLoopBranchBack =
565
- BranchInst::Create (VectorLoopStartBlock, EndBlock, PredHasActiveLanes);
566
- Builder.Insert (VectorLoopBranchBack);
567
-
568
- DTU.applyUpdates (
569
- {{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock},
570
- {DominatorTree::Insert, VectorLoopIncBlock, EndBlock}});
571
-
572
- // If we found a mismatch then we need to calculate which lane in the vector
573
- // had a mismatch and add that on to the current loop index.
574
- Builder.SetInsertPoint (VectorLoopMismatchBlock);
575
- PHINode *FoundPred = Builder.CreatePHI (PredVTy, 1 , " mismatch_vec_found_pred" );
576
- FoundPred->addIncoming (VectorMatchCmp, VectorLoopStartBlock);
577
- PHINode *LastLoopPred =
578
- Builder.CreatePHI (PredVTy, 1 , " mismatch_vec_last_loop_pred" );
579
- LastLoopPred->addIncoming (LoopPred, VectorLoopStartBlock);
580
- PHINode *VectorFoundIndex =
581
- Builder.CreatePHI (I64Type, 1 , " mismatch_vec_found_index" );
582
- VectorFoundIndex->addIncoming (VectorIndexPhi, VectorLoopStartBlock);
583
-
584
- Value *PredMatchCmp = Builder.CreateAnd (LastLoopPred, FoundPred);
585
- Value *Ctz = Builder.CreateIntrinsic (
586
- Intrinsic::experimental_cttz_elts, {ResType , PredMatchCmp->getType ()},
587
- {PredMatchCmp, /* ZeroIsPoison=*/ Builder.getInt1 (true )});
588
- Ctz = Builder.CreateZExt (Ctz, I64Type);
589
- Value *VectorLoopRes64 = Builder.CreateAdd (VectorFoundIndex, Ctz, " " ,
590
- /* HasNUW=*/ true , /* HasNSW=*/ true );
591
- Value *VectorLoopRes = Builder.CreateTrunc (VectorLoopRes64, ResType );
616
+ Value *VectorLoopRes =
617
+ createMaskedFindMismatch (Builder, DTU, GEPA, GEPB, ExtStart, ExtEnd);
592
618
593
619
Builder.Insert (BranchInst::Create (EndBlock));
594
620
0 commit comments