Skip to content

Commit d5d2c0d

Browse files
committed
[LoopVectorizer] Add support for chaining partial reductions
1 parent ab3e008 commit d5d2c0d

File tree

6 files changed

+619
-26
lines changed

6 files changed

+619
-26
lines changed

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
368368
InstructionCost Invalid = InstructionCost::getInvalid();
369369
InstructionCost Cost(TTI::TCC_Basic);
370370

371-
if (Opcode != Instruction::Add)
371+
if (Opcode != Instruction::Add && Opcode != Instruction::Sub)
372372
return Invalid;
373373

374374
EVT InputEVT = EVT::getEVT(InputType);

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8799,12 +8799,10 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
87998799
/// are valid so recipes can be formed later.
88008800
void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
88018801
// Find all possible partial reductions.
8802-
SmallVector<std::pair<PartialReductionChain, unsigned>, 1>
8802+
SmallVector<std::pair<PartialReductionChain, unsigned>>
88038803
PartialReductionChains;
88048804
for (const auto &[Phi, RdxDesc] : Legal->getReductionVars())
8805-
if (std::optional<std::pair<PartialReductionChain, unsigned>> Pair =
8806-
getScaledReduction(Phi, RdxDesc, Range))
8807-
PartialReductionChains.push_back(*Pair);
8805+
PartialReductionChains.append(getScaledReduction(Phi, RdxDesc.getLoopExitInstr(), Range));
88088806

88098807
// A partial reduction is invalid if any of its extends are used by
88108808
// something that isn't another partial reduction. This is because the
@@ -8832,48 +8830,65 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
88328830
}
88338831
}
88348832

8835-
std::optional<std::pair<PartialReductionChain, unsigned>>
8836-
VPRecipeBuilder::getScaledReduction(PHINode *PHI,
8837-
const RecurrenceDescriptor &Rdx,
8833+
SmallVector<std::pair<PartialReductionChain, unsigned>>
8834+
VPRecipeBuilder::getScaledReduction(Instruction *PHI,
8835+
Instruction *RdxExitInstr,
88388836
VFRange &Range) {
8837+
SmallVector<std::pair<PartialReductionChain, unsigned>> Chains;
8838+
8839+
if(!CM.TheLoop->contains(RdxExitInstr))
8840+
return Chains;
8841+
88398842
// TODO: Allow scaling reductions when predicating. The select at
88408843
// the end of the loop chooses between the phi value and most recent
88418844
// reduction result, both of which have different VFs to the active lane
88428845
// mask when scaling.
8843-
if (CM.blockNeedsPredicationForAnyReason(Rdx.getLoopExitInstr()->getParent()))
8844-
return std::nullopt;
8846+
if (CM.blockNeedsPredicationForAnyReason(RdxExitInstr->getParent()))
8847+
return Chains;
88458848

8846-
auto *Update = dyn_cast<BinaryOperator>(Rdx.getLoopExitInstr());
8849+
auto *Update = dyn_cast<BinaryOperator>(RdxExitInstr);
88478850
if (!Update)
8848-
return std::nullopt;
8851+
return Chains;
88498852

88508853
Value *Op = Update->getOperand(0);
88518854
if (Op == PHI)
88528855
Op = Update->getOperand(1);
88538856

8857+
if (auto *OpInst = dyn_cast<Instruction>(Op)) {
8858+
auto SR0 = getScaledReduction(PHI, OpInst, Range);
8859+
if(!SR0.empty()) {
8860+
Chains.append(SR0);
8861+
PHI = SR0.rbegin()->first.Reduction;
8862+
8863+
Op = Update->getOperand(0);
8864+
if (Op == PHI)
8865+
Op = Update->getOperand(1);
8866+
}
8867+
}
8868+
88548869
auto *BinOp = dyn_cast<BinaryOperator>(Op);
88558870
if (!BinOp || !BinOp->hasOneUse())
8856-
return std::nullopt;
8871+
return Chains;
88578872

88588873
using namespace llvm::PatternMatch;
88598874
Value *A, *B;
88608875
if (!match(BinOp->getOperand(0), m_ZExtOrSExt(m_Value(A))) ||
88618876
!match(BinOp->getOperand(1), m_ZExtOrSExt(m_Value(B))))
8862-
return std::nullopt;
8877+
return Chains;
88638878

88648879
Instruction *ExtA = cast<Instruction>(BinOp->getOperand(0));
88658880
Instruction *ExtB = cast<Instruction>(BinOp->getOperand(1));
88668881

88678882
// Check that the extends extend from the same type.
88688883
if (A->getType() != B->getType())
8869-
return std::nullopt;
8884+
return Chains;
88708885

88718886
TTI::PartialReductionExtendKind OpAExtend =
88728887
TargetTransformInfo::getPartialReductionExtendKind(ExtA);
88738888
TTI::PartialReductionExtendKind OpBExtend =
88748889
TargetTransformInfo::getPartialReductionExtendKind(ExtB);
88758890

8876-
PartialReductionChain Chain(Rdx.getLoopExitInstr(), ExtA, ExtB, BinOp);
8891+
PartialReductionChain Chain(RdxExitInstr, ExtA, ExtB, BinOp);
88778892

88788893
unsigned TargetScaleFactor =
88798894
PHI->getType()->getPrimitiveSizeInBits().getKnownScalarFactor(
@@ -8887,9 +8902,9 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
88878902
return Cost.isValid();
88888903
},
88898904
Range))
8890-
return std::make_pair(Chain, TargetScaleFactor);
8905+
Chains.push_back(std::make_pair(Chain, TargetScaleFactor));
88918906

8892-
return std::nullopt;
8907+
return Chains;
88938908
}
88948909

88958910
VPRecipeBase *
@@ -8986,7 +9001,8 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
89869001

89879002
VPValue *BinOp = Operands[0];
89889003
VPValue *Phi = Operands[1];
8989-
if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe()))
9004+
VPRecipeBase *BinOpRecipe = BinOp->getDefiningRecipe();
9005+
if (isa<VPReductionPHIRecipe>(BinOpRecipe) || isa<VPPartialReductionRecipe>(BinOpRecipe))
89909006
std::swap(BinOp, Phi);
89919007

89929008
return new VPPartialReductionRecipe(Reduction->getOpcode(), BinOp, Phi,

llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,8 @@ class VPRecipeBuilder {
144144
/// Returns null if no scaled reduction was found, otherwise a pair with a
145145
/// struct containing reduction information and the scaling factor between the
146146
/// number of elements in the input and output.
147-
std::optional<std::pair<PartialReductionChain, unsigned>>
148-
getScaledReduction(PHINode *PHI, const RecurrenceDescriptor &Rdx,
147+
SmallVector<std::pair<PartialReductionChain, unsigned>>
148+
getScaledReduction(Instruction *PHI, Instruction *RdxExitInstr,
149149
VFRange &Range);
150150

151151
public:

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2463,12 +2463,13 @@ class VPPartialReductionRecipe : public VPSingleDefRecipe {
24632463
: VPSingleDefRecipe(VPDef::VPPartialReductionSC,
24642464
ArrayRef<VPValue *>({Op0, Op1}), ReductionInst),
24652465
Opcode(Opcode) {
2466-
assert(isa<VPReductionPHIRecipe>(getOperand(1)->getDefiningRecipe()) &&
2466+
auto *DefiningRecipe = getOperand(1)->getDefiningRecipe();
2467+
assert((isa<VPReductionPHIRecipe>(DefiningRecipe) || isa<VPPartialReductionRecipe>(DefiningRecipe)) &&
24672468
"Unexpected operand order for partial reduction recipe");
24682469
}
24692470
~VPPartialReductionRecipe() override = default;
24702471
VPPartialReductionRecipe *clone() override {
2471-
return new VPPartialReductionRecipe(Opcode, getOperand(0), getOperand(1));
2472+
return new VPPartialReductionRecipe(Opcode, getOperand(0), getOperand(1), getUnderlyingInstr());
24722473
}
24732474

24742475
VP_CLASSOF_IMPL(VPDef::VPPartialReductionSC)

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -326,13 +326,21 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
326326
State.setDebugLocFrom(getDebugLoc());
327327
auto &Builder = State.Builder;
328328

329-
assert(getOpcode() == Instruction::Add &&
330-
"Unhandled partial reduction opcode");
331-
332329
Value *BinOpVal = State.get(getOperand(0));
333330
Value *PhiVal = State.get(getOperand(1));
334331
assert(PhiVal && BinOpVal && "Phi and Mul must be set");
335332

333+
auto Opcode = getOpcode();
334+
335+
// Currently we don't have a partial_reduce_sub intrinsic,
336+
// so mimic the behaviour by negating the second operand
337+
if(Opcode == Instruction::Sub) {
338+
BinOpVal = Builder.CreateSub(Constant::getNullValue(BinOpVal->getType()), BinOpVal);
339+
Opcode = Instruction::Add;
340+
}
341+
342+
assert(Opcode == Instruction::Add && "Unhandled partial reduction opcode");
343+
336344
Type *RetTy = PhiVal->getType();
337345

338346
CallInst *V = Builder.CreateIntrinsic(

0 commit comments

Comments
 (0)