Skip to content

Commit 9c89faa

Browse files
authored
[LoopVectorizer][AArch64] Add support for partial reduce subtraction (#123636)
1 parent 499d6da commit 9c89faa

File tree

5 files changed

+257
-57
lines changed

5 files changed

+257
-57
lines changed

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4683,7 +4683,9 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
46834683
InstructionCost Invalid = InstructionCost::getInvalid();
46844684
InstructionCost Cost(TTI::TCC_Basic);
46854685

4686-
if (Opcode != Instruction::Add)
4686+
// Sub opcodes currently only occur in chained cases.
4687+
// Independent partial reduction subtractions are still costed as an add
4688+
if (Opcode != Instruction::Add && Opcode != Instruction::Sub)
46874689
return Invalid;
46884690

46894691
if (InputTypeA != InputTypeB)

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8804,6 +8804,10 @@ bool VPRecipeBuilder::getScaledReductions(
88048804
return false;
88058805

88068806
using namespace llvm::PatternMatch;
8807+
// Use the side-effect of match to replace BinOp only if the pattern is
8808+
// matched, we don't care at this point whether it actually matched.
8809+
match(BinOp, m_Neg(m_BinOp(BinOp)));
8810+
88078811
Value *A, *B;
88088812
if (!match(BinOp->getOperand(0), m_ZExtOrSExt(m_Value(A))) ||
88098813
!match(BinOp->getOperand(1), m_ZExtOrSExt(m_Value(B))))
@@ -8936,6 +8940,19 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
89368940
std::swap(BinOp, Accumulator);
89378941

89388942
unsigned ReductionOpcode = Reduction->getOpcode();
8943+
if (ReductionOpcode == Instruction::Sub) {
8944+
VPBasicBlock *ParentBlock = Builder.getInsertBlock();
8945+
assert(ParentBlock && "Builder must have an insert block.");
8946+
8947+
auto *const Zero = ConstantInt::get(Reduction->getType(), 0);
8948+
SmallVector<VPValue *, 2> Ops;
8949+
Ops.push_back(Plan.getOrAddLiveIn(Zero));
8950+
Ops.push_back(BinOp);
8951+
BinOp = new VPWidenRecipe(*Reduction, make_range(Ops.begin(), Ops.end()));
8952+
ParentBlock->appendRecipe(BinOp->getDefiningRecipe());
8953+
ReductionOpcode = Instruction::Add;
8954+
}
8955+
89398956
if (CM.blockNeedsPredicationForAnyReason(Reduction->getParent())) {
89408957
assert((ReductionOpcode == Instruction::Add ||
89418958
ReductionOpcode == Instruction::Sub) &&

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "llvm/IR/Instruction.h"
2828
#include "llvm/IR/Instructions.h"
2929
#include "llvm/IR/Intrinsics.h"
30+
#include "llvm/IR/PatternMatch.h"
3031
#include "llvm/IR/Type.h"
3132
#include "llvm/IR/Value.h"
3233
#include "llvm/IR/VectorBuilder.h"
@@ -284,13 +285,18 @@ InstructionCost
284285
VPPartialReductionRecipe::computeCost(ElementCount VF,
285286
VPCostContext &Ctx) const {
286287
std::optional<unsigned> Opcode = std::nullopt;
287-
VPRecipeBase *BinOpR = getOperand(0)->getDefiningRecipe();
288+
VPValue *BinOp = getOperand(0);
288289

289290
// If the partial reduction is predicated, a select will be operand 0 rather
290291
// than the binary op
291292
using namespace llvm::VPlanPatternMatch;
292293
if (match(getOperand(0), m_Select(m_VPValue(), m_VPValue(), m_VPValue())))
293-
BinOpR = BinOpR->getOperand(1)->getDefiningRecipe();
294+
BinOp = BinOp->getDefiningRecipe()->getOperand(1);
295+
296+
// If BinOp is a negation, use the side effect of match to assign the actual
297+
// binary operation to BinOp
298+
match(BinOp, m_Binary<Instruction::Sub>(m_SpecificInt(0), m_VPValue(BinOp)));
299+
VPRecipeBase *BinOpR = BinOp->getDefiningRecipe();
294300

295301
if (auto *WidenR = dyn_cast<VPWidenRecipe>(BinOpR))
296302
Opcode = std::make_optional(WidenR->getOpcode());

0 commit comments

Comments
 (0)