Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 192 additions & 0 deletions llvm/lib/Transforms/Vectorize/VectorCombine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ class VectorCombine {
bool foldShuffleOfCastops(Instruction &I);
bool foldShuffleOfShuffles(Instruction &I);
bool foldPermuteOfIntrinsic(Instruction &I);
bool foldShufflesOfLengthChangingShuffles(Instruction &I);
bool foldShuffleOfIntrinsics(Instruction &I);
bool foldShuffleToIdentity(Instruction &I);
bool foldShuffleFromReductions(Instruction &I);
Expand Down Expand Up @@ -2878,6 +2879,195 @@ bool VectorCombine::foldShuffleOfShuffles(Instruction &I) {
return true;
}

/// Try to convert a chain of length-preserving shuffles that are fed by
/// length-changing shuffles from the same source, e.g. a chain of length 3:
///
/// "shuffle (shuffle (shuffle x, (shuffle y, undef)),
/// (shuffle y, undef)),
// (shuffle y, undef)"
///
/// into a single shuffle fed by a length-changing shuffle:
///
/// "shuffle x, (shuffle y, undef)"
///
/// Such chains arise e.g. from folding extract/insert sequences.
bool VectorCombine::foldShufflesOfLengthChangingShuffles(Instruction &I) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a really long function for only one changed test function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm adding a few more tests.

FixedVectorType *TrunkType = dyn_cast<FixedVectorType>(I.getType());
if (!TrunkType)
return false;

unsigned ChainLength = 0;
SmallVector<int> Mask;
SmallVector<int> YMask;
InstructionCost OldCost = 0;
InstructionCost NewCost = 0;
Value *Trunk = &I;
unsigned NumTrunkElts = TrunkType->getNumElements();
Value *Y = nullptr;

for (;;) {
// Match the current trunk against (commutations of) the pattern
// "shuffle trunk', (shuffle y, undef)"
ArrayRef<int> OuterMask;
Value *OuterV0, *OuterV1;
if (ChainLength != 0 && !Trunk->hasOneUse())
break;
if (!match(Trunk, m_Shuffle(m_Value(OuterV0), m_Value(OuterV1),
m_Mask(OuterMask))))
break;
if (OuterV0->getType() != TrunkType) {
// This shuffle is not length-preserving, so it cannot be part of the
// chain.
break;
}

ArrayRef<int> InnerMask0, InnerMask1;
Value *A0, *A1, *B0, *B1;
bool Match0 =
match(OuterV0, m_Shuffle(m_Value(A0), m_Value(B0), m_Mask(InnerMask0)));
bool Match1 =
match(OuterV1, m_Shuffle(m_Value(A1), m_Value(B1), m_Mask(InnerMask1)));
bool Match0Leaf = Match0 && A0->getType() != I.getType();
bool Match1Leaf = Match1 && A1->getType() != I.getType();
if (Match0Leaf == Match1Leaf) {
// Only handle the case of exactly one leaf in each step. The "two leaves"
// case is handled by foldShuffleOfShuffles.
break;
}

SmallVector<int> CommutedOuterMask;
if (Match0Leaf) {
std::swap(OuterV0, OuterV1);
std::swap(InnerMask0, InnerMask1);
std::swap(A0, A1);
std::swap(B0, B1);
llvm::append_range(CommutedOuterMask, OuterMask);
for (int &M : CommutedOuterMask) {
if (M == PoisonMaskElem)
continue;
if (M < (int)NumTrunkElts)
M += NumTrunkElts;
else
M -= NumTrunkElts;
}
OuterMask = CommutedOuterMask;
}
if (!OuterV1->hasOneUse())
break;

if (!isa<UndefValue>(A1)) {
if (!Y)
Y = A1;
else if (Y != A1)
break;
}
if (!isa<UndefValue>(B1)) {
if (!Y)
Y = B1;
else if (Y != B1)
break;
}

auto *YType = cast<FixedVectorType>(A1->getType());
int NumLeafElts = YType->getNumElements();
SmallVector<int> LocalYMask(InnerMask1);
for (int &M : LocalYMask) {
if (M >= NumLeafElts)
M -= NumLeafElts;
}

InstructionCost LocalOldCost =
TTI.getInstructionCost(cast<User>(Trunk), CostKind) +
TTI.getInstructionCost(cast<User>(OuterV1), CostKind);

// Handle the initial (start of chain) case.
if (!ChainLength) {
Mask.assign(OuterMask);
YMask.assign(LocalYMask);
OldCost = NewCost = LocalOldCost;
Trunk = OuterV0;
ChainLength++;
continue;
}

// For the non-root case, first attempt to combine masks.
SmallVector<int> NewYMask(YMask);
bool Valid = true;
for (auto [CombinedM, LeafM] : llvm::zip(NewYMask, LocalYMask)) {
if (LeafM == -1 || CombinedM == LeafM)
continue;
if (CombinedM == -1) {
CombinedM = LeafM;
} else {
Valid = false;
break;
}
}
if (!Valid)
break;

SmallVector<int> NewMask;
NewMask.reserve(NumTrunkElts);
for (int M : Mask) {
if (M < 0 || M >= static_cast<int>(NumTrunkElts))
NewMask.push_back(M);
else
NewMask.push_back(OuterMask[M]);
}

// Break the chain if adding this new step complicates the shuffles such
// that it would increase the new cost by more than the old cost of this
// step.
InstructionCost LocalNewCost =
TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, TrunkType,
YType, NewYMask, CostKind) +
TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, TrunkType,
TrunkType, NewMask, CostKind);

if (LocalNewCost >= NewCost && LocalOldCost < LocalNewCost - NewCost)
break;

LLVM_DEBUG({
if (ChainLength == 1) {
dbgs() << "Found chain of shuffles fed by length-changing shuffles: "
<< I << '\n';
}
dbgs() << " next chain link: " << *Trunk << '\n'
<< " old cost: " << (OldCost + LocalOldCost)
<< " new cost: " << LocalNewCost << '\n';
});

Mask = NewMask;
YMask = NewYMask;
OldCost += LocalOldCost;
NewCost = LocalNewCost;
Trunk = OuterV0;
ChainLength++;
}
if (ChainLength <= 1)
return false;

if (llvm::all_of(Mask, [&](int M) {
return M < 0 || M >= static_cast<int>(NumTrunkElts);
})) {
// Produce a canonical simplified form if all elements are sourced from Y.
for (int &M : Mask) {
if (M >= static_cast<int>(NumTrunkElts))
M = YMask[M - NumTrunkElts];
}
Value *Root =
Builder.CreateShuffleVector(Y, PoisonValue::get(Y->getType()), Mask);
replaceValue(I, *Root);
return true;
}

Value *Leaf =
Builder.CreateShuffleVector(Y, PoisonValue::get(Y->getType()), YMask);
Value *Root = Builder.CreateShuffleVector(Trunk, Leaf, Mask);
replaceValue(I, *Root);
return true;
}

/// Try to convert
/// "shuffle (intrinsic), (intrinsic)" into "intrinsic (shuffle), (shuffle)".
bool VectorCombine::foldShuffleOfIntrinsics(Instruction &I) {
Expand Down Expand Up @@ -4799,6 +4989,8 @@ bool VectorCombine::run() {
return true;
if (foldPermuteOfIntrinsic(I))
return true;
if (foldShufflesOfLengthChangingShuffles(I))
return true;
if (foldShuffleOfIntrinsics(I))
return true;
if (foldSelectShuffle(I))
Expand Down
Loading
Loading