Skip to content

Commit 577631f

Browse files
committed
Reapply "[VPlan] Add transformation to narrow interleave groups. (#106441)"
This reverts commit ff3e2ba. The recommmitted version limits to transform to cases where no interleaving is taking place, to avoid a mis-compile when interleaving. Original commit message: This patch adds a new narrowInterleaveGroups transfrom, which tries convert a plan with interleave groups with VF elements to a plan that instead replaces the interleave groups with wide loads and stores processing VF elements. This effectively is a very simple form of loop-aware SLP, where we use interleave groups to identify candidates. This initial version is quite restricted and hopefully serves as a starting point for how to best model those kinds of transforms. Depends on #106431. Fixes #82936. PR: #106441
1 parent a074831 commit 577631f

6 files changed

+165
-187
lines changed

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

+81-13
Original file line numberDiff line numberDiff line change
@@ -2247,6 +2247,36 @@ void VPlanTransforms::materializeBroadcasts(VPlan &Plan) {
22472247
}
22482248
}
22492249

2250+
/// Returns true if \p V is VPWidenLoadRecipe or VPInterleaveRecipe that can be
2251+
/// converted to a narrower recipe. \p V is used by a wide recipe \p WideMember
2252+
/// that feeds a store interleave group at index \p Idx, \p WideMember0 is the
2253+
/// recipe feeding the same interleave group at index 0. A VPWidenLoadRecipe can
2254+
/// be narrowed to an index-independent load if it feeds all wide ops at all
2255+
/// indices (checked by via the operands of the wide recipe at lane0, \p
2256+
/// WideMember0). A VPInterleaveRecipe can be narrowed to a wide load, if \p V
2257+
/// is defined at \p Idx of a load interleave group.
2258+
static bool canNarrowLoad(VPWidenRecipe *WideMember0, VPWidenRecipe *WideMember,
2259+
VPValue *V, unsigned Idx) {
2260+
auto *DefR = V->getDefiningRecipe();
2261+
if (!DefR)
2262+
return false;
2263+
if (auto *W = dyn_cast<VPWidenLoadRecipe>(DefR))
2264+
return !W->getMask() &&
2265+
all_of(zip(WideMember0->operands(), WideMember->operands()),
2266+
[V](const auto P) {
2267+
// V must be as at the same places in both WideMember0 and
2268+
// WideMember.
2269+
const auto &[WideMember0Op, WideMemberOp] = P;
2270+
return (WideMember0Op == V) == (WideMemberOp == V);
2271+
});
2272+
2273+
if (auto *IR = dyn_cast<VPInterleaveRecipe>(DefR))
2274+
return IR->getInterleaveGroup()->getFactor() ==
2275+
IR->getInterleaveGroup()->getNumMembers() &&
2276+
IR->getVPValue(Idx) == V;
2277+
return false;
2278+
}
2279+
22502280
/// Returns true if \p IR is a full interleave group with factor and number of
22512281
/// members both equal to \p VF. The interleave group must also access the full
22522282
/// vector width \p VectorRegWidth.
@@ -2284,7 +2314,7 @@ void VPlanTransforms::narrowInterleaveGroups(VPlan &Plan, ElementCount VF,
22842314
unsigned VectorRegWidth) {
22852315
using namespace llvm::VPlanPatternMatch;
22862316
VPRegionBlock *VectorLoop = Plan.getVectorLoopRegion();
2287-
if (VF.isScalable() || !VectorLoop)
2317+
if (VF.isScalable() || !VectorLoop || Plan.getUF() != 1)
22882318
return;
22892319

22902320
VPCanonicalIVPHIRecipe *CanonicalIV = Plan.getCanonicalIV();
@@ -2309,6 +2339,8 @@ void VPlanTransforms::narrowInterleaveGroups(VPlan &Plan, ElementCount VF,
23092339
if (R.mayWriteToMemory() && !InterleaveR)
23102340
return;
23112341

2342+
// All other ops are allowed, but we reject uses that cannot be converted
2343+
// when checking all allowed consumers (store interleave groups) below.
23122344
if (!InterleaveR)
23132345
continue;
23142346

@@ -2323,7 +2355,7 @@ void VPlanTransforms::narrowInterleaveGroups(VPlan &Plan, ElementCount VF,
23232355

23242356
// For now, we only support full interleave groups storing load interleave
23252357
// groups.
2326-
if (!all_of(enumerate(InterleaveR->getStoredValues()), [](auto Op) {
2358+
if (all_of(enumerate(InterleaveR->getStoredValues()), [](auto Op) {
23272359
VPRecipeBase *DefR = Op.value()->getDefiningRecipe();
23282360
if (!DefR)
23292361
return false;
@@ -2333,31 +2365,67 @@ void VPlanTransforms::narrowInterleaveGroups(VPlan &Plan, ElementCount VF,
23332365
IR->getInterleaveGroup()->getNumMembers() &&
23342366
IR->getVPValue(Op.index()) == Op.value();
23352367
})) {
2368+
StoreGroups.push_back(InterleaveR);
2369+
continue;
2370+
}
2371+
2372+
// Check if all values feeding InterleaveR are matching wide recipes, which
2373+
// operands that can be narrowed.
2374+
auto *WideMember0 = dyn_cast_or_null<VPWidenRecipe>(
2375+
InterleaveR->getStoredValues()[0]->getDefiningRecipe());
2376+
if (!WideMember0)
23362377
return;
2378+
for (const auto &[I, V] : enumerate(InterleaveR->getStoredValues())) {
2379+
auto *R = dyn_cast<VPWidenRecipe>(V->getDefiningRecipe());
2380+
if (!R || R->getOpcode() != WideMember0->getOpcode() ||
2381+
R->getNumOperands() > 2)
2382+
return;
2383+
if (any_of(R->operands(), [WideMember0, Idx = I, R](VPValue *V) {
2384+
return !canNarrowLoad(WideMember0, R, V, Idx);
2385+
}))
2386+
return;
23372387
}
23382388
StoreGroups.push_back(InterleaveR);
23392389
}
23402390

23412391
if (StoreGroups.empty())
23422392
return;
23432393

2344-
// Convert InterleaveGroup R to a single VPWidenLoadRecipe.
2394+
// Convert InterleaveGroup \p R to a single VPWidenLoadRecipe.
23452395
auto NarrowOp = [](VPRecipeBase *R) -> VPValue * {
2346-
auto *LoadGroup = cast<VPInterleaveRecipe>(R);
2347-
// Narrow interleave group to wide load, as transformed VPlan will only
2396+
if (auto *LoadGroup = dyn_cast<VPInterleaveRecipe>(R)) {
2397+
// Narrow interleave group to wide load, as transformed VPlan will only
2398+
// process one original iteration.
2399+
auto *L = new VPWidenLoadRecipe(
2400+
*cast<LoadInst>(LoadGroup->getInterleaveGroup()->getInsertPos()),
2401+
LoadGroup->getAddr(), LoadGroup->getMask(), /*Consecutive=*/true,
2402+
/*Reverse=*/false, LoadGroup->getDebugLoc());
2403+
L->insertBefore(LoadGroup);
2404+
return L;
2405+
}
2406+
2407+
auto *WideLoad = cast<VPWidenLoadRecipe>(R);
2408+
2409+
// Narrow wide load to uniform scalar load, as transformed VPlan will only
23482410
// process one original iteration.
2349-
auto *L = new VPWidenLoadRecipe(
2350-
*cast<LoadInst>(LoadGroup->getInterleaveGroup()->getInsertPos()),
2351-
LoadGroup->getAddr(), LoadGroup->getMask(), /*Consecutive=*/true,
2352-
/*Reverse=*/false, LoadGroup->getDebugLoc());
2353-
L->insertBefore(LoadGroup);
2354-
return L;
2411+
auto *N = new VPReplicateRecipe(&WideLoad->getIngredient(),
2412+
WideLoad->operands(), /*IsUniform*/ true);
2413+
N->insertBefore(WideLoad);
2414+
return N;
23552415
};
23562416

23572417
// Narrow operation tree rooted at store groups.
23582418
for (auto *StoreGroup : StoreGroups) {
2359-
VPValue *Res =
2360-
NarrowOp(StoreGroup->getStoredValues()[0]->getDefiningRecipe());
2419+
VPValue *Res = nullptr;
2420+
if (auto *WideMember0 = dyn_cast<VPWidenRecipe>(
2421+
StoreGroup->getStoredValues()[0]->getDefiningRecipe())) {
2422+
for (unsigned Idx = 0, E = WideMember0->getNumOperands(); Idx != E; ++Idx)
2423+
WideMember0->setOperand(
2424+
Idx, NarrowOp(WideMember0->getOperand(Idx)->getDefiningRecipe()));
2425+
Res = WideMember0;
2426+
} else {
2427+
Res = NarrowOp(StoreGroup->getStoredValues()[0]->getDefiningRecipe());
2428+
}
23612429

23622430
auto *S = new VPWidenStoreRecipe(
23632431
*cast<StoreInst>(StoreGroup->getInterleaveGroup()->getInsertPos()),

llvm/test/Transforms/LoopVectorize/AArch64/transform-narrow-interleave-to-widen-memory-cost.ll

+11-11
Original file line numberDiff line numberDiff line change
@@ -100,27 +100,27 @@ define void @test_complex_add_double(ptr %res, ptr noalias %A, ptr noalias %B, i
100100
; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds nuw { double, double }, ptr [[B]], i64 [[TMP0]]
101101
; CHECK-NEXT: [[TMP5:%.*]] = getelementptr inbounds nuw { double, double }, ptr [[B]], i64 [[TMP1]]
102102
; CHECK-NEXT: [[WIDE_VEC:%.*]] = load <4 x double>, ptr [[TMP2]], align 4
103-
; CHECK-NEXT: [[STRIDED_VEC:%.*]] = shufflevector <4 x double> [[WIDE_VEC]], <4 x double> poison, <2 x i32> <i32 0, i32 2>
103+
; CHECK-NEXT: [[STRIDED_VEC4:%.*]] = shufflevector <4 x double> [[WIDE_VEC]], <4 x double> poison, <2 x i32> <i32 0, i32 2>
104104
; CHECK-NEXT: [[STRIDED_VEC1:%.*]] = shufflevector <4 x double> [[WIDE_VEC]], <4 x double> poison, <2 x i32> <i32 1, i32 3>
105105
; CHECK-NEXT: [[WIDE_VEC2:%.*]] = load <4 x double>, ptr [[TMP3]], align 4
106-
; CHECK-NEXT: [[STRIDED_VEC3:%.*]] = shufflevector <4 x double> [[WIDE_VEC2]], <4 x double> poison, <2 x i32> <i32 0, i32 2>
107-
; CHECK-NEXT: [[STRIDED_VEC4:%.*]] = shufflevector <4 x double> [[WIDE_VEC2]], <4 x double> poison, <2 x i32> <i32 1, i32 3>
106+
; CHECK-NEXT: [[WIDE_LOAD1:%.*]] = shufflevector <4 x double> [[WIDE_VEC2]], <4 x double> poison, <2 x i32> <i32 0, i32 2>
107+
; CHECK-NEXT: [[STRIDED_VEC5:%.*]] = shufflevector <4 x double> [[WIDE_VEC2]], <4 x double> poison, <2 x i32> <i32 1, i32 3>
108108
; CHECK-NEXT: [[WIDE_VEC5:%.*]] = load <4 x double>, ptr [[TMP4]], align 4
109-
; CHECK-NEXT: [[STRIDED_VEC6:%.*]] = shufflevector <4 x double> [[WIDE_VEC5]], <4 x double> poison, <2 x i32> <i32 0, i32 2>
109+
; CHECK-NEXT: [[STRIDED_VEC10:%.*]] = shufflevector <4 x double> [[WIDE_VEC5]], <4 x double> poison, <2 x i32> <i32 0, i32 2>
110110
; CHECK-NEXT: [[STRIDED_VEC7:%.*]] = shufflevector <4 x double> [[WIDE_VEC5]], <4 x double> poison, <2 x i32> <i32 1, i32 3>
111111
; CHECK-NEXT: [[WIDE_VEC8:%.*]] = load <4 x double>, ptr [[TMP5]], align 4
112-
; CHECK-NEXT: [[STRIDED_VEC9:%.*]] = shufflevector <4 x double> [[WIDE_VEC8]], <4 x double> poison, <2 x i32> <i32 0, i32 2>
113-
; CHECK-NEXT: [[STRIDED_VEC10:%.*]] = shufflevector <4 x double> [[WIDE_VEC8]], <4 x double> poison, <2 x i32> <i32 1, i32 3>
114-
; CHECK-NEXT: [[TMP6:%.*]] = fadd <2 x double> [[STRIDED_VEC]], [[STRIDED_VEC6]]
115-
; CHECK-NEXT: [[TMP7:%.*]] = fadd <2 x double> [[STRIDED_VEC3]], [[STRIDED_VEC9]]
116-
; CHECK-NEXT: [[TMP8:%.*]] = fadd <2 x double> [[STRIDED_VEC1]], [[STRIDED_VEC7]]
112+
; CHECK-NEXT: [[WIDE_LOAD3:%.*]] = shufflevector <4 x double> [[WIDE_VEC8]], <4 x double> poison, <2 x i32> <i32 0, i32 2>
113+
; CHECK-NEXT: [[STRIDED_VEC11:%.*]] = shufflevector <4 x double> [[WIDE_VEC8]], <4 x double> poison, <2 x i32> <i32 1, i32 3>
117114
; CHECK-NEXT: [[TMP9:%.*]] = fadd <2 x double> [[STRIDED_VEC4]], [[STRIDED_VEC10]]
115+
; CHECK-NEXT: [[TMP7:%.*]] = fadd <2 x double> [[WIDE_LOAD1]], [[WIDE_LOAD3]]
116+
; CHECK-NEXT: [[TMP8:%.*]] = fadd <2 x double> [[STRIDED_VEC1]], [[STRIDED_VEC7]]
117+
; CHECK-NEXT: [[TMP15:%.*]] = fadd <2 x double> [[STRIDED_VEC5]], [[STRIDED_VEC11]]
118118
; CHECK-NEXT: [[TMP10:%.*]] = getelementptr inbounds nuw { double, double }, ptr [[RES]], i64 [[TMP0]]
119119
; CHECK-NEXT: [[TMP11:%.*]] = getelementptr inbounds nuw { double, double }, ptr [[RES]], i64 [[TMP1]]
120-
; CHECK-NEXT: [[TMP12:%.*]] = shufflevector <2 x double> [[TMP6]], <2 x double> [[TMP8]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
120+
; CHECK-NEXT: [[TMP12:%.*]] = shufflevector <2 x double> [[TMP9]], <2 x double> [[TMP8]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
121121
; CHECK-NEXT: [[INTERLEAVED_VEC:%.*]] = shufflevector <4 x double> [[TMP12]], <4 x double> poison, <4 x i32> <i32 0, i32 2, i32 1, i32 3>
122122
; CHECK-NEXT: store <4 x double> [[INTERLEAVED_VEC]], ptr [[TMP10]], align 4
123-
; CHECK-NEXT: [[TMP13:%.*]] = shufflevector <2 x double> [[TMP7]], <2 x double> [[TMP9]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
123+
; CHECK-NEXT: [[TMP13:%.*]] = shufflevector <2 x double> [[TMP7]], <2 x double> [[TMP15]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
124124
; CHECK-NEXT: [[INTERLEAVED_VEC11:%.*]] = shufflevector <4 x double> [[TMP13]], <4 x double> poison, <4 x i32> <i32 0, i32 2, i32 1, i32 3>
125125
; CHECK-NEXT: store <4 x double> [[INTERLEAVED_VEC11]], ptr [[TMP11]], align 4
126126
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4

llvm/test/Transforms/LoopVectorize/AArch64/transform-narrow-interleave-to-widen-memory-unroll.ll

+13-5
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,19 @@ define void @load_store_interleave_group(ptr noalias %data) {
1919
; CHECK-NEXT: [[TMP3:%.*]] = shl nsw i64 [[TMP1]], 1
2020
; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds i64, ptr [[DATA]], i64 [[TMP2]]
2121
; CHECK-NEXT: [[TMP5:%.*]] = getelementptr inbounds i64, ptr [[DATA]], i64 [[TMP3]]
22-
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <2 x i64>, ptr [[TMP4]], align 8
23-
; CHECK-NEXT: [[WIDE_LOAD1:%.*]] = load <2 x i64>, ptr [[TMP5]], align 8
24-
; CHECK-NEXT: store <2 x i64> [[WIDE_LOAD]], ptr [[TMP4]], align 8
25-
; CHECK-NEXT: store <2 x i64> [[WIDE_LOAD1]], ptr [[TMP5]], align 8
26-
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 2
22+
; CHECK-NEXT: [[WIDE_VEC:%.*]] = load <4 x i64>, ptr [[TMP4]], align 8
23+
; CHECK-NEXT: [[STRIDED_VEC:%.*]] = shufflevector <4 x i64> [[WIDE_VEC]], <4 x i64> poison, <2 x i32> <i32 0, i32 2>
24+
; CHECK-NEXT: [[STRIDED_VEC1:%.*]] = shufflevector <4 x i64> [[WIDE_VEC]], <4 x i64> poison, <2 x i32> <i32 1, i32 3>
25+
; CHECK-NEXT: [[WIDE_VEC2:%.*]] = load <4 x i64>, ptr [[TMP5]], align 8
26+
; CHECK-NEXT: [[STRIDED_VEC3:%.*]] = shufflevector <4 x i64> [[WIDE_VEC2]], <4 x i64> poison, <2 x i32> <i32 0, i32 2>
27+
; CHECK-NEXT: [[STRIDED_VEC4:%.*]] = shufflevector <4 x i64> [[WIDE_VEC2]], <4 x i64> poison, <2 x i32> <i32 1, i32 3>
28+
; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <2 x i64> [[STRIDED_VEC]], <2 x i64> [[STRIDED_VEC1]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
29+
; CHECK-NEXT: [[INTERLEAVED_VEC:%.*]] = shufflevector <4 x i64> [[TMP8]], <4 x i64> poison, <4 x i32> <i32 0, i32 2, i32 1, i32 3>
30+
; CHECK-NEXT: store <4 x i64> [[INTERLEAVED_VEC]], ptr [[TMP4]], align 8
31+
; CHECK-NEXT: [[TMP7:%.*]] = shufflevector <2 x i64> [[STRIDED_VEC3]], <2 x i64> [[STRIDED_VEC4]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
32+
; CHECK-NEXT: [[INTERLEAVED_VEC5:%.*]] = shufflevector <4 x i64> [[TMP7]], <4 x i64> poison, <4 x i32> <i32 0, i32 2, i32 1, i32 3>
33+
; CHECK-NEXT: store <4 x i64> [[INTERLEAVED_VEC5]], ptr [[TMP5]], align 8
34+
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4
2735
; CHECK-NEXT: [[TMP6:%.*]] = icmp eq i64 [[INDEX_NEXT]], 100
2836
; CHECK-NEXT: br i1 [[TMP6]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
2937
; CHECK: [[MIDDLE_BLOCK]]:

0 commit comments

Comments
 (0)