Skip to content

Commit 7868964

Browse files
committed
Rebase and update tests
1 parent 7a62406 commit 7868964

File tree

2 files changed

+101
-35
lines changed

2 files changed

+101
-35
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1872,14 +1872,15 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
18721872
// 8to64
18731873
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv16i8, Custom);
18741874

1875-
// USDOT
1876-
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i64, Custom);
1877-
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i32, Custom);
1878-
1879-
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv4i64, Custom);
1880-
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv8i32, Custom);
1881-
setPartialReduceMLAAction(MVT::nxv8i16, MVT::nxv16i16, Custom);
1882-
setPartialReduceMLAAction(MVT::nxv16i8, MVT::nxv32i8, Custom);
1875+
if (Subtarget->hasMatMulInt8()) {
1876+
// USDOT
1877+
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i64, Custom);
1878+
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i32, Custom);
1879+
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv4i64, Custom);
1880+
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv8i32, Custom);
1881+
setPartialReduceMLAAction(MVT::nxv8i16, MVT::nxv16i16, Custom);
1882+
setPartialReduceMLAAction(MVT::nxv16i8, MVT::nxv32i8, Custom);
1883+
}
18831884
}
18841885

18851886
// Handle operations that are only available in non-streaming SVE mode.
@@ -29495,21 +29496,24 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
2949529496
return Scatter;
2949629497
}
2949729498

29498-
/// If a PARTIAL_REDUCE_MLA node comes in with an accumulator-input type pairing
29499-
/// of nxv2i64/nxv16i8, we cannot directly lower it to a (u|s)dot. We can
29500-
/// however still make use of the dot product instruction by instead
29501-
/// accumulating over two steps: nxv16i8 -> nxv4i32 -> nxv2i64.
2950229499
SDValue
2950329500
AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
2950429501
SelectionDAG &DAG) const {
29505-
SDLoc DL(Op);
29502+
if (SDValue UsdotNode = LowerPARTIAL_REDUCE_MLAToUSDOT(Op, DAG))
29503+
return UsdotNode;
2950629504

29507-
SDValue Acc = Op.getOperand(0);
2950829505
SDValue LHS = Op.getOperand(1);
29509-
SDValue RHS = Op.getOperand(2);
2951029506
EVT ResultVT = Op.getValueType();
29511-
assert(ResultVT == MVT::nxv2i64 && LHS.getValueType() == MVT::nxv16i8);
29507+
/// If a PARTIAL_REDUCE_MLA node comes in with an accumulator-input type
29508+
/// pairing of nxv2i64/nxv16i8, we cannot directly lower it to a (u|s)dot. We
29509+
/// can however still make use of the dot product instruction by instead
29510+
/// accumulating over two steps: nxv16i8 -> nxv4i32 -> nxv2i64.
29511+
if (ResultVT != MVT::nxv2i64 || LHS.getValueType() != MVT::nxv16i8)
29512+
return SDValue();
2951229513

29514+
SDLoc DL(Op);
29515+
SDValue Acc = Op.getOperand(0);
29516+
SDValue RHS = Op.getOperand(2);
2951329517
SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, MVT::nxv4i32,
2951429518
DAG.getConstant(0, DL, MVT::nxv4i32), LHS, RHS);
2951529519

@@ -29529,13 +29533,13 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
2952929533
return DAG.getNode(ISD::ADD, DL, ResultVT, Acc, Extended);
2953029534
}
2953129535

29532-
// Lower PARTIAL_REDUCE_*MLA(Acc, MUL(ZEXT(MulOpLHS), SEXT(MulOpRHS)), Splat 1)
29533-
// to USDOT(Acc, MulOpLHS, MulOpRHS)
29534-
// Lower PARTIAL_REDUCE_*MLA(Acc, MUL(SEXT(MulOpLHS), ZEXT(MulOpRHS)), Splat 1)
29535-
// to USDOT(Acc, MulOpRHS, MulOpLHS)
29536+
// partial.reduce.umla(acc, mul(zext(mulOpLHS), sext(mulOpRHS)), splat(1))
29537+
// -> USDOT(acc, mulOpLHS, mulOpRHS)
29538+
// partial.reduce.smla(acc, mul(sext(mulOpLHS), zext(mulOpRHS)), splat(1))
29539+
// -> USDOT(acc, mulOpRHS, mulOpLHS)
2953629540
SDValue
2953729541
AArch64TargetLowering::LowerPARTIAL_REDUCE_MLAToUSDOT(SDValue Op,
29538-
SelectionDAG &DAG) const {
29542+
SelectionDAG &DAG) const {
2953929543
bool Scalable = Op.getValueType().isScalableVector();
2954029544
auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
2954129545
if (Scalable && !Subtarget.isSVEorStreamingSVEAvailable())
@@ -29591,7 +29595,7 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLAToUSDOT(SDValue Op,
2959129595
// Don't want this to be split because there is no nxv2i64 version of usdot
2959229596
if ((AccVT == MVT::nxv4i64 && MulOpLHSVT == MVT::nxv16i8) ||
2959329597
(AccVT == MVT::v4i64 && MulOpLHSVT == MVT::v16i8)) {
29594-
EVT AccVTI32 = (AccVT.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
29598+
EVT AccVTI32 = AccVT.isScalableVector() ? MVT::nxv4i32 : MVT::v4i32;
2959529599

2959629600
SDValue DotI32 =
2959729601
DAG.getNode(Opcode, DL, AccVTI32, DAG.getConstant(0, DL, AccVTI32),

llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll

Lines changed: 75 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM
44
; RUN: llc -mtriple=aarch64 -mattr=+sve,+i8mm -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SVE
55
; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SVE2
6-
; RUN: llc -mtriple=aarch64 -mattr=+sme -force-streaming -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SME
6+
; RUN: llc -mtriple=aarch64 -mattr=+sve,+sme,+i8mm -force-streaming -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SME
77

88
define <vscale x 4 x i32> @udot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
99
; CHECK-LABEL: udot:
@@ -299,12 +299,43 @@ define <vscale x 4 x i64> @usdot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i
299299
;
300300
; CHECK-NEWLOWERING-LABEL: usdot_8to64:
301301
; CHECK-NEWLOWERING: // %bb.0: // %entry
302-
; CHECK-NEWLOWERING-NEXT: mov z4.s, #0 // =0x0
303-
; CHECK-NEWLOWERING-NEXT: usdot z4.s, z2.b, z3.b
304-
; CHECK-NEWLOWERING-NEXT: sunpklo z2.d, z4.s
305-
; CHECK-NEWLOWERING-NEXT: sunpkhi z3.d, z4.s
306-
; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z2.d
307-
; CHECK-NEWLOWERING-NEXT: add z1.d, z1.d, z3.d
302+
; CHECK-NEWLOWERING-NEXT: uunpkhi z4.h, z2.b
303+
; CHECK-NEWLOWERING-NEXT: uunpklo z2.h, z2.b
304+
; CHECK-NEWLOWERING-NEXT: sunpkhi z5.h, z3.b
305+
; CHECK-NEWLOWERING-NEXT: sunpklo z3.h, z3.b
306+
; CHECK-NEWLOWERING-NEXT: ptrue p0.d
307+
; CHECK-NEWLOWERING-NEXT: uunpklo z6.s, z4.h
308+
; CHECK-NEWLOWERING-NEXT: uunpklo z7.s, z2.h
309+
; CHECK-NEWLOWERING-NEXT: sunpklo z24.s, z5.h
310+
; CHECK-NEWLOWERING-NEXT: sunpklo z25.s, z3.h
311+
; CHECK-NEWLOWERING-NEXT: uunpkhi z4.s, z4.h
312+
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
313+
; CHECK-NEWLOWERING-NEXT: sunpkhi z5.s, z5.h
314+
; CHECK-NEWLOWERING-NEXT: sunpkhi z3.s, z3.h
315+
; CHECK-NEWLOWERING-NEXT: uunpklo z26.d, z6.s
316+
; CHECK-NEWLOWERING-NEXT: uunpklo z27.d, z7.s
317+
; CHECK-NEWLOWERING-NEXT: sunpklo z28.d, z24.s
318+
; CHECK-NEWLOWERING-NEXT: sunpklo z29.d, z25.s
319+
; CHECK-NEWLOWERING-NEXT: uunpkhi z6.d, z6.s
320+
; CHECK-NEWLOWERING-NEXT: uunpkhi z7.d, z7.s
321+
; CHECK-NEWLOWERING-NEXT: sunpkhi z24.d, z24.s
322+
; CHECK-NEWLOWERING-NEXT: sunpkhi z25.d, z25.s
323+
; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z26.d, z28.d
324+
; CHECK-NEWLOWERING-NEXT: uunpklo z26.d, z4.s
325+
; CHECK-NEWLOWERING-NEXT: sunpklo z28.d, z5.s
326+
; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z27.d, z29.d
327+
; CHECK-NEWLOWERING-NEXT: uunpklo z27.d, z2.s
328+
; CHECK-NEWLOWERING-NEXT: sunpklo z29.d, z3.s
329+
; CHECK-NEWLOWERING-NEXT: uunpkhi z4.d, z4.s
330+
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.d, z2.s
331+
; CHECK-NEWLOWERING-NEXT: sunpkhi z5.d, z5.s
332+
; CHECK-NEWLOWERING-NEXT: sunpkhi z3.d, z3.s
333+
; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z6.d, z24.d
334+
; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z7.d, z25.d
335+
; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z26.d, z28.d
336+
; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z27.d, z29.d
337+
; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z4.d, z5.d
338+
; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z2.d, z3.d
308339
; CHECK-NEWLOWERING-NEXT: ret
309340
entry:
310341
%a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
@@ -369,12 +400,43 @@ define <vscale x 4 x i64> @sudot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i
369400
;
370401
; CHECK-NEWLOWERING-LABEL: sudot_8to64:
371402
; CHECK-NEWLOWERING: // %bb.0: // %entry
372-
; CHECK-NEWLOWERING-NEXT: mov z4.s, #0 // =0x0
373-
; CHECK-NEWLOWERING-NEXT: usdot z4.s, z3.b, z2.b
374-
; CHECK-NEWLOWERING-NEXT: sunpklo z2.d, z4.s
375-
; CHECK-NEWLOWERING-NEXT: sunpkhi z3.d, z4.s
376-
; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z2.d
377-
; CHECK-NEWLOWERING-NEXT: add z1.d, z1.d, z3.d
403+
; CHECK-NEWLOWERING-NEXT: sunpkhi z4.h, z2.b
404+
; CHECK-NEWLOWERING-NEXT: sunpklo z2.h, z2.b
405+
; CHECK-NEWLOWERING-NEXT: uunpkhi z5.h, z3.b
406+
; CHECK-NEWLOWERING-NEXT: uunpklo z3.h, z3.b
407+
; CHECK-NEWLOWERING-NEXT: ptrue p0.d
408+
; CHECK-NEWLOWERING-NEXT: sunpklo z6.s, z4.h
409+
; CHECK-NEWLOWERING-NEXT: sunpklo z7.s, z2.h
410+
; CHECK-NEWLOWERING-NEXT: uunpklo z24.s, z5.h
411+
; CHECK-NEWLOWERING-NEXT: uunpklo z25.s, z3.h
412+
; CHECK-NEWLOWERING-NEXT: sunpkhi z4.s, z4.h
413+
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
414+
; CHECK-NEWLOWERING-NEXT: uunpkhi z5.s, z5.h
415+
; CHECK-NEWLOWERING-NEXT: uunpkhi z3.s, z3.h
416+
; CHECK-NEWLOWERING-NEXT: sunpklo z26.d, z6.s
417+
; CHECK-NEWLOWERING-NEXT: sunpklo z27.d, z7.s
418+
; CHECK-NEWLOWERING-NEXT: uunpklo z28.d, z24.s
419+
; CHECK-NEWLOWERING-NEXT: uunpklo z29.d, z25.s
420+
; CHECK-NEWLOWERING-NEXT: sunpkhi z6.d, z6.s
421+
; CHECK-NEWLOWERING-NEXT: sunpkhi z7.d, z7.s
422+
; CHECK-NEWLOWERING-NEXT: uunpkhi z24.d, z24.s
423+
; CHECK-NEWLOWERING-NEXT: uunpkhi z25.d, z25.s
424+
; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z26.d, z28.d
425+
; CHECK-NEWLOWERING-NEXT: sunpklo z26.d, z4.s
426+
; CHECK-NEWLOWERING-NEXT: uunpklo z28.d, z5.s
427+
; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z27.d, z29.d
428+
; CHECK-NEWLOWERING-NEXT: sunpklo z27.d, z2.s
429+
; CHECK-NEWLOWERING-NEXT: uunpklo z29.d, z3.s
430+
; CHECK-NEWLOWERING-NEXT: sunpkhi z4.d, z4.s
431+
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.d, z2.s
432+
; CHECK-NEWLOWERING-NEXT: uunpkhi z5.d, z5.s
433+
; CHECK-NEWLOWERING-NEXT: uunpkhi z3.d, z3.s
434+
; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z6.d, z24.d
435+
; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z7.d, z25.d
436+
; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z26.d, z28.d
437+
; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z27.d, z29.d
438+
; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z4.d, z5.d
439+
; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z2.d, z3.d
378440
; CHECK-NEWLOWERING-NEXT: ret
379441
entry:
380442
%a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>

0 commit comments

Comments
 (0)