-
Notifications
You must be signed in to change notification settings - Fork 15.5k
[NVPTX] Implement computeKnownBitsForTargetNode for LoadV #154165
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-backend-nvptx Author: Kevin McAfee (kalxr) ChangesThe logic is the same as the original combine with some minor adaptations to enable reuse for vectors. Full diff: https://github.com/llvm/llvm-project/pull/154165.diff 3 Files Affected:
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 74e6c139c610d..8190c23407250 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -5242,6 +5242,58 @@ static SDValue PerformFADDCombine(SDNode *N,
return PerformFADDCombineWithOperands(N, N1, N0, DCI, OptLevel);
}
+// Helper function to check if an AND operation on a load can be eliminated.
+// Returns a replacement value if the load can be eliminated, else nullopt.
+static std::optional<SDValue>
+canEliminateLoadAND(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
+ SDValue Val, SDValue Mask, SDValue AExt) {
+ if (Val->getOpcode() != NVPTXISD::LoadV2 &&
+ Val->getOpcode() != NVPTXISD::LoadV4) {
+ return std::nullopt;
+ }
+
+ ConstantSDNode *MaskCnst = dyn_cast<ConstantSDNode>(Mask);
+ if (!MaskCnst) {
+ // Not an AND with a constant
+ return std::nullopt;
+ }
+
+ uint64_t MaskVal = MaskCnst->getZExtValue();
+ if (MaskVal != 0xff) {
+ // Not an AND that chops off top 8 bits
+ return std::nullopt;
+ }
+
+ MemSDNode *Mem = dyn_cast<MemSDNode>(Val);
+ if (!Mem) {
+ // Not a MemSDNode
+ return std::nullopt;
+ }
+
+ EVT MemVT = Mem->getMemoryVT();
+ if (MemVT != MVT::v2i8 && MemVT != MVT::v4i8) {
+ // We only handle the i8 case
+ return std::nullopt;
+ }
+
+ unsigned ExtType = Val->getConstantOperandVal(Val->getNumOperands() - 1);
+ if (ExtType == ISD::SEXTLOAD) {
+ // If the load is a sextload, the AND is needed to zero out the high 8 bits
+ return std::nullopt;
+ }
+
+ SDValue Result = Val;
+
+ if (AExt) {
+ // Re-insert the ext as a zext.
+ Result =
+ DCI.DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), AExt.getValueType(), Val);
+ }
+
+ // If we get here, the AND is unnecessary. Replace it with the load.
+ return Result;
+}
+
static SDValue PerformANDCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI) {
// The type legalizer turns a vector load of i8 values into a zextload to i16
@@ -5252,9 +5304,8 @@ static SDValue PerformANDCombine(SDNode *N,
SDValue Val = N->getOperand(0);
SDValue Mask = N->getOperand(1);
- if (isa<ConstantSDNode>(Val)) {
+ if (isa<ConstantSDNode>(Val))
std::swap(Val, Mask);
- }
SDValue AExt;
@@ -5264,49 +5315,24 @@ static SDValue PerformANDCombine(SDNode *N,
Val = Val->getOperand(0);
}
- if (Val->getOpcode() == NVPTXISD::LoadV2 ||
- Val->getOpcode() == NVPTXISD::LoadV4) {
- ConstantSDNode *MaskCnst = dyn_cast<ConstantSDNode>(Mask);
- if (!MaskCnst) {
- // Not an AND with a constant
- return SDValue();
- }
-
- uint64_t MaskVal = MaskCnst->getZExtValue();
- if (MaskVal != 0xff) {
- // Not an AND that chops off top 8 bits
- return SDValue();
- }
-
- MemSDNode *Mem = dyn_cast<MemSDNode>(Val);
- if (!Mem) {
- // Not a MemSDNode?!?
- return SDValue();
- }
-
- EVT MemVT = Mem->getMemoryVT();
- if (MemVT != MVT::v2i8 && MemVT != MVT::v4i8) {
- // We only handle the i8 case
- return SDValue();
- }
-
- unsigned ExtType = Val->getConstantOperandVal(Val->getNumOperands() - 1);
- if (ExtType == ISD::SEXTLOAD) {
- // If for some reason the load is a sextload, the and is needed to zero
- // out the high 8 bits
- return SDValue();
- }
-
- bool AddTo = false;
- if (AExt.getNode() != nullptr) {
- // Re-insert the ext as a zext.
- Val = DCI.DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N),
- AExt.getValueType(), Val);
- AddTo = true;
+ if (Val.getOpcode() == ISD::BUILD_VECTOR &&
+ Mask.getOpcode() == ISD::BUILD_VECTOR) {
+ assert(Val->getNumOperands() == Mask->getNumOperands() && !AExt);
+ for (unsigned I = 0; I < Val->getNumOperands(); ++I) {
+ // We know that the AExt is null and therefore the result of this call
+ // will be the BUILD_VECTOR operand or nullopt. Rather than create a new
+ // BUILD_VECTOR with the collection of operands, we can just use the
+ // original and ignore the result.
+ if (!canEliminateLoadAND(N, DCI, Val->getOperand(I), Mask->getOperand(I),
+ AExt)
+ .has_value())
+ return SDValue();
}
-
- // If we get here, the AND is unnecessary. Just replace it with the load
- DCI.CombineTo(N, Val, AddTo);
+ DCI.CombineTo(N, Val, false);
+ } else {
+ auto Result = canEliminateLoadAND(N, DCI, Val, Mask, AExt);
+ if (Result.has_value())
+ DCI.CombineTo(N, Result.value(), AExt.getNode() != nullptr);
}
return SDValue();
diff --git a/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll b/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll
index 98f94bb7b3ac1..77398b5fa41cb 100644
--- a/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll
@@ -103,5 +103,34 @@ define <2 x i8> @test_call_2xi8(<2 x i8> %a) {
%res = call <2 x i8> @test_call_2xi8(<2 x i8> %a)
ret <2 x i8> %res
}
+
+define <2 x float> @test_uitofp_2xi8(<2 x i8> %a) {
+; O0-LABEL: test_uitofp_2xi8(
+; O0: {
+; O0-NEXT: .reg .b16 %rs<3>;
+; O0-NEXT: .reg .b32 %r<4>;
+; O0-EMPTY:
+; O0-NEXT: // %bb.0:
+; O0-NEXT: ld.param.v2.b8 {%rs1, %rs2}, [test_uitofp_2xi8_param_0];
+; O0-NEXT: mov.b32 %r1, {%rs1, %rs2};
+; O0-NEXT: cvt.rn.f32.u16 %r2, %rs2;
+; O0-NEXT: cvt.rn.f32.u16 %r3, %rs1;
+; O0-NEXT: st.param.v2.b32 [func_retval0], {%r3, %r2};
+; O0-NEXT: ret;
+;
+; O3-LABEL: test_uitofp_2xi8(
+; O3: {
+; O3-NEXT: .reg .b16 %rs<3>;
+; O3-NEXT: .reg .b32 %r<3>;
+; O3-EMPTY:
+; O3-NEXT: // %bb.0:
+; O3-NEXT: ld.param.v2.b8 {%rs1, %rs2}, [test_uitofp_2xi8_param_0];
+; O3-NEXT: cvt.rn.f32.u16 %r1, %rs2;
+; O3-NEXT: cvt.rn.f32.u16 %r2, %rs1;
+; O3-NEXT: st.param.v2.b32 [func_retval0], {%r2, %r1};
+; O3-NEXT: ret;
+ %1 = uitofp <2 x i8> %a to <2 x float>
+ ret <2 x float> %1
+}
;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
; COMMON: {{.*}}
diff --git a/llvm/test/CodeGen/NVPTX/shift-opt.ll b/llvm/test/CodeGen/NVPTX/shift-opt.ll
index e7866b01064c7..e0d22c62993ba 100644
--- a/llvm/test/CodeGen/NVPTX/shift-opt.ll
+++ b/llvm/test/CodeGen/NVPTX/shift-opt.ll
@@ -71,18 +71,17 @@ define <2 x i16> @test_vec(<2 x i16> %x, <2 x i8> %y) {
; CHECK-LABEL: test_vec(
; CHECK: {
; CHECK-NEXT: .reg .b16 %rs<7>;
-; CHECK-NEXT: .reg .b32 %r<5>;
+; CHECK-NEXT: .reg .b32 %r<4>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.v2.b16 {%rs1, %rs2}, [test_vec_param_0];
; CHECK-NEXT: ld.param.v2.b8 {%rs3, %rs4}, [test_vec_param_1];
; CHECK-NEXT: mov.b32 %r1, {%rs3, %rs4};
-; CHECK-NEXT: and.b32 %r2, %r1, 16711935;
; CHECK-NEXT: shr.u16 %rs5, %rs2, 5;
; CHECK-NEXT: shr.u16 %rs6, %rs1, 5;
-; CHECK-NEXT: mov.b32 %r3, {%rs6, %rs5};
-; CHECK-NEXT: or.b32 %r4, %r3, %r2;
-; CHECK-NEXT: st.param.b32 [func_retval0], %r4;
+; CHECK-NEXT: mov.b32 %r2, {%rs6, %rs5};
+; CHECK-NEXT: or.b32 %r3, %r2, %r1;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
; CHECK-NEXT: ret;
%ext = zext <2 x i8> %y to <2 x i16>
%shl = shl <2 x i16> %ext, splat(i16 5)
|
AlexMaclean
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice this looks simpler. I'd prefer that we get even a little more generic still though. I think currently the only time we'll ever generate extending vector loads is for i8 but that might change in the future and I think it would not be too complex to implement a generic version of this function which does something like the following: get the size in bits of the memVT and divide that by the number of elements loaded then get the difference between that value and the actual size of the produced VT and clear that many high bits.
AlexMaclean
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a couple minor nits but this looks good to me. Thanks!
5469878 to
4098eee
Compare
Remove AND combines as they are no longer needed after this.