Skip to content

Commit 055db3e

Browse files
authored
[LV] Optimise latch exit induction users for some early exit loops (#128880)
This is the first of two PRs that attempts to improve the IR generated in the exit blocks of vectorised loops with uncountable early exits. In this PR I am improving the generated code for users of induction variables in early exit loops that have a unique exit block, when exiting via the latch. I have moved some of the code for calculating the exit values in latch exit blocks from `optimizeInductionExitUsers` into a new function `optimizeLatchExitInductionUser`. I intend to follow this up very soon with another patch to optimise the code for induction users in the vector.early.exit block.
1 parent 1ff10fa commit 055db3e

File tree

2 files changed

+191
-100
lines changed

2 files changed

+191
-100
lines changed

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 63 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -737,67 +737,76 @@ static VPWidenInductionRecipe *getOptimizableIVOf(VPValue *VPV) {
737737
return IsWideIVInc() ? WideIV : nullptr;
738738
}
739739

740-
void VPlanTransforms::optimizeInductionExitUsers(
741-
VPlan &Plan, DenseMap<VPValue *, VPValue *> &EndValues) {
740+
/// Attempts to optimize the induction variable exit values for users in the
741+
/// exit block coming from the latch in the original scalar loop.
742+
static VPValue *
743+
optimizeLatchExitInductionUser(VPlan &Plan, VPTypeAnalysis &TypeInfo,
744+
VPBlockBase *PredVPBB, VPValue *Op,
745+
DenseMap<VPValue *, VPValue *> &EndValues) {
742746
using namespace VPlanPatternMatch;
743-
SmallVector<VPIRBasicBlock *> ExitVPBBs(Plan.getExitBlocks());
744-
if (ExitVPBBs.size() != 1)
745-
return;
746747

747-
VPIRBasicBlock *ExitVPBB = ExitVPBBs[0];
748-
VPBlockBase *PredVPBB = ExitVPBB->getSinglePredecessor();
749-
if (!PredVPBB)
750-
return;
751-
assert(PredVPBB == Plan.getMiddleBlock() &&
752-
"predecessor must be the middle block");
753-
754-
VPTypeAnalysis TypeInfo(Plan.getCanonicalIV()->getScalarType());
755-
VPBuilder B(Plan.getMiddleBlock()->getTerminator());
756-
for (VPRecipeBase &R : *ExitVPBB) {
757-
auto *ExitIRI = cast<VPIRInstruction>(&R);
758-
if (!isa<PHINode>(ExitIRI->getInstruction()))
759-
break;
748+
VPValue *Incoming;
749+
if (!match(Op, m_VPInstruction<VPInstruction::ExtractFromEnd>(
750+
m_VPValue(Incoming), m_SpecificInt(1))))
751+
return nullptr;
760752

761-
VPValue *Incoming;
762-
if (!match(ExitIRI->getOperand(0),
763-
m_VPInstruction<VPInstruction::ExtractFromEnd>(
764-
m_VPValue(Incoming), m_SpecificInt(1))))
765-
continue;
753+
auto *WideIV = getOptimizableIVOf(Incoming);
754+
if (!WideIV)
755+
return nullptr;
766756

767-
auto *WideIV = getOptimizableIVOf(Incoming);
768-
if (!WideIV)
769-
continue;
770-
VPValue *EndValue = EndValues.lookup(WideIV);
771-
assert(EndValue && "end value must have been pre-computed");
757+
VPValue *EndValue = EndValues.lookup(WideIV);
758+
assert(EndValue && "end value must have been pre-computed");
759+
760+
// `getOptimizableIVOf()` always returns the pre-incremented IV, so if it
761+
// changed it means the exit is using the incremented value, so we don't
762+
// need to subtract the step.
763+
if (Incoming != WideIV)
764+
return EndValue;
765+
766+
// Otherwise, subtract the step from the EndValue.
767+
VPBuilder B(cast<VPBasicBlock>(PredVPBB)->getTerminator());
768+
VPValue *Step = WideIV->getStepValue();
769+
Type *ScalarTy = TypeInfo.inferScalarType(WideIV);
770+
if (ScalarTy->isIntegerTy())
771+
return B.createNaryOp(Instruction::Sub, {EndValue, Step}, {}, "ind.escape");
772+
if (ScalarTy->isPointerTy()) {
773+
auto *Zero = Plan.getOrAddLiveIn(
774+
ConstantInt::get(Step->getLiveInIRValue()->getType(), 0));
775+
return B.createPtrAdd(EndValue,
776+
B.createNaryOp(Instruction::Sub, {Zero, Step}), {},
777+
"ind.escape");
778+
}
779+
if (ScalarTy->isFloatingPointTy()) {
780+
const auto &ID = WideIV->getInductionDescriptor();
781+
return B.createNaryOp(
782+
ID.getInductionBinOp()->getOpcode() == Instruction::FAdd
783+
? Instruction::FSub
784+
: Instruction::FAdd,
785+
{EndValue, Step}, {ID.getInductionBinOp()->getFastMathFlags()});
786+
}
787+
llvm_unreachable("all possible induction types must be handled");
788+
return nullptr;
789+
}
772790

773-
if (Incoming != WideIV) {
774-
ExitIRI->setOperand(0, EndValue);
775-
continue;
776-
}
791+
void VPlanTransforms::optimizeInductionExitUsers(
792+
VPlan &Plan, DenseMap<VPValue *, VPValue *> &EndValues) {
793+
VPBlockBase *MiddleVPBB = Plan.getMiddleBlock();
794+
VPTypeAnalysis TypeInfo(Plan.getCanonicalIV()->getScalarType());
795+
for (VPIRBasicBlock *ExitVPBB : Plan.getExitBlocks()) {
796+
for (VPRecipeBase &R : *ExitVPBB) {
797+
auto *ExitIRI = cast<VPIRInstruction>(&R);
798+
if (!isa<PHINode>(ExitIRI->getInstruction()))
799+
break;
777800

778-
VPValue *Escape = nullptr;
779-
VPValue *Step = WideIV->getStepValue();
780-
Type *ScalarTy = TypeInfo.inferScalarType(WideIV);
781-
if (ScalarTy->isIntegerTy()) {
782-
Escape =
783-
B.createNaryOp(Instruction::Sub, {EndValue, Step}, {}, "ind.escape");
784-
} else if (ScalarTy->isPointerTy()) {
785-
auto *Zero = Plan.getOrAddLiveIn(
786-
ConstantInt::get(Step->getLiveInIRValue()->getType(), 0));
787-
Escape = B.createPtrAdd(EndValue,
788-
B.createNaryOp(Instruction::Sub, {Zero, Step}),
789-
{}, "ind.escape");
790-
} else if (ScalarTy->isFloatingPointTy()) {
791-
const auto &ID = WideIV->getInductionDescriptor();
792-
Escape = B.createNaryOp(
793-
ID.getInductionBinOp()->getOpcode() == Instruction::FAdd
794-
? Instruction::FSub
795-
: Instruction::FAdd,
796-
{EndValue, Step}, {ID.getInductionBinOp()->getFastMathFlags()});
797-
} else {
798-
llvm_unreachable("all possible induction types must be handled");
801+
for (auto [Idx, PredVPBB] : enumerate(ExitVPBB->getPredecessors())) {
802+
if (PredVPBB == MiddleVPBB)
803+
if (VPValue *Escape = optimizeLatchExitInductionUser(
804+
Plan, TypeInfo, PredVPBB, ExitIRI->getOperand(Idx),
805+
EndValues))
806+
ExitIRI->setOperand(Idx, Escape);
807+
// TODO: Optimize early exit induction users in follow-on patch.
808+
}
799809
}
800-
ExitIRI->setOperand(0, Escape);
801810
}
802811
}
803812

0 commit comments

Comments
 (0)