@@ -13706,8 +13706,8 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
1370613706 if (VT != Subtarget.getXLenVT())
1370713707 return SDValue();
1370813708
13709- if (!Subtarget.hasStdExtZba() && !Subtarget.hasVendorXTHeadBa())
13710- return SDValue ();
13709+ const bool HasShlAdd =
13710+ Subtarget.hasStdExtZba() || Subtarget.hasVendorXTHeadBa ();
1371113711
1371213712 ConstantSDNode *CNode = dyn_cast<ConstantSDNode>(N->getOperand(1));
1371313713 if (!CNode)
@@ -13720,107 +13720,123 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
1372013720 // other target properly freezes X in these cases either.
1372113721 SDValue X = N->getOperand(0);
1372213722
13723- for (uint64_t Divisor : {3, 5, 9}) {
13724- if (MulAmt % Divisor != 0)
13725- continue;
13726- uint64_t MulAmt2 = MulAmt / Divisor;
13727- // 3/5/9 * 2^N -> shl (shXadd X, X), N
13728- if (isPowerOf2_64(MulAmt2)) {
13729- SDLoc DL(N);
13730- SDValue X = N->getOperand(0);
13731- // Put the shift first if we can fold a zext into the
13732- // shift forming a slli.uw.
13733- if (X.getOpcode() == ISD::AND && isa<ConstantSDNode>(X.getOperand(1)) &&
13734- X.getConstantOperandVal(1) == UINT64_C(0xffffffff)) {
13735- SDValue Shl = DAG.getNode(ISD::SHL, DL, VT, X,
13736- DAG.getConstant(Log2_64(MulAmt2), DL, VT));
13737- return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Shl,
13738- DAG.getConstant(Log2_64(Divisor - 1), DL, VT), Shl);
13723+ if (HasShlAdd) {
13724+ for (uint64_t Divisor : {3, 5, 9}) {
13725+ if (MulAmt % Divisor != 0)
13726+ continue;
13727+ uint64_t MulAmt2 = MulAmt / Divisor;
13728+ // 3/5/9 * 2^N -> shl (shXadd X, X), N
13729+ if (isPowerOf2_64(MulAmt2)) {
13730+ SDLoc DL(N);
13731+ SDValue X = N->getOperand(0);
13732+ // Put the shift first if we can fold a zext into the
13733+ // shift forming a slli.uw.
13734+ if (X.getOpcode() == ISD::AND && isa<ConstantSDNode>(X.getOperand(1)) &&
13735+ X.getConstantOperandVal(1) == UINT64_C(0xffffffff)) {
13736+ SDValue Shl = DAG.getNode(ISD::SHL, DL, VT, X,
13737+ DAG.getConstant(Log2_64(MulAmt2), DL, VT));
13738+ return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Shl,
13739+ DAG.getConstant(Log2_64(Divisor - 1), DL, VT),
13740+ Shl);
13741+ }
13742+ // Otherwise, put rhe shl second so that it can fold with following
13743+ // instructions (e.g. sext or add).
13744+ SDValue Mul359 =
13745+ DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
13746+ DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
13747+ return DAG.getNode(ISD::SHL, DL, VT, Mul359,
13748+ DAG.getConstant(Log2_64(MulAmt2), DL, VT));
13749+ }
13750+
13751+ // 3/5/9 * 3/5/9 -> shXadd (shYadd X, X), (shYadd X, X)
13752+ if (MulAmt2 == 3 || MulAmt2 == 5 || MulAmt2 == 9) {
13753+ SDLoc DL(N);
13754+ SDValue Mul359 =
13755+ DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
13756+ DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
13757+ return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359,
13758+ DAG.getConstant(Log2_64(MulAmt2 - 1), DL, VT),
13759+ Mul359);
1373913760 }
13740- // Otherwise, put rhe shl second so that it can fold with following
13741- // instructions (e.g. sext or add).
13742- SDValue Mul359 =
13743- DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
13744- DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
13745- return DAG.getNode(ISD::SHL, DL, VT, Mul359,
13746- DAG.getConstant(Log2_64(MulAmt2), DL, VT));
1374713761 }
1374813762
13749- // 3/5/9 * 3/5/9 -> shXadd (shYadd X, X), (shYadd X, X)
13750- if (MulAmt2 == 3 || MulAmt2 == 5 || MulAmt2 == 9) {
13751- SDLoc DL(N);
13752- SDValue Mul359 =
13753- DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
13754- DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
13755- return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359,
13756- DAG.getConstant(Log2_64(MulAmt2 - 1), DL, VT),
13757- Mul359);
13758- }
13759- }
13760-
13761- // If this is a power 2 + 2/4/8, we can use a shift followed by a single
13762- // shXadd. First check if this a sum of two power of 2s because that's
13763- // easy. Then count how many zeros are up to the first bit.
13764- if (isPowerOf2_64(MulAmt & (MulAmt - 1))) {
13765- unsigned ScaleShift = llvm::countr_zero(MulAmt);
13766- if (ScaleShift >= 1 && ScaleShift < 4) {
13767- unsigned ShiftAmt = Log2_64((MulAmt & (MulAmt - 1)));
13768- SDLoc DL(N);
13769- SDValue Shift1 =
13770- DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(ShiftAmt, DL, VT));
13771- return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
13772- DAG.getConstant(ScaleShift, DL, VT), Shift1);
13763+ // If this is a power 2 + 2/4/8, we can use a shift followed by a single
13764+ // shXadd. First check if this a sum of two power of 2s because that's
13765+ // easy. Then count how many zeros are up to the first bit.
13766+ if (isPowerOf2_64(MulAmt & (MulAmt - 1))) {
13767+ unsigned ScaleShift = llvm::countr_zero(MulAmt);
13768+ if (ScaleShift >= 1 && ScaleShift < 4) {
13769+ unsigned ShiftAmt = Log2_64((MulAmt & (MulAmt - 1)));
13770+ SDLoc DL(N);
13771+ SDValue Shift1 =
13772+ DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(ShiftAmt, DL, VT));
13773+ return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
13774+ DAG.getConstant(ScaleShift, DL, VT), Shift1);
13775+ }
1377313776 }
13774- }
1377513777
13776- // 2^(1,2,3) * 3,5,9 + 1 -> (shXadd (shYadd x, x), x)
13777- // This is the two instruction form, there are also three instruction
13778- // variants we could implement. e.g.
13779- // (2^(1,2,3) * 3,5,9 + 1) << C2
13780- // 2^(C1>3) * 3,5,9 +/- 1
13781- for (uint64_t Divisor : {3, 5, 9}) {
13782- uint64_t C = MulAmt - 1;
13783- if (C <= Divisor)
13784- continue;
13785- unsigned TZ = llvm::countr_zero(C);
13786- if ((C >> TZ) == Divisor && (TZ == 1 || TZ == 2 || TZ == 3)) {
13787- SDLoc DL(N);
13788- SDValue Mul359 =
13789- DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
13790- DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
13791- return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359,
13792- DAG.getConstant(TZ, DL, VT), X);
13778+ // 2^(1,2,3) * 3,5,9 + 1 -> (shXadd (shYadd x, x), x)
13779+ // This is the two instruction form, there are also three instruction
13780+ // variants we could implement. e.g.
13781+ // (2^(1,2,3) * 3,5,9 + 1) << C2
13782+ // 2^(C1>3) * 3,5,9 +/- 1
13783+ for (uint64_t Divisor : {3, 5, 9}) {
13784+ uint64_t C = MulAmt - 1;
13785+ if (C <= Divisor)
13786+ continue;
13787+ unsigned TZ = llvm::countr_zero(C);
13788+ if ((C >> TZ) == Divisor && (TZ == 1 || TZ == 2 || TZ == 3)) {
13789+ SDLoc DL(N);
13790+ SDValue Mul359 =
13791+ DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
13792+ DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
13793+ return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359,
13794+ DAG.getConstant(TZ, DL, VT), X);
13795+ }
1379313796 }
13794- }
1379513797
13796- // 2^n + 2/4/8 + 1 -> (add (shl X, C1), (shXadd X, X))
13797- if (MulAmt > 2 && isPowerOf2_64((MulAmt - 1) & (MulAmt - 2))) {
13798- unsigned ScaleShift = llvm::countr_zero(MulAmt - 1);
13799- if (ScaleShift >= 1 && ScaleShift < 4) {
13800- unsigned ShiftAmt = Log2_64(((MulAmt - 1) & (MulAmt - 2)));
13801- SDLoc DL(N);
13802- SDValue Shift1 =
13803- DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(ShiftAmt, DL, VT));
13804- return DAG.getNode(ISD::ADD, DL, VT, Shift1,
13805- DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
13806- DAG.getConstant(ScaleShift, DL, VT), X));
13798+ // 2^n + 2/4/8 + 1 -> (add (shl X, C1), (shXadd X, X))
13799+ if (MulAmt > 2 && isPowerOf2_64((MulAmt - 1) & (MulAmt - 2))) {
13800+ unsigned ScaleShift = llvm::countr_zero(MulAmt - 1);
13801+ if (ScaleShift >= 1 && ScaleShift < 4) {
13802+ unsigned ShiftAmt = Log2_64(((MulAmt - 1) & (MulAmt - 2)));
13803+ SDLoc DL(N);
13804+ SDValue Shift1 =
13805+ DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(ShiftAmt, DL, VT));
13806+ return DAG.getNode(ISD::ADD, DL, VT, Shift1,
13807+ DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
13808+ DAG.getConstant(ScaleShift, DL, VT), X));
13809+ }
1380713810 }
13808- }
1380913811
13810- // 2^N - 3/5/9 --> (sub (shl X, C1), (shXadd X, x))
13811- for (uint64_t Offset : {3, 5, 9}) {
13812- if (isPowerOf2_64(MulAmt + Offset)) {
13813- SDLoc DL(N);
13814- SDValue Shift1 =
13815- DAG.getNode(ISD::SHL, DL, VT, X,
13816- DAG.getConstant(Log2_64(MulAmt + Offset), DL, VT));
13817- SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
13818- DAG.getConstant(Log2_64(Offset - 1), DL, VT),
13819- X);
13820- return DAG.getNode(ISD::SUB, DL, VT, Shift1, Mul359);
13812+ // 2^N - 3/5/9 --> (sub (shl X, C1), (shXadd X, x))
13813+ for (uint64_t Offset : {3, 5, 9}) {
13814+ if (isPowerOf2_64(MulAmt + Offset)) {
13815+ SDLoc DL(N);
13816+ SDValue Shift1 =
13817+ DAG.getNode(ISD::SHL, DL, VT, X,
13818+ DAG.getConstant(Log2_64(MulAmt + Offset), DL, VT));
13819+ SDValue Mul359 =
13820+ DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
13821+ DAG.getConstant(Log2_64(Offset - 1), DL, VT), X);
13822+ return DAG.getNode(ISD::SUB, DL, VT, Shift1, Mul359);
13823+ }
1382113824 }
1382213825 }
1382313826
13827+ // 2^N - 2^M -> (sub (shl X, C1), (shl X, C2))
13828+ uint64_t MulAmtLowBit = MulAmt & (-MulAmt);
13829+ if (isPowerOf2_64(MulAmt + MulAmtLowBit)) {
13830+ uint64_t ShiftAmt1 = MulAmt + MulAmtLowBit;
13831+ SDLoc DL(N);
13832+ SDValue Shift1 = DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0),
13833+ DAG.getConstant(Log2_64(ShiftAmt1), DL, VT));
13834+ SDValue Shift2 =
13835+ DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0),
13836+ DAG.getConstant(Log2_64(MulAmtLowBit), DL, VT));
13837+ return DAG.getNode(ISD::SUB, DL, VT, Shift1, Shift2);
13838+ }
13839+
1382413840 return SDValue();
1382513841}
1382613842
0 commit comments