-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[AArch64] Extend vecreduce to udot/sdot transformation to support usdot #120094
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AArch64] Extend vecreduce to udot/sdot transformation to support usdot #120094
Conversation
@llvm/pr-subscribers-backend-aarch64 Author: Igor Kirillov (igogo-x86) ChangesPatch is 226.53 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/120094.diff 2 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index c19265613c706d..3ef6ef356465d3 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -18283,16 +18283,38 @@ static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG,
unsigned ExtOpcode = Op0.getOpcode();
SDValue A = Op0;
SDValue B;
+ unsigned DotOpcode;
if (ExtOpcode == ISD::MUL) {
A = Op0.getOperand(0);
B = Op0.getOperand(1);
- if (A.getOpcode() != B.getOpcode() ||
- A.getOperand(0).getValueType() != B.getOperand(0).getValueType())
+ if (A.getOperand(0).getValueType() != B.getOperand(0).getValueType())
return SDValue();
- ExtOpcode = A.getOpcode();
- }
- if (ExtOpcode != ISD::ZERO_EXTEND && ExtOpcode != ISD::SIGN_EXTEND)
+ auto OpCodeA = A.getOpcode();
+ if (OpCodeA != ISD::ZERO_EXTEND && OpCodeA != ISD::SIGN_EXTEND)
+ return SDValue();
+
+ auto OpCodeB = B.getOpcode();
+ if (OpCodeB != ISD::ZERO_EXTEND && OpCodeB != ISD::SIGN_EXTEND)
+ return SDValue();
+
+ if (OpCodeA == OpCodeB) {
+ DotOpcode =
+ OpCodeA == ISD::ZERO_EXTEND ? AArch64ISD::UDOT : AArch64ISD::SDOT;
+ } else {
+ // Check USDOT support support
+ if (!ST->hasMatMulInt8())
+ return SDValue();
+ DotOpcode = AArch64ISD::USDOT;
+ if (OpCodeA == ISD::SIGN_EXTEND)
+ std::swap(A, B);
+ }
+ } else if (ExtOpcode == ISD::ZERO_EXTEND) {
+ DotOpcode = AArch64ISD::UDOT;
+ } else if (ExtOpcode == ISD::SIGN_EXTEND) {
+ DotOpcode = AArch64ISD::SDOT;
+ } else {
return SDValue();
+ }
EVT Op0VT = A.getOperand(0).getValueType();
bool IsValidElementCount = Op0VT.getVectorNumElements() % 8 == 0;
@@ -18318,8 +18340,6 @@ static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG,
NumOfVecReduce = Op0VT.getVectorNumElements() / 8;
TargetType = MVT::v2i32;
}
- auto DotOpcode =
- (ExtOpcode == ISD::ZERO_EXTEND) ? AArch64ISD::UDOT : AArch64ISD::SDOT;
// Handle the case where we need to generate only one Dot operation.
if (NumOfVecReduce == 1) {
SDValue Zeros = DAG.getConstant(0, DL, TargetType);
diff --git a/llvm/test/CodeGen/AArch64/neon-dotreduce.ll b/llvm/test/CodeGen/AArch64/neon-dotreduce.ll
index c345c1e50bbbb7..05ac2956da00c7 100644
--- a/llvm/test/CodeGen/AArch64/neon-dotreduce.ll
+++ b/llvm/test/CodeGen/AArch64/neon-dotreduce.ll
@@ -1,22 +1,28 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
-; RUN: llc -mtriple aarch64-linux-gnu -mattr=+dotprod < %s | FileCheck %s --check-prefixes=CHECK,CHECK-SD
-; RUN: llc -mtriple aarch64-linux-gnu -mattr=+dotprod -global-isel -global-isel-abort=2 < %s 2>&1 | FileCheck %s --check-prefixes=CHECK,CHECK-GI
+; RUN: llc -mtriple aarch64-linux-gnu -mattr=+dotprod,+i8mm < %s | FileCheck %s --check-prefixes=CHECK,CHECK-SD
+; RUN: llc -mtriple aarch64-linux-gnu -mattr=+dotprod,+i8mm -global-isel -global-isel-abort=2 < %s 2>&1 | FileCheck %s --check-prefixes=CHECK,CHECK-GI
; CHECK-GI: warning: Instruction selection used fallback path for test_udot_v5i8
; CHECK-GI-NEXT: warning: Instruction selection used fallback path for test_udot_v5i8_nomla
; CHECK-GI-NEXT: warning: Instruction selection used fallback path for test_sdot_v5i8
; CHECK-GI-NEXT: warning: Instruction selection used fallback path for test_sdot_v5i8_double
; CHECK-GI-NEXT: warning: Instruction selection used fallback path for test_sdot_v5i8_double_nomla
+; CHECK-GI-NEXT: warning: Instruction selection used fallback path for test_usdot_v5i8
+; CHECK-GI-NEXT: warning: Instruction selection used fallback path for test_usdot_v5i8_double
; CHECK-GI-NEXT: warning: Instruction selection used fallback path for test_udot_v25i8
; CHECK-GI-NEXT: warning: Instruction selection used fallback path for test_udot_v25i8_nomla
; CHECK-GI-NEXT: warning: Instruction selection used fallback path for test_sdot_v25i8
; CHECK-GI-NEXT: warning: Instruction selection used fallback path for test_sdot_v25i8_double
; CHECK-GI-NEXT: warning: Instruction selection used fallback path for test_sdot_v25i8_double_nomla
+; CHECK-GI-NEXT: warning: Instruction selection used fallback path for test_usdot_v25i8
+; CHECK-GI-NEXT: warning: Instruction selection used fallback path for test_usdot_v25i8_double
; CHECK-GI-NEXT: warning: Instruction selection used fallback path for test_udot_v33i8
; CHECK-GI-NEXT: warning: Instruction selection used fallback path for test_udot_v33i8_nomla
; CHECK-GI-NEXT: warning: Instruction selection used fallback path for test_sdot_v33i8
; CHECK-GI-NEXT: warning: Instruction selection used fallback path for test_sdot_v33i8_double
; CHECK-GI-NEXT: warning: Instruction selection used fallback path for test_sdot_v33i8_double_nomla
+; CHECK-GI-NEXT: warning: Instruction selection used fallback path for test_usdot_v33i8
+; CHECK-GI-NEXT: warning: Instruction selection used fallback path for test_usdot_v33i8_double
declare i32 @llvm.vector.reduce.add.v4i32(<4 x i32>)
declare i32 @llvm.vector.reduce.add.v5i32(<5 x i32>)
@@ -290,6 +296,128 @@ entry:
ret i32 %x
}
+define i32 @test_usdot_v4i8(ptr nocapture readonly %a, ptr nocapture readonly %b, i32 %sum) {
+; CHECK-SD-LABEL: test_usdot_v4i8:
+; CHECK-SD: // %bb.0: // %entry
+; CHECK-SD-NEXT: ldr s0, [x0]
+; CHECK-SD-NEXT: ldr s1, [x1]
+; CHECK-SD-NEXT: ushll v0.8h, v0.8b, #0
+; CHECK-SD-NEXT: sshll v1.8h, v1.8b, #0
+; CHECK-SD-NEXT: smull v0.4s, v1.4h, v0.4h
+; CHECK-SD-NEXT: addv s0, v0.4s
+; CHECK-SD-NEXT: fmov w8, s0
+; CHECK-SD-NEXT: add w0, w8, w2
+; CHECK-SD-NEXT: ret
+;
+; CHECK-GI-LABEL: test_usdot_v4i8:
+; CHECK-GI: // %bb.0: // %entry
+; CHECK-GI-NEXT: ldr w8, [x0]
+; CHECK-GI-NEXT: ldr w9, [x1]
+; CHECK-GI-NEXT: fmov s0, w8
+; CHECK-GI-NEXT: fmov s2, w9
+; CHECK-GI-NEXT: uxtb w8, w8
+; CHECK-GI-NEXT: sxtb w9, w9
+; CHECK-GI-NEXT: mov b1, v0.b[1]
+; CHECK-GI-NEXT: mov b3, v0.b[2]
+; CHECK-GI-NEXT: mov b5, v2.b[2]
+; CHECK-GI-NEXT: mov b4, v0.b[3]
+; CHECK-GI-NEXT: mov b0, v2.b[1]
+; CHECK-GI-NEXT: mov b6, v2.b[3]
+; CHECK-GI-NEXT: fmov s2, w9
+; CHECK-GI-NEXT: fmov w10, s1
+; CHECK-GI-NEXT: fmov w11, s3
+; CHECK-GI-NEXT: fmov s1, w8
+; CHECK-GI-NEXT: fmov w13, s5
+; CHECK-GI-NEXT: fmov w8, s4
+; CHECK-GI-NEXT: fmov w12, s0
+; CHECK-GI-NEXT: uxtb w10, w10
+; CHECK-GI-NEXT: uxtb w11, w11
+; CHECK-GI-NEXT: sxtb w13, w13
+; CHECK-GI-NEXT: uxtb w8, w8
+; CHECK-GI-NEXT: sxtb w12, w12
+; CHECK-GI-NEXT: mov v1.h[1], w10
+; CHECK-GI-NEXT: fmov w10, s6
+; CHECK-GI-NEXT: fmov s0, w11
+; CHECK-GI-NEXT: fmov s3, w13
+; CHECK-GI-NEXT: mov v2.h[1], w12
+; CHECK-GI-NEXT: sxtb w10, w10
+; CHECK-GI-NEXT: mov v0.h[1], w8
+; CHECK-GI-NEXT: ushll v1.4s, v1.4h, #0
+; CHECK-GI-NEXT: mov v3.h[1], w10
+; CHECK-GI-NEXT: sshll v2.4s, v2.4h, #0
+; CHECK-GI-NEXT: ushll v0.4s, v0.4h, #0
+; CHECK-GI-NEXT: sshll v3.4s, v3.4h, #0
+; CHECK-GI-NEXT: mov v1.d[1], v0.d[0]
+; CHECK-GI-NEXT: mov v2.d[1], v3.d[0]
+; CHECK-GI-NEXT: mul v0.4s, v2.4s, v1.4s
+; CHECK-GI-NEXT: addv s0, v0.4s
+; CHECK-GI-NEXT: fmov w8, s0
+; CHECK-GI-NEXT: add w0, w8, w2
+; CHECK-GI-NEXT: ret
+entry:
+ %0 = load <4 x i8>, ptr %a
+ %1 = zext <4 x i8> %0 to <4 x i32>
+ %2 = load <4 x i8>, ptr %b
+ %3 = sext <4 x i8> %2 to <4 x i32>
+ %4 = mul nsw <4 x i32> %3, %1
+ %5 = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %4)
+ %op.extra = add nsw i32 %5, %sum
+ ret i32 %op.extra
+}
+
+define i32 @test_usdot_v4i8_double(<4 x i8> %a, <4 x i8> %b, <4 x i8> %c, <4 x i8> %d) {
+; CHECK-SD-LABEL: test_usdot_v4i8_double:
+; CHECK-SD: // %bb.0: // %entry
+; CHECK-SD-NEXT: ushll v3.4s, v3.4h, #0
+; CHECK-SD-NEXT: bic v2.4h, #255, lsl #8
+; CHECK-SD-NEXT: ushll v1.4s, v1.4h, #0
+; CHECK-SD-NEXT: bic v0.4h, #255, lsl #8
+; CHECK-SD-NEXT: shl v3.4s, v3.4s, #24
+; CHECK-SD-NEXT: ushll v2.4s, v2.4h, #0
+; CHECK-SD-NEXT: shl v1.4s, v1.4s, #24
+; CHECK-SD-NEXT: ushll v0.4s, v0.4h, #0
+; CHECK-SD-NEXT: sshr v3.4s, v3.4s, #24
+; CHECK-SD-NEXT: sshr v1.4s, v1.4s, #24
+; CHECK-SD-NEXT: mul v2.4s, v2.4s, v3.4s
+; CHECK-SD-NEXT: mla v2.4s, v0.4s, v1.4s
+; CHECK-SD-NEXT: addv s0, v2.4s
+; CHECK-SD-NEXT: fmov w0, s0
+; CHECK-SD-NEXT: ret
+;
+; CHECK-GI-LABEL: test_usdot_v4i8_double:
+; CHECK-GI: // %bb.0: // %entry
+; CHECK-GI-NEXT: ushll v1.4s, v1.4h, #0
+; CHECK-GI-NEXT: ushll v3.4s, v3.4h, #0
+; CHECK-GI-NEXT: movi v4.2d, #0x0000ff000000ff
+; CHECK-GI-NEXT: ushll v0.4s, v0.4h, #0
+; CHECK-GI-NEXT: ushll v2.4s, v2.4h, #0
+; CHECK-GI-NEXT: shl v1.4s, v1.4s, #24
+; CHECK-GI-NEXT: shl v3.4s, v3.4s, #24
+; CHECK-GI-NEXT: and v0.16b, v0.16b, v4.16b
+; CHECK-GI-NEXT: and v2.16b, v2.16b, v4.16b
+; CHECK-GI-NEXT: sshr v1.4s, v1.4s, #24
+; CHECK-GI-NEXT: sshr v3.4s, v3.4s, #24
+; CHECK-GI-NEXT: mul v0.4s, v0.4s, v1.4s
+; CHECK-GI-NEXT: mul v1.4s, v2.4s, v3.4s
+; CHECK-GI-NEXT: addv s0, v0.4s
+; CHECK-GI-NEXT: addv s1, v1.4s
+; CHECK-GI-NEXT: fmov w8, s0
+; CHECK-GI-NEXT: fmov w9, s1
+; CHECK-GI-NEXT: add w0, w8, w9
+; CHECK-GI-NEXT: ret
+entry:
+ %az = zext <4 x i8> %a to <4 x i32>
+ %bz = sext <4 x i8> %b to <4 x i32>
+ %m1 = mul nuw nsw <4 x i32> %az, %bz
+ %r1 = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %m1)
+ %cz = zext <4 x i8> %c to <4 x i32>
+ %dz = sext <4 x i8> %d to <4 x i32>
+ %m2 = mul nuw nsw <4 x i32> %cz, %dz
+ %r2 = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %m2)
+ %x = add i32 %r1, %r2
+ ret i32 %x
+}
+
define i32 @test_udot_v5i8(ptr nocapture readonly %a, ptr nocapture readonly %b, i32 %sum) {
; CHECK-LABEL: test_udot_v5i8:
; CHECK: // %bb.0: // %entry
@@ -414,6 +542,65 @@ entry:
ret i32 %x
}
+define i32 @test_usdot_v5i8(ptr nocapture readonly %a, ptr nocapture readonly %b, i32 %sum) {
+; CHECK-LABEL: test_usdot_v5i8:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ldr d0, [x0]
+; CHECK-NEXT: ldr d1, [x1]
+; CHECK-NEXT: movi v3.2d, #0000000000000000
+; CHECK-NEXT: ushll v0.8h, v0.8b, #0
+; CHECK-NEXT: sshll v1.8h, v1.8b, #0
+; CHECK-NEXT: smull2 v2.4s, v1.8h, v0.8h
+; CHECK-NEXT: mov v3.s[0], v2.s[0]
+; CHECK-NEXT: smlal v3.4s, v1.4h, v0.4h
+; CHECK-NEXT: addv s0, v3.4s
+; CHECK-NEXT: fmov w8, s0
+; CHECK-NEXT: add w0, w8, w2
+; CHECK-NEXT: ret
+entry:
+ %0 = load <5 x i8>, ptr %a
+ %1 = zext <5 x i8> %0 to <5 x i32>
+ %2 = load <5 x i8>, ptr %b
+ %3 = sext <5 x i8> %2 to <5 x i32>
+ %4 = mul nsw <5 x i32> %3, %1
+ %5 = call i32 @llvm.vector.reduce.add.v5i32(<5 x i32> %4)
+ %op.extra = add nsw i32 %5, %sum
+ ret i32 %op.extra
+}
+
+define i32 @test_usdot_v5i8_double(<5 x i8> %a, <5 x i8> %b, <5 x i8> %c, <5 x i8> %d) {
+; CHECK-LABEL: test_usdot_v5i8_double:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: sshll v1.8h, v1.8b, #0
+; CHECK-NEXT: ushll v0.8h, v0.8b, #0
+; CHECK-NEXT: ushll v2.8h, v2.8b, #0
+; CHECK-NEXT: sshll v3.8h, v3.8b, #0
+; CHECK-NEXT: movi v5.2d, #0000000000000000
+; CHECK-NEXT: movi v6.2d, #0000000000000000
+; CHECK-NEXT: smull2 v4.4s, v0.8h, v1.8h
+; CHECK-NEXT: smull2 v7.4s, v2.8h, v3.8h
+; CHECK-NEXT: mov v6.s[0], v4.s[0]
+; CHECK-NEXT: mov v5.s[0], v7.s[0]
+; CHECK-NEXT: smlal v6.4s, v0.4h, v1.4h
+; CHECK-NEXT: smlal v5.4s, v2.4h, v3.4h
+; CHECK-NEXT: add v0.4s, v6.4s, v5.4s
+; CHECK-NEXT: addv s0, v0.4s
+; CHECK-NEXT: fmov w0, s0
+; CHECK-NEXT: ret
+entry:
+ %az = zext <5 x i8> %a to <5 x i32>
+ %bz = sext <5 x i8> %b to <5 x i32>
+ %m1 = mul nuw nsw <5 x i32> %az, %bz
+ %r1 = call i32 @llvm.vector.reduce.add.v5i32(<5 x i32> %m1)
+ %cz = zext <5 x i8> %c to <5 x i32>
+ %dz = sext <5 x i8> %d to <5 x i32>
+ %m2 = mul nuw nsw <5 x i32> %cz, %dz
+ %r2 = call i32 @llvm.vector.reduce.add.v5i32(<5 x i32> %m2)
+ %x = add i32 %r1, %r2
+ ret i32 %x
+}
+
+
define i32 @test_udot_v8i8(ptr nocapture readonly %a, ptr nocapture readonly %b) {
; CHECK-LABEL: test_udot_v8i8:
; CHECK: // %bb.0: // %entry
@@ -508,6 +695,77 @@ entry:
ret i32 %2
}
+define i32 @test_usdot_v8i8(ptr nocapture readonly %a, ptr nocapture readonly %b) {
+; CHECK-SD-LABEL: test_usdot_v8i8:
+; CHECK-SD: // %bb.0: // %entry
+; CHECK-SD-NEXT: movi v0.2d, #0000000000000000
+; CHECK-SD-NEXT: ldr d1, [x0]
+; CHECK-SD-NEXT: ldr d2, [x1]
+; CHECK-SD-NEXT: usdot v0.2s, v1.8b, v2.8b
+; CHECK-SD-NEXT: addp v0.2s, v0.2s, v0.2s
+; CHECK-SD-NEXT: fmov w0, s0
+; CHECK-SD-NEXT: ret
+;
+; CHECK-GI-LABEL: test_usdot_v8i8:
+; CHECK-GI: // %bb.0: // %entry
+; CHECK-GI-NEXT: ldr d0, [x0]
+; CHECK-GI-NEXT: ldr d1, [x1]
+; CHECK-GI-NEXT: ushll v0.8h, v0.8b, #0
+; CHECK-GI-NEXT: sshll v1.8h, v1.8b, #0
+; CHECK-GI-NEXT: ushll2 v2.4s, v0.8h, #0
+; CHECK-GI-NEXT: sshll2 v3.4s, v1.8h, #0
+; CHECK-GI-NEXT: ushll v0.4s, v0.4h, #0
+; CHECK-GI-NEXT: sshll v1.4s, v1.4h, #0
+; CHECK-GI-NEXT: mul v2.4s, v3.4s, v2.4s
+; CHECK-GI-NEXT: mla v2.4s, v1.4s, v0.4s
+; CHECK-GI-NEXT: addv s0, v2.4s
+; CHECK-GI-NEXT: fmov w0, s0
+; CHECK-GI-NEXT: ret
+entry:
+ %0 = load <8 x i8>, ptr %a
+ %1 = zext <8 x i8> %0 to <8 x i32>
+ %2 = load <8 x i8>, ptr %b
+ %3 = sext <8 x i8> %2 to <8 x i32>
+ %4 = mul nsw <8 x i32> %3, %1
+ %5 = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %4)
+ ret i32 %5
+}
+
+define i32 @test_usdot_swapped_operands_v8i8(ptr nocapture readonly %a, ptr nocapture readonly %b) {
+; CHECK-SD-LABEL: test_usdot_swapped_operands_v8i8:
+; CHECK-SD: // %bb.0: // %entry
+; CHECK-SD-NEXT: movi v0.2d, #0000000000000000
+; CHECK-SD-NEXT: ldr d1, [x0]
+; CHECK-SD-NEXT: ldr d2, [x1]
+; CHECK-SD-NEXT: usdot v0.2s, v2.8b, v1.8b
+; CHECK-SD-NEXT: addp v0.2s, v0.2s, v0.2s
+; CHECK-SD-NEXT: fmov w0, s0
+; CHECK-SD-NEXT: ret
+;
+; CHECK-GI-LABEL: test_usdot_swapped_operands_v8i8:
+; CHECK-GI: // %bb.0: // %entry
+; CHECK-GI-NEXT: ldr d0, [x0]
+; CHECK-GI-NEXT: ldr d1, [x1]
+; CHECK-GI-NEXT: sshll v0.8h, v0.8b, #0
+; CHECK-GI-NEXT: ushll v1.8h, v1.8b, #0
+; CHECK-GI-NEXT: sshll2 v2.4s, v0.8h, #0
+; CHECK-GI-NEXT: ushll2 v3.4s, v1.8h, #0
+; CHECK-GI-NEXT: sshll v0.4s, v0.4h, #0
+; CHECK-GI-NEXT: ushll v1.4s, v1.4h, #0
+; CHECK-GI-NEXT: mul v2.4s, v3.4s, v2.4s
+; CHECK-GI-NEXT: mla v2.4s, v1.4s, v0.4s
+; CHECK-GI-NEXT: addv s0, v2.4s
+; CHECK-GI-NEXT: fmov w0, s0
+; CHECK-GI-NEXT: ret
+entry:
+ %0 = load <8 x i8>, ptr %a
+ %1 = sext <8 x i8> %0 to <8 x i32>
+ %2 = load <8 x i8>, ptr %b
+ %3 = zext <8 x i8> %2 to <8 x i32>
+ %4 = mul nsw <8 x i32> %3, %1
+ %5 = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %4)
+ ret i32 %5
+}
define i32 @test_udot_v16i8(ptr nocapture readonly %a, ptr nocapture readonly %b, i32 %sum) {
; CHECK-LABEL: test_udot_v16i8:
@@ -587,6 +845,101 @@ entry:
ret i32 %2
}
+define i32 @test_usdot_v16i8(ptr nocapture readonly %a, ptr nocapture readonly %b, i32 %sum) {
+; CHECK-SD-LABEL: test_usdot_v16i8:
+; CHECK-SD: // %bb.0: // %entry
+; CHECK-SD-NEXT: movi v0.2d, #0000000000000000
+; CHECK-SD-NEXT: ldr q1, [x0]
+; CHECK-SD-NEXT: ldr q2, [x1]
+; CHECK-SD-NEXT: usdot v0.4s, v1.16b, v2.16b
+; CHECK-SD-NEXT: addv s0, v0.4s
+; CHECK-SD-NEXT: fmov w8, s0
+; CHECK-SD-NEXT: add w0, w8, w2
+; CHECK-SD-NEXT: ret
+;
+; CHECK-GI-LABEL: test_usdot_v16i8:
+; CHECK-GI: // %bb.0: // %entry
+; CHECK-GI-NEXT: ldr q0, [x0]
+; CHECK-GI-NEXT: ldr q1, [x1]
+; CHECK-GI-NEXT: ushll v2.8h, v0.8b, #0
+; CHECK-GI-NEXT: ushll2 v0.8h, v0.16b, #0
+; CHECK-GI-NEXT: sshll v3.8h, v1.8b, #0
+; CHECK-GI-NEXT: sshll2 v1.8h, v1.16b, #0
+; CHECK-GI-NEXT: ushll2 v4.4s, v2.8h, #0
+; CHECK-GI-NEXT: ushll2 v5.4s, v0.8h, #0
+; CHECK-GI-NEXT: sshll2 v6.4s, v3.8h, #0
+; CHECK-GI-NEXT: sshll2 v7.4s, v1.8h, #0
+; CHECK-GI-NEXT: ushll v2.4s, v2.4h, #0
+; CHECK-GI-NEXT: ushll v0.4s, v0.4h, #0
+; CHECK-GI-NEXT: sshll v3.4s, v3.4h, #0
+; CHECK-GI-NEXT: sshll v1.4s, v1.4h, #0
+; CHECK-GI-NEXT: mul v4.4s, v6.4s, v4.4s
+; CHECK-GI-NEXT: mul v5.4s, v7.4s, v5.4s
+; CHECK-GI-NEXT: mla v4.4s, v3.4s, v2.4s
+; CHECK-GI-NEXT: mla v5.4s, v1.4s, v0.4s
+; CHECK-GI-NEXT: add v0.4s, v4.4s, v5.4s
+; CHECK-GI-NEXT: addv s0, v0.4s
+; CHECK-GI-NEXT: fmov w8, s0
+; CHECK-GI-NEXT: add w0, w8, w2
+; CHECK-GI-NEXT: ret
+entry:
+ %0 = load <16 x i8>, ptr %a
+ %1 = zext <16 x i8> %0 to <16 x i32>
+ %2 = load <16 x i8>, ptr %b
+ %3 = sext <16 x i8> %2 to <16 x i32>
+ %4 = mul nsw <16 x i32> %3, %1
+ %5 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %4)
+ %op.extra = add nsw i32 %5, %sum
+ ret i32 %op.extra
+}
+
+define i32 @test_usdot_swapped_operands_v16i8(ptr nocapture readonly %a, ptr nocapture readonly %b, i32 %sum) {
+; CHECK-SD-LABEL: test_usdot_swapped_operands_v16i8:
+; CHECK-SD: // %bb.0: // %entry
+; CHECK-SD-NEXT: movi v0.2d, #0000000000000000
+; CHECK-SD-NEXT: ldr q1, [x0]
+; CHECK-SD-NEXT: ldr q2, [x1]
+; CHECK-SD-NEXT: usdot v0.4s, v2.16b, v1.16b
+; CHECK-SD-NEXT: addv s0, v0.4s
+; CHECK-SD-NEXT: fmov w8, s0
+; CHECK-SD-NEXT: add w0, w8, w2
+; CHECK-SD-NEXT: ret
+;
+; CHECK-GI-LABEL: test_usdot_swapped_operands_v16i8:
+; CHECK-GI: // %bb.0: // %entry
+; CHECK-GI-NEXT: ldr q0, [x0]
+; CHECK-GI-NEXT: ldr q1, [x1]
+; CHECK-GI-NEXT: sshll v2.8h, v0.8b, #0
+; CHECK-GI-NEXT: sshll2 v0.8h, v0.16b, #0
+; CHECK-GI-NEXT: ushll v3.8h, v1.8b, #0
+; CHECK-GI-NEXT: ushll2 v1.8h, v1.16b, #0
+; CHECK-GI-NEXT: sshll2 v4.4s, v2.8h, #0
+; CHECK-GI-NEXT: sshll2 v5.4s, v0.8h, #0
+; CHECK-GI-NEXT: ushll2 v6.4s, v3.8h, #0
+; CHECK-GI-NEXT: ushll2 v7.4s, v1.8h, #0
+; CHECK-GI-NEXT: sshll v2.4s, v2.4h, #0
+; CHECK-GI-NEXT: sshll v0.4s, v0.4h, #0
+; CHECK-GI-NEXT: ushll v3.4s, v3.4h, #0
+; CHECK-GI-NEXT: ushll v1.4s, v1.4h, #0
+; CHECK-GI-NEXT: mul v4.4s, v6.4s, v4.4s
+; CHECK-GI-NEXT: mul v5.4s, v7.4s, v5.4s
+; CHECK-GI-NEXT: mla v4.4s, v3.4s, v2.4s
+; CHECK-GI-NEXT: mla v5.4s, v1.4s, v0.4s
+; CHECK-GI-NEXT: add v0.4s, v4.4s, v5.4s
+; CHECK-GI-NEXT: addv s0, v0.4s
+; CHECK-GI-NEXT: fmov w8, s0
+; CHECK-GI-NEXT: add w0, w8, w2
+; CHECK-GI-NEXT: ret
+entry:
+ %0 = load <16 x i8>, ptr %a
+ %1 = sext <16 x i8> %0 to <16 x i32>
+ %2 = load <16 x i8>, ptr %b
+ %3 = zext <16 x i8> %2 to <16 x i32>
+ %4 = mul nsw <16 x i32> %3, %1
+ %5 = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %4)
+ %op.extra = add nsw i32 %5, %sum
+ ret i32 %op.extra
+}
define i32 @test_udot_v8i8_double(<8 x i8> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> %d) {
; CHECK-SD-LABEL: test_udot_v8i8_double:
@@ -860,19 +1213,253 @@ entry:
ret i32 %x
}
-define i32 @test_udot_v24i8(ptr nocapture readonly %a, ptr nocapture readonly %b, i32 %sum) {
-; CHECK-SD-LABEL: test_udot_v24i8:
+
+define i32 @test_usdot_v8i8_double(<8 x i8> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> %d) {
+; CHECK-SD-LABEL: test_usdot_v8i8_double:
; CHECK-SD: // %bb.0: // %entry
-; CHECK-SD-NEXT: movi v0.2d, #0000000000000000
-; CHECK-SD-NEXT: movi v1.2d, #0000000000000000
-; CHECK-SD-NEXT: ldr q2, [x0]
-; CHECK-SD-NEXT: ldr q3, [x1]
-; CHECK-SD-NEXT: ldr d4, [x0, #16]
-; CHECK-SD-NEXT: ldr d5, [x1, #16]
-; CHECK-SD-NEXT: udot v1.2s, v5.8b, v4.8b
-; CHECK-SD-NEXT: udot v0.4s, v3.16b, v2.16b
-; CHECK-SD-NEXT: addp v1.2s, v1.2s, v1.2s
-; CHECK-SD-NEXT: addv s0, v0.4s
+; CHECK-SD-NEXT: movi v4.2d, #0000000000000000
+; CHECK-SD-NEXT: movi v5.2d, #0000000000000000
+; CHECK-SD-NEXT: usdot v5.2s, v0.8b, v1.8b
+; CHECK-SD-NEXT: usdot v4.2s, v2.8b, v3.8b
+; CHECK-SD-NEXT: add v0.2s, v5.2s, v4.2s
+; CHECK-SD-NEXT: addp v0.2s, v0.2s, v0.2s
+; CHECK-SD-NEXT: fmov w0, s0
+; CHECK-SD-NEXT: ret
+;
+; CHECK-GI-LABEL: test_usdot_v8i8_double:
+; CHECK-GI: // %bb.0: // %entry
+; CHECK-GI-NEXT: ushll v0.8h, v0.8b, #0
+; CHECK-GI-NEXT: sshll v1.8h, v1.8b, #0
+; CHECK-GI-NEXT: ushll v2.8h, v2.8b, #0
+; CHECK-GI-NEXT: sshll v3.8h, v3.8b, #0
+; CHECK-GI-NEXT: ushll2 v4.4s, v0.8h, #0
+; CHECK-GI-NEXT: sshll2 v5.4s, v1.8h, #0
+; CHECK-GI-NEXT: ushll2 v6.4s, v2.8h, #0
+; CHECK-GI-NEXT: sshll2 v7.4s, v3.8h, #0
+; CHECK-...
[truncated]
|
@@ -860,19 +1213,253 @@ entry: | |||
ret i32 %x | |||
} | |||
|
|||
define i32 @test_udot_v24i8(ptr nocapture readonly %a, ptr nocapture readonly %b, i32 %sum) { | |||
; CHECK-SD-LABEL: test_udot_v24i8: | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't change the original tests; that's git diff
doing a poor job. Running git diff --minimal
gives better representation
; RUN: llc -mtriple aarch64-linux-gnu -mattr=+dotprod < %s | FileCheck %s --check-prefixes=CHECK,CHECK-SD | ||
; RUN: llc -mtriple aarch64-linux-gnu -mattr=+dotprod -global-isel -global-isel-abort=2 < %s 2>&1 | FileCheck %s --check-prefixes=CHECK,CHECK-GI | ||
; RUN: llc -mtriple aarch64-linux-gnu -mattr=+dotprod,+i8mm < %s | FileCheck %s --check-prefixes=CHECK,CHECK-SD | ||
; RUN: llc -mtriple aarch64-linux-gnu -mattr=+dotprod,+i8mm -global-isel -global-isel-abort=2 < %s 2>&1 | FileCheck %s --check-prefixes=CHECK,CHECK-GI |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure it is a good place for the tests, but this file is responsible for testing the functionality of the performVecReduceAddCombine
function and other tests with +i8mm
have different purposes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. LGTM
eead9ce
to
31e138d
Compare
No description provided.