Skip to content

[AArch64][LoopVectorize] Use upper bound trip count instead of the constant TC when choosing max VF #67697

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 24 additions & 23 deletions llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1662,14 +1662,14 @@ class LoopVectorizationCostModel {
/// elements is a power-of-2 larger than zero. If scalable vectorization is
/// disabled or unsupported, then the scalable part will be equal to
/// ElementCount::getScalable(0).
FixedScalableVFPair computeFeasibleMaxVF(unsigned ConstTripCount,
FixedScalableVFPair computeFeasibleMaxVF(unsigned MaxTripCount,
ElementCount UserVF,
bool FoldTailByMasking);

/// \return the maximized element count based on the targets vector
/// registers and the loop trip-count, but limited to a maximum safe VF.
/// This is a helper function of computeFeasibleMaxVF.
ElementCount getMaximizedVFForTarget(unsigned ConstTripCount,
ElementCount getMaximizedVFForTarget(unsigned MaxTripCount,
unsigned SmallestType,
unsigned WidestType,
ElementCount MaxSafeVF,
Expand Down Expand Up @@ -4811,7 +4811,7 @@ LoopVectorizationCostModel::getMaxLegalScalableVF(unsigned MaxSafeElements) {
}

FixedScalableVFPair LoopVectorizationCostModel::computeFeasibleMaxVF(
unsigned ConstTripCount, ElementCount UserVF, bool FoldTailByMasking) {
unsigned MaxTripCount, ElementCount UserVF, bool FoldTailByMasking) {
MinBWs = computeMinimumValueSizes(TheLoop->getBlocks(), *DB, &TTI);
unsigned SmallestType, WidestType;
std::tie(SmallestType, WidestType) = getSmallestAndWidestTypes();
Expand Down Expand Up @@ -4899,12 +4899,12 @@ FixedScalableVFPair LoopVectorizationCostModel::computeFeasibleMaxVF(
FixedScalableVFPair Result(ElementCount::getFixed(1),
ElementCount::getScalable(0));
if (auto MaxVF =
getMaximizedVFForTarget(ConstTripCount, SmallestType, WidestType,
getMaximizedVFForTarget(MaxTripCount, SmallestType, WidestType,
MaxSafeFixedVF, FoldTailByMasking))
Result.FixedVF = MaxVF;

if (auto MaxVF =
getMaximizedVFForTarget(ConstTripCount, SmallestType, WidestType,
getMaximizedVFForTarget(MaxTripCount, SmallestType, WidestType,
MaxSafeScalableVF, FoldTailByMasking))
if (MaxVF.isScalable()) {
Result.ScalableVF = MaxVF;
Expand All @@ -4928,6 +4928,7 @@ LoopVectorizationCostModel::computeMaxVF(ElementCount UserVF, unsigned UserIC) {
}

unsigned TC = PSE.getSE()->getSmallConstantTripCount(TheLoop);
unsigned MaxTC = PSE.getSE()->getSmallConstantMaxTripCount(TheLoop);
LLVM_DEBUG(dbgs() << "LV: Found trip count: " << TC << '\n');
if (TC == 1) {
reportVectorizationFailure("Single iteration (non) loop",
Expand All @@ -4938,7 +4939,7 @@ LoopVectorizationCostModel::computeMaxVF(ElementCount UserVF, unsigned UserIC) {

switch (ScalarEpilogueStatus) {
case CM_ScalarEpilogueAllowed:
return computeFeasibleMaxVF(TC, UserVF, false);
return computeFeasibleMaxVF(MaxTC, UserVF, false);
case CM_ScalarEpilogueNotAllowedUsePredicate:
[[fallthrough]];
case CM_ScalarEpilogueNotNeededUsePredicate:
Expand Down Expand Up @@ -4976,7 +4977,7 @@ LoopVectorizationCostModel::computeMaxVF(ElementCount UserVF, unsigned UserIC) {
LLVM_DEBUG(dbgs() << "LV: Cannot fold tail by masking: vectorize with a "
"scalar epilogue instead.\n");
ScalarEpilogueStatus = CM_ScalarEpilogueAllowed;
return computeFeasibleMaxVF(TC, UserVF, false);
return computeFeasibleMaxVF(MaxTC, UserVF, false);
}
return FixedScalableVFPair::getNone();
}
Expand All @@ -4993,7 +4994,7 @@ LoopVectorizationCostModel::computeMaxVF(ElementCount UserVF, unsigned UserIC) {
InterleaveInfo.invalidateGroupsRequiringScalarEpilogue();
}

FixedScalableVFPair MaxFactors = computeFeasibleMaxVF(TC, UserVF, true);
FixedScalableVFPair MaxFactors = computeFeasibleMaxVF(MaxTC, UserVF, true);

// Avoid tail folding if the trip count is known to be a multiple of any VF
// we choose.
Expand Down Expand Up @@ -5069,7 +5070,7 @@ LoopVectorizationCostModel::computeMaxVF(ElementCount UserVF, unsigned UserIC) {
}

ElementCount LoopVectorizationCostModel::getMaximizedVFForTarget(
unsigned ConstTripCount, unsigned SmallestType, unsigned WidestType,
unsigned MaxTripCount, unsigned SmallestType, unsigned WidestType,
ElementCount MaxSafeVF, bool FoldTailByMasking) {
bool ComputeScalableMaxVF = MaxSafeVF.isScalable();
const TypeSize WidestRegister = TTI.getRegisterBitWidth(
Expand Down Expand Up @@ -5108,24 +5109,24 @@ ElementCount LoopVectorizationCostModel::getMaximizedVFForTarget(
}

// When a scalar epilogue is required, at least one iteration of the scalar
// loop has to execute. Adjust ConstTripCount accordingly to avoid picking a
// loop has to execute. Adjust MaxTripCount accordingly to avoid picking a
// max VF that results in a dead vector loop.
if (ConstTripCount > 0 && requiresScalarEpilogue(true))
ConstTripCount -= 1;

if (ConstTripCount && ConstTripCount <= WidestRegisterMinEC &&
(!FoldTailByMasking || isPowerOf2_32(ConstTripCount))) {
// If loop trip count (TC) is known at compile time there is no point in
// choosing VF greater than TC (as done in the loop below). Select maximum
// power of two which doesn't exceed TC.
// If MaxVectorElementCount is scalable, we only fall back on a fixed VF
// when the TC is less than or equal to the known number of lanes.
auto ClampedConstTripCount = llvm::bit_floor(ConstTripCount);
if (MaxTripCount > 0 && requiresScalarEpilogue(true))
MaxTripCount -= 1;

if (MaxTripCount && MaxTripCount <= WidestRegisterMinEC &&
(!FoldTailByMasking || isPowerOf2_32(MaxTripCount))) {
// If upper bound loop trip count (TC) is known at compile time there is no
// point in choosing VF greater than TC (as done in the loop below). Select
// maximum power of two which doesn't exceed TC. If MaxVectorElementCount is
// scalable, we only fall back on a fixed VF when the TC is less than or
// equal to the known number of lanes.
auto ClampedUpperTripCount = llvm::bit_floor(MaxTripCount);
LLVM_DEBUG(dbgs() << "LV: Clamping the MaxVF to maximum power of two not "
"exceeding the constant trip count: "
<< ClampedConstTripCount << "\n");
<< ClampedUpperTripCount << "\n");
return ElementCount::get(
ClampedConstTripCount,
ClampedUpperTripCount,
FoldTailByMasking ? MaxVectorElementCount.isScalable() : false);
}

Expand Down
31 changes: 31 additions & 0 deletions llvm/test/Transforms/LoopVectorize/AArch64/clamped-trip-count.ll
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,34 @@ for.body: ; preds = %entry, %for.body
for.cond.cleanup: ; preds = %for.body
ret void
}

define void @clamped_tc_max_8(ptr nocapture %dst, i32 %n, i64 %val){
; CHECK-LABEL: define void @clamped_tc_max_8(
; CHECK: call void @llvm.masked.store.nxv8i8.p0(<vscale x 8 x i8> {{.*}}, ptr {{.*}}, i32 1, <vscale x 8 x i1> {{.*}})

entry:
%rem = and i32 %n, 63
%cmp8.not = icmp eq i32 %rem, 0
br i1 %cmp8.not, label %for.cond.cleanup, label %for.body.preheader

for.body.preheader: ; preds = %entry
%add = add nuw nsw i32 %rem, 7
%shr = lshr i32 %add, 3
%wide.trip.count = zext i32 %shr to i64
br label %for.body

for.body: ; preds = %for.body.preheader, %for.body
%indvars.iv = phi i64 [ 0, %for.body.preheader ], [ %indvars.iv.next, %for.body ]
%p_out_tail.09 = phi ptr [ %dst, %for.body.preheader ], [ %incdec.ptr, %for.body ]
%0 = shl nuw nsw i64 %indvars.iv, 3
%shr3 = lshr i64 %val, %0
%conv4 = trunc i64 %shr3 to i8
store i8 %conv4, ptr %p_out_tail.09, align 1
%incdec.ptr = getelementptr inbounds i8, ptr %p_out_tail.09, i64 1
%indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
%exitcond.not = icmp eq i64 %indvars.iv.next, %wide.trip.count
br i1 %exitcond.not, label %for.cond.cleanup, label %for.body

for.cond.cleanup: ; preds = %for.body
ret void
}