From d11030d2601bad619638a674915f0860ffd77459 Mon Sep 17 00:00:00 2001 From: Kevin McAfee Date: Mon, 18 Aug 2025 17:36:45 +0000 Subject: [PATCH 1/6] pre-commit test --- llvm/test/CodeGen/NVPTX/i8x2-instructions.ll | 34 ++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll b/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll index 53150c1a01314..7463fb5b59e02 100644 --- a/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll +++ b/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll @@ -103,5 +103,39 @@ define <2 x i8> @test_call_2xi8(<2 x i8> %a) { %res = call <2 x i8> @test_call_2xi8(<2 x i8> %a) ret <2 x i8> %res } + +define <2 x float> @test_uitofp_2xi8(<2 x i8> %a) { +; O0-LABEL: test_uitofp_2xi8( +; O0: { +; O0-NEXT: .reg .b16 %rs<5>; +; O0-NEXT: .reg .b32 %r<5>; +; O0-EMPTY: +; O0-NEXT: // %bb.0: +; O0-NEXT: ld.param.v2.b8 {%rs1, %rs2}, [test_uitofp_2xi8_param_0]; +; O0-NEXT: mov.b32 %r1, {%rs1, %rs2}; +; O0-NEXT: and.b32 %r2, %r1, 16711935; +; O0-NEXT: mov.b32 {%rs3, %rs4}, %r2; +; O0-NEXT: cvt.rn.f32.u16 %r3, %rs4; +; O0-NEXT: cvt.rn.f32.u16 %r4, %rs3; +; O0-NEXT: st.param.v2.b32 [func_retval0], {%r4, %r3}; +; O0-NEXT: ret; +; +; O3-LABEL: test_uitofp_2xi8( +; O3: { +; O3-NEXT: .reg .b16 %rs<5>; +; O3-NEXT: .reg .b32 %r<5>; +; O3-EMPTY: +; O3-NEXT: // %bb.0: +; O3-NEXT: ld.param.v2.b8 {%rs1, %rs2}, [test_uitofp_2xi8_param_0]; +; O3-NEXT: mov.b32 %r1, {%rs1, %rs2}; +; O3-NEXT: and.b32 %r2, %r1, 16711935; +; O3-NEXT: mov.b32 {%rs3, %rs4}, %r2; +; O3-NEXT: cvt.rn.f32.u16 %r3, %rs4; +; O3-NEXT: cvt.rn.f32.u16 %r4, %rs3; +; O3-NEXT: st.param.v2.b32 [func_retval0], {%r4, %r3}; +; O3-NEXT: ret; + %1 = uitofp <2 x i8> %a to <2 x float> + ret <2 x float> %1 +} ;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line: ; COMMON: {{.*}} From ea1f7d53c3635f1b8d3c15819a52b2cbf77e2a2e Mon Sep 17 00:00:00 2001 From: Kevin McAfee Date: Mon, 18 Aug 2025 17:38:29 +0000 Subject: [PATCH 2/6] [NVPTX] Support vectors for AND combine --- llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 114 ++++++++++++------- llvm/test/CodeGen/NVPTX/i8x2-instructions.ll | 25 ++-- llvm/test/CodeGen/NVPTX/shift-opt.ll | 9 +- 3 files changed, 84 insertions(+), 64 deletions(-) diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index 74e6c139c610d..8190c23407250 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -5242,6 +5242,58 @@ static SDValue PerformFADDCombine(SDNode *N, return PerformFADDCombineWithOperands(N, N1, N0, DCI, OptLevel); } +// Helper function to check if an AND operation on a load can be eliminated. +// Returns a replacement value if the load can be eliminated, else nullopt. +static std::optional +canEliminateLoadAND(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, + SDValue Val, SDValue Mask, SDValue AExt) { + if (Val->getOpcode() != NVPTXISD::LoadV2 && + Val->getOpcode() != NVPTXISD::LoadV4) { + return std::nullopt; + } + + ConstantSDNode *MaskCnst = dyn_cast(Mask); + if (!MaskCnst) { + // Not an AND with a constant + return std::nullopt; + } + + uint64_t MaskVal = MaskCnst->getZExtValue(); + if (MaskVal != 0xff) { + // Not an AND that chops off top 8 bits + return std::nullopt; + } + + MemSDNode *Mem = dyn_cast(Val); + if (!Mem) { + // Not a MemSDNode + return std::nullopt; + } + + EVT MemVT = Mem->getMemoryVT(); + if (MemVT != MVT::v2i8 && MemVT != MVT::v4i8) { + // We only handle the i8 case + return std::nullopt; + } + + unsigned ExtType = Val->getConstantOperandVal(Val->getNumOperands() - 1); + if (ExtType == ISD::SEXTLOAD) { + // If the load is a sextload, the AND is needed to zero out the high 8 bits + return std::nullopt; + } + + SDValue Result = Val; + + if (AExt) { + // Re-insert the ext as a zext. + Result = + DCI.DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), AExt.getValueType(), Val); + } + + // If we get here, the AND is unnecessary. Replace it with the load. + return Result; +} + static SDValue PerformANDCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { // The type legalizer turns a vector load of i8 values into a zextload to i16 @@ -5252,9 +5304,8 @@ static SDValue PerformANDCombine(SDNode *N, SDValue Val = N->getOperand(0); SDValue Mask = N->getOperand(1); - if (isa(Val)) { + if (isa(Val)) std::swap(Val, Mask); - } SDValue AExt; @@ -5264,49 +5315,24 @@ static SDValue PerformANDCombine(SDNode *N, Val = Val->getOperand(0); } - if (Val->getOpcode() == NVPTXISD::LoadV2 || - Val->getOpcode() == NVPTXISD::LoadV4) { - ConstantSDNode *MaskCnst = dyn_cast(Mask); - if (!MaskCnst) { - // Not an AND with a constant - return SDValue(); - } - - uint64_t MaskVal = MaskCnst->getZExtValue(); - if (MaskVal != 0xff) { - // Not an AND that chops off top 8 bits - return SDValue(); - } - - MemSDNode *Mem = dyn_cast(Val); - if (!Mem) { - // Not a MemSDNode?!? - return SDValue(); - } - - EVT MemVT = Mem->getMemoryVT(); - if (MemVT != MVT::v2i8 && MemVT != MVT::v4i8) { - // We only handle the i8 case - return SDValue(); - } - - unsigned ExtType = Val->getConstantOperandVal(Val->getNumOperands() - 1); - if (ExtType == ISD::SEXTLOAD) { - // If for some reason the load is a sextload, the and is needed to zero - // out the high 8 bits - return SDValue(); - } - - bool AddTo = false; - if (AExt.getNode() != nullptr) { - // Re-insert the ext as a zext. - Val = DCI.DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), - AExt.getValueType(), Val); - AddTo = true; + if (Val.getOpcode() == ISD::BUILD_VECTOR && + Mask.getOpcode() == ISD::BUILD_VECTOR) { + assert(Val->getNumOperands() == Mask->getNumOperands() && !AExt); + for (unsigned I = 0; I < Val->getNumOperands(); ++I) { + // We know that the AExt is null and therefore the result of this call + // will be the BUILD_VECTOR operand or nullopt. Rather than create a new + // BUILD_VECTOR with the collection of operands, we can just use the + // original and ignore the result. + if (!canEliminateLoadAND(N, DCI, Val->getOperand(I), Mask->getOperand(I), + AExt) + .has_value()) + return SDValue(); } - - // If we get here, the AND is unnecessary. Just replace it with the load - DCI.CombineTo(N, Val, AddTo); + DCI.CombineTo(N, Val, false); + } else { + auto Result = canEliminateLoadAND(N, DCI, Val, Mask, AExt); + if (Result.has_value()) + DCI.CombineTo(N, Result.value(), AExt.getNode() != nullptr); } return SDValue(); diff --git a/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll b/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll index 7463fb5b59e02..f4053d84593a5 100644 --- a/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll +++ b/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll @@ -107,32 +107,27 @@ define <2 x i8> @test_call_2xi8(<2 x i8> %a) { define <2 x float> @test_uitofp_2xi8(<2 x i8> %a) { ; O0-LABEL: test_uitofp_2xi8( ; O0: { -; O0-NEXT: .reg .b16 %rs<5>; -; O0-NEXT: .reg .b32 %r<5>; +; O0-NEXT: .reg .b16 %rs<3>; +; O0-NEXT: .reg .b32 %r<4>; ; O0-EMPTY: ; O0-NEXT: // %bb.0: ; O0-NEXT: ld.param.v2.b8 {%rs1, %rs2}, [test_uitofp_2xi8_param_0]; ; O0-NEXT: mov.b32 %r1, {%rs1, %rs2}; -; O0-NEXT: and.b32 %r2, %r1, 16711935; -; O0-NEXT: mov.b32 {%rs3, %rs4}, %r2; -; O0-NEXT: cvt.rn.f32.u16 %r3, %rs4; -; O0-NEXT: cvt.rn.f32.u16 %r4, %rs3; -; O0-NEXT: st.param.v2.b32 [func_retval0], {%r4, %r3}; +; O0-NEXT: cvt.rn.f32.u16 %r2, %rs2; +; O0-NEXT: cvt.rn.f32.u16 %r3, %rs1; +; O0-NEXT: st.param.v2.b32 [func_retval0], {%r3, %r2}; ; O0-NEXT: ret; ; ; O3-LABEL: test_uitofp_2xi8( ; O3: { -; O3-NEXT: .reg .b16 %rs<5>; -; O3-NEXT: .reg .b32 %r<5>; +; O3-NEXT: .reg .b16 %rs<3>; +; O3-NEXT: .reg .b32 %r<3>; ; O3-EMPTY: ; O3-NEXT: // %bb.0: ; O3-NEXT: ld.param.v2.b8 {%rs1, %rs2}, [test_uitofp_2xi8_param_0]; -; O3-NEXT: mov.b32 %r1, {%rs1, %rs2}; -; O3-NEXT: and.b32 %r2, %r1, 16711935; -; O3-NEXT: mov.b32 {%rs3, %rs4}, %r2; -; O3-NEXT: cvt.rn.f32.u16 %r3, %rs4; -; O3-NEXT: cvt.rn.f32.u16 %r4, %rs3; -; O3-NEXT: st.param.v2.b32 [func_retval0], {%r4, %r3}; +; O3-NEXT: cvt.rn.f32.u16 %r1, %rs2; +; O3-NEXT: cvt.rn.f32.u16 %r2, %rs1; +; O3-NEXT: st.param.v2.b32 [func_retval0], {%r2, %r1}; ; O3-NEXT: ret; %1 = uitofp <2 x i8> %a to <2 x float> ret <2 x float> %1 diff --git a/llvm/test/CodeGen/NVPTX/shift-opt.ll b/llvm/test/CodeGen/NVPTX/shift-opt.ll index e7866b01064c7..e0d22c62993ba 100644 --- a/llvm/test/CodeGen/NVPTX/shift-opt.ll +++ b/llvm/test/CodeGen/NVPTX/shift-opt.ll @@ -71,18 +71,17 @@ define <2 x i16> @test_vec(<2 x i16> %x, <2 x i8> %y) { ; CHECK-LABEL: test_vec( ; CHECK: { ; CHECK-NEXT: .reg .b16 %rs<7>; -; CHECK-NEXT: .reg .b32 %r<5>; +; CHECK-NEXT: .reg .b32 %r<4>; ; CHECK-EMPTY: ; CHECK-NEXT: // %bb.0: ; CHECK-NEXT: ld.param.v2.b16 {%rs1, %rs2}, [test_vec_param_0]; ; CHECK-NEXT: ld.param.v2.b8 {%rs3, %rs4}, [test_vec_param_1]; ; CHECK-NEXT: mov.b32 %r1, {%rs3, %rs4}; -; CHECK-NEXT: and.b32 %r2, %r1, 16711935; ; CHECK-NEXT: shr.u16 %rs5, %rs2, 5; ; CHECK-NEXT: shr.u16 %rs6, %rs1, 5; -; CHECK-NEXT: mov.b32 %r3, {%rs6, %rs5}; -; CHECK-NEXT: or.b32 %r4, %r3, %r2; -; CHECK-NEXT: st.param.b32 [func_retval0], %r4; +; CHECK-NEXT: mov.b32 %r2, {%rs6, %rs5}; +; CHECK-NEXT: or.b32 %r3, %r2, %r1; +; CHECK-NEXT: st.param.b32 [func_retval0], %r3; ; CHECK-NEXT: ret; %ext = zext <2 x i8> %y to <2 x i16> %shl = shl <2 x i16> %ext, splat(i16 5) From 885dcae41ee122734c5c77a0fcb4e5f99c903cb8 Mon Sep 17 00:00:00 2001 From: Kevin McAfee Date: Mon, 18 Aug 2025 23:51:51 +0000 Subject: [PATCH 3/6] Remove AND combine and implement computeKnownBits for LoadV2/4 --- llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 123 ++++---------------- 1 file changed, 25 insertions(+), 98 deletions(-) diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index 8190c23407250..67a5f2ae84c31 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -5242,102 +5242,6 @@ static SDValue PerformFADDCombine(SDNode *N, return PerformFADDCombineWithOperands(N, N1, N0, DCI, OptLevel); } -// Helper function to check if an AND operation on a load can be eliminated. -// Returns a replacement value if the load can be eliminated, else nullopt. -static std::optional -canEliminateLoadAND(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, - SDValue Val, SDValue Mask, SDValue AExt) { - if (Val->getOpcode() != NVPTXISD::LoadV2 && - Val->getOpcode() != NVPTXISD::LoadV4) { - return std::nullopt; - } - - ConstantSDNode *MaskCnst = dyn_cast(Mask); - if (!MaskCnst) { - // Not an AND with a constant - return std::nullopt; - } - - uint64_t MaskVal = MaskCnst->getZExtValue(); - if (MaskVal != 0xff) { - // Not an AND that chops off top 8 bits - return std::nullopt; - } - - MemSDNode *Mem = dyn_cast(Val); - if (!Mem) { - // Not a MemSDNode - return std::nullopt; - } - - EVT MemVT = Mem->getMemoryVT(); - if (MemVT != MVT::v2i8 && MemVT != MVT::v4i8) { - // We only handle the i8 case - return std::nullopt; - } - - unsigned ExtType = Val->getConstantOperandVal(Val->getNumOperands() - 1); - if (ExtType == ISD::SEXTLOAD) { - // If the load is a sextload, the AND is needed to zero out the high 8 bits - return std::nullopt; - } - - SDValue Result = Val; - - if (AExt) { - // Re-insert the ext as a zext. - Result = - DCI.DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), AExt.getValueType(), Val); - } - - // If we get here, the AND is unnecessary. Replace it with the load. - return Result; -} - -static SDValue PerformANDCombine(SDNode *N, - TargetLowering::DAGCombinerInfo &DCI) { - // The type legalizer turns a vector load of i8 values into a zextload to i16 - // registers, optionally ANY_EXTENDs it (if target type is integer), - // and ANDs off the high 8 bits. Since we turn this load into a - // target-specific DAG node, the DAG combiner fails to eliminate these AND - // nodes. Do that here. - SDValue Val = N->getOperand(0); - SDValue Mask = N->getOperand(1); - - if (isa(Val)) - std::swap(Val, Mask); - - SDValue AExt; - - // Generally, we will see zextload -> IMOV16rr -> ANY_EXTEND -> and - if (Val.getOpcode() == ISD::ANY_EXTEND) { - AExt = Val; - Val = Val->getOperand(0); - } - - if (Val.getOpcode() == ISD::BUILD_VECTOR && - Mask.getOpcode() == ISD::BUILD_VECTOR) { - assert(Val->getNumOperands() == Mask->getNumOperands() && !AExt); - for (unsigned I = 0; I < Val->getNumOperands(); ++I) { - // We know that the AExt is null and therefore the result of this call - // will be the BUILD_VECTOR operand or nullopt. Rather than create a new - // BUILD_VECTOR with the collection of operands, we can just use the - // original and ignore the result. - if (!canEliminateLoadAND(N, DCI, Val->getOperand(I), Mask->getOperand(I), - AExt) - .has_value()) - return SDValue(); - } - DCI.CombineTo(N, Val, false); - } else { - auto Result = canEliminateLoadAND(N, DCI, Val, Mask, AExt); - if (Result.has_value()) - DCI.CombineTo(N, Result.value(), AExt.getNode() != nullptr); - } - - return SDValue(); -} - static SDValue PerformREMCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, CodeGenOptLevel OptLevel) { @@ -6009,8 +5913,6 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N, return PerformADDCombine(N, DCI, OptLevel); case ISD::ADDRSPACECAST: return combineADDRSPACECAST(N, DCI); - case ISD::AND: - return PerformANDCombine(N, DCI); case ISD::SIGN_EXTEND: case ISD::ZERO_EXTEND: return combineMulWide(N, DCI, OptLevel); @@ -6635,6 +6537,27 @@ static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known, } } +static void computeKnownBitsFori8VLoad(const SDValue Op, KnownBits &Known) { + MemSDNode *Mem = dyn_cast(Op); + if (!Mem) { + return; + } + + EVT MemVT = Mem->getMemoryVT(); + if (MemVT != MVT::v2i8 && MemVT != MVT::v4i8) { + return; + } + + unsigned ExtType = Mem->getConstantOperandVal(Mem->getNumOperands() - 1); + if (ExtType == ISD::SEXTLOAD) { + Known = Known.sext(Known.getBitWidth()); + return; + } + KnownBits HighZeros(Known.getBitWidth() - 8); + HighZeros.setAllZero(); + Known.insertBits(HighZeros, 8); +} + void NVPTXTargetLowering::computeKnownBitsForTargetNode( const SDValue Op, KnownBits &Known, const APInt &DemandedElts, const SelectionDAG &DAG, unsigned Depth) const { @@ -6644,6 +6567,10 @@ void NVPTXTargetLowering::computeKnownBitsForTargetNode( case NVPTXISD::PRMT: computeKnownBitsForPRMT(Op, Known, DAG, Depth); break; + case NVPTXISD::LoadV2: + case NVPTXISD::LoadV4: + computeKnownBitsFori8VLoad(Op, Known); + break; default: break; } From 2a2215ab6b8546472352ca8f04935b84b627f175 Mon Sep 17 00:00:00 2001 From: Kevin McAfee Date: Wed, 20 Aug 2025 00:11:45 +0000 Subject: [PATCH 4/6] generalize for more types --- llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp | 23 +++++---------- llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h | 1 + llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 31 +++++++++++---------- 3 files changed, 24 insertions(+), 31 deletions(-) diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp index 520ce4deb9a57..d86c0905cf943 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp @@ -1309,26 +1309,17 @@ bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) { return true; } +unsigned NVPTXDAGToDAGISel::getFromTypeWidthForLoad(const MemSDNode *Mem) { + EVT MemVT = Mem->getMemoryVT(); + auto ElementBitWidth = MemVT.getSizeInBits() / (Mem->getNumValues() - 1); + return ElementBitWidth; +} + bool NVPTXDAGToDAGISel::tryLDU(SDNode *N) { auto *LD = cast(N); - unsigned NumElts; - switch (N->getOpcode()) { - default: - llvm_unreachable("Unexpected opcode"); - case ISD::INTRINSIC_W_CHAIN: - NumElts = 1; - break; - case NVPTXISD::LDUV2: - NumElts = 2; - break; - case NVPTXISD::LDUV4: - NumElts = 4; - break; - } - SDLoc DL(N); - const unsigned FromTypeWidth = LD->getMemoryVT().getSizeInBits() / NumElts; + const unsigned FromTypeWidth = getFromTypeWidthForLoad(LD); const MVT::SimpleValueType TargetVT = LD->getSimpleValueType(0).SimpleTy; // If this is an LDU intrinsic, the address is the third operand. If its an diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h index 65731722f5343..e2ad55bc1796d 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h +++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h @@ -111,6 +111,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel { public: static NVPTX::AddressSpace getAddrSpace(const MemSDNode *N); + static unsigned getFromTypeWidthForLoad(const MemSDNode *Mem); }; class NVPTXDAGToDAGISelLegacy : public SelectionDAGISelLegacy { diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index 67a5f2ae84c31..0d96c9353141a 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -14,6 +14,7 @@ #include "NVPTXISelLowering.h" #include "MCTargetDesc/NVPTXBaseInfo.h" #include "NVPTX.h" +#include "NVPTXISelDAGToDAG.h" #include "NVPTXSubtarget.h" #include "NVPTXTargetMachine.h" #include "NVPTXTargetObjectFile.h" @@ -6537,25 +6538,24 @@ static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known, } } -static void computeKnownBitsFori8VLoad(const SDValue Op, KnownBits &Known) { - MemSDNode *Mem = dyn_cast(Op); - if (!Mem) { - return; - } +static void computeKnownBitsForVLoad(const SDValue Op, KnownBits &Known) { + MemSDNode *LD = cast(Op); - EVT MemVT = Mem->getMemoryVT(); - if (MemVT != MVT::v2i8 && MemVT != MVT::v4i8) { + // We can't do anything without knowing the sign bit. + auto ExtType = LD->getConstantOperandVal(LD->getNumOperands() - 1); + if (ExtType == ISD::SEXTLOAD) return; - } - unsigned ExtType = Mem->getConstantOperandVal(Mem->getNumOperands() - 1); - if (ExtType == ISD::SEXTLOAD) { - Known = Known.sext(Known.getBitWidth()); + // ExtLoading to vector types is weird and may not work well with known bits. + auto DestVT = LD->getValueType(0); + if (DestVT.isVector()) return; - } - KnownBits HighZeros(Known.getBitWidth() - 8); + + assert(Known.getBitWidth() == DestVT.getSizeInBits()); + auto ElementBitWidth = NVPTXDAGToDAGISel::getFromTypeWidthForLoad(LD); + KnownBits HighZeros(Known.getBitWidth() - ElementBitWidth); HighZeros.setAllZero(); - Known.insertBits(HighZeros, 8); + Known.insertBits(HighZeros, ElementBitWidth); } void NVPTXTargetLowering::computeKnownBitsForTargetNode( @@ -6569,7 +6569,8 @@ void NVPTXTargetLowering::computeKnownBitsForTargetNode( break; case NVPTXISD::LoadV2: case NVPTXISD::LoadV4: - computeKnownBitsFori8VLoad(Op, Known); + case NVPTXISD::LoadV8: + computeKnownBitsForVLoad(Op, Known); break; default: break; From ef99d229ac828cae59088ff0d7b788b3c7e7987f Mon Sep 17 00:00:00 2001 From: Kevin McAfee Date: Wed, 20 Aug 2025 18:24:27 +0000 Subject: [PATCH 5/6] Refactor/expand use of getFromTypeWidthForLoad --- llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp | 30 +++++++-------------- llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 4 +-- 2 files changed, 11 insertions(+), 23 deletions(-) diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp index d86c0905cf943..3300ed9a5a81c 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp @@ -1150,15 +1150,12 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) { return true; } -static unsigned getLoadStoreVectorNumElts(SDNode *N) { +static unsigned getStoreVectorNumElts(SDNode *N) { switch (N->getOpcode()) { - case NVPTXISD::LoadV2: case NVPTXISD::StoreV2: return 2; - case NVPTXISD::LoadV4: case NVPTXISD::StoreV4: return 4; - case NVPTXISD::LoadV8: case NVPTXISD::StoreV8: return 8; default: @@ -1171,7 +1168,6 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) { const EVT MemEVT = LD->getMemoryVT(); if (!MemEVT.isSimple()) return false; - const MVT MemVT = MemEVT.getSimpleVT(); // Address Space Setting const auto CodeAddrSpace = getAddrSpace(LD); @@ -1191,18 +1187,15 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) { // Float : ISD::NON_EXTLOAD or ISD::EXTLOAD and the type is float // Read at least 8 bits (predicates are stored as 8-bit values) // The last operand holds the original LoadSDNode::getExtensionType() value - const unsigned TotalWidth = MemVT.getSizeInBits(); const unsigned ExtensionType = N->getConstantOperandVal(N->getNumOperands() - 1); const unsigned FromType = (ExtensionType == ISD::SEXTLOAD) ? NVPTX::PTXLdStInstCode::Signed : NVPTX::PTXLdStInstCode::Untyped; - const unsigned FromTypeWidth = TotalWidth / getLoadStoreVectorNumElts(N); + const unsigned FromTypeWidth = getFromTypeWidthForLoad(LD); assert(!(EltVT.isVector() && ExtensionType != ISD::NON_EXTLOAD)); - assert(isPowerOf2_32(FromTypeWidth) && FromTypeWidth >= 8 && - FromTypeWidth <= 128 && TotalWidth <= 256 && "Invalid width for load"); const auto [Base, Offset] = selectADDR(N->getOperand(1), CurDAG); SDValue Ops[] = {getI32Imm(Ordering, DL), @@ -1247,30 +1240,23 @@ bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) { const EVT LoadedEVT = LD->getMemoryVT(); if (!LoadedEVT.isSimple()) return false; - const MVT LoadedVT = LoadedEVT.getSimpleVT(); SDLoc DL(LD); - const unsigned TotalWidth = LoadedVT.getSizeInBits(); unsigned ExtensionType; - unsigned NumElts; if (const auto *Load = dyn_cast(LD)) { ExtensionType = Load->getExtensionType(); - NumElts = 1; } else { ExtensionType = LD->getConstantOperandVal(LD->getNumOperands() - 1); - NumElts = getLoadStoreVectorNumElts(LD); } const unsigned FromType = (ExtensionType == ISD::SEXTLOAD) ? NVPTX::PTXLdStInstCode::Signed : NVPTX::PTXLdStInstCode::Untyped; - const unsigned FromTypeWidth = TotalWidth / NumElts; + const unsigned FromTypeWidth = getFromTypeWidthForLoad(LD); assert(!(LD->getSimpleValueType(0).isVector() && ExtensionType != ISD::NON_EXTLOAD)); - assert(isPowerOf2_32(FromTypeWidth) && FromTypeWidth >= 8 && - FromTypeWidth <= 128 && TotalWidth <= 256 && "Invalid width for load"); const auto [Base, Offset] = selectADDR(LD->getOperand(1), CurDAG); SDValue Ops[] = {getI32Imm(FromType, DL), getI32Imm(FromTypeWidth, DL), Base, @@ -1310,8 +1296,12 @@ bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) { } unsigned NVPTXDAGToDAGISel::getFromTypeWidthForLoad(const MemSDNode *Mem) { - EVT MemVT = Mem->getMemoryVT(); - auto ElementBitWidth = MemVT.getSizeInBits() / (Mem->getNumValues() - 1); + auto TotalWidth = Mem->getMemoryVT().getSizeInBits(); + auto NumElts = Mem->getNumValues() - 1; + auto ElementBitWidth = TotalWidth / NumElts; + assert(isPowerOf2_32(ElementBitWidth) && ElementBitWidth >= 8 && + ElementBitWidth <= 128 && TotalWidth <= 256 && + "Invalid width for load"); return ElementBitWidth; } @@ -1434,7 +1424,7 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) { // - for integer type, always use 'u' const unsigned TotalWidth = StoreVT.getSimpleVT().getSizeInBits(); - const unsigned NumElts = getLoadStoreVectorNumElts(ST); + const unsigned NumElts = getStoreVectorNumElts(ST); SmallVector Ops; for (auto &V : ST->ops().slice(1, NumElts)) diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index 0d96c9353141a..c1d3e9567d85d 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -6553,9 +6553,7 @@ static void computeKnownBitsForVLoad(const SDValue Op, KnownBits &Known) { assert(Known.getBitWidth() == DestVT.getSizeInBits()); auto ElementBitWidth = NVPTXDAGToDAGISel::getFromTypeWidthForLoad(LD); - KnownBits HighZeros(Known.getBitWidth() - ElementBitWidth); - HighZeros.setAllZero(); - Known.insertBits(HighZeros, ElementBitWidth); + Known.Zero.setHighBits(Known.getBitWidth() - ElementBitWidth); } void NVPTXTargetLowering::computeKnownBitsForTargetNode( From 4098eee4606980e84a2fe6e6baae74ca2496898b Mon Sep 17 00:00:00 2001 From: Kevin McAfee Date: Wed, 20 Aug 2025 18:27:47 +0000 Subject: [PATCH 6/6] rename --- llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index c1d3e9567d85d..ad56d2f12caf6 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -6538,7 +6538,7 @@ static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known, } } -static void computeKnownBitsForVLoad(const SDValue Op, KnownBits &Known) { +static void computeKnownBitsForLoadV(const SDValue Op, KnownBits &Known) { MemSDNode *LD = cast(Op); // We can't do anything without knowing the sign bit. @@ -6568,7 +6568,7 @@ void NVPTXTargetLowering::computeKnownBitsForTargetNode( case NVPTXISD::LoadV2: case NVPTXISD::LoadV4: case NVPTXISD::LoadV8: - computeKnownBitsForVLoad(Op, Known); + computeKnownBitsForLoadV(Op, Known); break; default: break;