Skip to content

Commit 641be94

Browse files
ThomasRaouxthomasfaingnaert
authored andcommitted
[NVPTX] Make i16x2 a native type and add supported vec instructions (llvm#65799)
recommit llvm#65432 with minor bug fix for bitcasts
1 parent b241821 commit 641be94

15 files changed

+850
-200
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -610,9 +610,10 @@ bool NVPTXDAGToDAGISel::SelectSETP_F16X2(SDNode *N) {
610610
bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) {
611611
SDValue Vector = N->getOperand(0);
612612

613-
// We only care about f16x2 as it's the only real vector type we
613+
// We only care about 16x2 as it's the only real vector type we
614614
// need to deal with.
615-
if (Vector.getSimpleValueType() != MVT::v2f16)
615+
MVT VT = Vector.getSimpleValueType();
616+
if (!Isv2x16VT(VT))
616617
return false;
617618

618619
// Find and record all uses of this vector that extract element 0 or 1.
@@ -825,6 +826,7 @@ pickOpcodeForVT(MVT::SimpleValueType VT, unsigned Opcode_i8,
825826
return Opcode_i16;
826827
case MVT::v2f16:
827828
case MVT::v2bf16:
829+
case MVT::v2i16:
828830
return Opcode_i32;
829831
case MVT::f32:
830832
return Opcode_f32;
@@ -906,9 +908,8 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
906908
// Vector Setting
907909
unsigned vecType = NVPTX::PTXLdStInstCode::Scalar;
908910
if (SimpleVT.isVector()) {
909-
assert((LoadedVT == MVT::v2f16 || LoadedVT == MVT::v2bf16) &&
910-
"Unexpected vector type");
911-
// v2f16/v2bf16 is loaded using ld.b32
911+
assert(Isv2x16VT(LoadedVT) && "Unexpected vector type");
912+
// v2f16/v2bf16/v2i16 is loaded using ld.b32
912913
fromTypeWidth = 32;
913914
}
914915

@@ -1058,10 +1059,10 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
10581059

10591060
EVT EltVT = N->getValueType(0);
10601061

1061-
// v8f16 is a special case. PTX doesn't have ld.v8.f16
1062-
// instruction. Instead, we split the vector into v2f16 chunks and
1062+
// v8x16 is a special case. PTX doesn't have ld.v8.16
1063+
// instruction. Instead, we split the vector into v2x16 chunks and
10631064
// load them with ld.v4.b32.
1064-
if (EltVT == MVT::v2f16 || EltVT == MVT::v2bf16) {
1065+
if (Isv2x16VT(EltVT)) {
10651066
assert(N->getOpcode() == NVPTXISD::LoadV4 && "Unexpected load opcode.");
10661067
EltVT = MVT::i32;
10671068
FromType = NVPTX::PTXLdStInstCode::Untyped;
@@ -1257,10 +1258,12 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
12571258
if (EltVT.isVector()) {
12581259
NumElts = EltVT.getVectorNumElements();
12591260
EltVT = EltVT.getVectorElementType();
1260-
// vectors of f16 are loaded/stored as multiples of v2f16 elements.
1261-
if (EltVT == MVT::f16 && N->getValueType(0) == MVT::v2f16) {
1261+
// vectors of 16bits type are loaded/stored as multiples of v2x16 elements.
1262+
if ((EltVT == MVT::f16 && N->getValueType(0) == MVT::v2f16) ||
1263+
(EltVT == MVT::bf16 && N->getValueType(0) == MVT::v2bf16) ||
1264+
(EltVT == MVT::i16 && N->getValueType(0) == MVT::v2i16)) {
12621265
assert(NumElts % 2 == 0 && "Vector must have even number of elements");
1263-
EltVT = MVT::v2f16;
1266+
EltVT = N->getValueType(0);
12641267
NumElts /= 2;
12651268
}
12661269
}
@@ -1674,9 +1677,8 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
16741677
MVT ScalarVT = SimpleVT.getScalarType();
16751678
unsigned toTypeWidth = ScalarVT.getSizeInBits();
16761679
if (SimpleVT.isVector()) {
1677-
assert((StoreVT == MVT::v2f16 || StoreVT == MVT::v2bf16) &&
1678-
"Unexpected vector type");
1679-
// v2f16 is stored using st.b32
1680+
assert(Isv2x16VT(StoreVT) && "Unexpected vector type");
1681+
// v2x16 is stored using st.b32
16801682
toTypeWidth = 32;
16811683
}
16821684

@@ -1840,10 +1842,10 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
18401842
return false;
18411843
}
18421844

1843-
// v8f16 is a special case. PTX doesn't have st.v8.f16
1844-
// instruction. Instead, we split the vector into v2f16 chunks and
1845+
// v8x16 is a special case. PTX doesn't have st.v8.x16
1846+
// instruction. Instead, we split the vector into v2x16 chunks and
18451847
// store them with st.v4.b32.
1846-
if (EltVT == MVT::v2f16 || EltVT == MVT::v2bf16) {
1848+
if (Isv2x16VT(EltVT)) {
18471849
assert(N->getOpcode() == NVPTXISD::StoreV4 && "Unexpected load opcode.");
18481850
EltVT = MVT::i32;
18491851
ToType = NVPTX::PTXLdStInstCode::Untyped;

0 commit comments

Comments
 (0)