Skip to content

Commit 5de0579

Browse files
committed
[LV][NFC] Refactor code for extracting first active element
Refactor the code to extract the first active element of a vector in the early exit block, in preparation for PR #130766. I've replaced the VPInstruction::ExtractFirstActive nodes with a combination of a new VPInstruction::FirstActiveLane node and a Instruction::ExtractElement node.
1 parent b910610 commit 5de0579

File tree

5 files changed

+46
-30
lines changed

5 files changed

+46
-30
lines changed

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -879,9 +879,8 @@ class VPInstruction : public VPRecipeWithIRFlags,
879879
// Returns a scalar boolean value, which is true if any lane of its (only
880880
// boolean) vector operand is true.
881881
AnyOf,
882-
// Extracts the first active lane of a vector, where the first operand is
883-
// the predicate, and the second operand is the vector to extract.
884-
ExtractFirstActive,
882+
// Calculates the first active lane index of the vector predicate operand.
883+
FirstActiveLane,
885884
};
886885

887886
private:

llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
5050
return SetResultTyFromOp();
5151

5252
switch (Opcode) {
53+
case Instruction::ExtractElement:
54+
return inferScalarType(R->getOperand(0));
5355
case Instruction::Select: {
5456
Type *ResTy = inferScalarType(R->getOperand(1));
5557
VPValue *OtherV = R->getOperand(2);
@@ -78,7 +80,8 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
7880
case VPInstruction::CanonicalIVIncrementForPart:
7981
case VPInstruction::AnyOf:
8082
return SetResultTyFromOp();
81-
case VPInstruction::ExtractFirstActive:
83+
case VPInstruction::FirstActiveLane:
84+
return Type::getIntNTy(Ctx, 64);
8285
case VPInstruction::ExtractFromEnd: {
8386
Type *BaseTy = inferScalarType(R->getOperand(0));
8487
if (auto *VecTy = dyn_cast<VectorType>(BaseTy))

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,11 @@ Value *VPInstruction::generate(VPTransformState &State) {
461461
Value *A = State.get(getOperand(0));
462462
return Builder.CreateNot(A, Name);
463463
}
464+
case Instruction::ExtractElement: {
465+
Value *Vec = State.get(getOperand(0));
466+
Value *Idx = State.get(getOperand(1), true);
467+
return Builder.CreateExtractElement(Vec, Idx, Name);
468+
}
464469
case Instruction::ICmp: {
465470
bool OnlyFirstLaneUsed = vputils::onlyFirstLaneUsed(this);
466471
Value *A = State.get(getOperand(0), OnlyFirstLaneUsed);
@@ -705,12 +710,10 @@ Value *VPInstruction::generate(VPTransformState &State) {
705710
Value *A = State.get(getOperand(0));
706711
return Builder.CreateOrReduce(A);
707712
}
708-
case VPInstruction::ExtractFirstActive: {
709-
Value *Vec = State.get(getOperand(0));
710-
Value *Mask = State.get(getOperand(1));
711-
Value *Ctz = Builder.CreateCountTrailingZeroElems(
712-
Builder.getInt64Ty(), Mask, true, "first.active.lane");
713-
return Builder.CreateExtractElement(Vec, Ctz, "early.exit.value");
713+
case VPInstruction::FirstActiveLane: {
714+
Value *Mask = State.get(getOperand(0));
715+
return Builder.CreateCountTrailingZeroElems(Builder.getInt64Ty(), Mask,
716+
true, Name);
714717
}
715718
default:
716719
llvm_unreachable("Unsupported opcode for instruction");
@@ -737,22 +740,24 @@ InstructionCost VPInstruction::computeCost(ElementCount VF,
737740
}
738741

739742
switch (getOpcode()) {
743+
case Instruction::ExtractElement: {
744+
// Add on the cost of extracting the element.
745+
auto *VecTy = toVectorTy(Ctx.Types.inferScalarType(getOperand(0)), VF);
746+
return Ctx.TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy,
747+
Ctx.CostKind);
748+
}
740749
case VPInstruction::AnyOf: {
741750
auto *VecTy = toVectorTy(Ctx.Types.inferScalarType(this), VF);
742751
return Ctx.TTI.getArithmeticReductionCost(
743752
Instruction::Or, cast<VectorType>(VecTy), std::nullopt, Ctx.CostKind);
744753
}
745-
case VPInstruction::ExtractFirstActive: {
754+
case VPInstruction::FirstActiveLane: {
746755
// Calculate the cost of determining the lane index.
747-
auto *PredTy = toVectorTy(Ctx.Types.inferScalarType(getOperand(1)), VF);
756+
auto *PredTy = toVectorTy(Ctx.Types.inferScalarType(getOperand(0)), VF);
748757
IntrinsicCostAttributes Attrs(Intrinsic::experimental_cttz_elts,
749758
Type::getInt64Ty(Ctx.LLVMCtx),
750759
{PredTy, Type::getInt1Ty(Ctx.LLVMCtx)});
751-
InstructionCost Cost = Ctx.TTI.getIntrinsicInstrCost(Attrs, Ctx.CostKind);
752-
// Add on the cost of extracting the element.
753-
auto *VecTy = toVectorTy(Ctx.Types.inferScalarType(getOperand(0)), VF);
754-
return Cost + Ctx.TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy,
755-
Ctx.CostKind);
760+
return Ctx.TTI.getIntrinsicInstrCost(Attrs, Ctx.CostKind);
756761
}
757762
default:
758763
// TODO: Compute cost other VPInstructions once the legacy cost model has
@@ -765,7 +770,8 @@ InstructionCost VPInstruction::computeCost(ElementCount VF,
765770

766771
bool VPInstruction::isVectorToScalar() const {
767772
return getOpcode() == VPInstruction::ExtractFromEnd ||
768-
getOpcode() == VPInstruction::ExtractFirstActive ||
773+
getOpcode() == Instruction::ExtractElement ||
774+
getOpcode() == VPInstruction::FirstActiveLane ||
769775
getOpcode() == VPInstruction::ComputeReductionResult ||
770776
getOpcode() == VPInstruction::AnyOf;
771777
}
@@ -824,13 +830,14 @@ bool VPInstruction::opcodeMayReadOrWriteFromMemory() const {
824830
if (Instruction::isBinaryOp(getOpcode()))
825831
return false;
826832
switch (getOpcode()) {
833+
case Instruction::ExtractElement:
827834
case Instruction::ICmp:
828835
case Instruction::Select:
829836
case VPInstruction::AnyOf:
830837
case VPInstruction::CalculateTripCountMinusVF:
831838
case VPInstruction::CanonicalIVIncrementForPart:
832839
case VPInstruction::ExtractFromEnd:
833-
case VPInstruction::ExtractFirstActive:
840+
case VPInstruction::FirstActiveLane:
834841
case VPInstruction::FirstOrderRecurrenceSplice:
835842
case VPInstruction::LogicalAnd:
836843
case VPInstruction::Not:
@@ -939,7 +946,6 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
939946
case VPInstruction::Broadcast:
940947
O << "broadcast";
941948
break;
942-
943949
case VPInstruction::ExtractFromEnd:
944950
O << "extract-from-end";
945951
break;
@@ -955,8 +961,8 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
955961
case VPInstruction::AnyOf:
956962
O << "any-of";
957963
break;
958-
case VPInstruction::ExtractFirstActive:
959-
O << "extract-first-active";
964+
case VPInstruction::FirstActiveLane:
965+
O << "first-active-lane";
960966
break;
961967
default:
962968
O << Instruction::getOpcodeName(getOpcode());

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2157,10 +2157,14 @@ void VPlanTransforms::handleUncountableEarlyExit(
21572157
ExitIRI->extractLastLaneOfOperand(MiddleBuilder);
21582158
}
21592159
// Add the incoming value from the early exit.
2160-
if (!IncomingFromEarlyExit->isLiveIn())
2161-
IncomingFromEarlyExit =
2162-
EarlyExitB.createNaryOp(VPInstruction::ExtractFirstActive,
2163-
{IncomingFromEarlyExit, EarlyExitTakenCond});
2160+
if (!IncomingFromEarlyExit->isLiveIn()) {
2161+
VPValue *FirstActiveLane = EarlyExitB.createNaryOp(
2162+
VPInstruction::FirstActiveLane, {EarlyExitTakenCond}, nullptr,
2163+
"first.active.lane");
2164+
IncomingFromEarlyExit = EarlyExitB.createNaryOp(
2165+
Instruction::ExtractElement, {IncomingFromEarlyExit, FirstActiveLane},
2166+
nullptr, "early.exit.value");
2167+
}
21642168
ExitIRI->addOperand(IncomingFromEarlyExit);
21652169
}
21662170
MiddleBuilder.createNaryOp(VPInstruction::BranchOnCond, {IsEarlyExitTaken});

llvm/test/Transforms/LoopVectorize/AArch64/early_exit_costs.ll

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@ define i64 @same_exit_block_pre_inc_use1_sve() #1 {
1111
; CHECK-LABEL: LV: Checking a loop in 'same_exit_block_pre_inc_use1_sve'
1212
; CHECK: LV: Selecting VF: vscale x 16
1313
; CHECK: Calculating cost of work in exit block vector.early.exit
14-
; CHECK-NEXT: Cost of 6 for VF vscale x 16: EMIT vp<{{.*}}> = extract-first-active
15-
; CHECK-NEXT: Cost of 6 for VF vscale x 16: EMIT vp<{{.*}}> = extract-first-active
14+
; CHECK-NEXT: Cost of 4 for VF vscale x 16: EMIT vp<{{.*}}> = first-active-lane vp<{{.*}}>
15+
; CHECK-NEXT: Cost of 2 for VF vscale x 16: EMIT vp<{{.*}}> = extractelement ir<{{.*}}>, vp<{{.*}}>
16+
; CHECK-NEXT: Cost of 4 for VF vscale x 16: EMIT vp<{{.*}}>.1 = first-active-lane vp<{{.*}}>
17+
; CHECK-NEXT: Cost of 2 for VF vscale x 16: EMIT vp<{{.*}}>.1 = extractelement ir<{{.*}}>, vp<%first.active.lane>.1
1618
; CHECK: LV: Minimum required TC for runtime checks to be profitable:32
1719
entry:
1820
%p1 = alloca [1024 x i8]
@@ -48,8 +50,10 @@ define i64 @same_exit_block_pre_inc_use1_nosve() {
4850
; CHECK-LABEL: LV: Checking a loop in 'same_exit_block_pre_inc_use1_nosve'
4951
; CHECK: LV: Selecting VF: 16
5052
; CHECK: Calculating cost of work in exit block vector.early.exit
51-
; CHECK-NEXT: Cost of 50 for VF 16: EMIT vp<{{.*}}> = extract-first-active
52-
; CHECK-NEXT: Cost of 50 for VF 16: EMIT vp<{{.*}}> = extract-first-active
53+
; CHECK-NEXT: Cost of 48 for VF 16: EMIT vp<{{.*}}> = first-active-lane vp<{{.*}}>
54+
; CHECK-NEXT: Cost of 2 for VF 16: EMIT vp<{{.*}}> = extractelement ir<{{.*}}>, vp<{{.*}}>
55+
; CHECK-NEXT: Cost of 48 for VF 16: EMIT vp<{{.*}}>.1 = first-active-lane vp<{{.*}}>
56+
; CHECK-NEXT: Cost of 2 for VF 16: EMIT vp<{{.*}}>.1 = extractelement ir<{{.*}}>, vp<%first.active.lane>.1
5357
; CHECK: LV: Minimum required TC for runtime checks to be profitable:176
5458
; CHECK-NEXT: LV: Vectorization is not beneficial: expected trip count < minimum profitable VF (64 < 176)
5559
; CHECK-NEXT: LV: Too many memory checks needed.

0 commit comments

Comments
 (0)