Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 15 additions & 34 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1150,15 +1150,12 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
return true;
}

static unsigned getLoadStoreVectorNumElts(SDNode *N) {
static unsigned getStoreVectorNumElts(SDNode *N) {
switch (N->getOpcode()) {
case NVPTXISD::LoadV2:
case NVPTXISD::StoreV2:
return 2;
case NVPTXISD::LoadV4:
case NVPTXISD::StoreV4:
return 4;
case NVPTXISD::LoadV8:
case NVPTXISD::StoreV8:
return 8;
default:
Expand All @@ -1171,7 +1168,6 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
const EVT MemEVT = LD->getMemoryVT();
if (!MemEVT.isSimple())
return false;
const MVT MemVT = MemEVT.getSimpleVT();

// Address Space Setting
const auto CodeAddrSpace = getAddrSpace(LD);
Expand All @@ -1191,18 +1187,15 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
// Float : ISD::NON_EXTLOAD or ISD::EXTLOAD and the type is float
// Read at least 8 bits (predicates are stored as 8-bit values)
// The last operand holds the original LoadSDNode::getExtensionType() value
const unsigned TotalWidth = MemVT.getSizeInBits();
const unsigned ExtensionType =
N->getConstantOperandVal(N->getNumOperands() - 1);
const unsigned FromType = (ExtensionType == ISD::SEXTLOAD)
? NVPTX::PTXLdStInstCode::Signed
: NVPTX::PTXLdStInstCode::Untyped;

const unsigned FromTypeWidth = TotalWidth / getLoadStoreVectorNumElts(N);
const unsigned FromTypeWidth = getFromTypeWidthForLoad(LD);

assert(!(EltVT.isVector() && ExtensionType != ISD::NON_EXTLOAD));
assert(isPowerOf2_32(FromTypeWidth) && FromTypeWidth >= 8 &&
FromTypeWidth <= 128 && TotalWidth <= 256 && "Invalid width for load");

const auto [Base, Offset] = selectADDR(N->getOperand(1), CurDAG);
SDValue Ops[] = {getI32Imm(Ordering, DL),
Expand Down Expand Up @@ -1247,30 +1240,23 @@ bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) {
const EVT LoadedEVT = LD->getMemoryVT();
if (!LoadedEVT.isSimple())
return false;
const MVT LoadedVT = LoadedEVT.getSimpleVT();

SDLoc DL(LD);

const unsigned TotalWidth = LoadedVT.getSizeInBits();
unsigned ExtensionType;
unsigned NumElts;
if (const auto *Load = dyn_cast<LoadSDNode>(LD)) {
ExtensionType = Load->getExtensionType();
NumElts = 1;
} else {
ExtensionType = LD->getConstantOperandVal(LD->getNumOperands() - 1);
NumElts = getLoadStoreVectorNumElts(LD);
}
const unsigned FromType = (ExtensionType == ISD::SEXTLOAD)
? NVPTX::PTXLdStInstCode::Signed
: NVPTX::PTXLdStInstCode::Untyped;

const unsigned FromTypeWidth = TotalWidth / NumElts;
const unsigned FromTypeWidth = getFromTypeWidthForLoad(LD);

assert(!(LD->getSimpleValueType(0).isVector() &&
ExtensionType != ISD::NON_EXTLOAD));
assert(isPowerOf2_32(FromTypeWidth) && FromTypeWidth >= 8 &&
FromTypeWidth <= 128 && TotalWidth <= 256 && "Invalid width for load");

const auto [Base, Offset] = selectADDR(LD->getOperand(1), CurDAG);
SDValue Ops[] = {getI32Imm(FromType, DL), getI32Imm(FromTypeWidth, DL), Base,
Expand Down Expand Up @@ -1309,26 +1295,21 @@ bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) {
return true;
}

unsigned NVPTXDAGToDAGISel::getFromTypeWidthForLoad(const MemSDNode *Mem) {
auto TotalWidth = Mem->getMemoryVT().getSizeInBits();
auto NumElts = Mem->getNumValues() - 1;
auto ElementBitWidth = TotalWidth / NumElts;
assert(isPowerOf2_32(ElementBitWidth) && ElementBitWidth >= 8 &&
ElementBitWidth <= 128 && TotalWidth <= 256 &&
"Invalid width for load");
return ElementBitWidth;
}

bool NVPTXDAGToDAGISel::tryLDU(SDNode *N) {
auto *LD = cast<MemSDNode>(N);

unsigned NumElts;
switch (N->getOpcode()) {
default:
llvm_unreachable("Unexpected opcode");
case ISD::INTRINSIC_W_CHAIN:
NumElts = 1;
break;
case NVPTXISD::LDUV2:
NumElts = 2;
break;
case NVPTXISD::LDUV4:
NumElts = 4;
break;
}

SDLoc DL(N);
const unsigned FromTypeWidth = LD->getMemoryVT().getSizeInBits() / NumElts;
const unsigned FromTypeWidth = getFromTypeWidthForLoad(LD);
const MVT::SimpleValueType TargetVT = LD->getSimpleValueType(0).SimpleTy;

// If this is an LDU intrinsic, the address is the third operand. If its an
Expand Down Expand Up @@ -1443,7 +1424,7 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
// - for integer type, always use 'u'
const unsigned TotalWidth = StoreVT.getSimpleVT().getSizeInBits();

const unsigned NumElts = getLoadStoreVectorNumElts(ST);
const unsigned NumElts = getStoreVectorNumElts(ST);

SmallVector<SDValue, 16> Ops;
for (auto &V : ST->ops().slice(1, NumElts))
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {

public:
static NVPTX::AddressSpace getAddrSpace(const MemSDNode *N);
static unsigned getFromTypeWidthForLoad(const MemSDNode *Mem);
};

class NVPTXDAGToDAGISelLegacy : public SelectionDAGISelLegacy {
Expand Down
96 changes: 24 additions & 72 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "NVPTXISelLowering.h"
#include "MCTargetDesc/NVPTXBaseInfo.h"
#include "NVPTX.h"
#include "NVPTXISelDAGToDAG.h"
#include "NVPTXSubtarget.h"
#include "NVPTXTargetMachine.h"
#include "NVPTXTargetObjectFile.h"
Expand Down Expand Up @@ -5242,76 +5243,6 @@ static SDValue PerformFADDCombine(SDNode *N,
return PerformFADDCombineWithOperands(N, N1, N0, DCI, OptLevel);
}

static SDValue PerformANDCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI) {
// The type legalizer turns a vector load of i8 values into a zextload to i16
// registers, optionally ANY_EXTENDs it (if target type is integer),
// and ANDs off the high 8 bits. Since we turn this load into a
// target-specific DAG node, the DAG combiner fails to eliminate these AND
// nodes. Do that here.
SDValue Val = N->getOperand(0);
SDValue Mask = N->getOperand(1);

if (isa<ConstantSDNode>(Val)) {
std::swap(Val, Mask);
}

SDValue AExt;

// Generally, we will see zextload -> IMOV16rr -> ANY_EXTEND -> and
if (Val.getOpcode() == ISD::ANY_EXTEND) {
AExt = Val;
Val = Val->getOperand(0);
}

if (Val->getOpcode() == NVPTXISD::LoadV2 ||
Val->getOpcode() == NVPTXISD::LoadV4) {
ConstantSDNode *MaskCnst = dyn_cast<ConstantSDNode>(Mask);
if (!MaskCnst) {
// Not an AND with a constant
return SDValue();
}

uint64_t MaskVal = MaskCnst->getZExtValue();
if (MaskVal != 0xff) {
// Not an AND that chops off top 8 bits
return SDValue();
}

MemSDNode *Mem = dyn_cast<MemSDNode>(Val);
if (!Mem) {
// Not a MemSDNode?!?
return SDValue();
}

EVT MemVT = Mem->getMemoryVT();
if (MemVT != MVT::v2i8 && MemVT != MVT::v4i8) {
// We only handle the i8 case
return SDValue();
}

unsigned ExtType = Val->getConstantOperandVal(Val->getNumOperands() - 1);
if (ExtType == ISD::SEXTLOAD) {
// If for some reason the load is a sextload, the and is needed to zero
// out the high 8 bits
return SDValue();
}

bool AddTo = false;
if (AExt.getNode() != nullptr) {
// Re-insert the ext as a zext.
Val = DCI.DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N),
AExt.getValueType(), Val);
AddTo = true;
}

// If we get here, the AND is unnecessary. Just replace it with the load
DCI.CombineTo(N, Val, AddTo);
}

return SDValue();
}

static SDValue PerformREMCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
CodeGenOptLevel OptLevel) {
Expand Down Expand Up @@ -5983,8 +5914,6 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
return PerformADDCombine(N, DCI, OptLevel);
case ISD::ADDRSPACECAST:
return combineADDRSPACECAST(N, DCI);
case ISD::AND:
return PerformANDCombine(N, DCI);
case ISD::SIGN_EXTEND:
case ISD::ZERO_EXTEND:
return combineMulWide(N, DCI, OptLevel);
Expand Down Expand Up @@ -6609,6 +6538,24 @@ static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known,
}
}

static void computeKnownBitsForLoadV(const SDValue Op, KnownBits &Known) {
MemSDNode *LD = cast<MemSDNode>(Op);

// We can't do anything without knowing the sign bit.
auto ExtType = LD->getConstantOperandVal(LD->getNumOperands() - 1);
if (ExtType == ISD::SEXTLOAD)
return;

// ExtLoading to vector types is weird and may not work well with known bits.
auto DestVT = LD->getValueType(0);
if (DestVT.isVector())
return;

assert(Known.getBitWidth() == DestVT.getSizeInBits());
auto ElementBitWidth = NVPTXDAGToDAGISel::getFromTypeWidthForLoad(LD);
Known.Zero.setHighBits(Known.getBitWidth() - ElementBitWidth);
}

void NVPTXTargetLowering::computeKnownBitsForTargetNode(
const SDValue Op, KnownBits &Known, const APInt &DemandedElts,
const SelectionDAG &DAG, unsigned Depth) const {
Expand All @@ -6618,6 +6565,11 @@ void NVPTXTargetLowering::computeKnownBitsForTargetNode(
case NVPTXISD::PRMT:
computeKnownBitsForPRMT(Op, Known, DAG, Depth);
break;
case NVPTXISD::LoadV2:
case NVPTXISD::LoadV4:
case NVPTXISD::LoadV8:
computeKnownBitsForLoadV(Op, Known);
break;
default:
break;
}
Expand Down
29 changes: 29 additions & 0 deletions llvm/test/CodeGen/NVPTX/i8x2-instructions.ll
Original file line number Diff line number Diff line change
Expand Up @@ -103,5 +103,34 @@ define <2 x i8> @test_call_2xi8(<2 x i8> %a) {
%res = call <2 x i8> @test_call_2xi8(<2 x i8> %a)
ret <2 x i8> %res
}

define <2 x float> @test_uitofp_2xi8(<2 x i8> %a) {
; O0-LABEL: test_uitofp_2xi8(
; O0: {
; O0-NEXT: .reg .b16 %rs<3>;
; O0-NEXT: .reg .b32 %r<4>;
; O0-EMPTY:
; O0-NEXT: // %bb.0:
; O0-NEXT: ld.param.v2.b8 {%rs1, %rs2}, [test_uitofp_2xi8_param_0];
; O0-NEXT: mov.b32 %r1, {%rs1, %rs2};
; O0-NEXT: cvt.rn.f32.u16 %r2, %rs2;
; O0-NEXT: cvt.rn.f32.u16 %r3, %rs1;
; O0-NEXT: st.param.v2.b32 [func_retval0], {%r3, %r2};
; O0-NEXT: ret;
;
; O3-LABEL: test_uitofp_2xi8(
; O3: {
; O3-NEXT: .reg .b16 %rs<3>;
; O3-NEXT: .reg .b32 %r<3>;
; O3-EMPTY:
; O3-NEXT: // %bb.0:
; O3-NEXT: ld.param.v2.b8 {%rs1, %rs2}, [test_uitofp_2xi8_param_0];
; O3-NEXT: cvt.rn.f32.u16 %r1, %rs2;
; O3-NEXT: cvt.rn.f32.u16 %r2, %rs1;
; O3-NEXT: st.param.v2.b32 [func_retval0], {%r2, %r1};
; O3-NEXT: ret;
%1 = uitofp <2 x i8> %a to <2 x float>
ret <2 x float> %1
}
;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
; COMMON: {{.*}}
9 changes: 4 additions & 5 deletions llvm/test/CodeGen/NVPTX/shift-opt.ll
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,17 @@ define <2 x i16> @test_vec(<2 x i16> %x, <2 x i8> %y) {
; CHECK-LABEL: test_vec(
; CHECK: {
; CHECK-NEXT: .reg .b16 %rs<7>;
; CHECK-NEXT: .reg .b32 %r<5>;
; CHECK-NEXT: .reg .b32 %r<4>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.v2.b16 {%rs1, %rs2}, [test_vec_param_0];
; CHECK-NEXT: ld.param.v2.b8 {%rs3, %rs4}, [test_vec_param_1];
; CHECK-NEXT: mov.b32 %r1, {%rs3, %rs4};
; CHECK-NEXT: and.b32 %r2, %r1, 16711935;
; CHECK-NEXT: shr.u16 %rs5, %rs2, 5;
; CHECK-NEXT: shr.u16 %rs6, %rs1, 5;
; CHECK-NEXT: mov.b32 %r3, {%rs6, %rs5};
; CHECK-NEXT: or.b32 %r4, %r3, %r2;
; CHECK-NEXT: st.param.b32 [func_retval0], %r4;
; CHECK-NEXT: mov.b32 %r2, {%rs6, %rs5};
; CHECK-NEXT: or.b32 %r3, %r2, %r1;
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
; CHECK-NEXT: ret;
%ext = zext <2 x i8> %y to <2 x i16>
%shl = shl <2 x i16> %ext, splat(i16 5)
Expand Down