Skip to content

Commit 34d5f25

Browse files
committed
[LoopVectorizer] Add support for chaining partial reductions
1 parent e33f456 commit 34d5f25

File tree

6 files changed

+623
-25
lines changed

6 files changed

+623
-25
lines changed

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
368368
InstructionCost Invalid = InstructionCost::getInvalid();
369369
InstructionCost Cost(TTI::TCC_Basic);
370370

371-
if (Opcode != Instruction::Add)
371+
if (Opcode != Instruction::Add && Opcode != Instruction::Sub)
372372
return Invalid;
373373

374374
if (InputTypeA != InputTypeB)

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8790,12 +8790,12 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
87908790
/// are valid so recipes can be formed later.
87918791
void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
87928792
// Find all possible partial reductions.
8793-
SmallVector<std::pair<PartialReductionChain, unsigned>, 1>
8793+
SmallVector<std::pair<PartialReductionChain, unsigned>>
87948794
PartialReductionChains;
8795-
for (const auto &[Phi, RdxDesc] : Legal->getReductionVars())
8796-
if (std::optional<std::pair<PartialReductionChain, unsigned>> Pair =
8797-
getScaledReduction(Phi, RdxDesc, Range))
8798-
PartialReductionChains.push_back(*Pair);
8795+
for (const auto &[Phi, RdxDesc] : Legal->getReductionVars()) {
8796+
if (auto SR = getScaledReduction(Phi, RdxDesc.getLoopExitInstr(), Range))
8797+
PartialReductionChains.append(*SR);
8798+
}
87998799

88008800
// A partial reduction is invalid if any of its extends are used by
88018801
// something that isn't another partial reduction. This is because the
@@ -8823,26 +8823,42 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
88238823
}
88248824
}
88258825

8826-
std::optional<std::pair<PartialReductionChain, unsigned>>
8827-
VPRecipeBuilder::getScaledReduction(PHINode *PHI,
8828-
const RecurrenceDescriptor &Rdx,
8826+
std::optional<SmallVector<std::pair<PartialReductionChain, unsigned>>>
8827+
VPRecipeBuilder::getScaledReduction(Instruction *PHI,
8828+
Instruction *RdxExitInstr,
88298829
VFRange &Range) {
8830+
8831+
if(!CM.TheLoop->contains(RdxExitInstr))
8832+
return std::nullopt;
8833+
88308834
// TODO: Allow scaling reductions when predicating. The select at
88318835
// the end of the loop chooses between the phi value and most recent
88328836
// reduction result, both of which have different VFs to the active lane
88338837
// mask when scaling.
8834-
if (CM.blockNeedsPredicationForAnyReason(Rdx.getLoopExitInstr()->getParent()))
8838+
if (CM.blockNeedsPredicationForAnyReason(RdxExitInstr->getParent()))
88358839
return std::nullopt;
88368840

8837-
auto *Update = dyn_cast<BinaryOperator>(Rdx.getLoopExitInstr());
8841+
auto *Update = dyn_cast<BinaryOperator>(RdxExitInstr);
88388842
if (!Update)
88398843
return std::nullopt;
88408844

88418845
Value *Op = Update->getOperand(0);
88428846
Value *PhiOp = Update->getOperand(1);
8843-
if (Op == PHI) {
8844-
Op = Update->getOperand(1);
8845-
PhiOp = Update->getOperand(0);
8847+
if (Op == PHI)
8848+
std::swap(Op, PhiOp);
8849+
8850+
SmallVector<std::pair<PartialReductionChain, unsigned>> Chains;
8851+
8852+
if (auto *OpInst = dyn_cast<Instruction>(Op)) {
8853+
if(auto SR0 = getScaledReduction(PHI, OpInst, Range)) {
8854+
Chains.append(*SR0);
8855+
PHI = SR0->rbegin()->first.Reduction;
8856+
8857+
Op = Update->getOperand(0);
8858+
PhiOp = Update->getOperand(1);
8859+
if (Op == PHI)
8860+
std::swap(Op, PhiOp);
8861+
}
88468862
}
88478863
if (PhiOp != PHI)
88488864
return std::nullopt;
@@ -8860,12 +8876,16 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
88608876
Instruction *ExtA = cast<Instruction>(BinOp->getOperand(0));
88618877
Instruction *ExtB = cast<Instruction>(BinOp->getOperand(1));
88628878

8879+
// Check that the extends extend from the same type.
8880+
if (A->getType() != B->getType())
8881+
return std::nullopt;
8882+
88638883
TTI::PartialReductionExtendKind OpAExtend =
88648884
TargetTransformInfo::getPartialReductionExtendKind(ExtA);
88658885
TTI::PartialReductionExtendKind OpBExtend =
88668886
TargetTransformInfo::getPartialReductionExtendKind(ExtB);
88678887

8868-
PartialReductionChain Chain(Rdx.getLoopExitInstr(), ExtA, ExtB, BinOp);
8888+
PartialReductionChain Chain(RdxExitInstr, ExtA, ExtB, BinOp);
88698889

88708890
unsigned TargetScaleFactor =
88718891
PHI->getType()->getPrimitiveSizeInBits().getKnownScalarFactor(
@@ -8880,9 +8900,9 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
88808900
return Cost.isValid();
88818901
},
88828902
Range))
8883-
return std::make_pair(Chain, TargetScaleFactor);
8903+
Chains.push_back(std::make_pair(Chain, TargetScaleFactor));
88848904

8885-
return std::nullopt;
8905+
return Chains;
88868906
}
88878907

88888908
VPRecipeBase *
@@ -8979,7 +8999,8 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
89798999

89809000
VPValue *BinOp = Operands[0];
89819001
VPValue *Phi = Operands[1];
8982-
if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe()))
9002+
VPRecipeBase *BinOpRecipe = BinOp->getDefiningRecipe();
9003+
if (isa<VPReductionPHIRecipe>(BinOpRecipe) || isa<VPPartialReductionRecipe>(BinOpRecipe))
89839004
std::swap(BinOp, Phi);
89849005

89859006
return new VPPartialReductionRecipe(Reduction->getOpcode(), BinOp, Phi,

llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,8 @@ class VPRecipeBuilder {
144144
/// Returns null if no scaled reduction was found, otherwise a pair with a
145145
/// struct containing reduction information and the scaling factor between the
146146
/// number of elements in the input and output.
147-
std::optional<std::pair<PartialReductionChain, unsigned>>
148-
getScaledReduction(PHINode *PHI, const RecurrenceDescriptor &Rdx,
147+
std::optional<SmallVector<std::pair<PartialReductionChain, unsigned>>>
148+
getScaledReduction(Instruction *PHI, Instruction *RdxExitInstr,
149149
VFRange &Range);
150150

151151
public:

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2453,13 +2453,14 @@ class VPPartialReductionRecipe : public VPSingleDefRecipe {
24532453
: VPSingleDefRecipe(VPDef::VPPartialReductionSC,
24542454
ArrayRef<VPValue *>({Op0, Op1}), ReductionInst),
24552455
Opcode(Opcode) {
2456-
assert(isa<VPReductionPHIRecipe>(getOperand(1)->getDefiningRecipe()) &&
2456+
auto *DefiningRecipe = getOperand(1)->getDefiningRecipe();
2457+
assert((isa<VPReductionPHIRecipe>(DefiningRecipe) || isa<VPPartialReductionRecipe>(DefiningRecipe)) &&
24572458
"Unexpected operand order for partial reduction recipe");
24582459
}
24592460
~VPPartialReductionRecipe() override = default;
24602461

24612462
VPPartialReductionRecipe *clone() override {
2462-
return new VPPartialReductionRecipe(Opcode, getOperand(0), getOperand(1));
2463+
return new VPPartialReductionRecipe(Opcode, getOperand(0), getOperand(1), getUnderlyingInstr());
24632464
}
24642465

24652466
VP_CLASSOF_IMPL(VPDef::VPPartialReductionSC)

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -317,13 +317,21 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
317317
State.setDebugLocFrom(getDebugLoc());
318318
auto &Builder = State.Builder;
319319

320-
assert(getOpcode() == Instruction::Add &&
321-
"Unhandled partial reduction opcode");
322-
323320
Value *BinOpVal = State.get(getOperand(0));
324321
Value *PhiVal = State.get(getOperand(1));
325322
assert(PhiVal && BinOpVal && "Phi and Mul must be set");
326323

324+
auto Opcode = getOpcode();
325+
326+
// Currently we don't have a partial_reduce_sub intrinsic,
327+
// so mimic the behaviour by negating the second operand
328+
if(Opcode == Instruction::Sub) {
329+
BinOpVal = Builder.CreateSub(Constant::getNullValue(BinOpVal->getType()), BinOpVal);
330+
Opcode = Instruction::Add;
331+
}
332+
333+
assert(Opcode == Instruction::Add && "Unhandled partial reduction opcode");
334+
327335
Type *RetTy = PhiVal->getType();
328336

329337
CallInst *V = Builder.CreateIntrinsic(

0 commit comments

Comments
 (0)