Skip to content

[RISCV] Optimize divide by constant for VP intrinsics #125991

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

jaidTw
Copy link
Contributor

@jaidTw jaidTw commented Feb 6, 2025

  • Introduces ISD opcodes VP_MULHU and VP_MULHS
  • Implemented divide by constants foldings for vp.u(s)div and vp.u(s)rem as well as some other minor foldings such as div by pow of 2, div by INT_MAX, etc.

yetingk and others added 5 commits February 6, 2025 02:08
Add ISD opcodes VP_MULH/VP_MULHS which could be used by VP optimizations.
This patch implemented divide by constants foldings for vp.u(s)div and vp.u(s)rem as well as some other minor foldings such as div by pow of 2, div by INT_MAX, etc.
* Set VP_MULHU/VP_MULHS with i64 vector input to Expand on Zve64*
* Moved forward the IsOperationLegalOrCustom check in BuildSDIV/BuildUDIV
We can't constant fold VP_MUL yet or combine (VP_SUB 0, X) and
VP_ADD.

Add some flags to keep track of when we need to emit VP_MUL/VP_ADD/VP_SUB.
@jaidTw jaidTw requested a review from topperc February 6, 2025 03:30
@llvmbot
Copy link
Member

llvmbot commented Feb 6, 2025

@llvm/pr-subscribers-llvm-selectiondag
@llvm/pr-subscribers-llvm-ir

@llvm/pr-subscribers-backend-risc-v

Author: Jesse Huang (jaidTw)

Changes
  • Introduces ISD opcodes VP_MULHU and VP_MULHS
  • Implemented divide by constants foldings for vp.u(s)div and vp.u(s)rem as well as some other minor foldings such as div by pow of 2, div by INT_MAX, etc.

Patch is 121.91 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/125991.diff

8 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+4)
  • (modified) llvm/include/llvm/IR/VPIntrinsics.def (+5)
  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+330)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp (+4-4)
  • (modified) llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp (+268)
  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+15)
  • (added) llvm/test/CodeGen/RISCV/rvv/vpdiv-by-const-zve64.ll (+113)
  • (added) llvm/test/CodeGen/RISCV/rvv/vpdiv-by-const.ll (+1755)
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 04ee24c0916e5f5..6447752c451d88f 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -5108,6 +5108,10 @@ class TargetLowering : public TargetLoweringBase {
   SDValue buildSDIVPow2WithCMov(SDNode *N, const APInt &Divisor,
                                 SelectionDAG &DAG,
                                 SmallVectorImpl<SDNode *> &Created) const;
+  SDValue BuildVPSDIV(SDNode *N, SelectionDAG &DAG, bool IsAfterLegalization,
+                      SmallVectorImpl<SDNode *> &Created) const;
+  SDValue BuildVPUDIV(SDNode *N, SelectionDAG &DAG, bool IsAfterLegalization,
+                      SmallVectorImpl<SDNode *> &Created) const;
 
   /// Targets may override this function to provide custom SDIV lowering for
   /// power-of-2 denominators.  If the target returns an empty SDValue, LLVM
diff --git a/llvm/include/llvm/IR/VPIntrinsics.def b/llvm/include/llvm/IR/VPIntrinsics.def
index 55f4719da7c8b1a..e71ca44779adb03 100644
--- a/llvm/include/llvm/IR/VPIntrinsics.def
+++ b/llvm/include/llvm/IR/VPIntrinsics.def
@@ -180,6 +180,11 @@ HELPER_REGISTER_BINARY_INT_VP(vp_xor, VP_XOR, Xor, XOR)
 
 #undef HELPER_REGISTER_BINARY_INT_VP
 
+BEGIN_REGISTER_VP_SDNODE(VP_MULHU, -1, vp_mulhs, 2, 3)
+END_REGISTER_VP_SDNODE(VP_MULHU)
+BEGIN_REGISTER_VP_SDNODE(VP_MULHS, -1, vp_mulhs, 2, 3)
+END_REGISTER_VP_SDNODE(VP_MULHS)
+
 // llvm.vp.smin(x,y,mask,vlen)
 BEGIN_REGISTER_VP(vp_smin, 2, 3, VP_SMIN, -1)
 VP_PROPERTY_BINARYOP
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 8858c2012c70671..74ab35f8c5f0583 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -564,6 +564,14 @@ namespace {
     SDValue visitFSUBForFMACombine(SDNode *N);
     SDValue visitFMULForFMADistributiveCombine(SDNode *N);
 
+    SDValue visitVPUDIV(SDNode *N);
+    SDValue visitVPUDIVLike(SDValue N0, SDValue N1, SDNode *N);
+    SDValue BuildVPUDIV(SDNode *N);
+    SDValue visitVPSDIV(SDNode *N);
+    SDValue visitVPSDIVLike(SDValue N0, SDValue N1, SDNode *N);
+    SDValue BuildVPSDIV(SDNode *N);
+    SDValue visitVPREM(SDNode *N);
+
     SDValue XformToShuffleWithZero(SDNode *N);
     bool reassociationCanBreakAddressingModePattern(unsigned Opc,
                                                     const SDLoc &DL,
@@ -5161,6 +5169,59 @@ SDValue DAGCombiner::visitREM(SDNode *N) {
   return SDValue();
 }
 
+// handles ISD::VP_SREM and ISD::VP_UREM
+SDValue DAGCombiner::visitVPREM(SDNode *N) {
+  unsigned Opcode = N->getOpcode();
+  SDValue N0 = N->getOperand(0);
+  SDValue N1 = N->getOperand(1);
+  SDValue Mask = N->getOperand(2);
+  SDValue VL = N->getOperand(3);
+  EVT VT = N->getValueType(0);
+  EVT CCVT =
+      EVT::getVectorVT(*DAG.getContext(), MVT::i1, VT.getVectorElementCount());
+
+  bool IsSigned = (Opcode == ISD::VP_SREM);
+  SDLoc DL(N);
+
+  // fold (vp.urem X, -1) -> select(FX == -1, 0, FX)
+  // Freeze the numerator to avoid a miscompile with an undefined value.
+  if (!IsSigned && llvm::isAllOnesOrAllOnesSplat(N1, /*AllowUndefs*/ false)) {
+    SDValue F0 = DAG.getFreeze(N0);
+    SDValue EqualsNeg1 = DAG.getSetCCVP(DL, CCVT, F0, N1, ISD::SETEQ, Mask, VL);
+    return DAG.getNode(ISD::VP_SELECT, DL, VT, EqualsNeg1,
+                       DAG.getConstant(0, DL, VT), F0, VL);
+  }
+
+  AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
+
+  // If X/C can be simplified by the division-by-constant logic, lower
+  // X%C to the equivalent of X-X/C*C.
+  // Reuse the SDIVLike/UDIVLike combines - to avoid mangling nodes, the
+  // speculative DIV must not cause a DIVREM conversion.  We guard against this
+  // by skipping the simplification if isIntDivCheap().  When div is not cheap,
+  // combine will not return a DIVREM.  Regardless, checking cheapness here
+  // makes sense since the simplification results in fatter code.
+  if (DAG.isKnownNeverZero(N1) && !TLI.isIntDivCheap(VT, Attr)) {
+    SDValue OptimizedDiv =
+        IsSigned ? visitVPSDIVLike(N0, N1, N) : visitVPUDIVLike(N0, N1, N);
+    if (OptimizedDiv.getNode() && OptimizedDiv.getNode() != N) {
+      // If the equivalent Div node also exists, update its users.
+      unsigned DivOpcode = IsSigned ? ISD::VP_SDIV : ISD::VP_UDIV;
+      if (SDNode *DivNode = DAG.getNodeIfExists(DivOpcode, N->getVTList(),
+                                                {N0, N1, Mask, VL}))
+        CombineTo(DivNode, OptimizedDiv);
+      SDValue Mul =
+          DAG.getNode(ISD::VP_MUL, DL, VT, OptimizedDiv, N1, Mask, VL);
+      SDValue Sub = DAG.getNode(ISD::VP_SUB, DL, VT, N0, Mul, Mask, VL);
+      AddToWorklist(OptimizedDiv.getNode());
+      AddToWorklist(Mul.getNode());
+      return Sub;
+    }
+  }
+
+  return SDValue();
+}
+
 SDValue DAGCombiner::visitMULHS(SDNode *N) {
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
@@ -27219,6 +27280,268 @@ SDValue DAGCombiner::visitVP_FSUB(SDNode *N) {
   return SDValue();
 }
 
+SDValue DAGCombiner::BuildVPUDIV(SDNode *N) {
+  // when optimising for minimum size, we don't want to expand a div to a mul
+  // and a shift.
+  if (DAG.getMachineFunction().getFunction().hasMinSize())
+    return SDValue();
+
+  SmallVector<SDNode *, 8> Built;
+  if (SDValue S = TLI.BuildVPUDIV(N, DAG, LegalOperations, Built)) {
+    for (SDNode *N : Built)
+      AddToWorklist(N);
+    return S;
+  }
+
+  return SDValue();
+}
+
+/// Given an ISD::VP_SDIV node expressing a divide by constant, return
+/// a DAG expression to select that will generate the same value by multiplying
+/// by a magic number.
+/// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
+SDValue DAGCombiner::BuildVPSDIV(SDNode *N) {
+  // when optimising for minimum size, we don't want to expand a div to a mul
+  // and a shift.
+  if (DAG.getMachineFunction().getFunction().hasMinSize())
+    return SDValue();
+
+  SmallVector<SDNode *, 8> Built;
+  if (SDValue S = TLI.BuildVPSDIV(N, DAG, LegalOperations, Built)) {
+    for (SDNode *N : Built)
+      AddToWorklist(N);
+    return S;
+  }
+
+  return SDValue();
+}
+
+SDValue DAGCombiner::visitVPUDIV(SDNode *N) {
+  SDValue N0 = N->getOperand(0);
+  SDValue N1 = N->getOperand(1);
+  SDValue Mask = N->getOperand(2);
+  SDValue VL = N->getOperand(3);
+  EVT VT = N->getValueType(0);
+  SDLoc DL(N);
+
+  ConstantSDNode *N1C = isConstOrConstSplat(N1);
+  // fold (vp.udiv X, -1) -> vp.select(X == -1, 1, 0)
+  if (N1C && N1C->isAllOnes()) {
+    EVT CCVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
+                                VT.getVectorElementCount());
+    return DAG.getNode(ISD::VP_SELECT, DL, VT,
+                       DAG.getSetCCVP(DL, CCVT, N0, N1, ISD::SETEQ, Mask, VL),
+                       DAG.getConstant(1, DL, VT), DAG.getConstant(0, DL, VT),
+                       VL);
+  }
+
+  if (SDValue V = visitVPUDIVLike(N0, N1, N)) {
+    // If the corresponding remainder node exists, update its users with
+    // (Dividend - (Quotient * Divisor).
+    if (SDNode *RemNode = DAG.getNodeIfExists(ISD::VP_UREM, N->getVTList(),
+                                              {N0, N1, Mask, VL})) {
+      SDValue Mul = DAG.getNode(ISD::VP_MUL, DL, VT, V, N1, Mask, VL);
+      SDValue Sub = DAG.getNode(ISD::VP_SUB, DL, VT, N0, Mul, Mask, VL);
+      AddToWorklist(Mul.getNode());
+      AddToWorklist(Sub.getNode());
+      CombineTo(RemNode, Sub);
+    }
+    return V;
+  }
+
+  return SDValue();
+}
+
+SDValue DAGCombiner::visitVPUDIVLike(SDValue N0, SDValue N1, SDNode *N) {
+  SDLoc DL(N);
+  SDValue Mask = N->getOperand(2);
+  SDValue VL = N->getOperand(3);
+  EVT VT = N->getValueType(0);
+
+  // fold (vp.udiv x, (1 << c)) -> vp.lshr(x, c)
+  if (isConstantOrConstantVector(N1, /*NoOpaques*/ true) &&
+      DAG.isKnownToBeAPowerOfTwo(N1)) {
+    SDValue LogBase2 = BuildLogBase2(N1, DL);
+    AddToWorklist(LogBase2.getNode());
+
+    EVT ShiftVT = getShiftAmountTy(N0.getValueType());
+    SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
+    AddToWorklist(Trunc.getNode());
+    return DAG.getNode(ISD::VP_SRL, DL, VT, N0, Trunc, Mask, VL);
+  }
+
+  // fold (vp.udiv x, (vp.shl c, y)) -> vp.lshr(x, vp.add(log2(c)+y)) iff c is
+  // power of 2
+  if (N1.getOpcode() == ISD::VP_SHL && N1->getOperand(2) == Mask &&
+      N1->getOperand(3) == VL) {
+    SDValue N10 = N1.getOperand(0);
+    if (isConstantOrConstantVector(N10, /*NoOpaques*/ true) &&
+        DAG.isKnownToBeAPowerOfTwo(N10)) {
+      SDValue LogBase2 = BuildLogBase2(N10, DL);
+      AddToWorklist(LogBase2.getNode());
+
+      EVT ADDVT = N1.getOperand(1).getValueType();
+      SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ADDVT);
+      AddToWorklist(Trunc.getNode());
+      SDValue Add = DAG.getNode(ISD::VP_ADD, DL, ADDVT, N1.getOperand(1), Trunc,
+                                Mask, VL);
+      AddToWorklist(Add.getNode());
+      return DAG.getNode(ISD::VP_SRL, DL, VT, N0, Add, Mask, VL);
+    }
+  }
+
+  // fold (vp.udiv x, Splat(shl c, y)) -> vp.lshr(x, add(log2(c)+y)) iff c is
+  // power of 2
+  if (N1.getOpcode() == ISD::SPLAT_VECTOR) {
+    SDValue N10 = N1.getOperand(0);
+    if (N10.getOpcode() == ISD::SHL) {
+      SDValue N0SHL = N10.getOperand(0);
+      if (isa<ConstantSDNode>(N0SHL) && DAG.isKnownToBeAPowerOfTwo(N0SHL)) {
+        SDValue LogBase2 = BuildLogBase2(N0SHL, DL);
+        AddToWorklist(LogBase2.getNode());
+
+        EVT ADDVT = N10.getOperand(1).getValueType();
+        SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ADDVT);
+        AddToWorklist(Trunc.getNode());
+        SDValue Add =
+            DAG.getNode(ISD::ADD, DL, ADDVT, N10.getOperand(1), Trunc);
+        AddToWorklist(Add.getNode());
+        SDValue Splat = DAG.getSplatVector(VT, DL, Add);
+        AddToWorklist(Splat.getNode());
+        return DAG.getNode(ISD::VP_SRL, DL, VT, N0, Splat, Mask, VL);
+      }
+    }
+  }
+
+  // fold (udiv x, c) -> alternate
+  AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
+  if (isConstantOrConstantVector(N1) &&
+      !TLI.isIntDivCheap(N->getValueType(0), Attr))
+    if (SDValue Op = BuildVPUDIV(N))
+      return Op;
+  return SDValue();
+}
+
+SDValue DAGCombiner::visitVPSDIV(SDNode *N) {
+  SDValue N0 = N->getOperand(0);
+  SDValue N1 = N->getOperand(1);
+  SDValue Mask = N->getOperand(2);
+  SDValue VL = N->getOperand(3);
+  EVT VT = N->getValueType(0);
+  SDLoc DL(N);
+
+  // fold (vp.sdiv X, -1) -> 0-X
+  ConstantSDNode *N1C = isConstOrConstSplat(N1);
+  if (N1C && N1C->isAllOnes())
+    return DAG.getNode(ISD::VP_SUB, DL, VT, DAG.getConstant(0, DL, VT), N0,
+                       Mask, VL);
+
+  // fold (vp.sdiv X, MIN_SIGNED) -> select(X == MIN_SIGNED, 1, 0)
+  if (N1C && N1C->getAPIntValue().isMinSignedValue()) {
+    EVT CCVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
+                                VT.getVectorElementCount());
+    return DAG.getNode(ISD::VP_SELECT, DL, VT,
+                       DAG.getSetCCVP(DL, CCVT, N0, N1, ISD::SETEQ, Mask, VL),
+                       DAG.getConstant(1, DL, VT), DAG.getConstant(0, DL, VT),
+                       VL);
+  }
+
+  // If we know the sign bits of both operands are zero, strength reduce to a
+  // vp.udiv instead.  Handles (X&15) /s 4 -> X&15 >> 2
+  if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0))
+    return DAG.getNode(ISD::VP_UDIV, DL, N1.getValueType(), N0, N1, Mask, VL);
+
+  if (SDValue V = visitVPSDIVLike(N0, N1, N)) {
+    // If the corresponding remainder node exists, update its users with
+    // (Dividend - (Quotient * Divisor).
+    if (SDNode *RemNode = DAG.getNodeIfExists(ISD::VP_SREM, N->getVTList(),
+                                              {N0, N1, Mask, VL})) {
+      SDValue Mul = DAG.getNode(ISD::VP_MUL, DL, VT, V, N1, Mask, VL);
+      SDValue Sub = DAG.getNode(ISD::VP_SUB, DL, VT, N0, Mul, Mask, VL);
+      AddToWorklist(Mul.getNode());
+      AddToWorklist(Sub.getNode());
+      CombineTo(RemNode, Sub);
+    }
+    return V;
+  }
+  return SDValue();
+}
+
+SDValue DAGCombiner::visitVPSDIVLike(SDValue N0, SDValue N1, SDNode *N) {
+  SDLoc DL(N);
+  SDValue Mask = N->getOperand(2);
+  SDValue VL = N->getOperand(3);
+  EVT VT = N->getValueType(0);
+  unsigned BitWidth = VT.getScalarSizeInBits();
+
+  // fold (vp.sdiv X, V of pow 2)
+  if (N1.getOpcode() == ISD::SPLAT_VECTOR &&
+      isDivisorPowerOfTwo(N1.getOperand(0))) {
+    // Create constants that are functions of the shift amount value.
+    SDValue N = N1.getOperand(0);
+    EVT CCVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
+                                VT.getVectorElementCount());
+    EVT ScalarShiftAmtTy =
+        getShiftAmountTy(N0.getValueType().getVectorElementType());
+    SDValue Bits = DAG.getConstant(BitWidth, DL, ScalarShiftAmtTy);
+    SDValue C1 = DAG.getNode(ISD::CTTZ, DL, VT.getVectorElementType(), N);
+    C1 = DAG.getZExtOrTrunc(C1, DL, ScalarShiftAmtTy);
+    SDValue Inexact = DAG.getNode(ISD::SUB, DL, ScalarShiftAmtTy, Bits, C1);
+    if (!isa<ConstantSDNode>(Inexact))
+      return SDValue();
+
+    // Splat the sign bit into the register
+    EVT VecShiftAmtTy = EVT::getVectorVT(*DAG.getContext(), ScalarShiftAmtTy,
+                                         VT.getVectorElementCount());
+    SDValue Sign =
+        DAG.getNode(ISD::VP_SRA, DL, VT, N0,
+                    DAG.getConstant(BitWidth - 1, DL, VecShiftAmtTy), Mask, VL);
+    AddToWorklist(Sign.getNode());
+
+    // Add N0, ((N0 < 0) ? abs(N1) - 1 : 0);
+    Inexact = DAG.getSplat(VT, DL, Inexact);
+    C1 = DAG.getSplat(VT, DL, C1);
+    SDValue Srl = DAG.getNode(ISD::VP_SRL, DL, VT, Sign, Inexact, Mask, VL);
+    AddToWorklist(Srl.getNode());
+    SDValue Add = DAG.getNode(ISD::VP_ADD, DL, VT, N0, Srl, Mask, VL);
+    AddToWorklist(Add.getNode());
+    SDValue Sra = DAG.getNode(ISD::VP_SRA, DL, VT, Add, C1, Mask, VL);
+    AddToWorklist(Sra.getNode());
+
+    // Special case: (sdiv X, 1) -> X
+    // Special Case: (sdiv X, -1) -> 0-X
+    SDValue One = DAG.getConstant(1, DL, VT);
+    SDValue AllOnes = DAG.getAllOnesConstant(DL, VT);
+    SDValue IsOne = DAG.getSetCCVP(DL, CCVT, N1, One, ISD::SETEQ, Mask, VL);
+    SDValue IsAllOnes =
+        DAG.getSetCCVP(DL, CCVT, N1, AllOnes, ISD::SETEQ, Mask, VL);
+    SDValue IsOneOrAllOnes =
+        DAG.getNode(ISD::VP_OR, DL, CCVT, IsOne, IsAllOnes, Mask, VL);
+    Sra = DAG.getNode(ISD::VP_SELECT, DL, VT, IsOneOrAllOnes, N0, Sra, VL);
+
+    // If dividing by a positive value, we're done. Otherwise, the result must
+    // be negated.
+    SDValue Zero = DAG.getConstant(0, DL, VT);
+    SDValue Sub = DAG.getNode(ISD::VP_SUB, DL, VT, Zero, Sra, Mask, VL);
+
+    // FIXME: Use SELECT_CC once we improve SELECT_CC constant-folding.
+    SDValue IsNeg = DAG.getSetCCVP(DL, CCVT, N1, Zero, ISD::SETLT, Mask, VL);
+    SDValue Res = DAG.getNode(ISD::VP_SELECT, DL, VT, IsNeg, Sub, Sra, VL);
+    return Res;
+  }
+
+  // If integer divide is expensive and we satisfy the requirements, emit an
+  // alternate sequence.  Targets may check function attributes for size/speed
+  // trade-offs.
+  AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
+  if (isConstantOrConstantVector(N1) &&
+      !TLI.isIntDivCheap(N->getValueType(0), Attr))
+    if (SDValue Op = BuildVPSDIV(N))
+      return Op;
+
+  return SDValue();
+}
+
 SDValue DAGCombiner::visitVPOp(SDNode *N) {
 
   if (N->getOpcode() == ISD::VP_GATHER)
@@ -27262,6 +27585,13 @@ SDValue DAGCombiner::visitVPOp(SDNode *N) {
       return visitMUL<VPMatchContext>(N);
     case ISD::VP_SUB:
       return foldSubCtlzNot<VPMatchContext>(N, DAG);
+    case ISD::VP_UDIV:
+      return visitVPUDIV(N);
+    case ISD::VP_SDIV:
+      return visitVPSDIV(N);
+    case ISD::VP_UREM:
+    case ISD::VP_SREM:
+      return visitVPREM(N);
     default:
       break;
     }
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 1000235ab4061f7..6e2f37d7c3dd43d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -1277,8 +1277,8 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
   case ISD::ADD: case ISD::VP_ADD:
   case ISD::SUB: case ISD::VP_SUB:
   case ISD::MUL: case ISD::VP_MUL:
-  case ISD::MULHS:
-  case ISD::MULHU:
+  case ISD::MULHS: case ISD::VP_MULHS:
+  case ISD::MULHU: case ISD::VP_MULHU:
   case ISD::ABDS:
   case ISD::ABDU:
   case ISD::AVGCEILS:
@@ -4552,8 +4552,8 @@ void DAGTypeLegalizer::WidenVectorResult(SDNode *N, unsigned ResNo) {
   case ISD::ADD: case ISD::VP_ADD:
   case ISD::AND: case ISD::VP_AND:
   case ISD::MUL: case ISD::VP_MUL:
-  case ISD::MULHS:
-  case ISD::MULHU:
+  case ISD::MULHS: case ISD::VP_MULHS:
+  case ISD::MULHU: case ISD::VP_MULHU:
   case ISD::ABDS:
   case ISD::ABDU:
   case ISD::OR: case ISD::VP_OR:
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index adfb96041c5c06b..b7846212a94abee 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -6492,6 +6492,137 @@ SDValue TargetLowering::BuildSDIV(SDNode *N, SelectionDAG &DAG,
   return DAG.getNode(ISD::ADD, dl, VT, Q, T);
 }
 
+/// Given an ISD::VP_SDIV node expressing a divide by constant,
+/// return a DAG expression to select that will generate the same value by
+/// multiplying by a magic number.
+/// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
+SDValue TargetLowering::BuildVPSDIV(SDNode *N, SelectionDAG &DAG,
+                                    bool IsAfterLegalization,
+                                    SmallVectorImpl<SDNode *> &Created) const {
+  SDLoc DL(N);
+  EVT VT = N->getValueType(0);
+  EVT SVT = VT.getScalarType();
+  EVT ShVT = getShiftAmountTy(VT, DAG.getDataLayout());
+  EVT ShSVT = ShVT.getScalarType();
+  unsigned EltBits = VT.getScalarSizeInBits();
+
+  // Check to see if we can do this.
+  if (!isTypeLegal(VT) ||
+      !isOperationLegalOrCustom(ISD::VP_MULHS, VT, IsAfterLegalization))
+    return SDValue();
+
+  bool AnyFactorOne = false;
+  bool AnyFactorNegOne = false;
+
+  SmallVector<SDValue, 16> MagicFactors, Factors, Shifts, ShiftMasks;
+
+  auto BuildSDIVPattern = [&](ConstantSDNode *C) {
+    if (C->isZero())
+      return false;
+
+    const APInt &Divisor = C->getAPIntValue();
+    SignedDivisionByConstantInfo magics =
+        SignedDivisionByConstantInfo::get(Divisor);
+    int NumeratorFactor = 0;
+    int ShiftMask = -1;
+
+    if (Divisor.isOne() || Divisor.isAllOnes()) {
+      // If d is +1/-1, we just multiply the numerator by +1/-1.
+      NumeratorFactor = Divisor.getSExtValue();
+      magics.Magic = 0;
+      magics.ShiftAmount = 0;
+      ShiftMask = 0;
+      AnyFactorOne |= Divisor.isOne();
+      AnyFactorNegOne |= Divisor.isAllOnes();
+    } else if (Divisor.isStrictlyPositive() && magics.Magic.isNegative()) {
+      // If d > 0 and m < 0, add the numerator.
+      NumeratorFactor = 1;
+      AnyFactorOne = true;
+    } else if (Divisor.isNegative() && magics.Magic.isStrictlyPositive()) {
+      // If d < 0 and m > 0, subtract the numerator.
+      NumeratorFactor = -1;
+      AnyFactorNegOne = true;
+    }
+
+    MagicFactors.push_back(DAG.getConstant(magics.Magic, DL, SVT));
+    Factors.push_back(DAG.getSignedConstant(NumeratorFactor, DL, SVT));
+    Shifts.push_back(DAG.getConstant(magics.ShiftAmount, DL, ShSVT));
+    ShiftMasks.push_back(DAG.getSignedConstant(ShiftMask, DL, SVT));
+    return true;
+  };
+
+  SDValue N0 = N->getOperand(0);
+  SDValue N1 = N->getOperand(1);
+  SDValue Mask = N->getOperand(2);
+  SDValue VL = N->getOperand(3);
+
+  // Collect the shifts / magic values from each element.
+  if (!ISD::matchUnaryPredicate(N1, BuildSDIVPattern))
+    return SDValue();
+
+  SDValue MagicFactor, Factor, Shift, ShiftMask;
+  if (N1.getOpcode() == ISD::BUILD_VECTOR) {
+    MagicFactor = DAG.getBuildVector(VT, DL, MagicFactors);
+    Factor = DAG.getBuildVector(VT, DL, Factors);
+    Shift = DAG.getBuildVector(ShVT, DL, Shifts);
+    ShiftMask = DAG.getBuildVector(VT, DL, ShiftMasks);
+  } else if (N1.getOpcode() == ISD::SPLAT_VECTOR) {
+    assert(MagicFac...
[truncated]

Copy link

github-actions bot commented Feb 6, 2025

✅ With the latest revision this PR passed the undef deprecator.

Copy link

github-actions bot commented Feb 6, 2025

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff fe7e280820c8f4a46f49357097d7f6897bd31d41 f8c91e958b9252fa8f9542ec2364476e20a45f34 --extensions h,cpp -- llvm/include/llvm/CodeGen/TargetLowering.h llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp llvm/lib/Target/RISCV/RISCVISelLowering.cpp
View the diff from clang-format here.
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 6e2f37d7c3..16bdf6516a 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -1277,8 +1277,10 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
   case ISD::ADD: case ISD::VP_ADD:
   case ISD::SUB: case ISD::VP_SUB:
   case ISD::MUL: case ISD::VP_MUL:
-  case ISD::MULHS: case ISD::VP_MULHS:
-  case ISD::MULHU: case ISD::VP_MULHU:
+  case ISD::MULHS:
+  case ISD::VP_MULHS:
+  case ISD::MULHU:
+  case ISD::VP_MULHU:
   case ISD::ABDS:
   case ISD::ABDU:
   case ISD::AVGCEILS:
@@ -4552,8 +4554,10 @@ void DAGTypeLegalizer::WidenVectorResult(SDNode *N, unsigned ResNo) {
   case ISD::ADD: case ISD::VP_ADD:
   case ISD::AND: case ISD::VP_AND:
   case ISD::MUL: case ISD::VP_MUL:
-  case ISD::MULHS: case ISD::VP_MULHS:
-  case ISD::MULHU: case ISD::VP_MULHU:
+  case ISD::MULHS:
+  case ISD::VP_MULHS:
+  case ISD::MULHU:
+  case ISD::VP_MULHU:
   case ISD::ABDS:
   case ISD::ABDU:
   case ISD::OR: case ISD::VP_OR:
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index ba4aaf36a0..1951531657 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -684,23 +684,51 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
     setOperationAction({ISD::INTRINSIC_W_CHAIN, ISD::INTRINSIC_VOID},
                        MVT::Other, Custom);
 
-    static const unsigned IntegerVPOps[] = {
-        ISD::VP_ADD,         ISD::VP_SUB,         ISD::VP_MUL,
-        ISD::VP_SDIV,        ISD::VP_UDIV,        ISD::VP_SREM,
-        ISD::VP_UREM,        ISD::VP_AND,         ISD::VP_OR,
-        ISD::VP_XOR,         ISD::VP_SRA,         ISD::VP_SRL,
-        ISD::VP_SHL,         ISD::VP_REDUCE_ADD,  ISD::VP_REDUCE_AND,
-        ISD::VP_REDUCE_OR,   ISD::VP_REDUCE_XOR,  ISD::VP_REDUCE_SMAX,
-        ISD::VP_REDUCE_SMIN, ISD::VP_REDUCE_UMAX, ISD::VP_REDUCE_UMIN,
-        ISD::VP_MERGE,       ISD::VP_SELECT,      ISD::VP_FP_TO_SINT,
-        ISD::VP_FP_TO_UINT,  ISD::VP_SETCC,       ISD::VP_SIGN_EXTEND,
-        ISD::VP_ZERO_EXTEND, ISD::VP_TRUNCATE,    ISD::VP_SMIN,
-        ISD::VP_SMAX,        ISD::VP_UMIN,        ISD::VP_UMAX,
-        ISD::VP_MULHU, ISD::VP_MULHS,
-        ISD::VP_ABS, ISD::EXPERIMENTAL_VP_REVERSE, ISD::EXPERIMENTAL_VP_SPLICE,
-        ISD::VP_SADDSAT,     ISD::VP_UADDSAT,     ISD::VP_SSUBSAT,
-        ISD::VP_USUBSAT,     ISD::VP_CTTZ_ELTS,   ISD::VP_CTTZ_ELTS_ZERO_UNDEF,
-        ISD::EXPERIMENTAL_VP_SPLAT};
+    static const unsigned IntegerVPOps[] = {ISD::VP_ADD,
+                                            ISD::VP_SUB,
+                                            ISD::VP_MUL,
+                                            ISD::VP_SDIV,
+                                            ISD::VP_UDIV,
+                                            ISD::VP_SREM,
+                                            ISD::VP_UREM,
+                                            ISD::VP_AND,
+                                            ISD::VP_OR,
+                                            ISD::VP_XOR,
+                                            ISD::VP_SRA,
+                                            ISD::VP_SRL,
+                                            ISD::VP_SHL,
+                                            ISD::VP_REDUCE_ADD,
+                                            ISD::VP_REDUCE_AND,
+                                            ISD::VP_REDUCE_OR,
+                                            ISD::VP_REDUCE_XOR,
+                                            ISD::VP_REDUCE_SMAX,
+                                            ISD::VP_REDUCE_SMIN,
+                                            ISD::VP_REDUCE_UMAX,
+                                            ISD::VP_REDUCE_UMIN,
+                                            ISD::VP_MERGE,
+                                            ISD::VP_SELECT,
+                                            ISD::VP_FP_TO_SINT,
+                                            ISD::VP_FP_TO_UINT,
+                                            ISD::VP_SETCC,
+                                            ISD::VP_SIGN_EXTEND,
+                                            ISD::VP_ZERO_EXTEND,
+                                            ISD::VP_TRUNCATE,
+                                            ISD::VP_SMIN,
+                                            ISD::VP_SMAX,
+                                            ISD::VP_UMIN,
+                                            ISD::VP_UMAX,
+                                            ISD::VP_MULHU,
+                                            ISD::VP_MULHS,
+                                            ISD::VP_ABS,
+                                            ISD::EXPERIMENTAL_VP_REVERSE,
+                                            ISD::EXPERIMENTAL_VP_SPLICE,
+                                            ISD::VP_SADDSAT,
+                                            ISD::VP_UADDSAT,
+                                            ISD::VP_SSUBSAT,
+                                            ISD::VP_USUBSAT,
+                                            ISD::VP_CTTZ_ELTS,
+                                            ISD::VP_CTTZ_ELTS_ZERO_UNDEF,
+                                            ISD::EXPERIMENTAL_VP_SPLAT};
 
     static const unsigned FloatingPointVPOps[] = {
         ISD::VP_FADD,        ISD::VP_FSUB,        ISD::VP_FMUL,

@jaidTw
Copy link
Contributor Author

jaidTw commented Feb 25, 2025

All comments above are fixed

@topperc topperc requested review from preames and lukel97 February 25, 2025 22:44
@topperc
Copy link
Collaborator

topperc commented Feb 25, 2025

@lukel97 if the vectorizer is moving away from using VP intrinsics, I guess we don't need this patch anymore?

@wangpc-pp
Copy link
Contributor

@lukel97 if the vectorizer is moving away from using VP intrinsics, I guess we don't need this patch anymore?

IIUC, it it temporary? We may add it back in some day.

@lukel97
Copy link
Contributor

lukel97 commented Feb 26, 2025

@lukel97 if the vectorizer is moving away from using VP intrinsics, I guess we don't need this patch anymore?

IIUC, it it temporary? We may add it back in some day.

Users can also still write VP intrinsics by hand, I'm not sure how common this is though. I don't have a particularly strong opinion either way. Lifting InstCombine/DAGCombiner is still on the VP roadmap https://llvm.org/docs/Proposals/VectorPredication.html

Copy link
Contributor

@lukel97 lukel97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was there are particular reason for doing this in DAGCombiner as opposed to InstCombine? E.g. I think instcombine currently handles the udiv of max case:

define <vscale x 2 x i64> @f(<vscale x 2 x i64> %v) {
  %x = udiv <vscale x 2 x i64> %v, splat (i64 18446744073709551615)
  ret <vscale x 2 x i64> %x
}

->

define <vscale x 2 x i64> @f(<vscale x 2 x i64> %v) {
  %1 = icmp eq <vscale x 2 x i64> %v, splat (i64 -1)
  %x = zext <vscale x 2 x i1> %1 to <vscale x 2 x i64>
  ret <vscale x 2 x i64> %x
}

@topperc
Copy link
Collaborator

topperc commented Feb 26, 2025

@lukel97 if the vectorizer is moving away from using VP intrinsics, I guess we don't need this patch anymore?

IIUC, it it temporary? We may add it back in some day.

Users can also still write VP intrinsics by hand, I'm not sure how common this is though.

I don't think it's common.

Lifting InstCombine/DAGCombiner is still on the VP roadmap https://llvm.org/docs/Proposals/VectorPredication.html

RISC-V seems to be the only active development of VP intrinsics in tree. We the RISC-V community, should agree on what's important to us and prioritize where we spend effort.

@topperc
Copy link
Collaborator

topperc commented Feb 26, 2025

Was there are particular reason for doing this in DAGCombiner as opposed to InstCombine? E.g. I think instcombine currently handles the udiv of max case:

define <vscale x 2 x i64> @f(<vscale x 2 x i64> %v) {
  %x = udiv <vscale x 2 x i64> %v, splat (i64 18446744073709551615)
  ret <vscale x 2 x i64> %x
}

->

define <vscale x 2 x i64> @f(<vscale x 2 x i64> %v) {
  %1 = icmp eq <vscale x 2 x i64> %v, splat (i64 -1)
  %x = zext <vscale x 2 x i1> %1 to <vscale x 2 x i64>
  ret <vscale x 2 x i64> %x
}

Is this question about the max case or the entire patch?

@lukel97
Copy link
Contributor

lukel97 commented Feb 26, 2025

RISC-V seems to be the only active development of VP intrinsics in tree. We the RISC-V community, should agree on what's important to us and prioritize where we spend effort.

Agreed. If the only use case for non-trapping VP intrinsics was EVL tail folding in the loop vectorizer then adding optimisations for them is less of a priority. We should probably talk about this in the RISC-V LLVM sync up on Thursday.

It's worth noting though that the loop vectorizer still emits load/store VP intrinsics, and they might be missing combines. Is this something you've looked at already in your downstream?

Is this question about the max case or the entire patch?

The entire patch, I just chose a random case to check if instcombine did any strength reduction

@topperc
Copy link
Collaborator

topperc commented Feb 26, 2025

RISC-V seems to be the only active development of VP intrinsics in tree. We the RISC-V community, should agree on what's important to us and prioritize where we spend effort.

Agreed. If the only use case for non-trapping VP intrinsics was EVL tail folding in the loop vectorizer then adding optimisations for them is less of a priority. We should probably talk about this in the RISC-V LLVM sync up on Thursday.

It's worth noting though that the loop vectorizer still emits load/store VP intrinsics, and they might be missing combines. Is this something you've looked at already in your downstream?

Is this question about the max case or the entire patch?

The entire patch, I just chose a random case to check if instcombine did any strength reduction

InstCombine for sure does not do the BuildVPUDIV/BuildVPSDIV part of this code that does the main optimization for arbritrary constants.

@@ -1277,8 +1277,8 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
case ISD::ADD: case ISD::VP_ADD:
case ISD::SUB: case ISD::VP_SUB:
case ISD::MUL: case ISD::VP_MUL:
case ISD::MULHS:
case ISD::MULHU:
case ISD::MULHS: case ISD::VP_MULHS:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this tested?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't find a way to test this. The div by constant optimization that creates MULH won't fire on illegal vector types.

I guess we can leave it so we dont't forget it if we add more VP_MULH combines in the future. Its sufficiently simililar to VP_MUL that it should just work.

@@ -4552,8 +4552,8 @@ void DAGTypeLegalizer::WidenVectorResult(SDNode *N, unsigned ResNo) {
case ISD::ADD: case ISD::VP_ADD:
case ISD::AND: case ISD::VP_AND:
case ISD::MUL: case ISD::VP_MUL:
case ISD::MULHS:
case ISD::MULHU:
case ISD::MULHS: case ISD::VP_MULHS:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this tested?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't find a way to test this. The div by constant optimization that creates MULH won't fire on illegal vector types.

I guess we can leave it so we dont't forget it if we add more VP_MULH combines in the future. Its sufficiently simililar to VP_MUL that it should just work.

Copy link
Collaborator

@topperc topperc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@topperc topperc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants