@@ -610,9 +610,10 @@ bool NVPTXDAGToDAGISel::SelectSETP_F16X2(SDNode *N) {
610
610
bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT (SDNode *N) {
611
611
SDValue Vector = N->getOperand (0 );
612
612
613
- // We only care about f16x2 as it's the only real vector type we
613
+ // We only care about 16x2 as it's the only real vector type we
614
614
// need to deal with.
615
- if (Vector.getSimpleValueType () != MVT::v2f16)
615
+ MVT VT = Vector.getSimpleValueType ();
616
+ if (!Isv2x16VT (VT))
616
617
return false ;
617
618
618
619
// Find and record all uses of this vector that extract element 0 or 1.
@@ -825,6 +826,7 @@ pickOpcodeForVT(MVT::SimpleValueType VT, unsigned Opcode_i8,
825
826
return Opcode_i16;
826
827
case MVT::v2f16:
827
828
case MVT::v2bf16:
829
+ case MVT::v2i16:
828
830
return Opcode_i32;
829
831
case MVT::f32:
830
832
return Opcode_f32;
@@ -906,9 +908,8 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
906
908
// Vector Setting
907
909
unsigned vecType = NVPTX::PTXLdStInstCode::Scalar;
908
910
if (SimpleVT.isVector ()) {
909
- assert ((LoadedVT == MVT::v2f16 || LoadedVT == MVT::v2bf16) &&
910
- " Unexpected vector type" );
911
- // v2f16/v2bf16 is loaded using ld.b32
911
+ assert (Isv2x16VT (LoadedVT) && " Unexpected vector type" );
912
+ // v2f16/v2bf16/v2i16 is loaded using ld.b32
912
913
fromTypeWidth = 32 ;
913
914
}
914
915
@@ -1058,10 +1059,10 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
1058
1059
1059
1060
EVT EltVT = N->getValueType (0 );
1060
1061
1061
- // v8f16 is a special case. PTX doesn't have ld.v8.f16
1062
- // instruction. Instead, we split the vector into v2f16 chunks and
1062
+ // v8x16 is a special case. PTX doesn't have ld.v8.16
1063
+ // instruction. Instead, we split the vector into v2x16 chunks and
1063
1064
// load them with ld.v4.b32.
1064
- if (EltVT == MVT::v2f16 || EltVT == MVT::v2bf16 ) {
1065
+ if (Isv2x16VT ( EltVT) ) {
1065
1066
assert (N->getOpcode () == NVPTXISD::LoadV4 && " Unexpected load opcode." );
1066
1067
EltVT = MVT::i32;
1067
1068
FromType = NVPTX::PTXLdStInstCode::Untyped;
@@ -1257,10 +1258,12 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
1257
1258
if (EltVT.isVector ()) {
1258
1259
NumElts = EltVT.getVectorNumElements ();
1259
1260
EltVT = EltVT.getVectorElementType ();
1260
- // vectors of f16 are loaded/stored as multiples of v2f16 elements.
1261
- if (EltVT == MVT::f16 && N->getValueType (0 ) == MVT::v2f16) {
1261
+ // vectors of 16bits type are loaded/stored as multiples of v2x16 elements.
1262
+ if ((EltVT == MVT::f16 && N->getValueType (0 ) == MVT::v2f16) ||
1263
+ (EltVT == MVT::bf16 && N->getValueType (0 ) == MVT::v2bf16) ||
1264
+ (EltVT == MVT::i16 && N->getValueType (0 ) == MVT::v2i16)) {
1262
1265
assert (NumElts % 2 == 0 && " Vector must have even number of elements" );
1263
- EltVT = MVT::v2f16 ;
1266
+ EltVT = N-> getValueType ( 0 ) ;
1264
1267
NumElts /= 2 ;
1265
1268
}
1266
1269
}
@@ -1674,9 +1677,8 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
1674
1677
MVT ScalarVT = SimpleVT.getScalarType ();
1675
1678
unsigned toTypeWidth = ScalarVT.getSizeInBits ();
1676
1679
if (SimpleVT.isVector ()) {
1677
- assert ((StoreVT == MVT::v2f16 || StoreVT == MVT::v2bf16) &&
1678
- " Unexpected vector type" );
1679
- // v2f16 is stored using st.b32
1680
+ assert (Isv2x16VT (StoreVT) && " Unexpected vector type" );
1681
+ // v2x16 is stored using st.b32
1680
1682
toTypeWidth = 32 ;
1681
1683
}
1682
1684
@@ -1840,10 +1842,10 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
1840
1842
return false ;
1841
1843
}
1842
1844
1843
- // v8f16 is a special case. PTX doesn't have st.v8.f16
1844
- // instruction. Instead, we split the vector into v2f16 chunks and
1845
+ // v8x16 is a special case. PTX doesn't have st.v8.x16
1846
+ // instruction. Instead, we split the vector into v2x16 chunks and
1845
1847
// store them with st.v4.b32.
1846
- if (EltVT == MVT::v2f16 || EltVT == MVT::v2bf16 ) {
1848
+ if (Isv2x16VT ( EltVT) ) {
1847
1849
assert (N->getOpcode () == NVPTXISD::StoreV4 && " Unexpected load opcode." );
1848
1850
EltVT = MVT::i32;
1849
1851
ToType = NVPTX::PTXLdStInstCode::Untyped;
0 commit comments