Skip to content

[RISCV] Optimize divide by constant for VP intrinsics #125991

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -5108,6 +5108,10 @@ class TargetLowering : public TargetLoweringBase {
SDValue buildSDIVPow2WithCMov(SDNode *N, const APInt &Divisor,
SelectionDAG &DAG,
SmallVectorImpl<SDNode *> &Created) const;
SDValue BuildVPSDIV(SDNode *N, SelectionDAG &DAG, bool IsAfterLegalization,
SmallVectorImpl<SDNode *> &Created) const;
SDValue BuildVPUDIV(SDNode *N, SelectionDAG &DAG, bool IsAfterLegalization,
SmallVectorImpl<SDNode *> &Created) const;

/// Targets may override this function to provide custom SDIV lowering for
/// power-of-2 denominators. If the target returns an empty SDValue, LLVM
Expand Down
5 changes: 5 additions & 0 deletions llvm/include/llvm/IR/VPIntrinsics.def
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,11 @@ HELPER_REGISTER_BINARY_INT_VP(vp_xor, VP_XOR, Xor, XOR)

#undef HELPER_REGISTER_BINARY_INT_VP

BEGIN_REGISTER_VP_SDNODE(VP_MULHU, -1, vp_mulhs, 2, 3)
END_REGISTER_VP_SDNODE(VP_MULHU)
BEGIN_REGISTER_VP_SDNODE(VP_MULHS, -1, vp_mulhs, 2, 3)
END_REGISTER_VP_SDNODE(VP_MULHS)

// llvm.vp.smin(x,y,mask,vlen)
BEGIN_REGISTER_VP(vp_smin, 2, 3, VP_SMIN, -1)
VP_PROPERTY_BINARYOP
Expand Down
308 changes: 306 additions & 2 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,12 @@ namespace {
SDValue visitFSUBForFMACombine(SDNode *N);
SDValue visitFMULForFMADistributiveCombine(SDNode *N);

SDValue visitVPUDIV(SDNode *N);
SDValue visitVPUDIVLike(SDValue N0, SDValue N1, SDNode *N);
SDValue visitVPSDIV(SDNode *N);
SDValue visitVPSDIVLike(SDValue N0, SDValue N1, SDNode *N);
SDValue visitVPREM(SDNode *N);

SDValue XformToShuffleWithZero(SDNode *N);
bool reassociationCanBreakAddressingModePattern(unsigned Opc,
const SDLoc &DL,
Expand Down Expand Up @@ -5161,6 +5167,59 @@ SDValue DAGCombiner::visitREM(SDNode *N) {
return SDValue();
}

// handles ISD::VP_SREM and ISD::VP_UREM
SDValue DAGCombiner::visitVPREM(SDNode *N) {
unsigned Opcode = N->getOpcode();
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
SDValue Mask = N->getOperand(2);
SDValue VL = N->getOperand(3);
EVT VT = N->getValueType(0);
EVT CCVT =
EVT::getVectorVT(*DAG.getContext(), MVT::i1, VT.getVectorElementCount());

bool IsSigned = (Opcode == ISD::VP_SREM);
SDLoc DL(N);

// fold (vp.urem X, -1) -> select(FX == -1, 0, FX)
// Freeze the numerator to avoid a miscompile with an undefined value.
if (!IsSigned && llvm::isAllOnesOrAllOnesSplat(N1, /*AllowUndefs*/ false)) {
SDValue F0 = DAG.getFreeze(N0);
SDValue EqualsNeg1 = DAG.getSetCCVP(DL, CCVT, F0, N1, ISD::SETEQ, Mask, VL);
return DAG.getNode(ISD::VP_SELECT, DL, VT, EqualsNeg1,
DAG.getConstant(0, DL, VT), F0, VL);
}

AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();

// If X/C can be simplified by the division-by-constant logic, lower
// X%C to the equivalent of X-X/C*C.
// Reuse the SDIVLike/UDIVLike combines - to avoid mangling nodes, the
// speculative DIV must not cause a DIVREM conversion. We guard against this
// by skipping the simplification if isIntDivCheap(). When div is not cheap,
// combine will not return a DIVREM. Regardless, checking cheapness here
// makes sense since the simplification results in fatter code.
if (DAG.isKnownNeverZero(N1) && !TLI.isIntDivCheap(VT, Attr)) {
SDValue OptimizedDiv =
IsSigned ? visitVPSDIVLike(N0, N1, N) : visitVPUDIVLike(N0, N1, N);
if (OptimizedDiv.getNode() && OptimizedDiv.getNode() != N) {
// If the equivalent Div node also exists, update its users.
unsigned DivOpcode = IsSigned ? ISD::VP_SDIV : ISD::VP_UDIV;
if (SDNode *DivNode = DAG.getNodeIfExists(DivOpcode, N->getVTList(),
{N0, N1, Mask, VL}))
CombineTo(DivNode, OptimizedDiv);
SDValue Mul =
DAG.getNode(ISD::VP_MUL, DL, VT, OptimizedDiv, N1, Mask, VL);
SDValue Sub = DAG.getNode(ISD::VP_SUB, DL, VT, N0, Mul, Mask, VL);
AddToWorklist(OptimizedDiv.getNode());
AddToWorklist(Mul.getNode());
return Sub;
}
}

return SDValue();
}

SDValue DAGCombiner::visitMULHS(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
Expand Down Expand Up @@ -27219,6 +27278,232 @@ SDValue DAGCombiner::visitVP_FSUB(SDNode *N) {
return SDValue();
}

SDValue DAGCombiner::visitVPUDIV(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
SDValue Mask = N->getOperand(2);
SDValue VL = N->getOperand(3);
EVT VT = N->getValueType(0);
SDLoc DL(N);

ConstantSDNode *N1C = isConstOrConstSplat(N1);
// fold (vp.udiv X, -1) -> vp.select(X == -1, 1, 0)
if (N1C && N1C->isAllOnes()) {
EVT CCVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
VT.getVectorElementCount());
return DAG.getNode(ISD::VP_SELECT, DL, VT,
DAG.getSetCCVP(DL, CCVT, N0, N1, ISD::SETEQ, Mask, VL),
DAG.getConstant(1, DL, VT), DAG.getConstant(0, DL, VT),
VL);
}

if (SDValue V = visitVPUDIVLike(N0, N1, N)) {
// If the corresponding remainder node exists, update its users with
// (Dividend - (Quotient * Divisor).
if (SDNode *RemNode = DAG.getNodeIfExists(ISD::VP_UREM, N->getVTList(),
{N0, N1, Mask, VL})) {
SDValue Mul = DAG.getNode(ISD::VP_MUL, DL, VT, V, N1, Mask, VL);
SDValue Sub = DAG.getNode(ISD::VP_SUB, DL, VT, N0, Mul, Mask, VL);
AddToWorklist(Mul.getNode());
AddToWorklist(Sub.getNode());
CombineTo(RemNode, Sub);
}
return V;
}

return SDValue();
}

SDValue DAGCombiner::visitVPUDIVLike(SDValue N0, SDValue N1, SDNode *N) {
SDLoc DL(N);
SDValue Mask = N->getOperand(2);
SDValue VL = N->getOperand(3);
EVT VT = N->getValueType(0);

// fold (vp.udiv x, (1 << c)) -> vp.lshr(x, c)
if (isConstantOrConstantVector(N1, /*NoOpaques=*/true) &&
DAG.isKnownToBeAPowerOfTwo(N1)) {
SDValue LogBase2 = BuildLogBase2(N1, DL);
AddToWorklist(LogBase2.getNode());

EVT ShiftVT = getShiftAmountTy(N0.getValueType());
SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ShiftVT);
AddToWorklist(Trunc.getNode());
return DAG.getNode(ISD::VP_SRL, DL, VT, N0, Trunc, Mask, VL);
}

// fold (vp.udiv x, (vp.shl c, y)) -> vp.lshr(x, vp.add(log2(c)+y)) iff c is
// power of 2
if (N1.getOpcode() == ISD::VP_SHL && N1->getOperand(2) == Mask &&
N1->getOperand(3) == VL) {
SDValue N10 = N1.getOperand(0);
if (isConstantOrConstantVector(N10, /*NoOpaques=*/true) &&
DAG.isKnownToBeAPowerOfTwo(N10)) {
SDValue LogBase2 = BuildLogBase2(N10, DL);
AddToWorklist(LogBase2.getNode());

EVT ADDVT = N1.getOperand(1).getValueType();
SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ADDVT);
AddToWorklist(Trunc.getNode());
SDValue Add = DAG.getNode(ISD::VP_ADD, DL, ADDVT, N1.getOperand(1), Trunc,
Mask, VL);
AddToWorklist(Add.getNode());
return DAG.getNode(ISD::VP_SRL, DL, VT, N0, Add, Mask, VL);
}
}

// fold (vp.udiv x, Splat(shl c, y)) -> vp.lshr(x, add(log2(c)+y)) iff c is
// power of 2
if (N1.getOpcode() == ISD::SPLAT_VECTOR) {
SDValue N10 = N1.getOperand(0);
if (N10.getOpcode() == ISD::SHL) {
SDValue N0SHL = N10.getOperand(0);
if (isa<ConstantSDNode>(N0SHL) && DAG.isKnownToBeAPowerOfTwo(N0SHL)) {
SDValue LogBase2 = BuildLogBase2(N0SHL, DL);
AddToWorklist(LogBase2.getNode());

EVT ADDVT = N10.getOperand(1).getValueType();
SDValue Trunc = DAG.getZExtOrTrunc(LogBase2, DL, ADDVT);
AddToWorklist(Trunc.getNode());
SDValue Add =
DAG.getNode(ISD::ADD, DL, ADDVT, N10.getOperand(1), Trunc);
AddToWorklist(Add.getNode());
SDValue Splat = DAG.getSplatVector(VT, DL, Add);
AddToWorklist(Splat.getNode());
return DAG.getNode(ISD::VP_SRL, DL, VT, N0, Splat, Mask, VL);
}
}
}

// fold (udiv x, c) -> alternate
AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
if (isConstantOrConstantVector(N1) &&
!TLI.isIntDivCheap(N->getValueType(0), Attr)) {
if (SDValue Op = BuildUDIV(N))
return Op;
}
return SDValue();
}

SDValue DAGCombiner::visitVPSDIV(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
SDValue Mask = N->getOperand(2);
SDValue VL = N->getOperand(3);
EVT VT = N->getValueType(0);
SDLoc DL(N);

// fold (vp.sdiv X, -1) -> 0-X
ConstantSDNode *N1C = isConstOrConstSplat(N1);
if (N1C && N1C->isAllOnes())
return DAG.getNode(ISD::VP_SUB, DL, VT, DAG.getConstant(0, DL, VT), N0,
Mask, VL);

// fold (vp.sdiv X, MIN_SIGNED) -> select(X == MIN_SIGNED, 1, 0)
if (N1C && N1C->getAPIntValue().isMinSignedValue()) {
EVT CCVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
VT.getVectorElementCount());
return DAG.getNode(ISD::VP_SELECT, DL, VT,
DAG.getSetCCVP(DL, CCVT, N0, N1, ISD::SETEQ, Mask, VL),
DAG.getConstant(1, DL, VT), DAG.getConstant(0, DL, VT),
VL);
}

// If we know the sign bits of both operands are zero, strength reduce to a
// vp.udiv instead. Handles (X&15) /s 4 -> X&15 >> 2
if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0))
return DAG.getNode(ISD::VP_UDIV, DL, N1.getValueType(), N0, N1, Mask, VL);

if (SDValue V = visitVPSDIVLike(N0, N1, N)) {
// If the corresponding remainder node exists, update its users with
// (Dividend - (Quotient * Divisor).
if (SDNode *RemNode = DAG.getNodeIfExists(ISD::VP_SREM, N->getVTList(),
{N0, N1, Mask, VL})) {
SDValue Mul = DAG.getNode(ISD::VP_MUL, DL, VT, V, N1, Mask, VL);
SDValue Sub = DAG.getNode(ISD::VP_SUB, DL, VT, N0, Mul, Mask, VL);
AddToWorklist(Mul.getNode());
AddToWorklist(Sub.getNode());
CombineTo(RemNode, Sub);
}
return V;
}
return SDValue();
}

SDValue DAGCombiner::visitVPSDIVLike(SDValue N0, SDValue N1, SDNode *N) {
SDLoc DL(N);
SDValue Mask = N->getOperand(2);
SDValue VL = N->getOperand(3);
EVT VT = N->getValueType(0);
unsigned BitWidth = VT.getScalarSizeInBits();

// fold (vp.sdiv X, V of pow 2)
if (N1.getOpcode() == ISD::SPLAT_VECTOR &&
isDivisorPowerOfTwo(N1.getOperand(0))) {
// Create constants that are functions of the shift amount value.
SDValue N = N1.getOperand(0);
EVT CCVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
VT.getVectorElementCount());
EVT ScalarShiftAmtTy =
getShiftAmountTy(N0.getValueType().getVectorElementType());
SDValue Bits = DAG.getConstant(BitWidth, DL, ScalarShiftAmtTy);
SDValue C1 = DAG.getNode(ISD::CTTZ, DL, VT.getVectorElementType(), N);
C1 = DAG.getZExtOrTrunc(C1, DL, ScalarShiftAmtTy);
SDValue Inexact = DAG.getNode(ISD::SUB, DL, ScalarShiftAmtTy, Bits, C1);
if (!isa<ConstantSDNode>(Inexact))
return SDValue();

// Splat the sign bit into the register
EVT VecShiftAmtTy = EVT::getVectorVT(*DAG.getContext(), ScalarShiftAmtTy,
VT.getVectorElementCount());
SDValue Sign =
DAG.getNode(ISD::VP_SRA, DL, VT, N0,
DAG.getConstant(BitWidth - 1, DL, VecShiftAmtTy), Mask, VL);
AddToWorklist(Sign.getNode());

// Add N0, ((N0 < 0) ? abs(N1) - 1 : 0);
Inexact = DAG.getSplat(VT, DL, Inexact);
C1 = DAG.getSplat(VT, DL, C1);
SDValue Srl = DAG.getNode(ISD::VP_SRL, DL, VT, Sign, Inexact, Mask, VL);
AddToWorklist(Srl.getNode());
SDValue Add = DAG.getNode(ISD::VP_ADD, DL, VT, N0, Srl, Mask, VL);
AddToWorklist(Add.getNode());
SDValue Sra = DAG.getNode(ISD::VP_SRA, DL, VT, Add, C1, Mask, VL);
AddToWorklist(Sra.getNode());

// Special case: (sdiv X, 1) -> X
// Special Case: (sdiv X, -1) -> 0-X
SDValue One = DAG.getConstant(1, DL, VT);
SDValue AllOnes = DAG.getAllOnesConstant(DL, VT);
SDValue IsOne = DAG.getSetCCVP(DL, CCVT, N1, One, ISD::SETEQ, Mask, VL);
SDValue IsAllOnes =
DAG.getSetCCVP(DL, CCVT, N1, AllOnes, ISD::SETEQ, Mask, VL);
SDValue IsOneOrAllOnes =
DAG.getNode(ISD::VP_OR, DL, CCVT, IsOne, IsAllOnes, Mask, VL);
Sra = DAG.getNode(ISD::VP_SELECT, DL, VT, IsOneOrAllOnes, N0, Sra, VL);

// If dividing by a positive value, we're done. Otherwise, the result must
// be negated.
SDValue Zero = DAG.getConstant(0, DL, VT);
SDValue Sub = DAG.getNode(ISD::VP_SUB, DL, VT, Zero, Sra, Mask, VL);

SDValue IsNeg = DAG.getSetCCVP(DL, CCVT, N1, Zero, ISD::SETLT, Mask, VL);
SDValue Res = DAG.getNode(ISD::VP_SELECT, DL, VT, IsNeg, Sub, Sra, VL);
return Res;
}

// If integer divide is expensive and we satisfy the requirements, emit an
// alternate sequence. Targets may check function attributes for size/speed
// trade-offs.
AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
if (isConstantOrConstantVector(N1) &&
!TLI.isIntDivCheap(N->getValueType(0), Attr))
if (SDValue Op = BuildSDIV(N))
return Op;

return SDValue();
}

SDValue DAGCombiner::visitVPOp(SDNode *N) {

if (N->getOpcode() == ISD::VP_GATHER)
Expand Down Expand Up @@ -27262,6 +27547,13 @@ SDValue DAGCombiner::visitVPOp(SDNode *N) {
return visitMUL<VPMatchContext>(N);
case ISD::VP_SUB:
return foldSubCtlzNot<VPMatchContext>(N, DAG);
case ISD::VP_UDIV:
return visitVPUDIV(N);
case ISD::VP_SDIV:
return visitVPSDIV(N);
case ISD::VP_UREM:
case ISD::VP_SREM:
return visitVPREM(N);
default:
break;
}
Expand Down Expand Up @@ -28309,7 +28601,13 @@ SDValue DAGCombiner::BuildSDIV(SDNode *N) {
return SDValue();

SmallVector<SDNode *, 8> Built;
if (SDValue S = TLI.BuildSDIV(N, DAG, LegalOperations, LegalTypes, Built)) {
SDValue S;
if (N->isVPOpcode())
S = TLI.BuildVPSDIV(N, DAG, LegalOperations, Built);
else
S = TLI.BuildSDIV(N, DAG, LegalOperations, LegalTypes, Built);

if (S) {
for (SDNode *N : Built)
AddToWorklist(N);
return S;
Expand Down Expand Up @@ -28350,7 +28648,13 @@ SDValue DAGCombiner::BuildUDIV(SDNode *N) {
return SDValue();

SmallVector<SDNode *, 8> Built;
if (SDValue S = TLI.BuildUDIV(N, DAG, LegalOperations, LegalTypes, Built)) {
SDValue S;
if (N->isVPOpcode())
S = TLI.BuildVPUDIV(N, DAG, LegalOperations, Built);
else
S = TLI.BuildUDIV(N, DAG, LegalOperations, LegalTypes, Built);

if (S) {
for (SDNode *N : Built)
AddToWorklist(N);
return S;
Expand Down
8 changes: 4 additions & 4 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1277,8 +1277,8 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
case ISD::ADD: case ISD::VP_ADD:
case ISD::SUB: case ISD::VP_SUB:
case ISD::MUL: case ISD::VP_MUL:
case ISD::MULHS:
case ISD::MULHU:
case ISD::MULHS: case ISD::VP_MULHS:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this tested?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't find a way to test this. The div by constant optimization that creates MULH won't fire on illegal vector types.

I guess we can leave it so we dont't forget it if we add more VP_MULH combines in the future. Its sufficiently simililar to VP_MUL that it should just work.

case ISD::MULHU: case ISD::VP_MULHU:
case ISD::ABDS:
case ISD::ABDU:
case ISD::AVGCEILS:
Expand Down Expand Up @@ -4552,8 +4552,8 @@ void DAGTypeLegalizer::WidenVectorResult(SDNode *N, unsigned ResNo) {
case ISD::ADD: case ISD::VP_ADD:
case ISD::AND: case ISD::VP_AND:
case ISD::MUL: case ISD::VP_MUL:
case ISD::MULHS:
case ISD::MULHU:
case ISD::MULHS: case ISD::VP_MULHS:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this tested?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't find a way to test this. The div by constant optimization that creates MULH won't fire on illegal vector types.

I guess we can leave it so we dont't forget it if we add more VP_MULH combines in the future. Its sufficiently simililar to VP_MUL that it should just work.

case ISD::MULHU: case ISD::VP_MULHU:
case ISD::ABDS:
case ISD::ABDU:
case ISD::OR: case ISD::VP_OR:
Expand Down
Loading
Loading