Skip to content

Commit 0b80288

Browse files
pasaulaisldrumm
authored andcommitted
[NVPTX] Preserve v16i8 vector loads when legalizing
This is done by lowering v16i8 loads into LoadV4 operations with i32 results instead of letting ReplaceLoadVector split it into smaller loads during legalization. This is done at dag-combine1 time, so that vector operations with i8 elements can be optimised away instead of being needlessly split during legalization, which involves storing to the stack and loading it back.
1 parent 906d3ff commit 0b80288

File tree

2 files changed

+166
-2
lines changed

2 files changed

+166
-2
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -701,8 +701,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
701701
setOperationAction(ISD::UMUL_LOHI, MVT::i64, Expand);
702702

703703
// We have some custom DAG combine patterns for these nodes
704-
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::FADD, ISD::MUL, ISD::SHL,
705-
ISD::SREM, ISD::UREM, ISD::EXTRACT_VECTOR_ELT,
704+
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
705+
ISD::LOAD, ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM,
706706
ISD::VSELECT});
707707

708708
// setcc for f16x2 and bf16x2 needs special handling to prevent
@@ -5479,6 +5479,45 @@ static SDValue PerformVSELECTCombine(SDNode *N,
54795479
return DCI.DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v4i8, E);
54805480
}
54815481

5482+
static SDValue PerformLOADCombine(SDNode *N,
5483+
TargetLowering::DAGCombinerInfo &DCI) {
5484+
SelectionDAG &DAG = DCI.DAG;
5485+
LoadSDNode *LD = cast<LoadSDNode>(N);
5486+
5487+
// Lower a v16i8 load into a LoadV4 operation with i32 results instead of
5488+
// letting ReplaceLoadVector split it into smaller loads during legalization.
5489+
// This is done at dag-combine1 time, so that vector operations with i8
5490+
// elements can be optimised away instead of being needlessly split during
5491+
// legalization, which involves storing to the stack and loading it back.
5492+
EVT VT = N->getValueType(0);
5493+
if (VT != MVT::v16i8)
5494+
return SDValue();
5495+
5496+
SDLoc DL(N);
5497+
5498+
// Create a v4i32 vector load operation, effectively <4 x v4i8>.
5499+
unsigned Opc = NVPTXISD::LoadV4;
5500+
EVT NewVT = MVT::v4i32;
5501+
EVT EltVT = NewVT.getVectorElementType();
5502+
unsigned NumElts = NewVT.getVectorNumElements();
5503+
EVT RetVTs[] = {EltVT, EltVT, EltVT, EltVT, MVT::Other};
5504+
SDVTList RetVTList = DAG.getVTList(RetVTs);
5505+
SmallVector<SDValue, 8> Ops(N->ops());
5506+
Ops.push_back(DAG.getIntPtrConstant(LD->getExtensionType(), DL));
5507+
SDValue NewLoad = DAG.getMemIntrinsicNode(Opc, DL, RetVTList, Ops, NewVT,
5508+
LD->getMemOperand());
5509+
SDValue NewChain = NewLoad.getValue(NumElts);
5510+
5511+
// Create a vector of the same type returned by the original load.
5512+
SmallVector<SDValue, 4> Elts;
5513+
for (unsigned i = 0; i < NumElts; i++)
5514+
Elts.push_back(NewLoad.getValue(i));
5515+
return DCI.DAG.getMergeValues(
5516+
{DCI.DAG.getBitcast(VT, DCI.DAG.getBuildVector(NewVT, DL, Elts)),
5517+
NewChain},
5518+
DL);
5519+
}
5520+
54825521
SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
54835522
DAGCombinerInfo &DCI) const {
54845523
CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
@@ -5498,6 +5537,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
54985537
return PerformREMCombine(N, DCI, OptLevel);
54995538
case ISD::SETCC:
55005539
return PerformSETCCCombine(N, DCI);
5540+
case ISD::LOAD:
5541+
return PerformLOADCombine(N, DCI);
55015542
case NVPTXISD::StoreRetval:
55025543
case NVPTXISD::StoreRetvalV2:
55035544
case NVPTXISD::StoreRetvalV4:

llvm/test/CodeGen/NVPTX/LoadStoreVectorizer.ll

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,126 @@ define float @ff(ptr %p) {
5252
%sum = fadd float %sum3, %v4
5353
ret float %sum
5454
}
55+
56+
define void @combine_v16i8(ptr noundef align 16 %ptr1, ptr noundef align 16 %ptr2) {
57+
; ENABLED-LABEL: combine_v16i8
58+
; ENABLED: ld.v4.u32
59+
%val0 = load i8, ptr %ptr1, align 16
60+
%ptr1.1 = getelementptr inbounds i8, ptr %ptr1, i64 1
61+
%val1 = load i8, ptr %ptr1.1, align 1
62+
%ptr1.2 = getelementptr inbounds i8, ptr %ptr1, i64 2
63+
%val2 = load i8, ptr %ptr1.2, align 2
64+
%ptr1.3 = getelementptr inbounds i8, ptr %ptr1, i64 3
65+
%val3 = load i8, ptr %ptr1.3, align 1
66+
%ptr1.4 = getelementptr inbounds i8, ptr %ptr1, i64 4
67+
%val4 = load i8, ptr %ptr1.4, align 4
68+
%ptr1.5 = getelementptr inbounds i8, ptr %ptr1, i64 5
69+
%val5 = load i8, ptr %ptr1.5, align 1
70+
%ptr1.6 = getelementptr inbounds i8, ptr %ptr1, i64 6
71+
%val6 = load i8, ptr %ptr1.6, align 2
72+
%ptr1.7 = getelementptr inbounds i8, ptr %ptr1, i64 7
73+
%val7 = load i8, ptr %ptr1.7, align 1
74+
%ptr1.8 = getelementptr inbounds i8, ptr %ptr1, i64 8
75+
%val8 = load i8, ptr %ptr1.8, align 8
76+
%ptr1.9 = getelementptr inbounds i8, ptr %ptr1, i64 9
77+
%val9 = load i8, ptr %ptr1.9, align 1
78+
%ptr1.10 = getelementptr inbounds i8, ptr %ptr1, i64 10
79+
%val10 = load i8, ptr %ptr1.10, align 2
80+
%ptr1.11 = getelementptr inbounds i8, ptr %ptr1, i64 11
81+
%val11 = load i8, ptr %ptr1.11, align 1
82+
%ptr1.12 = getelementptr inbounds i8, ptr %ptr1, i64 12
83+
%val12 = load i8, ptr %ptr1.12, align 4
84+
%ptr1.13 = getelementptr inbounds i8, ptr %ptr1, i64 13
85+
%val13 = load i8, ptr %ptr1.13, align 1
86+
%ptr1.14 = getelementptr inbounds i8, ptr %ptr1, i64 14
87+
%val14 = load i8, ptr %ptr1.14, align 2
88+
%ptr1.15 = getelementptr inbounds i8, ptr %ptr1, i64 15
89+
%val15 = load i8, ptr %ptr1.15, align 1
90+
%lane0 = zext i8 %val0 to i32
91+
%lane1 = zext i8 %val1 to i32
92+
%lane2 = zext i8 %val2 to i32
93+
%lane3 = zext i8 %val3 to i32
94+
%lane4 = zext i8 %val4 to i32
95+
%lane5 = zext i8 %val5 to i32
96+
%lane6 = zext i8 %val6 to i32
97+
%lane7 = zext i8 %val7 to i32
98+
%lane8 = zext i8 %val8 to i32
99+
%lane9 = zext i8 %val9 to i32
100+
%lane10 = zext i8 %val10 to i32
101+
%lane11 = zext i8 %val11 to i32
102+
%lane12 = zext i8 %val12 to i32
103+
%lane13 = zext i8 %val13 to i32
104+
%lane14 = zext i8 %val14 to i32
105+
%lane15 = zext i8 %val15 to i32
106+
%red.1 = add i32 %lane0, %lane1
107+
%red.2 = add i32 %red.1, %lane2
108+
%red.3 = add i32 %red.2, %lane3
109+
%red.4 = add i32 %red.3, %lane4
110+
%red.5 = add i32 %red.4, %lane5
111+
%red.6 = add i32 %red.5, %lane6
112+
%red.7 = add i32 %red.6, %lane7
113+
%red.8 = add i32 %red.7, %lane8
114+
%red.9 = add i32 %red.8, %lane9
115+
%red.10 = add i32 %red.9, %lane10
116+
%red.11 = add i32 %red.10, %lane11
117+
%red.12 = add i32 %red.11, %lane12
118+
%red.13 = add i32 %red.12, %lane13
119+
%red.14 = add i32 %red.13, %lane14
120+
%red = add i32 %red.14, %lane15
121+
store i32 %red, ptr %ptr2, align 4
122+
ret void
123+
}
124+
125+
define void @combine_v8i16(ptr noundef align 16 %ptr1, ptr noundef align 16 %ptr2) {
126+
; ENABLED-LABEL: combine_v8i16
127+
; ENABLED: ld.v4.b32
128+
%val0 = load i16, ptr %ptr1, align 16
129+
%ptr1.1 = getelementptr inbounds i16, ptr %ptr1, i64 1
130+
%val1 = load i16, ptr %ptr1.1, align 2
131+
%ptr1.2 = getelementptr inbounds i16, ptr %ptr1, i64 2
132+
%val2 = load i16, ptr %ptr1.2, align 4
133+
%ptr1.3 = getelementptr inbounds i16, ptr %ptr1, i64 3
134+
%val3 = load i16, ptr %ptr1.3, align 2
135+
%ptr1.4 = getelementptr inbounds i16, ptr %ptr1, i64 4
136+
%val4 = load i16, ptr %ptr1.4, align 4
137+
%ptr1.5 = getelementptr inbounds i16, ptr %ptr1, i64 5
138+
%val5 = load i16, ptr %ptr1.5, align 2
139+
%ptr1.6 = getelementptr inbounds i16, ptr %ptr1, i64 6
140+
%val6 = load i16, ptr %ptr1.6, align 4
141+
%ptr1.7 = getelementptr inbounds i16, ptr %ptr1, i64 7
142+
%val7 = load i16, ptr %ptr1.7, align 2
143+
%lane0 = zext i16 %val0 to i32
144+
%lane1 = zext i16 %val1 to i32
145+
%lane2 = zext i16 %val2 to i32
146+
%lane3 = zext i16 %val3 to i32
147+
%lane4 = zext i16 %val4 to i32
148+
%lane5 = zext i16 %val5 to i32
149+
%lane6 = zext i16 %val6 to i32
150+
%lane7 = zext i16 %val7 to i32
151+
%red.1 = add i32 %lane0, %lane1
152+
%red.2 = add i32 %red.1, %lane2
153+
%red.3 = add i32 %red.2, %lane3
154+
%red.4 = add i32 %red.3, %lane4
155+
%red.5 = add i32 %red.4, %lane5
156+
%red.6 = add i32 %red.5, %lane6
157+
%red = add i32 %red.6, %lane7
158+
store i32 %red, ptr %ptr2, align 4
159+
ret void
160+
}
161+
162+
define void @combine_v4i32(ptr noundef align 16 %ptr1, ptr noundef align 16 %ptr2) {
163+
; ENABLED-LABEL: combine_v4i32
164+
; ENABLED: ld.v4.u32
165+
%val0 = load i32, ptr %ptr1, align 16
166+
%ptr1.1 = getelementptr inbounds i32, ptr %ptr1, i64 1
167+
%val1 = load i32, ptr %ptr1.1, align 4
168+
%ptr1.2 = getelementptr inbounds i32, ptr %ptr1, i64 2
169+
%val2 = load i32, ptr %ptr1.2, align 8
170+
%ptr1.3 = getelementptr inbounds i32, ptr %ptr1, i64 3
171+
%val3 = load i32, ptr %ptr1.3, align 4
172+
%red.1 = add i32 %val0, %val1
173+
%red.2 = add i32 %red.1, %val2
174+
%red = add i32 %red.2, %val3
175+
store i32 %red, ptr %ptr2, align 4
176+
ret void
177+
}

0 commit comments

Comments
 (0)