Skip to content

Commit f122182

Browse files
committed
[DAGCombine] Fix multi-use miscompile in load combine
The load combine replaces a number of original loads with one new loads and also replaces the output chains of the original loads with the output chain of the new load. This is only correct if the old loads actually get removed, otherwise they may get incorrectly reordered. The code did enforce that all involved operations are one-use (which also guarantees that the loads will be removed), with one exceptions: For vector loads, multi-use was allowed to support multiple extract elements from one load. This patch collects these extract elements, and then validates that the loads are only used inside them. I think an alternative fix would be to replace the uses of the old output chains with TokenFactors that include both the old output chains and the new output chain. However, I think the proposed patch is preferable, as the profitability of the transform in the general multi-use case is unclear, as it may increase the overall number of loads. Fixes llvm#80911.
1 parent 69ddf1e commit f122182

File tree

4 files changed

+51
-25
lines changed

4 files changed

+51
-25
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8668,6 +8668,7 @@ using SDByteProvider = ByteProvider<SDNode *>;
86688668
static std::optional<SDByteProvider>
86698669
calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
86708670
std::optional<uint64_t> VectorIndex,
8671+
SmallPtrSetImpl<SDNode *> &ExtractElements,
86718672
unsigned StartingIndex = 0) {
86728673

86738674
// Typical i64 by i8 pattern requires recursion up to 8 calls depth
@@ -8694,12 +8695,12 @@ calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
86948695

86958696
switch (Op.getOpcode()) {
86968697
case ISD::OR: {
8697-
auto LHS =
8698-
calculateByteProvider(Op->getOperand(0), Index, Depth + 1, VectorIndex);
8698+
auto LHS = calculateByteProvider(Op->getOperand(0), Index, Depth + 1,
8699+
VectorIndex, ExtractElements);
86998700
if (!LHS)
87008701
return std::nullopt;
8701-
auto RHS =
8702-
calculateByteProvider(Op->getOperand(1), Index, Depth + 1, VectorIndex);
8702+
auto RHS = calculateByteProvider(Op->getOperand(1), Index, Depth + 1,
8703+
VectorIndex, ExtractElements);
87038704
if (!RHS)
87048705
return std::nullopt;
87058706

@@ -8726,7 +8727,8 @@ calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
87268727
return Index < ByteShift
87278728
? SDByteProvider::getConstantZero()
87288729
: calculateByteProvider(Op->getOperand(0), Index - ByteShift,
8729-
Depth + 1, VectorIndex, Index);
8730+
Depth + 1, VectorIndex, ExtractElements,
8731+
Index);
87308732
}
87318733
case ISD::ANY_EXTEND:
87328734
case ISD::SIGN_EXTEND:
@@ -8743,11 +8745,12 @@ calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
87438745
SDByteProvider::getConstantZero())
87448746
: std::nullopt;
87458747
return calculateByteProvider(NarrowOp, Index, Depth + 1, VectorIndex,
8746-
StartingIndex);
8748+
ExtractElements, StartingIndex);
87478749
}
87488750
case ISD::BSWAP:
87498751
return calculateByteProvider(Op->getOperand(0), ByteWidth - Index - 1,
8750-
Depth + 1, VectorIndex, StartingIndex);
8752+
Depth + 1, VectorIndex, ExtractElements,
8753+
StartingIndex);
87518754
case ISD::EXTRACT_VECTOR_ELT: {
87528755
auto OffsetOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
87538756
if (!OffsetOp)
@@ -8772,8 +8775,9 @@ calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
87728775
if ((*VectorIndex + 1) * NarrowByteWidth <= StartingIndex)
87738776
return std::nullopt;
87748777

8778+
ExtractElements.insert(Op.getNode());
87758779
return calculateByteProvider(Op->getOperand(0), Index, Depth + 1,
8776-
VectorIndex, StartingIndex);
8780+
VectorIndex, ExtractElements, StartingIndex);
87778781
}
87788782
case ISD::LOAD: {
87798783
auto L = cast<LoadSDNode>(Op.getNode());
@@ -9110,6 +9114,7 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
91109114
SDValue Chain;
91119115

91129116
SmallPtrSet<LoadSDNode *, 8> Loads;
9117+
SmallPtrSet<SDNode *, 8> ExtractElements;
91139118
std::optional<SDByteProvider> FirstByteProvider;
91149119
int64_t FirstOffset = INT64_MAX;
91159120

@@ -9119,7 +9124,9 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
91199124
unsigned ZeroExtendedBytes = 0;
91209125
for (int i = ByteWidth - 1; i >= 0; --i) {
91219126
auto P =
9122-
calculateByteProvider(SDValue(N, 0), i, 0, /*VectorIndex*/ std::nullopt,
9127+
calculateByteProvider(SDValue(N, 0), i, 0,
9128+
/*VectorIndex*/ std::nullopt, ExtractElements,
9129+
91239130
/*StartingIndex*/ i);
91249131
if (!P)
91259132
return SDValue();
@@ -9245,6 +9252,14 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
92459252
if (!Allowed || !Fast)
92469253
return SDValue();
92479254

9255+
// calculatebyteProvider() allows multi-use for vector loads. Ensure that
9256+
// all uses are in vector element extracts that are part of the pattern.
9257+
for (LoadSDNode *L : Loads)
9258+
if (L->getMemoryVT().isVector())
9259+
for (auto It = L->use_begin(); It != L->use_end(); ++It)
9260+
if (It.getUse().getResNo() == 0 && !ExtractElements.contains(*It))
9261+
return SDValue();
9262+
92489263
SDValue NewLoad =
92499264
DAG.getExtLoad(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, SDLoc(N), VT,
92509265
Chain, FirstLoad->getBasePtr(),

llvm/test/CodeGen/AArch64/load-combine.ll

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -606,10 +606,12 @@ define void @short_vector_to_i32_unused_high_i8(ptr %in, ptr %out, ptr %p) {
606606
; CHECK-LABEL: short_vector_to_i32_unused_high_i8:
607607
; CHECK: // %bb.0:
608608
; CHECK-NEXT: ldr s0, [x0]
609-
; CHECK-NEXT: ldrh w9, [x0]
610609
; CHECK-NEXT: ushll v0.8h, v0.8b, #0
611-
; CHECK-NEXT: umov w8, v0.h[2]
612-
; CHECK-NEXT: orr w8, w9, w8, lsl #16
610+
; CHECK-NEXT: umov w8, v0.h[1]
611+
; CHECK-NEXT: umov w9, v0.h[0]
612+
; CHECK-NEXT: umov w10, v0.h[2]
613+
; CHECK-NEXT: bfi w9, w8, #8, #8
614+
; CHECK-NEXT: orr w8, w9, w10, lsl #16
613615
; CHECK-NEXT: str w8, [x1]
614616
; CHECK-NEXT: ret
615617
%ld = load <4 x i8>, ptr %in, align 4

llvm/test/CodeGen/AMDGPU/combine-vload-extract.ll

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -205,12 +205,14 @@ define i64 @load_3xi16_combine(ptr addrspace(1) %p) #0 {
205205
; GCN-LABEL: load_3xi16_combine:
206206
; GCN: ; %bb.0:
207207
; GCN-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
208-
; GCN-NEXT: global_load_dword v2, v[0:1], off
209-
; GCN-NEXT: global_load_ushort v3, v[0:1], off offset:4
208+
; GCN-NEXT: global_load_dword v3, v[0:1], off
209+
; GCN-NEXT: global_load_ushort v2, v[0:1], off offset:4
210+
; GCN-NEXT: s_mov_b32 s4, 0xffff
210211
; GCN-NEXT: s_waitcnt vmcnt(1)
211-
; GCN-NEXT: v_mov_b32_e32 v0, v2
212+
; GCN-NEXT: v_and_b32_e32 v0, 0xffff0000, v3
213+
; GCN-NEXT: v_and_or_b32 v0, v3, s4, v0
212214
; GCN-NEXT: s_waitcnt vmcnt(0)
213-
; GCN-NEXT: v_mov_b32_e32 v1, v3
215+
; GCN-NEXT: v_mov_b32_e32 v1, v2
214216
; GCN-NEXT: s_setpc_b64 s[30:31]
215217
%gep.p = getelementptr i16, ptr addrspace(1) %p, i32 1
216218
%gep.2p = getelementptr i16, ptr addrspace(1) %p, i32 2

llvm/test/CodeGen/X86/load-combine.ll

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1283,26 +1283,33 @@ define i32 @zext_load_i32_by_i8_bswap_shl_16(ptr %arg) {
12831283
ret i32 %tmp8
12841284
}
12851285

1286-
; FIXME: This is a miscompile.
12871286
define i32 @pr80911_vector_load_multiuse(ptr %ptr, ptr %clobber) nounwind {
12881287
; CHECK-LABEL: pr80911_vector_load_multiuse:
12891288
; CHECK: # %bb.0:
1289+
; CHECK-NEXT: pushl %edi
12901290
; CHECK-NEXT: pushl %esi
1291-
; CHECK-NEXT: movl {{[0-9]+}}(%esp), %ecx
12921291
; CHECK-NEXT: movl {{[0-9]+}}(%esp), %edx
1293-
; CHECK-NEXT: movl (%edx), %esi
1294-
; CHECK-NEXT: movzwl (%edx), %eax
1295-
; CHECK-NEXT: movl $0, (%ecx)
1296-
; CHECK-NEXT: movl %esi, (%edx)
1292+
; CHECK-NEXT: movl {{[0-9]+}}(%esp), %esi
1293+
; CHECK-NEXT: movzbl (%esi), %ecx
1294+
; CHECK-NEXT: movzbl 1(%esi), %eax
1295+
; CHECK-NEXT: movzwl 2(%esi), %edi
1296+
; CHECK-NEXT: movl $0, (%edx)
1297+
; CHECK-NEXT: movw %di, 2(%esi)
1298+
; CHECK-NEXT: movb %al, 1(%esi)
1299+
; CHECK-NEXT: movb %cl, (%esi)
1300+
; CHECK-NEXT: shll $8, %eax
1301+
; CHECK-NEXT: orl %ecx, %eax
12971302
; CHECK-NEXT: popl %esi
1303+
; CHECK-NEXT: popl %edi
12981304
; CHECK-NEXT: retl
12991305
;
13001306
; CHECK64-LABEL: pr80911_vector_load_multiuse:
13011307
; CHECK64: # %bb.0:
1302-
; CHECK64-NEXT: movzwl (%rdi), %eax
1308+
; CHECK64-NEXT: movaps (%rdi), %xmm0
13031309
; CHECK64-NEXT: movl $0, (%rsi)
1304-
; CHECK64-NEXT: movl (%rdi), %ecx
1305-
; CHECK64-NEXT: movl %ecx, (%rdi)
1310+
; CHECK64-NEXT: movss %xmm0, (%rdi)
1311+
; CHECK64-NEXT: movaps %xmm0, -{{[0-9]+}}(%rsp)
1312+
; CHECK64-NEXT: movzwl -{{[0-9]+}}(%rsp), %eax
13061313
; CHECK64-NEXT: retq
13071314
%load = load <4 x i8>, ptr %ptr, align 16
13081315
store i32 0, ptr %clobber

0 commit comments

Comments
 (0)