Skip to content

Multiply by a power of 2 and ctz+shift should often be interchangeable #84763

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Validark opened this issue Mar 11, 2024 · 9 comments · Fixed by #85066
Closed

Multiply by a power of 2 and ctz+shift should often be interchangeable #84763

Validark opened this issue Mar 11, 2024 · 9 comments · Fixed by #85066
Assignees
Labels

Comments

@Validark
Copy link

The compiler should probably be able to optimize either of these functions to the other one, depending on cost for the particular hardware (Godbolt link):

export fn foo(x: u64, y: u64) u64 {
    return x *% (y & (~y +% 1));
}

export fn bar(x: u64, y: u64) u64 {
    if (y == 0) return 0;
    return x << @intCast(@ctz(y));
}

x86 znver4 emit:

foo:
        blsi    rax, rsi
        imul    rax, rdi
        ret

bar:
        tzcnt   rax, rsi
        shlx    rax, rdi, rax
        cmovb   rax, rsi
        ret

RISC-V sifive_u74 emit:

foo:
        neg     a2, a1
        and     a1, a1, a2
        mul     a0, a1, a0
        ret

.LCPI1_0:
        .quad   151050438420815295
.LCPI1_1:
        .ascii  "\000\001\002\007\003\r\b\023\004\031\016\034\t\"\024(\005\021\032&\017.\0350\n\037#6\0252)9?\006\f\022\030\033!'\020%-/\036518>\013\027 $,47=\026+3<*;:"
bar:
        beqz    a1, .LBB1_2
        lui     a2, %hi(.LCPI1_0)
        neg     a3, a1
        and     a1, a1, a3
        ld      a2, %lo(.LCPI1_0)(a2)
        mul     a1, a1, a2
        lui     a2, %hi(.LCPI1_1)
        srli    a1, a1, 58
        addi    a2, a2, %lo(.LCPI1_1)
        add     a1, a1, a2
        lbu     a1, 0(a1)
        sll     a0, a0, a1
        ret
.LBB1_2:
        li      a0, 0
        ret
@Sirraide
Copy link
Member

Could you also include the LLVM IR that gets generated for both functions (and ideally also an alive2 proof that the transformation is correct)?

@nikic
Copy link
Contributor

nikic commented Mar 11, 2024

I think the ask is for this transform? https://alive2.llvm.org/ce/z/Aht97G

@nikic
Copy link
Contributor

nikic commented Mar 11, 2024

@dtcxzyw @goldsteinn I'm a bit unsure on whether this is better suited as a middle-end or back-end transform. Any thoughts?

@Validark
Copy link
Author

Could you also include the LLVM IR that gets generated for both functions (and ideally also an alive2 proof that the transformation is correct)?

I think @nikic already provided you what you were asking for, but the reason this works is as follows: The idiom (y & (~y +% 1)), or perhaps y & -y in C, extracts the lowest set bit, hence the blsi emit on x86. Multiplying a * b at a bit level is equivalent to (a << 0)*b₀ + (a << 1)*b₁ + (a << 2)*b₂ + (a << 3)*b₃ + (a << 4)*b₄ + (a << 5)*b₅ + (a << 6)*b₆ + (a << 7)*b₇ + ..., where b₀ is the least significant bit of b. In this case, only one of the bits of b is a 1, so the answer is one of the terms in the previous expression. That's why a << @ctz(b) is equivalent to a * b when the popcount of b is 1.

@Validark
Copy link
Author

Related #76810

@dtcxzyw
Copy link
Member

dtcxzyw commented Mar 12, 2024

@dtcxzyw @goldsteinn I'm a bit unsure on whether this is better suited as a middle-end or back-end transform. Any thoughts?

We can canonicalize x * (y & -y) into x << cttz(y) in InstCombine. Then we reverse the transform in CGP if cttz is not natively supported by the target.

Alive2: https://alive2.llvm.org/ce/z/vhskjz

@dtcxzyw
Copy link
Member

dtcxzyw commented Mar 12, 2024

Confirmed that the pattern x * (y & -y) exists in openmpi :)

@dtcxzyw dtcxzyw self-assigned this Mar 12, 2024
@dtcxzyw
Copy link
Member

dtcxzyw commented Mar 12, 2024

We can canonicalize x * (y & -y) into x << cttz(y) in InstCombine.

Looks like the canonicalization is valueless :(

diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 371ad41ee965..7cfd040e37b9 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -2817,6 +2817,10 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
     return false;
   }
 
+  // See if the operator is a power of 2 (e.g., X & -X).
+  if (isKnownToBeAPowerOfTwo(I, /*OrZero*/ false, Depth, Q))
+    return true;
+
   KnownBits Known(BitWidth);
   computeKnownBits(I, DemandedElts, Known, Depth, Q);
   return Known.One != 0;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 278be6233f4b..7ec87a3ab4bd 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -505,18 +505,20 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
   //        (shl Op1, Log2(Op0))
   //    if Log2(Op1) folds away ->
   //        (shl Op0, Log2(Op1))
-  if (takeLog2(Builder, Op0, /*Depth*/ 0, /*AssumeNonZero*/ false,
+  bool Op0IsNonZero = isKnownNonZero(Op0, DL, /*Depth=*/0, &AC, &I, &DT);
+  if (takeLog2(Builder, Op0, /*Depth*/ 0, /*AssumeNonZero*/ Op0IsNonZero,
                /*DoFold*/ false)) {
-    Value *Res = takeLog2(Builder, Op0, /*Depth*/ 0, /*AssumeNonZero*/ false,
+    Value *Res = takeLog2(Builder, Op0, /*Depth*/ 0, /*AssumeNonZero*/ Op0IsNonZero,
                           /*DoFold*/ true);
     BinaryOperator *Shl = BinaryOperator::CreateShl(Op1, Res);
     // We can only propegate nuw flag.
     Shl->setHasNoUnsignedWrap(HasNUW);
     return Shl;
   }
-  if (takeLog2(Builder, Op1, /*Depth*/ 0, /*AssumeNonZero*/ false,
+  bool Op1IsNonZero = isKnownNonZero(Op1, DL, /*Depth=*/0, &AC, &I, &DT);
+  if (takeLog2(Builder, Op1, /*Depth*/ 0, /*AssumeNonZero*/ Op1IsNonZero,
                /*DoFold*/ false)) {
-    Value *Res = takeLog2(Builder, Op1, /*Depth*/ 0, /*AssumeNonZero*/ false,
+    Value *Res = takeLog2(Builder, Op1, /*Depth*/ 0, /*AssumeNonZero*/ Op1IsNonZero,
                           /*DoFold*/ true);
     BinaryOperator *Shl = BinaryOperator::CreateShl(Op0, Res);
     // We can only propegate nuw flag.
@@ -1328,6 +1330,13 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth,
         });
   }
 
+  // log2(X & -X) -> cttz(X)
+  // FIXME: Require one use?
+  if (AssumeNonZero && match(Op, m_c_And(m_Value(X), m_Neg(m_Deferred(X)))))
+    return IfFold([&]() {
+      return Builder.CreateBinaryIntrinsic(Intrinsic::cttz, X, Builder.getTrue());
+    });
+
   return nullptr;
 }

I am preparing a patch for CGP.

@Explorer09
Copy link

Is this the same issue that I've reported in GCC bug tracker? GCC bug 114341

Just for reference :)

dtcxzyw added a commit that referenced this issue May 29, 2024
…is unsupported (#85066)

This patch fold `shl X, cttz(Y)` to `mul (Y & -Y), X` if cttz is
unsupported by the target.
Alive2: https://alive2.llvm.org/ce/z/AtLN5Y
Fixes #84763.
vg0204 pushed a commit to vg0204/llvm-project that referenced this issue May 29, 2024
…is unsupported (llvm#85066)

This patch fold `shl X, cttz(Y)` to `mul (Y & -Y), X` if cttz is
unsupported by the target.
Alive2: https://alive2.llvm.org/ce/z/AtLN5Y
Fixes llvm#84763.
@EugeneZelenko EugeneZelenko added llvm:SelectionDAG SelectionDAGISel as well and removed llvm:optimizations labels May 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants