Skip to content

[SLP]Remove ExtraArgs from reductions. #99923

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 29 additions & 81 deletions llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16644,8 +16644,6 @@ class HorizontalReduction {
SmallVector<SmallVector<Value *>> ReducedVals;
/// Maps reduced value to the corresponding reduction operation.
DenseMap<Value *, SmallVector<Instruction *>> ReducedValsToOps;
// Use map vector to make stable output.
MapVector<Instruction *, Value *> ExtraArgs;
WeakTrackingVH ReductionRoot;
/// The type of reduction operation.
RecurKind RdxKind;
Expand Down Expand Up @@ -16978,30 +16976,26 @@ class HorizontalReduction {
// gather all the reduced values, sorting them by their value id.
BasicBlock *BB = Root->getParent();
bool IsCmpSelMinMax = isCmpSelMinMax(Root);
SmallVector<Instruction *> Worklist(1, Root);
SmallVector<std::pair<Instruction *, unsigned>> Worklist(
1, std::make_pair(Root, 0));
// Checks if the operands of the \p TreeN instruction are also reduction
// operations or should be treated as reduced values or an extra argument,
// which is not part of the reduction.
auto CheckOperands = [&](Instruction *TreeN,
SmallVectorImpl<Value *> &ExtraArgs,
SmallVectorImpl<Value *> &PossibleReducedVals,
SmallVectorImpl<Instruction *> &ReductionOps) {
SmallVectorImpl<Instruction *> &ReductionOps,
unsigned Level) {
for (int I : reverse(seq<int>(getFirstOperandIndex(TreeN),
getNumberOfOperands(TreeN)))) {
Value *EdgeVal = getRdxOperand(TreeN, I);
ReducedValsToOps[EdgeVal].push_back(TreeN);
auto *EdgeInst = dyn_cast<Instruction>(EdgeVal);
// Edge has wrong parent - mark as an extra argument.
if (EdgeInst && !isVectorLikeInstWithConstOps(EdgeInst) &&
!hasSameParent(EdgeInst, BB)) {
ExtraArgs.push_back(EdgeVal);
continue;
}
// If the edge is not an instruction, or it is different from the main
// reduction opcode or has too many uses - possible reduced value.
// Also, do not try to reduce const values, if the operation is not
// foldable.
if (!EdgeInst || getRdxKind(EdgeInst) != RdxKind ||
if (!EdgeInst || Level > RecursionMaxDepth ||
getRdxKind(EdgeInst) != RdxKind ||
IsCmpSelMinMax != isCmpSelMinMax(EdgeInst) ||
!hasRequiredNumberOfUses(IsCmpSelMinMax, EdgeInst) ||
!isVectorizable(RdxKind, EdgeInst) ||
Expand All @@ -17025,6 +17019,7 @@ class HorizontalReduction {
SmallSet<size_t, 2> LoadKeyUsed;

auto GenerateLoadsSubkey = [&](size_t Key, LoadInst *LI) {
Key = hash_combine(hash_value(LI->getParent()), Key);
Value *Ptr = getUnderlyingObject(LI->getPointerOperand());
if (LoadKeyUsed.contains(Key)) {
auto LIt = LoadsMap.find(Ptr);
Expand Down Expand Up @@ -17055,40 +17050,23 @@ class HorizontalReduction {
};

while (!Worklist.empty()) {
Instruction *TreeN = Worklist.pop_back_val();
SmallVector<Value *> Args;
auto [TreeN, Level] = Worklist.pop_back_val();
SmallVector<Value *> PossibleRedVals;
SmallVector<Instruction *> PossibleReductionOps;
CheckOperands(TreeN, Args, PossibleRedVals, PossibleReductionOps);
// If too many extra args - mark the instruction itself as a reduction
// value, not a reduction operation.
if (Args.size() < 2) {
addReductionOps(TreeN);
// Add extra args.
if (!Args.empty()) {
assert(Args.size() == 1 && "Expected only single argument.");
ExtraArgs[TreeN] = Args.front();
}
// Add reduction values. The values are sorted for better vectorization
// results.
for (Value *V : PossibleRedVals) {
size_t Key, Idx;
std::tie(Key, Idx) = generateKeySubkey(V, &TLI, GenerateLoadsSubkey,
/*AllowAlternate=*/false);
++PossibleReducedVals[Key][Idx]
.insert(std::make_pair(V, 0))
.first->second;
}
Worklist.append(PossibleReductionOps.rbegin(),
PossibleReductionOps.rend());
} else {
CheckOperands(TreeN, PossibleRedVals, PossibleReductionOps, Level);
addReductionOps(TreeN);
// Add reduction values. The values are sorted for better vectorization
// results.
for (Value *V : PossibleRedVals) {
size_t Key, Idx;
std::tie(Key, Idx) = generateKeySubkey(TreeN, &TLI, GenerateLoadsSubkey,
std::tie(Key, Idx) = generateKeySubkey(V, &TLI, GenerateLoadsSubkey,
/*AllowAlternate=*/false);
++PossibleReducedVals[Key][Idx]
.insert(std::make_pair(TreeN, 0))
.insert(std::make_pair(V, 0))
.first->second;
}
for (Instruction *I : reverse(PossibleReductionOps))
Worklist.emplace_back(I, I->getParent() == BB ? 0 : Level + 1);
}
auto PossibleReducedValsVect = PossibleReducedVals.takeVector();
// Sort values by the total number of values kinds to start the reduction
Expand Down Expand Up @@ -17165,18 +17143,9 @@ class HorizontalReduction {

// Track the reduced values in case if they are replaced by extractelement
// because of the vectorization.
DenseMap<Value *, WeakTrackingVH> TrackedVals(
ReducedVals.size() * ReducedVals.front().size() + ExtraArgs.size());
BoUpSLP::ExtraValueToDebugLocsMap ExternallyUsedValues;
DenseMap<Value *, WeakTrackingVH> TrackedVals(ReducedVals.size() *
ReducedVals.front().size());
SmallVector<std::pair<Value *, Value *>> ReplacedExternals;
ExternallyUsedValues.reserve(ExtraArgs.size() + 1);
// The same extra argument may be used several times, so log each attempt
// to use it.
for (const std::pair<Instruction *, Value *> &Pair : ExtraArgs) {
assert(Pair.first && "DebugLoc must be set.");
ExternallyUsedValues[Pair.second].push_back(Pair.first);
TrackedVals.try_emplace(Pair.second, Pair.second);
}

// The compare instruction of a min/max is the insertion point for new
// instructions and may be replaced with a new compare instruction.
Expand Down Expand Up @@ -17211,13 +17180,9 @@ class HorizontalReduction {
// Initialize the final value in the reduction.
return Res;
};
bool AnyBoolLogicOp =
any_of(ReductionOps.back(), [](Value *V) {
return isBoolLogicOp(cast<Instruction>(V));
});
// The reduction root is used as the insertion point for new instructions,
// so set it as externally used to prevent it from being deleted.
ExternallyUsedValues[ReductionRoot];
bool AnyBoolLogicOp = any_of(ReductionOps.back(), [](Value *V) {
return isBoolLogicOp(cast<Instruction>(V));
});
SmallDenseSet<Value *> IgnoreList(ReductionOps.size() *
ReductionOps.front().size());
for (ReductionOpsType &RdxOps : ReductionOps)
Expand Down Expand Up @@ -17439,8 +17404,11 @@ class HorizontalReduction {
V.reorderBottomToTop(/*IgnoreReorder=*/true);
// Keep extracted other reduction values, if they are used in the
// vectorization trees.
BoUpSLP::ExtraValueToDebugLocsMap LocalExternallyUsedValues(
ExternallyUsedValues);
BoUpSLP::ExtraValueToDebugLocsMap LocalExternallyUsedValues;
// The reduction root is used as the insertion point for new
// instructions, so set it as externally used to prevent it from being
// deleted.
LocalExternallyUsedValues[ReductionRoot];
for (unsigned Cnt = 0, Sz = ReducedVals.size(); Cnt < Sz; ++Cnt) {
if (Cnt == I || (ShuffledExtracts && Cnt == I - 1))
continue;
Expand Down Expand Up @@ -17487,23 +17455,6 @@ class HorizontalReduction {
for (Value *RdxVal : VL)
if (RequiredExtract.contains(RdxVal))
LocalExternallyUsedValues[RdxVal];
// Update LocalExternallyUsedValues for the scalar, replaced by
// extractelement instructions.
DenseMap<Value *, Value *> ReplacementToExternal;
for (const std::pair<Value *, Value *> &Pair : ReplacedExternals)
ReplacementToExternal.try_emplace(Pair.second, Pair.first);
for (const std::pair<Value *, Value *> &Pair : ReplacedExternals) {
Value *Ext = Pair.first;
auto RIt = ReplacementToExternal.find(Ext);
while (RIt != ReplacementToExternal.end()) {
Ext = RIt->second;
RIt = ReplacementToExternal.find(Ext);
}
auto *It = ExternallyUsedValues.find(Ext);
if (It == ExternallyUsedValues.end())
continue;
LocalExternallyUsedValues[Pair.second].append(It->second);
}
V.buildExternalUses(LocalExternallyUsedValues);

V.computeMinimumValueSizes();
Expand Down Expand Up @@ -17705,11 +17656,6 @@ class HorizontalReduction {
ExtraReductions.emplace_back(RedOp, RdxVal);
}
}
for (auto &Pair : ExternallyUsedValues) {
// Add each externally used value to the final reduction.
for (auto *I : Pair.second)
ExtraReductions.emplace_back(I, Pair.first);
}
// Iterate through all not-vectorized reduction values/extra arguments.
bool InitStep = true;
while (ExtraReductions.size() > 1) {
Expand Down Expand Up @@ -17861,6 +17807,8 @@ class HorizontalReduction {
assert(IsSupportedHorRdxIdentityOp &&
"The optimization of matched scalar identity horizontal reductions "
"must be supported.");
if (Cnt == 1)
return VectorizedValue;
switch (RdxKind) {
case RecurKind::Add: {
// res = mul vv, n
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ define void @test() {
; CHECK-NEXT: [[PHI1:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[OP_RDX25:%.*]], [[LOOP]] ]
; CHECK-NEXT: [[TMP6:%.*]] = phi <8 x i64> [ [[TMP0]], [[ENTRY]] ], [ [[TMP1]], [[LOOP]] ]
; CHECK-NEXT: [[TMP7:%.*]] = mul <8 x i64> [[TMP6]], <i64 4, i64 4, i64 4, i64 4, i64 4, i64 4, i64 4, i64 4>
; CHECK-NEXT: [[TMP5:%.*]] = mul <8 x i64> [[TMP1]], <i64 2, i64 2, i64 2, i64 2, i64 2, i64 2, i64 2, i64 2>
; CHECK-NEXT: [[TMP9:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP7]])
; CHECK-NEXT: [[TMP8:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP1]])
; CHECK-NEXT: [[TMP10:%.*]] = mul i64 [[TMP8]], 2
; CHECK-NEXT: [[OP_RDX33:%.*]] = add i64 [[TMP10]], [[TMP9]]
; CHECK-NEXT: [[OP_RDX25]] = add i64 [[OP_RDX33]], [[TMP3]]
; CHECK-NEXT: [[TMP8:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP5]])
; CHECK-NEXT: [[OP_RDX16:%.*]] = add i64 [[TMP9]], [[TMP8]]
; CHECK-NEXT: [[OP_RDX25]] = add i64 [[OP_RDX16]], [[TMP3]]
; CHECK-NEXT: br label [[LOOP]]
;
entry:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,24 @@
define i32 @foo(i32 %a) {
; CHECK-LABEL: @foo(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[TMP0:%.*]] = insertelement <2 x i32> <i32 poison, i32 0>, i32 [[A:%.*]], i32 0
; CHECK-NEXT: [[TMP1:%.*]] = sub nsw <2 x i32> zeroinitializer, [[TMP0]]
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <2 x i32> [[TMP1]], <2 x i32> poison, <4 x i32> <i32 0, i32 1, i32 1, i32 1>
; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x i32> [[TMP2]], i32 1
; CHECK-NEXT: [[TMP0:%.*]] = sub nsw i32 0, [[A:%.*]]
; CHECK-NEXT: [[LOCAL:%.*]] = sub nsw i32 0, 0
; CHECK-NEXT: br i1 false, label [[BB5:%.*]], label [[BB1:%.*]]
; CHECK: bb1:
; CHECK-NEXT: [[TMP4:%.*]] = mul <2 x i32> [[TMP1]], <i32 1, i32 3>
; CHECK-NEXT: [[TMP5:%.*]] = extractelement <2 x i32> [[TMP4]], i32 0
; CHECK-NEXT: [[TMP6:%.*]] = extractelement <2 x i32> [[TMP4]], i32 1
; CHECK-NEXT: [[OP_RDX10:%.*]] = add i32 [[TMP6]], [[TMP5]]
; CHECK-NEXT: [[OP_RDX11:%.*]] = add i32 [[OP_RDX10]], 0
; CHECK-NEXT: [[TMP1:%.*]] = mul i32 [[LOCAL]], 3
; CHECK-NEXT: [[OP_RDX2:%.*]] = add i32 [[TMP1]], [[TMP0]]
; CHECK-NEXT: [[OP_RDX3:%.*]] = add i32 [[OP_RDX2]], 0
; CHECK-NEXT: br label [[BB3:%.*]]
; CHECK: bb2:
; CHECK-NEXT: br label [[BB3]]
; CHECK: bb3:
; CHECK-NEXT: [[P1:%.*]] = phi i32 [ [[OP_RDX11]], [[BB1]] ], [ 0, [[BB2:%.*]] ]
; CHECK-NEXT: [[P1:%.*]] = phi i32 [ [[OP_RDX3]], [[BB1]] ], [ 0, [[BB2:%.*]] ]
; CHECK-NEXT: ret i32 0
; CHECK: bb4:
; CHECK-NEXT: [[TMP7:%.*]] = shufflevector <2 x i32> [[TMP1]], <2 x i32> poison, <4 x i32> <i32 1, i32 1, i32 1, i32 1>
; CHECK-NEXT: [[TMP8:%.*]] = add <4 x i32> [[TMP7]], [[TMP2]]
; CHECK-NEXT: [[TMP9:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP8]])
; CHECK-NEXT: [[OP_RDX8:%.*]] = add i32 [[TMP9]], 0
; CHECK-NEXT: [[OP_RDX9:%.*]] = add i32 [[OP_RDX8]], [[TMP3]]
; CHECK-NEXT: ret i32 [[OP_RDX9]]
; CHECK-NEXT: [[TMP2:%.*]] = mul i32 [[LOCAL]], 8
; CHECK-NEXT: [[OP_RDX:%.*]] = add i32 [[TMP2]], [[TMP0]]
; CHECK-NEXT: [[OP_RDX1:%.*]] = add i32 [[OP_RDX]], 0
; CHECK-NEXT: ret i32 [[OP_RDX1]]
; CHECK: bb5:
; CHECK-NEXT: br label [[BB4:%.*]]
;
Expand Down
Loading
Loading