Skip to content

Commit 26df370

Browse files
committed
[LoopVectorizer][AArch64] Add support for partial reduce subtraction
Instead of implementing a new intrinsic for subtracting partial reductions, generate a negation instruction for the second operand of the partial reduction.
1 parent 8353aa2 commit 26df370

File tree

6 files changed

+198
-61
lines changed

6 files changed

+198
-61
lines changed

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4673,7 +4673,7 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
46734673
InstructionCost Invalid = InstructionCost::getInvalid();
46744674
InstructionCost Cost(TTI::TCC_Basic);
46754675

4676-
if (Opcode != Instruction::Add)
4676+
if (Opcode != Instruction::Add && Opcode != Instruction::Sub)
46774677
return Invalid;
46784678

46794679
if (InputTypeA != InputTypeB)

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8697,8 +8697,9 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
86978697

86988698
// Build up a set of partial reduction bin ops for efficient use checking.
86998699
SmallSet<User *, 4> PartialReductionBinOps;
8700-
for (const auto &[PartialRdx, _] : PartialReductionChains)
8700+
for (const auto &[PartialRdx, _] : PartialReductionChains) {
87018701
PartialReductionBinOps.insert(PartialRdx.BinOp);
8702+
}
87028703

87038704
auto ExtendIsOnlyUsedByPartialReductions =
87048705
[&PartialReductionBinOps](Instruction *Extend) {
@@ -8761,20 +8762,23 @@ bool VPRecipeBuilder::getScaledReductions(
87618762
return false;
87628763

87638764
using namespace llvm::PatternMatch;
8765+
BinaryOperator *ExtendedBinOp = BinOp;
8766+
match(BinOp, m_Neg(m_BinOp(ExtendedBinOp)));
8767+
87648768
Value *A, *B;
8765-
if (!match(BinOp->getOperand(0), m_ZExtOrSExt(m_Value(A))) ||
8766-
!match(BinOp->getOperand(1), m_ZExtOrSExt(m_Value(B))))
8769+
if (!match(ExtendedBinOp->getOperand(0), m_ZExtOrSExt(m_Value(A))) ||
8770+
!match(ExtendedBinOp->getOperand(1), m_ZExtOrSExt(m_Value(B))))
87678771
return false;
87688772

8769-
Instruction *ExtA = cast<Instruction>(BinOp->getOperand(0));
8770-
Instruction *ExtB = cast<Instruction>(BinOp->getOperand(1));
8773+
Instruction *ExtA = cast<Instruction>(ExtendedBinOp->getOperand(0));
8774+
Instruction *ExtB = cast<Instruction>(ExtendedBinOp->getOperand(1));
87718775

87728776
TTI::PartialReductionExtendKind OpAExtend =
87738777
TargetTransformInfo::getPartialReductionExtendKind(ExtA);
87748778
TTI::PartialReductionExtendKind OpBExtend =
87758779
TargetTransformInfo::getPartialReductionExtendKind(ExtB);
87768780

8777-
PartialReductionChain Chain(RdxExitInstr, ExtA, ExtB, BinOp);
8781+
PartialReductionChain Chain(RdxExitInstr, ExtA, ExtB, ExtendedBinOp);
87788782

87798783
unsigned TargetScaleFactor =
87808784
PHI->getType()->getPrimitiveSizeInBits().getKnownScalarFactor(
@@ -8785,7 +8789,7 @@ bool VPRecipeBuilder::getScaledReductions(
87858789
InstructionCost Cost = TTI->getPartialReductionCost(
87868790
Update->getOpcode(), A->getType(), B->getType(), PHI->getType(),
87878791
VF, OpAExtend, OpBExtend,
8788-
std::make_optional(BinOp->getOpcode()));
8792+
std::make_optional(ExtendedBinOp->getOpcode()));
87898793
return Cost.isValid();
87908794
},
87918795
Range)) {

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "llvm/IR/Instruction.h"
2626
#include "llvm/IR/Instructions.h"
2727
#include "llvm/IR/Intrinsics.h"
28+
#include "llvm/IR/PatternMatch.h"
2829
#include "llvm/IR/Type.h"
2930
#include "llvm/IR/Value.h"
3031
#include "llvm/IR/VectorBuilder.h"
@@ -282,7 +283,20 @@ InstructionCost
282283
VPPartialReductionRecipe::computeCost(ElementCount VF,
283284
VPCostContext &Ctx) const {
284285
std::optional<unsigned> Opcode = std::nullopt;
285-
VPRecipeBase *BinOpR = getOperand(0)->getDefiningRecipe();
286+
VPValue *BinOp = getOperand(0);
287+
VPRecipeBase *BinOpR = BinOp->getDefiningRecipe();
288+
289+
using namespace llvm::PatternMatch;
290+
if (auto *UnderInst =
291+
dyn_cast_if_present<Instruction>(BinOp->getUnderlyingValue())) {
292+
if (match(UnderInst, m_Neg(m_BinOp()))) {
293+
BinOpR = BinOpR->getOperand(1)->getDefiningRecipe();
294+
}
295+
}
296+
// BinOp is never used again, any further interaction should be via the
297+
// defining recipe `BinOpR`
298+
BinOp = nullptr;
299+
286300
if (auto *WidenR = dyn_cast<VPWidenRecipe>(BinOpR))
287301
Opcode = std::make_optional(WidenR->getOpcode());
288302

@@ -318,13 +332,20 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
318332
State.setDebugLocFrom(getDebugLoc());
319333
auto &Builder = State.Builder;
320334

321-
assert(getOpcode() == Instruction::Add &&
322-
"Unhandled partial reduction opcode");
323-
324335
Value *BinOpVal = State.get(getOperand(0));
325336
Value *PhiVal = State.get(getOperand(1));
326337
assert(PhiVal && BinOpVal && "Phi and Mul must be set");
327338

339+
unsigned Opcode = getOpcode();
340+
341+
if (Opcode == Instruction::Sub) {
342+
bool HasNSW = cast<Instruction>(BinOpVal)->hasNoSignedWrap();
343+
BinOpVal = Builder.CreateNeg(BinOpVal, "", HasNSW);
344+
Opcode = Instruction::Add;
345+
}
346+
347+
assert(Opcode == Instruction::Add && "Unhandled partial reduction opcode");
348+
328349
Type *RetTy = PhiVal->getType();
329350

330351
CallInst *V = Builder.CreateIntrinsic(

0 commit comments

Comments
 (0)