Skip to content

Commit c18d9ea

Browse files
authored
[AArch64] Generalize bfdotq_lane patterns to work for f32/i32 duplanes (#171146)
This also removes an overly specific pattern that is redundant with this change. Fixes #170883
1 parent e473342 commit c18d9ea

File tree

3 files changed

+79
-47
lines changed

3 files changed

+79
-47
lines changed

llvm/lib/Target/AArch64/AArch64InstrFormats.td

Lines changed: 48 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,20 @@
1010
// Describe AArch64 instructions format here
1111
//
1212

13+
// Helper class to convert vector element types to integers.
14+
class ChangeElementTypeToInteger<ValueType InVT> {
15+
ValueType VT = !cond(
16+
!eq(InVT, v2f32): v2i32,
17+
!eq(InVT, v4f32): v4i32,
18+
// TODO: Other types.
19+
true : untyped);
20+
}
21+
22+
class VTPair<ValueType A, ValueType B> {
23+
ValueType VT0 = A;
24+
ValueType VT1 = B;
25+
}
26+
1327
// Format specifies the encoding used by the instruction. This is part of the
1428
// ad-hoc solution used to emit machine instruction encodings by our machine
1529
// code emitter.
@@ -8952,36 +8966,6 @@ multiclass SIMDThreeSameVectorBFDot<bit U, string asm> {
89528966
v4f32, v8bf16>;
89538967
}
89548968

8955-
class BaseSIMDThreeSameVectorBF16DotI<bit Q, bit U, string asm,
8956-
string dst_kind, string lhs_kind,
8957-
string rhs_kind,
8958-
RegisterOperand RegType,
8959-
ValueType AccumType,
8960-
ValueType InputType>
8961-
: BaseSIMDIndexedTied<Q, U, 0b0, 0b01, 0b1111,
8962-
RegType, RegType, V128, VectorIndexS,
8963-
asm, "", dst_kind, lhs_kind, rhs_kind,
8964-
[(set (AccumType RegType:$dst),
8965-
(AccumType (int_aarch64_neon_bfdot
8966-
(AccumType RegType:$Rd),
8967-
(InputType RegType:$Rn),
8968-
(InputType (bitconvert (AccumType
8969-
(AArch64duplane32 (v4f32 V128:$Rm),
8970-
VectorIndexS:$idx)))))))]> {
8971-
8972-
bits<2> idx;
8973-
let Inst{21} = idx{0}; // L
8974-
let Inst{11} = idx{1}; // H
8975-
}
8976-
8977-
multiclass SIMDThreeSameVectorBF16DotI<bit U, string asm> {
8978-
8979-
def v4bf16 : BaseSIMDThreeSameVectorBF16DotI<0, U, asm, ".2s", ".4h",
8980-
".2h", V64, v2f32, v4bf16>;
8981-
def v8bf16 : BaseSIMDThreeSameVectorBF16DotI<1, U, asm, ".4s", ".8h",
8982-
".2h", V128, v4f32, v8bf16>;
8983-
}
8984-
89858969
let mayRaiseFPException = 1, Uses = [FPCR] in
89868970
class SIMDBF16MLAL<bit Q, string asm, SDPatternOperator OpNode>
89878971
: BaseSIMDThreeSameVectorTied<Q, 0b1, 0b110, 0b11111, V128, asm, ".4s",
@@ -9054,6 +9038,40 @@ class BF16ToSinglePrecision<string asm>
90549038
}
90559039
} // End of let mayStore = 0, mayLoad = 0, hasSideEffects = 0
90569040

9041+
multiclass BaseSIMDThreeSameVectorBF16DotI<bit Q, bit U, string asm,
9042+
string dst_kind, string lhs_kind,
9043+
string rhs_kind,
9044+
RegisterOperand RegType,
9045+
ValueType AccumType,
9046+
ValueType InputType> {
9047+
let mayLoad = 0, mayStore = 0, hasSideEffects = 0 in {
9048+
def NAME : BaseSIMDIndexedTied<Q, U, 0b0, 0b01, 0b1111, RegType, RegType, V128, VectorIndexS,
9049+
asm, "", dst_kind, lhs_kind, rhs_kind, []>
9050+
{
9051+
bits<2> idx;
9052+
let Inst{21} = idx{0}; // L
9053+
let Inst{11} = idx{1}; // H
9054+
}
9055+
}
9056+
9057+
foreach DupTypes = [VTPair<AccumType, v4f32>,
9058+
VTPair<ChangeElementTypeToInteger<AccumType>.VT, v4i32>] in {
9059+
def : Pat<(AccumType (int_aarch64_neon_bfdot
9060+
(AccumType RegType:$Rd), (InputType RegType:$Rn),
9061+
(InputType (bitconvert
9062+
(DupTypes.VT0 (AArch64duplane32 (DupTypes.VT1
9063+
(bitconvert (v8bf16 V128:$Rm))), VectorIndexS:$Idx)))))),
9064+
(!cast<Instruction>(NAME) $Rd, $Rn, $Rm, VectorIndexS:$Idx)>;
9065+
}
9066+
}
9067+
9068+
multiclass SIMDThreeSameVectorBF16DotI<bit U, string asm> {
9069+
defm v4bf16 : BaseSIMDThreeSameVectorBF16DotI<0, U, asm, ".2s", ".4h",
9070+
".2h", V64, v2f32, v4bf16>;
9071+
defm v8bf16 : BaseSIMDThreeSameVectorBF16DotI<1, U, asm, ".4s", ".8h",
9072+
".2h", V128, v4f32, v8bf16>;
9073+
}
9074+
90579075
//----------------------------------------------------------------------------
90589076
class BaseSIMDThreeSameVectorIndexB<bit Q, bit U, bits<2> sz, bits<4> opc,
90599077
string asm, string dst_kind,

llvm/lib/Target/AArch64/AArch64InstrInfo.td

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1735,23 +1735,6 @@ def BFCVTN2 : SIMD_BFCVTN2;
17351735

17361736
def : Pat<(concat_vectors (v4bf16 V64:$Rd), (any_fpround (v4f32 V128:$Rn))),
17371737
(BFCVTN2 (v8bf16 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Rd, dsub)), V128:$Rn)>;
1738-
1739-
// Vector-scalar BFDOT:
1740-
// The second source operand of the 64-bit variant of BF16DOTlane is a 128-bit
1741-
// register (the instruction uses a single 32-bit lane from it), so the pattern
1742-
// is a bit tricky.
1743-
def : Pat<(v2f32 (int_aarch64_neon_bfdot
1744-
(v2f32 V64:$Rd), (v4bf16 V64:$Rn),
1745-
(v4bf16 (bitconvert
1746-
(v2i32 (AArch64duplane32
1747-
(v4i32 (bitconvert
1748-
(v8bf16 (insert_subvector undef,
1749-
(v4bf16 V64:$Rm),
1750-
(i64 0))))),
1751-
VectorIndexS:$idx)))))),
1752-
(BF16DOTlanev4bf16 (v2f32 V64:$Rd), (v4bf16 V64:$Rn),
1753-
(SUBREG_TO_REG (i32 0), V64:$Rm, dsub),
1754-
VectorIndexS:$idx)>;
17551738
}
17561739

17571740
let Predicates = [HasNEONandIsStreamingSafe, HasBF16] in {

llvm/test/CodeGen/AArch64/aarch64-bf16-dotprod-intrinsics.ll

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,37 @@ entry:
151151
ret <4 x float> %vbfmlaltq_v3.i
152152
}
153153

154+
define <4 x float> @test_vbfdotq_laneq_f32_v4i32_shufflevector(<8 x bfloat> %a, <8 x bfloat> %b) {
155+
; CHECK-LABEL: test_vbfdotq_laneq_f32_v4i32_shufflevector:
156+
; CHECK: // %bb.0: // %entry
157+
; CHECK-NEXT: movi v2.2d, #0000000000000000
158+
; CHECK-NEXT: bfdot v2.4s, v0.8h, v1.2h[0]
159+
; CHECK-NEXT: mov v0.16b, v2.16b
160+
; CHECK-NEXT: ret
161+
entry:
162+
%0 = bitcast <8 x bfloat> %b to <4 x i32>
163+
%1 = shufflevector <4 x i32> %0, <4 x i32> poison, <4 x i32> zeroinitializer
164+
%2 = bitcast <4 x i32> %1 to <8 x bfloat>
165+
%vbfdotq = call <4 x float> @llvm.aarch64.neon.bfdot.v4f32.v8bf16(<4 x float> zeroinitializer, <8 x bfloat> %a, <8 x bfloat> %2)
166+
ret <4 x float> %vbfdotq
167+
}
168+
169+
define <2 x float> @test_vbfdotq_laneq_f32_v2i32_shufflevector(<4 x bfloat> %a, <4 x bfloat> %b) {
170+
; CHECK-LABEL: test_vbfdotq_laneq_f32_v2i32_shufflevector:
171+
; CHECK: // %bb.0: // %entry
172+
; CHECK-NEXT: movi d2, #0000000000000000
173+
; CHECK-NEXT: // kill: def $d1 killed $d1 def $q1
174+
; CHECK-NEXT: bfdot v2.2s, v0.4h, v1.2h[0]
175+
; CHECK-NEXT: fmov d0, d2
176+
; CHECK-NEXT: ret
177+
entry:
178+
%0 = bitcast <4 x bfloat> %b to <2 x i32>
179+
%1 = shufflevector <2 x i32> %0, <2 x i32> poison, <2 x i32> zeroinitializer
180+
%2 = bitcast <2 x i32> %1 to <4 x bfloat>
181+
%vbfdotq = call <2 x float> @llvm.aarch64.neon.bfdot.v2f32.v4bf16(<2 x float> zeroinitializer, <4 x bfloat> %a, <4 x bfloat> %2)
182+
ret <2 x float> %vbfdotq
183+
}
184+
154185
declare <2 x float> @llvm.aarch64.neon.bfdot.v2f32.v4bf16(<2 x float>, <4 x bfloat>, <4 x bfloat>)
155186
declare <4 x float> @llvm.aarch64.neon.bfdot.v4f32.v8bf16(<4 x float>, <8 x bfloat>, <8 x bfloat>)
156187
declare <4 x float> @llvm.aarch64.neon.bfmmla(<4 x float>, <8 x bfloat>, <8 x bfloat>)

0 commit comments

Comments
 (0)