Skip to content

Commit 670d208

Browse files
authored
[AArch64] Implement promotion type legalisation for histogram intrinsic (#101017)
Currently the histogram intrinsic (llvm.experimental.vector.histogram.add) only allows i32 and i64 types for the memory locations to be updated, matching the restrictions of the histcnt instruction. This patch adds support for the legalisation of smaller types (i8 and i16) via promotion.
1 parent 908c89e commit 670d208

File tree

2 files changed

+142
-12
lines changed

2 files changed

+142
-12
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1776,9 +1776,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
17761776
setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);
17771777

17781778
// Histcnt is SVE2 only
1779-
if (Subtarget->hasSVE2())
1779+
if (Subtarget->hasSVE2()) {
17801780
setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::Other,
17811781
Custom);
1782+
setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::i8, Custom);
1783+
setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::i16, Custom);
1784+
}
17821785
}
17831786

17841787

@@ -28175,11 +28178,18 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
2817528178

2817628179
EVT IncVT = Inc.getValueType();
2817728180
EVT IndexVT = Index.getValueType();
28178-
EVT MemVT = EVT::getVectorVT(*DAG.getContext(), IncVT,
28179-
IndexVT.getVectorElementCount());
28181+
LLVMContext &Ctx = *DAG.getContext();
28182+
ElementCount EC = IndexVT.getVectorElementCount();
28183+
EVT MemVT = EVT::getVectorVT(Ctx, IncVT, EC);
28184+
EVT IncExtVT =
28185+
EVT::getIntegerVT(Ctx, AArch64::SVEBitsPerBlock / EC.getKnownMinValue());
28186+
EVT IncSplatVT = EVT::getVectorVT(Ctx, IncExtVT, EC);
28187+
bool ExtTrunc = IncSplatVT != MemVT;
28188+
2818028189
SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
28181-
SDValue PassThru = DAG.getSplatVector(MemVT, DL, Zero);
28182-
SDValue IncSplat = DAG.getSplatVector(MemVT, DL, Inc);
28190+
SDValue PassThru = DAG.getSplatVector(IncSplatVT, DL, Zero);
28191+
SDValue IncSplat = DAG.getSplatVector(
28192+
IncSplatVT, DL, DAG.getAnyExtOrTrunc(Inc, DL, IncExtVT));
2818328193
SDValue Ops[] = {Chain, PassThru, Mask, Ptr, Index, Scale};
2818428194

2818528195
MachineMemOperand *MMO = HG->getMemOperand();
@@ -28188,18 +28198,19 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
2818828198
MMO->getPointerInfo(), MachineMemOperand::MOLoad, MMO->getSize(),
2818928199
MMO->getAlign(), MMO->getAAInfo());
2819028200
ISD::MemIndexType IndexType = HG->getIndexType();
28191-
SDValue Gather =
28192-
DAG.getMaskedGather(DAG.getVTList(MemVT, MVT::Other), MemVT, DL, Ops,
28193-
GMMO, IndexType, ISD::NON_EXTLOAD);
28201+
SDValue Gather = DAG.getMaskedGather(
28202+
DAG.getVTList(IncSplatVT, MVT::Other), MemVT, DL, Ops, GMMO, IndexType,
28203+
ExtTrunc ? ISD::EXTLOAD : ISD::NON_EXTLOAD);
2819428204

2819528205
SDValue GChain = Gather.getValue(1);
2819628206

2819728207
// Perform the histcnt, multiply by inc, add to bucket data.
28198-
SDValue ID = DAG.getTargetConstant(Intrinsic::aarch64_sve_histcnt, DL, IncVT);
28208+
SDValue ID =
28209+
DAG.getTargetConstant(Intrinsic::aarch64_sve_histcnt, DL, IncExtVT);
2819928210
SDValue HistCnt =
2820028211
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, IndexVT, ID, Mask, Index, Index);
28201-
SDValue Mul = DAG.getNode(ISD::MUL, DL, MemVT, HistCnt, IncSplat);
28202-
SDValue Add = DAG.getNode(ISD::ADD, DL, MemVT, Gather, Mul);
28212+
SDValue Mul = DAG.getNode(ISD::MUL, DL, IncSplatVT, HistCnt, IncSplat);
28213+
SDValue Add = DAG.getNode(ISD::ADD, DL, IncSplatVT, Gather, Mul);
2820328214

2820428215
// Create an MMO for the scatter, without load|store flags.
2820528216
MachineMemOperand *SMMO = DAG.getMachineFunction().getMachineMemOperand(
@@ -28208,7 +28219,7 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
2820828219

2820928220
SDValue ScatterOps[] = {GChain, Add, Mask, Ptr, Index, Scale};
2821028221
SDValue Scatter = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MemVT, DL,
28211-
ScatterOps, SMMO, IndexType, false);
28222+
ScatterOps, SMMO, IndexType, ExtTrunc);
2821228223
return Scatter;
2821328224
}
2821428225

llvm/test/CodeGen/AArch64/sve2-histcnt.ll

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,123 @@ define void @histogram_i32_literal_noscale(ptr %base, <vscale x 4 x i32> %indice
5050
ret void
5151
}
5252

53+
define void @histogram_i32_promote(ptr %base, <vscale x 2 x i64> %indices, <vscale x 2 x i1> %mask, i32 %inc) #0 {
54+
; CHECK-LABEL: histogram_i32_promote:
55+
; CHECK: // %bb.0:
56+
; CHECK-NEXT: histcnt z1.d, p0/z, z0.d, z0.d
57+
; CHECK-NEXT: // kill: def $w1 killed $w1 def $x1
58+
; CHECK-NEXT: mov z3.d, x1
59+
; CHECK-NEXT: ld1w { z2.d }, p0/z, [x0, z0.d, lsl #2]
60+
; CHECK-NEXT: ptrue p1.d
61+
; CHECK-NEXT: mad z1.d, p1/m, z3.d, z2.d
62+
; CHECK-NEXT: st1w { z1.d }, p0, [x0, z0.d, lsl #2]
63+
; CHECK-NEXT: ret
64+
%buckets = getelementptr i32, ptr %base, <vscale x 2 x i64> %indices
65+
call void @llvm.experimental.vector.histogram.add.nxv2p0.i32(<vscale x 2 x ptr> %buckets, i32 %inc, <vscale x 2 x i1> %mask)
66+
ret void
67+
}
68+
69+
define void @histogram_i16(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask, i16 %inc) #0 {
70+
; CHECK-LABEL: histogram_i16:
71+
; CHECK: // %bb.0:
72+
; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s
73+
; CHECK-NEXT: mov z3.s, w1
74+
; CHECK-NEXT: ld1h { z2.s }, p0/z, [x0, z0.s, sxtw #1]
75+
; CHECK-NEXT: ptrue p1.s
76+
; CHECK-NEXT: mad z1.s, p1/m, z3.s, z2.s
77+
; CHECK-NEXT: st1h { z1.s }, p0, [x0, z0.s, sxtw #1]
78+
; CHECK-NEXT: ret
79+
%buckets = getelementptr i16, ptr %base, <vscale x 4 x i32> %indices
80+
call void @llvm.experimental.vector.histogram.add.nxv4p0.i16(<vscale x 4 x ptr> %buckets, i16 %inc, <vscale x 4 x i1> %mask)
81+
ret void
82+
}
83+
84+
define void @histogram_i8(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask, i8 %inc) #0 {
85+
; CHECK-LABEL: histogram_i8:
86+
; CHECK: // %bb.0:
87+
; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s
88+
; CHECK-NEXT: mov z3.s, w1
89+
; CHECK-NEXT: ld1b { z2.s }, p0/z, [x0, z0.s, sxtw]
90+
; CHECK-NEXT: ptrue p1.s
91+
; CHECK-NEXT: mad z1.s, p1/m, z3.s, z2.s
92+
; CHECK-NEXT: st1b { z1.s }, p0, [x0, z0.s, sxtw]
93+
; CHECK-NEXT: ret
94+
%buckets = getelementptr i8, ptr %base, <vscale x 4 x i32> %indices
95+
call void @llvm.experimental.vector.histogram.add.nxv4p0.i8(<vscale x 4 x ptr> %buckets, i8 %inc, <vscale x 4 x i1> %mask)
96+
ret void
97+
}
98+
99+
define void @histogram_i16_2_lane(ptr %base, <vscale x 2 x i64> %indices, <vscale x 2 x i1> %mask, i16 %inc) #0 {
100+
; CHECK-LABEL: histogram_i16_2_lane:
101+
; CHECK: // %bb.0:
102+
; CHECK-NEXT: histcnt z1.d, p0/z, z0.d, z0.d
103+
; CHECK-NEXT: // kill: def $w1 killed $w1 def $x1
104+
; CHECK-NEXT: mov z3.d, x1
105+
; CHECK-NEXT: ld1h { z2.d }, p0/z, [x0, z0.d, lsl #1]
106+
; CHECK-NEXT: ptrue p1.d
107+
; CHECK-NEXT: mad z1.d, p1/m, z3.d, z2.d
108+
; CHECK-NEXT: st1h { z1.d }, p0, [x0, z0.d, lsl #1]
109+
; CHECK-NEXT: ret
110+
%buckets = getelementptr i16, ptr %base, <vscale x 2 x i64> %indices
111+
call void @llvm.experimental.vector.histogram.add.nxv2p0.i16(<vscale x 2 x ptr> %buckets, i16 %inc, <vscale x 2 x i1> %mask)
112+
ret void
113+
}
114+
115+
define void @histogram_i8_2_lane(ptr %base, <vscale x 2 x i64> %indices, <vscale x 2 x i1> %mask, i8 %inc) #0 {
116+
; CHECK-LABEL: histogram_i8_2_lane:
117+
; CHECK: // %bb.0:
118+
; CHECK-NEXT: histcnt z1.d, p0/z, z0.d, z0.d
119+
; CHECK-NEXT: // kill: def $w1 killed $w1 def $x1
120+
; CHECK-NEXT: mov z3.d, x1
121+
; CHECK-NEXT: ld1b { z2.d }, p0/z, [x0, z0.d]
122+
; CHECK-NEXT: ptrue p1.d
123+
; CHECK-NEXT: mad z1.d, p1/m, z3.d, z2.d
124+
; CHECK-NEXT: st1b { z1.d }, p0, [x0, z0.d]
125+
; CHECK-NEXT: ret
126+
%buckets = getelementptr i8, ptr %base, <vscale x 2 x i64> %indices
127+
call void @llvm.experimental.vector.histogram.add.nxv2p0.i8(<vscale x 2 x ptr> %buckets, i8 %inc, <vscale x 2 x i1> %mask)
128+
ret void
129+
}
130+
131+
define void @histogram_i16_literal_1(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0 {
132+
; CHECK-LABEL: histogram_i16_literal_1:
133+
; CHECK: // %bb.0:
134+
; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s
135+
; CHECK-NEXT: ld1h { z2.s }, p0/z, [x0, z0.s, sxtw #1]
136+
; CHECK-NEXT: add z1.s, z2.s, z1.s
137+
; CHECK-NEXT: st1h { z1.s }, p0, [x0, z0.s, sxtw #1]
138+
; CHECK-NEXT: ret
139+
%buckets = getelementptr i16, ptr %base, <vscale x 4 x i32> %indices
140+
call void @llvm.experimental.vector.histogram.add.nxv4p0.i16(<vscale x 4 x ptr> %buckets, i16 1, <vscale x 4 x i1> %mask)
141+
ret void
142+
}
143+
144+
define void @histogram_i16_literal_2(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0 {
145+
; CHECK-LABEL: histogram_i16_literal_2:
146+
; CHECK: // %bb.0:
147+
; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s
148+
; CHECK-NEXT: ld1h { z2.s }, p0/z, [x0, z0.s, sxtw #1]
149+
; CHECK-NEXT: adr z1.s, [z2.s, z1.s, lsl #1]
150+
; CHECK-NEXT: st1h { z1.s }, p0, [x0, z0.s, sxtw #1]
151+
; CHECK-NEXT: ret
152+
%buckets = getelementptr i16, ptr %base, <vscale x 4 x i32> %indices
153+
call void @llvm.experimental.vector.histogram.add.nxv4p0.i16(<vscale x 4 x ptr> %buckets, i16 2, <vscale x 4 x i1> %mask)
154+
ret void
155+
}
156+
157+
define void @histogram_i16_literal_3(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0 {
158+
; CHECK-LABEL: histogram_i16_literal_3:
159+
; CHECK: // %bb.0:
160+
; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s
161+
; CHECK-NEXT: mov z3.s, #3 // =0x3
162+
; CHECK-NEXT: ld1h { z2.s }, p0/z, [x0, z0.s, sxtw #1]
163+
; CHECK-NEXT: ptrue p1.s
164+
; CHECK-NEXT: mad z1.s, p1/m, z3.s, z2.s
165+
; CHECK-NEXT: st1h { z1.s }, p0, [x0, z0.s, sxtw #1]
166+
; CHECK-NEXT: ret
167+
%buckets = getelementptr i16, ptr %base, <vscale x 4 x i32> %indices
168+
call void @llvm.experimental.vector.histogram.add.nxv4p0.i16(<vscale x 4 x ptr> %buckets, i16 3, <vscale x 4 x i1> %mask)
169+
ret void
170+
}
171+
53172
attributes #0 = { "target-features"="+sve2" vscale_range(1, 16) }

0 commit comments

Comments
 (0)