Skip to content

[NVPTX] Make i16x2 a native type and add supported vec instructions #65432

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Sep 7, 2023
38 changes: 19 additions & 19 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -612,10 +612,10 @@ bool NVPTXDAGToDAGISel::SelectSETP_F16X2(SDNode *N) {
bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) {
SDValue Vector = N->getOperand(0);

// We only care about f16x2 as it's the only real vector type we
// We only care about 16x2 as it's the only real vector type we
// need to deal with.
MVT VT = Vector.getSimpleValueType();
if (!(VT == MVT::v2f16 || VT == MVT::v2bf16))
if (!Isv2x16VT(VT))
return false;
// Find and record all uses of this vector that extract element 0 or 1.
SmallVector<SDNode *, 4> E0, E1;
Expand Down Expand Up @@ -828,6 +828,7 @@ pickOpcodeForVT(MVT::SimpleValueType VT, unsigned Opcode_i8,
return Opcode_i16;
case MVT::v2f16:
case MVT::v2bf16:
case MVT::v2i16:
return Opcode_i32;
case MVT::f32:
return Opcode_f32;
Expand Down Expand Up @@ -909,9 +910,8 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
// Vector Setting
unsigned vecType = NVPTX::PTXLdStInstCode::Scalar;
if (SimpleVT.isVector()) {
assert((LoadedVT == MVT::v2f16 || LoadedVT == MVT::v2bf16) &&
"Unexpected vector type");
// v2f16/v2bf16 is loaded using ld.b32
assert(Isv2x16VT(LoadedVT) && "Unexpected vector type");
// v2f16/v2bf16/v2i16 is loaded using ld.b32
fromTypeWidth = 32;
}

Expand Down Expand Up @@ -1061,10 +1061,10 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {

EVT EltVT = N->getValueType(0);

// v8f16 is a special case. PTX doesn't have ld.v8.f16
// instruction. Instead, we split the vector into v2f16 chunks and
// v8x16 is a special case. PTX doesn't have ld.v8.16
// instruction. Instead, we split the vector into v2x16 chunks and
// load them with ld.v4.b32.
if (EltVT == MVT::v2f16 || EltVT == MVT::v2bf16) {
if (Isv2x16VT(EltVT)) {
assert(N->getOpcode() == NVPTXISD::LoadV4 && "Unexpected load opcode.");
EltVT = MVT::i32;
FromType = NVPTX::PTXLdStInstCode::Untyped;
Expand Down Expand Up @@ -1260,12 +1260,13 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
if (EltVT.isVector()) {
NumElts = EltVT.getVectorNumElements();
EltVT = EltVT.getVectorElementType();
// vectors of f16 are loaded/stored as multiples of v2f16 elements.
// vectors of 16bits type are loaded/stored as multiples of v2x16 elements.
if ((EltVT == MVT::f16 && N->getValueType(0) == MVT::v2f16) ||
(EltVT == MVT::bf16 && N->getValueType(0) == MVT::v2bf16)) {
assert(NumElts % 2 == 0 && "Vector must have even number of elements");
EltVT = N->getValueType(0);
NumElts /= 2;
(EltVT == MVT::bf16 && N->getValueType(0) == MVT::v2bf16) ||
(EltVT == MVT::i16 && N->getValueType(0) == MVT::v2i16)) {
assert(NumElts % 2 == 0 && "Vector must have even number of elements");
EltVT = N->getValueType(0);
NumElts /= 2;
}
}

Expand Down Expand Up @@ -1678,9 +1679,8 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
MVT ScalarVT = SimpleVT.getScalarType();
unsigned toTypeWidth = ScalarVT.getSizeInBits();
if (SimpleVT.isVector()) {
assert((StoreVT == MVT::v2f16 || StoreVT == MVT::v2bf16) &&
"Unexpected vector type");
// v2f16 is stored using st.b32
assert(Isv2x16VT(StoreVT) && "Unexpected vector type");
// v2x16 is stored using st.b32
toTypeWidth = 32;
}

Expand Down Expand Up @@ -1844,10 +1844,10 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
return false;
}

// v8f16 is a special case. PTX doesn't have st.v8.f16
// instruction. Instead, we split the vector into v2f16 chunks and
// v8x16 is a special case. PTX doesn't have st.v8.x16
// instruction. Instead, we split the vector into v2x16 chunks and
// store them with st.v4.b32.
if (EltVT == MVT::v2f16 || EltVT == MVT::v2bf16) {
if (Isv2x16VT(EltVT)) {
assert(N->getOpcode() == NVPTXISD::StoreV4 && "Unexpected load opcode.");
EltVT = MVT::i32;
ToType = NVPTX::PTXLdStInstCode::Untyped;
Expand Down
Loading