Skip to content

Commit 44cfbef

Browse files
authored
[AArch64] Lower partial add reduction to udot or svdot (#101010)
This patch introduces lowering of the partial add reduction intrinsic to a udot or svdot for AArch64. This also involves adding a `shouldExpandPartialReductionIntrinsic` target hook, which AArch64 will return false from in the cases that it can be lowered.
1 parent df3d70b commit 44cfbef

File tree

7 files changed

+217
-25
lines changed

7 files changed

+217
-25
lines changed

llvm/include/llvm/CodeGen/SelectionDAG.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1594,6 +1594,11 @@ class SelectionDAG {
15941594
/// the target's desired shift amount type.
15951595
SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op);
15961596

1597+
/// Create the DAG equivalent of vector_partial_reduce where Op1 and Op2 are
1598+
/// its operands and ReducedTY is the intrinsic's return type.
1599+
SDValue getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
1600+
SDValue Op2);
1601+
15971602
/// Expand the specified \c ISD::VAARG node as the Legalize pass would.
15981603
SDValue expandVAArg(SDNode *Node);
15991604

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,13 @@ class TargetLoweringBase {
453453
return true;
454454
}
455455

456+
/// Return true if the @llvm.experimental.vector.partial.reduce.* intrinsic
457+
/// should be expanded using generic code in SelectionDAGBuilder.
458+
virtual bool
459+
shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const {
460+
return true;
461+
}
462+
456463
/// Return true if the @llvm.get.active.lane.mask intrinsic should be expanded
457464
/// using generic code in SelectionDAGBuilder.
458465
virtual bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const {

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
#include <cassert>
7575
#include <cstdint>
7676
#include <cstdlib>
77+
#include <deque>
7778
#include <limits>
7879
#include <optional>
7980
#include <set>
@@ -2439,6 +2440,35 @@ SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) {
24392440
return getZExtOrTrunc(Op, SDLoc(Op), ShTy);
24402441
}
24412442

2443+
SDValue SelectionDAG::getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
2444+
SDValue Op2) {
2445+
EVT FullTy = Op2.getValueType();
2446+
2447+
unsigned Stride = ReducedTy.getVectorMinNumElements();
2448+
unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride;
2449+
2450+
// Collect all of the subvectors
2451+
std::deque<SDValue> Subvectors = {Op1};
2452+
for (unsigned I = 0; I < ScaleFactor; I++) {
2453+
auto SourceIndex = getVectorIdxConstant(I * Stride, DL);
2454+
Subvectors.push_back(
2455+
getNode(ISD::EXTRACT_SUBVECTOR, DL, ReducedTy, {Op2, SourceIndex}));
2456+
}
2457+
2458+
// Flatten the subvector tree
2459+
while (Subvectors.size() > 1) {
2460+
Subvectors.push_back(
2461+
getNode(ISD::ADD, DL, ReducedTy, {Subvectors[0], Subvectors[1]}));
2462+
Subvectors.pop_front();
2463+
Subvectors.pop_front();
2464+
}
2465+
2466+
assert(Subvectors.size() == 1 &&
2467+
"There should only be one subvector after tree flattening");
2468+
2469+
return Subvectors[0];
2470+
}
2471+
24422472
SDValue SelectionDAG::expandVAArg(SDNode *Node) {
24432473
SDLoc dl(Node);
24442474
const TargetLowering &TLI = getTargetLoweringInfo();

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8038,34 +8038,15 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
80388038
return;
80398039
}
80408040
case Intrinsic::experimental_vector_partial_reduce_add: {
8041-
SDValue OpNode = getValue(I.getOperand(1));
8042-
EVT ReducedTy = EVT::getEVT(I.getType());
8043-
EVT FullTy = OpNode.getValueType();
80448041

8045-
unsigned Stride = ReducedTy.getVectorMinNumElements();
8046-
unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride;
8047-
8048-
// Collect all of the subvectors
8049-
std::deque<SDValue> Subvectors;
8050-
Subvectors.push_back(getValue(I.getOperand(0)));
8051-
for (unsigned i = 0; i < ScaleFactor; i++) {
8052-
auto SourceIndex = DAG.getVectorIdxConstant(i * Stride, sdl);
8053-
Subvectors.push_back(DAG.getNode(ISD::EXTRACT_SUBVECTOR, sdl, ReducedTy,
8054-
{OpNode, SourceIndex}));
8055-
}
8056-
8057-
// Flatten the subvector tree
8058-
while (Subvectors.size() > 1) {
8059-
Subvectors.push_back(DAG.getNode(ISD::ADD, sdl, ReducedTy,
8060-
{Subvectors[0], Subvectors[1]}));
8061-
Subvectors.pop_front();
8062-
Subvectors.pop_front();
8042+
if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) {
8043+
visitTargetIntrinsic(I, Intrinsic);
8044+
return;
80638045
}
80648046

8065-
assert(Subvectors.size() == 1 &&
8066-
"There should only be one subvector after tree flattening");
8067-
8068-
setValue(&I, Subvectors[0]);
8047+
setValue(&I, DAG.getPartialReduceAdd(sdl, EVT::getEVT(I.getType()),
8048+
getValue(I.getOperand(0)),
8049+
getValue(I.getOperand(1))));
80698050
return;
80708051
}
80718052
case Intrinsic::experimental_cttz_elts: {

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1988,6 +1988,15 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
19881988
return false;
19891989
}
19901990

1991+
bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
1992+
const IntrinsicInst *I) const {
1993+
if (I->getIntrinsicID() != Intrinsic::experimental_vector_partial_reduce_add)
1994+
return true;
1995+
1996+
EVT VT = EVT::getEVT(I->getType());
1997+
return VT != MVT::nxv4i32 && VT != MVT::nxv2i64;
1998+
}
1999+
19912000
bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
19922001
if (!Subtarget->isSVEorStreamingSVEAvailable())
19932002
return true;
@@ -21763,6 +21772,61 @@ static SDValue tryCombineWhileLo(SDNode *N,
2176321772
return SDValue(N, 0);
2176421773
}
2176521774

21775+
SDValue tryLowerPartialReductionToDot(SDNode *N,
21776+
const AArch64Subtarget *Subtarget,
21777+
SelectionDAG &DAG) {
21778+
21779+
assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
21780+
getIntrinsicID(N) ==
21781+
Intrinsic::experimental_vector_partial_reduce_add &&
21782+
"Expected a partial reduction node");
21783+
21784+
if (!Subtarget->isSVEorStreamingSVEAvailable())
21785+
return SDValue();
21786+
21787+
SDLoc DL(N);
21788+
21789+
// The narrower of the two operands. Used as the accumulator
21790+
auto NarrowOp = N->getOperand(1);
21791+
auto MulOp = N->getOperand(2);
21792+
if (MulOp->getOpcode() != ISD::MUL)
21793+
return SDValue();
21794+
21795+
auto ExtA = MulOp->getOperand(0);
21796+
auto ExtB = MulOp->getOperand(1);
21797+
bool IsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND;
21798+
bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND;
21799+
if (ExtA->getOpcode() != ExtB->getOpcode() || (!IsSExt && !IsZExt))
21800+
return SDValue();
21801+
21802+
auto A = ExtA->getOperand(0);
21803+
auto B = ExtB->getOperand(0);
21804+
if (A.getValueType() != B.getValueType())
21805+
return SDValue();
21806+
21807+
unsigned Opcode = 0;
21808+
21809+
if (IsSExt)
21810+
Opcode = AArch64ISD::SDOT;
21811+
else if (IsZExt)
21812+
Opcode = AArch64ISD::UDOT;
21813+
21814+
assert(Opcode != 0 && "Unexpected dot product case encountered.");
21815+
21816+
EVT ReducedType = N->getValueType(0);
21817+
EVT MulSrcType = A.getValueType();
21818+
21819+
// Dot products operate on chunks of four elements so there must be four times
21820+
// as many elements in the wide type
21821+
if (ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8)
21822+
return DAG.getNode(Opcode, DL, MVT::nxv4i32, NarrowOp, A, B);
21823+
21824+
if (ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16)
21825+
return DAG.getNode(Opcode, DL, MVT::nxv2i64, NarrowOp, A, B);
21826+
21827+
return SDValue();
21828+
}
21829+
2176621830
static SDValue performIntrinsicCombine(SDNode *N,
2176721831
TargetLowering::DAGCombinerInfo &DCI,
2176821832
const AArch64Subtarget *Subtarget) {
@@ -21771,6 +21835,12 @@ static SDValue performIntrinsicCombine(SDNode *N,
2177121835
switch (IID) {
2177221836
default:
2177321837
break;
21838+
case Intrinsic::experimental_vector_partial_reduce_add: {
21839+
if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
21840+
return Dot;
21841+
return DAG.getPartialReduceAdd(SDLoc(N), N->getValueType(0),
21842+
N->getOperand(1), N->getOperand(2));
21843+
}
2177421844
case Intrinsic::aarch64_neon_vcvtfxs2fp:
2177521845
case Intrinsic::aarch64_neon_vcvtfxu2fp:
2177621846
return tryCombineFixedPointConvert(N, DCI, DAG);

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -998,6 +998,9 @@ class AArch64TargetLowering : public TargetLowering {
998998

999999
bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const override;
10001000

1001+
bool
1002+
shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const override;
1003+
10011004
bool shouldExpandCttzElements(EVT VT) const override;
10021005

10031006
/// If a change in streaming mode is required on entry to/return from a
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s
3+
4+
define <vscale x 4 x i32> @dotp(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
5+
; CHECK-LABEL: dotp:
6+
; CHECK: // %bb.0: // %entry
7+
; CHECK-NEXT: udot z0.s, z1.b, z2.b
8+
; CHECK-NEXT: ret
9+
entry:
10+
%a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
11+
%b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
12+
%mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
13+
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
14+
ret <vscale x 4 x i32> %partial.reduce
15+
}
16+
17+
define <vscale x 2 x i64> @dotp_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
18+
; CHECK-LABEL: dotp_wide:
19+
; CHECK: // %bb.0: // %entry
20+
; CHECK-NEXT: udot z0.d, z1.h, z2.h
21+
; CHECK-NEXT: ret
22+
entry:
23+
%a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
24+
%b.wide = zext <vscale x 8 x i16> %b to <vscale x 8 x i64>
25+
%mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
26+
%partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %mult)
27+
ret <vscale x 2 x i64> %partial.reduce
28+
}
29+
30+
define <vscale x 4 x i32> @dotp_sext(<vscale x 4 x i32> %accc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
31+
; CHECK-LABEL: dotp_sext:
32+
; CHECK: // %bb.0: // %entry
33+
; CHECK-NEXT: sdot z0.s, z1.b, z2.b
34+
; CHECK-NEXT: ret
35+
entry:
36+
%a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
37+
%b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
38+
%mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
39+
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %accc, <vscale x 16 x i32> %mult)
40+
ret <vscale x 4 x i32> %partial.reduce
41+
}
42+
43+
define <vscale x 2 x i64> @dotp_wide_sext(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
44+
; CHECK-LABEL: dotp_wide_sext:
45+
; CHECK: // %bb.0: // %entry
46+
; CHECK-NEXT: sdot z0.d, z1.h, z2.h
47+
; CHECK-NEXT: ret
48+
entry:
49+
%a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
50+
%b.wide = sext <vscale x 8 x i16> %b to <vscale x 8 x i64>
51+
%mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
52+
%partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %mult)
53+
ret <vscale x 2 x i64> %partial.reduce
54+
}
55+
56+
define <vscale x 4 x i32> @not_dotp(<vscale x 4 x i32> %acc, <vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
57+
; CHECK-LABEL: not_dotp:
58+
; CHECK: // %bb.0: // %entry
59+
; CHECK-NEXT: and z1.h, z1.h, #0xff
60+
; CHECK-NEXT: and z2.h, z2.h, #0xff
61+
; CHECK-NEXT: ptrue p0.s
62+
; CHECK-NEXT: uunpklo z3.s, z1.h
63+
; CHECK-NEXT: uunpklo z4.s, z2.h
64+
; CHECK-NEXT: uunpkhi z1.s, z1.h
65+
; CHECK-NEXT: uunpkhi z2.s, z2.h
66+
; CHECK-NEXT: mla z0.s, p0/m, z3.s, z4.s
67+
; CHECK-NEXT: mla z0.s, p0/m, z1.s, z2.s
68+
; CHECK-NEXT: ret
69+
entry:
70+
%a.wide = zext <vscale x 8 x i8> %a to <vscale x 8 x i32>
71+
%b.wide = zext <vscale x 8 x i8> %b to <vscale x 8 x i32>
72+
%mult = mul nuw nsw <vscale x 8 x i32> %a.wide, %b.wide
73+
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %mult)
74+
ret <vscale x 4 x i32> %partial.reduce
75+
}
76+
77+
define <vscale x 2 x i64> @not_dotp_wide(<vscale x 2 x i64> %acc, <vscale x 4 x i16> %a, <vscale x 4 x i16> %b) {
78+
; CHECK-LABEL: not_dotp_wide:
79+
; CHECK: // %bb.0: // %entry
80+
; CHECK-NEXT: and z1.s, z1.s, #0xffff
81+
; CHECK-NEXT: and z2.s, z2.s, #0xffff
82+
; CHECK-NEXT: ptrue p0.d
83+
; CHECK-NEXT: uunpklo z3.d, z1.s
84+
; CHECK-NEXT: uunpklo z4.d, z2.s
85+
; CHECK-NEXT: uunpkhi z1.d, z1.s
86+
; CHECK-NEXT: uunpkhi z2.d, z2.s
87+
; CHECK-NEXT: mla z0.d, p0/m, z3.d, z4.d
88+
; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d
89+
; CHECK-NEXT: ret
90+
entry:
91+
%a.wide = zext <vscale x 4 x i16> %a to <vscale x 4 x i64>
92+
%b.wide = zext <vscale x 4 x i16> %b to <vscale x 4 x i64>
93+
%mult = mul nuw nsw <vscale x 4 x i64> %a.wide, %b.wide
94+
%partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> %acc, <vscale x 4 x i64> %mult)
95+
ret <vscale x 2 x i64> %partial.reduce
96+
}

0 commit comments

Comments
 (0)