Skip to content

Conversation

@kalxr
Copy link
Contributor

@kalxr kalxr commented Aug 18, 2025

Remove AND combines as they are no longer needed after this.

@llvmbot
Copy link
Member

llvmbot commented Aug 18, 2025

@llvm/pr-subscribers-backend-nvptx

Author: Kevin McAfee (kalxr)

Changes

The logic is the same as the original combine with some minor adaptations to enable reuse for vectors.


Full diff: https://github.com/llvm/llvm-project/pull/154165.diff

3 Files Affected:

  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+70-44)
  • (modified) llvm/test/CodeGen/NVPTX/i8x2-instructions.ll (+29)
  • (modified) llvm/test/CodeGen/NVPTX/shift-opt.ll (+4-5)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 74e6c139c610d..8190c23407250 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -5242,6 +5242,58 @@ static SDValue PerformFADDCombine(SDNode *N,
   return PerformFADDCombineWithOperands(N, N1, N0, DCI, OptLevel);
 }
 
+// Helper function to check if an AND operation on a load can be eliminated.
+// Returns a replacement value if the load can be eliminated, else nullopt.
+static std::optional<SDValue>
+canEliminateLoadAND(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
+                    SDValue Val, SDValue Mask, SDValue AExt) {
+  if (Val->getOpcode() != NVPTXISD::LoadV2 &&
+      Val->getOpcode() != NVPTXISD::LoadV4) {
+    return std::nullopt;
+  }
+
+  ConstantSDNode *MaskCnst = dyn_cast<ConstantSDNode>(Mask);
+  if (!MaskCnst) {
+    // Not an AND with a constant
+    return std::nullopt;
+  }
+
+  uint64_t MaskVal = MaskCnst->getZExtValue();
+  if (MaskVal != 0xff) {
+    // Not an AND that chops off top 8 bits
+    return std::nullopt;
+  }
+
+  MemSDNode *Mem = dyn_cast<MemSDNode>(Val);
+  if (!Mem) {
+    // Not a MemSDNode
+    return std::nullopt;
+  }
+
+  EVT MemVT = Mem->getMemoryVT();
+  if (MemVT != MVT::v2i8 && MemVT != MVT::v4i8) {
+    // We only handle the i8 case
+    return std::nullopt;
+  }
+
+  unsigned ExtType = Val->getConstantOperandVal(Val->getNumOperands() - 1);
+  if (ExtType == ISD::SEXTLOAD) {
+    // If the load is a sextload, the AND is needed to zero out the high 8 bits
+    return std::nullopt;
+  }
+
+  SDValue Result = Val;
+
+  if (AExt) {
+    // Re-insert the ext as a zext.
+    Result =
+        DCI.DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), AExt.getValueType(), Val);
+  }
+
+  // If we get here, the AND is unnecessary. Replace it with the load.
+  return Result;
+}
+
 static SDValue PerformANDCombine(SDNode *N,
                                  TargetLowering::DAGCombinerInfo &DCI) {
   // The type legalizer turns a vector load of i8 values into a zextload to i16
@@ -5252,9 +5304,8 @@ static SDValue PerformANDCombine(SDNode *N,
   SDValue Val = N->getOperand(0);
   SDValue Mask = N->getOperand(1);
 
-  if (isa<ConstantSDNode>(Val)) {
+  if (isa<ConstantSDNode>(Val))
     std::swap(Val, Mask);
-  }
 
   SDValue AExt;
 
@@ -5264,49 +5315,24 @@ static SDValue PerformANDCombine(SDNode *N,
     Val = Val->getOperand(0);
   }
 
-  if (Val->getOpcode() == NVPTXISD::LoadV2 ||
-      Val->getOpcode() == NVPTXISD::LoadV4) {
-    ConstantSDNode *MaskCnst = dyn_cast<ConstantSDNode>(Mask);
-    if (!MaskCnst) {
-      // Not an AND with a constant
-      return SDValue();
-    }
-
-    uint64_t MaskVal = MaskCnst->getZExtValue();
-    if (MaskVal != 0xff) {
-      // Not an AND that chops off top 8 bits
-      return SDValue();
-    }
-
-    MemSDNode *Mem = dyn_cast<MemSDNode>(Val);
-    if (!Mem) {
-      // Not a MemSDNode?!?
-      return SDValue();
-    }
-
-    EVT MemVT = Mem->getMemoryVT();
-    if (MemVT != MVT::v2i8 && MemVT != MVT::v4i8) {
-      // We only handle the i8 case
-      return SDValue();
-    }
-
-    unsigned ExtType = Val->getConstantOperandVal(Val->getNumOperands() - 1);
-    if (ExtType == ISD::SEXTLOAD) {
-      // If for some reason the load is a sextload, the and is needed to zero
-      // out the high 8 bits
-      return SDValue();
-    }
-
-    bool AddTo = false;
-    if (AExt.getNode() != nullptr) {
-      // Re-insert the ext as a zext.
-      Val = DCI.DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N),
-                            AExt.getValueType(), Val);
-      AddTo = true;
+  if (Val.getOpcode() == ISD::BUILD_VECTOR &&
+      Mask.getOpcode() == ISD::BUILD_VECTOR) {
+    assert(Val->getNumOperands() == Mask->getNumOperands() && !AExt);
+    for (unsigned I = 0; I < Val->getNumOperands(); ++I) {
+      // We know that the AExt is null and therefore the result of this call
+      // will be the BUILD_VECTOR operand or nullopt. Rather than create a new
+      // BUILD_VECTOR with the collection of operands, we can just use the
+      // original and ignore the result.
+      if (!canEliminateLoadAND(N, DCI, Val->getOperand(I), Mask->getOperand(I),
+                               AExt)
+               .has_value())
+        return SDValue();
     }
-
-    // If we get here, the AND is unnecessary.  Just replace it with the load
-    DCI.CombineTo(N, Val, AddTo);
+    DCI.CombineTo(N, Val, false);
+  } else {
+    auto Result = canEliminateLoadAND(N, DCI, Val, Mask, AExt);
+    if (Result.has_value())
+      DCI.CombineTo(N, Result.value(), AExt.getNode() != nullptr);
   }
 
   return SDValue();
diff --git a/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll b/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll
index 98f94bb7b3ac1..77398b5fa41cb 100644
--- a/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll
@@ -103,5 +103,34 @@ define <2 x i8> @test_call_2xi8(<2 x i8> %a) {
   %res = call <2 x i8> @test_call_2xi8(<2 x i8> %a)
   ret <2 x i8> %res
 }
+
+define <2 x float> @test_uitofp_2xi8(<2 x i8> %a) {
+; O0-LABEL: test_uitofp_2xi8(
+; O0:       {
+; O0-NEXT:    .reg .b16 %rs<3>;
+; O0-NEXT:    .reg .b32 %r<4>;
+; O0-EMPTY:
+; O0-NEXT:  // %bb.0:
+; O0-NEXT:    ld.param.v2.b8 {%rs1, %rs2}, [test_uitofp_2xi8_param_0];
+; O0-NEXT:    mov.b32 %r1, {%rs1, %rs2};
+; O0-NEXT:    cvt.rn.f32.u16 %r2, %rs2;
+; O0-NEXT:    cvt.rn.f32.u16 %r3, %rs1;
+; O0-NEXT:    st.param.v2.b32 [func_retval0], {%r3, %r2};
+; O0-NEXT:    ret;
+;
+; O3-LABEL: test_uitofp_2xi8(
+; O3:       {
+; O3-NEXT:    .reg .b16 %rs<3>;
+; O3-NEXT:    .reg .b32 %r<3>;
+; O3-EMPTY:
+; O3-NEXT:  // %bb.0:
+; O3-NEXT:    ld.param.v2.b8 {%rs1, %rs2}, [test_uitofp_2xi8_param_0];
+; O3-NEXT:    cvt.rn.f32.u16 %r1, %rs2;
+; O3-NEXT:    cvt.rn.f32.u16 %r2, %rs1;
+; O3-NEXT:    st.param.v2.b32 [func_retval0], {%r2, %r1};
+; O3-NEXT:    ret;
+  %1 = uitofp <2 x i8> %a to <2 x float>
+  ret <2 x float> %1
+}
 ;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
 ; COMMON: {{.*}}
diff --git a/llvm/test/CodeGen/NVPTX/shift-opt.ll b/llvm/test/CodeGen/NVPTX/shift-opt.ll
index e7866b01064c7..e0d22c62993ba 100644
--- a/llvm/test/CodeGen/NVPTX/shift-opt.ll
+++ b/llvm/test/CodeGen/NVPTX/shift-opt.ll
@@ -71,18 +71,17 @@ define <2 x i16> @test_vec(<2 x i16> %x, <2 x i8> %y) {
 ; CHECK-LABEL: test_vec(
 ; CHECK:       {
 ; CHECK-NEXT:    .reg .b16 %rs<7>;
-; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-NEXT:    .reg .b32 %r<4>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
 ; CHECK-NEXT:    ld.param.v2.b16 {%rs1, %rs2}, [test_vec_param_0];
 ; CHECK-NEXT:    ld.param.v2.b8 {%rs3, %rs4}, [test_vec_param_1];
 ; CHECK-NEXT:    mov.b32 %r1, {%rs3, %rs4};
-; CHECK-NEXT:    and.b32 %r2, %r1, 16711935;
 ; CHECK-NEXT:    shr.u16 %rs5, %rs2, 5;
 ; CHECK-NEXT:    shr.u16 %rs6, %rs1, 5;
-; CHECK-NEXT:    mov.b32 %r3, {%rs6, %rs5};
-; CHECK-NEXT:    or.b32 %r4, %r3, %r2;
-; CHECK-NEXT:    st.param.b32 [func_retval0], %r4;
+; CHECK-NEXT:    mov.b32 %r2, {%rs6, %rs5};
+; CHECK-NEXT:    or.b32 %r3, %r2, %r1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r3;
 ; CHECK-NEXT:    ret;
   %ext = zext <2 x i8> %y to <2 x i16>
   %shl = shl <2 x i16> %ext, splat(i16 5)

@kalxr kalxr changed the title [NVPTX] Support vectors for AND combine [NVPTX] Implement computeKnownBitsForTargetNode for LoadV2/4 Aug 18, 2025
Copy link
Member

@AlexMaclean AlexMaclean left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice this looks simpler. I'd prefer that we get even a little more generic still though. I think currently the only time we'll ever generate extending vector loads is for i8 but that might change in the future and I think it would not be too complex to implement a generic version of this function which does something like the following: get the size in bits of the memVT and divide that by the number of elements loaded then get the difference between that value and the actual size of the produced VT and clear that many high bits.

Copy link
Member

@AlexMaclean AlexMaclean left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a couple minor nits but this looks good to me. Thanks!

@kalxr kalxr changed the title [NVPTX] Implement computeKnownBitsForTargetNode for LoadV2/4 [NVPTX] Implement computeKnownBitsForTargetNode for LoadV Aug 20, 2025
@kalxr kalxr force-pushed the nvptx-and-combine branch from 5469878 to 4098eee Compare August 20, 2025 18:29
@kalxr kalxr enabled auto-merge (squash) August 20, 2025 18:30
@kalxr kalxr merged commit 691ccf2 into llvm:main Aug 20, 2025
9 checks passed
@kalxr kalxr deleted the nvptx-and-combine branch August 20, 2025 19:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants