-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[InstCombine] Factorise Add and Min/Max using Distributivity #101717
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write If you have received no comments on your PR for a week, you can request a review If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-llvm-ir @llvm/pr-subscribers-llvm-transforms Author: Jorge Botto (jf-botto) ChangesThis PR fixes part of #92433. It specifically adds the 4 cases mentioned in #92433 (comment). I've added 8 positive tests, 4 of which are mentioned in the comment above and 4 which are their commutative equivalents. Alive proof: https://alive2.llvm.org/ce/z/z6eFTb Full diff: https://github.com/llvm/llvm-project/pull/101717.diff 3 Files Affected:
diff --git a/llvm/include/llvm/IR/Operator.h b/llvm/include/llvm/IR/Operator.h
index f63f54ef94107..ec8b3f4b6318f 100644
--- a/llvm/include/llvm/IR/Operator.h
+++ b/llvm/include/llvm/IR/Operator.h
@@ -123,6 +123,9 @@ class OverflowingBinaryOperator : public Operator {
return NoWrapKind;
}
+ /// Return true if the instruction is commutative:
+ bool isCommutative() const { return Instruction::isCommutative(getOpcode()); }
+
static bool classof(const Instruction *I) {
return I->getOpcode() == Instruction::Add ||
I->getOpcode() == Instruction::Sub ||
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index cc68fd4cf1c1b..8944eec2d63d4 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1505,6 +1505,97 @@ foldMinimumOverTrailingOrLeadingZeroCount(Value *I0, Value *I1,
ConstantInt::getTrue(ZeroUndef->getType()));
}
+/// Return whether "X LOp (Y ROp Z)" is always equal to
+/// "(X LOp Y) ROp (X LOp Z)".
+static bool leftDistributesOverRightIntrinsic(Instruction::BinaryOps LOp,
+ bool hasNUW, bool hasNSW,
+ Intrinsic::ID ROp) {
+ switch (ROp) {
+ case Intrinsic::umax:
+ return hasNUW && LOp == Instruction::Add;
+ case Intrinsic::umin:
+ return hasNUW && LOp == Instruction::Add;
+ case Intrinsic::smax:
+ return hasNSW && LOp == Instruction::Add;
+ case Intrinsic::smin:
+ return hasNSW && LOp == Instruction::Add;
+ default:
+ return false;
+ }
+}
+
+// Attempts to factorise a common term
+// in an instruction that has the form "(A op' B) op (C op' D)
+// where op is an intrinsic and op' is a binop
+static Value *
+foldIntrinsicUsingDistributiveLaws(IntrinsicInst *II,
+ InstCombiner::BuilderTy &Builder) {
+ Value *LHS = II->getOperand(0), *RHS = II->getOperand(1);
+ Intrinsic::ID TopLevelOpcode = II->getIntrinsicID();
+
+ OverflowingBinaryOperator *Op0 = dyn_cast<OverflowingBinaryOperator>(LHS);
+ OverflowingBinaryOperator *Op1 = dyn_cast<OverflowingBinaryOperator>(RHS);
+
+ if (!Op0 || !Op1)
+ return nullptr;
+
+ if (Op0->getOpcode() != Op1->getOpcode())
+ return nullptr;
+
+ if (!(Op0->hasNoUnsignedWrap() == Op1->hasNoUnsignedWrap()) ||
+ !(Op0->hasNoSignedWrap() == Op1->hasNoSignedWrap()))
+ return nullptr;
+
+ if (!Op0->hasOneUse() || !Op1->hasOneUse())
+ return nullptr;
+
+ Instruction::BinaryOps InnerOpcode =
+ static_cast<Instruction::BinaryOps>(Op0->getOpcode());
+ bool HasNUW = Op0->hasNoUnsignedWrap();
+ bool HasNSW = Op0->hasNoSignedWrap();
+
+ if (!InnerOpcode)
+ return nullptr;
+
+ if (!leftDistributesOverRightIntrinsic(InnerOpcode, HasNUW, HasNSW,
+ TopLevelOpcode))
+ return nullptr;
+
+ assert(II->isCommutative() && Op0->isCommutative() &&
+ "Only inner and outer commutative op codes are supported.");
+
+ Value *A = Op0->getOperand(0);
+ Value *B = Op0->getOperand(1);
+ Value *C = Op1->getOperand(0);
+ Value *D = Op1->getOperand(1);
+
+ if (A == C || A == D) {
+ if (A != C)
+ std::swap(C, D);
+
+ Value *NewIntrinsic = Builder.CreateBinaryIntrinsic(TopLevelOpcode, B, D);
+ BinaryOperator *NewBinop =
+ cast<BinaryOperator>(Builder.CreateBinOp(InnerOpcode, NewIntrinsic, A));
+ NewBinop->setHasNoSignedWrap(HasNSW);
+ NewBinop->setHasNoUnsignedWrap(HasNUW);
+ return NewBinop;
+ }
+
+ if (B == D || B == C) {
+ if (B != D)
+ std::swap(C, D);
+
+ Value *NewIntrinsic = Builder.CreateBinaryIntrinsic(TopLevelOpcode, A, C);
+ BinaryOperator *NewBinop =
+ cast<BinaryOperator>(Builder.CreateBinOp(InnerOpcode, NewIntrinsic, B));
+ NewBinop->setHasNoSignedWrap(HasNSW);
+ NewBinop->setHasNoUnsignedWrap(HasNUW);
+ return NewBinop;
+ }
+
+ return nullptr;
+}
+
/// CallInst simplification. This mostly only handles folding of intrinsic
/// instructions. For normal calls, it allows visitCallBase to do the heavy
/// lifting.
@@ -1929,6 +2020,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
}
}
+ if (Value *V = foldIntrinsicUsingDistributiveLaws(II, Builder))
+ return replaceInstUsesWith(*II, V);
+
break;
}
case Intrinsic::bitreverse: {
diff --git a/llvm/test/Transforms/InstCombine/intrinsic-distributive.ll b/llvm/test/Transforms/InstCombine/intrinsic-distributive.ll
new file mode 100644
index 0000000000000..f58ce04cb6711
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/intrinsic-distributive.ll
@@ -0,0 +1,228 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S -passes=instcombine < %s 2>&1 | FileCheck %s
+
+
+define i8 @umax_of_add_nuw(i8 %a, i8 %b, i8 %c) {
+; CHECK-LABEL: define i8 @umax_of_add_nuw(
+; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]], i8 [[C:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.umax.i8(i8 [[B]], i8 [[C]])
+; CHECK-NEXT: [[MAX:%.*]] = add nuw i8 [[TMP1]], [[A]]
+; CHECK-NEXT: ret i8 [[MAX]]
+;
+ %add1 = add nuw i8 %b, %a
+ %add2 = add nuw i8 %c, %a
+ %max = call i8 @llvm.umax.i8(i8 %add1, i8 %add2)
+ ret i8 %max
+}
+
+define i8 @umax_of_add_nuw_comm(i8 %a, i8 %b, i8 %c) {
+; CHECK-LABEL: define i8 @umax_of_add_nuw_comm(
+; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]], i8 [[C:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.umax.i8(i8 [[B]], i8 [[C]])
+; CHECK-NEXT: [[MAX:%.*]] = add nuw i8 [[TMP1]], [[A]]
+; CHECK-NEXT: ret i8 [[MAX]]
+;
+ %add1 = add nuw i8 %a, %b
+ %add2 = add nuw i8 %a, %c
+ %max = call i8 @llvm.umax.i8(i8 %add1, i8 %add2)
+ ret i8 %max
+}
+
+
+; negative test
+define i8 @umax_of_add_nsw(i8 %a, i8 %b, i8 %c) {
+; CHECK-LABEL: define i8 @umax_of_add_nsw(
+; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]], i8 [[C:%.*]]) {
+; CHECK-NEXT: [[ADD1:%.*]] = add nsw i8 [[B]], [[A]]
+; CHECK-NEXT: [[ADD2:%.*]] = add nsw i8 [[C]], [[A]]
+; CHECK-NEXT: [[MAX:%.*]] = call i8 @llvm.umax.i8(i8 [[ADD1]], i8 [[ADD2]])
+; CHECK-NEXT: ret i8 [[MAX]]
+;
+ %add1 = add nsw i8 %b, %a
+ %add2 = add nsw i8 %c, %a
+ %max = call i8 @llvm.umax.i8(i8 %add1, i8 %add2)
+ ret i8 %max
+}
+
+; negative test
+define i8 @umax_of_add(i8 %a, i8 %b, i8 %c) {
+; CHECK-LABEL: define i8 @umax_of_add(
+; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]], i8 [[C:%.*]]) {
+; CHECK-NEXT: [[ADD1:%.*]] = add i8 [[B]], [[A]]
+; CHECK-NEXT: [[ADD2:%.*]] = add i8 [[C]], [[A]]
+; CHECK-NEXT: [[MAX:%.*]] = call i8 @llvm.umax.i8(i8 [[ADD1]], i8 [[ADD2]])
+; CHECK-NEXT: ret i8 [[MAX]]
+;
+ %add1 = add i8 %b, %a
+ %add2 = add i8 %c, %a
+ %max = call i8 @llvm.umax.i8(i8 %add1, i8 %add2)
+ ret i8 %max
+}
+
+define i8 @umin_of_add_nuw(i8 %a, i8 %b, i8 %c) {
+; CHECK-LABEL: define i8 @umin_of_add_nuw(
+; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]], i8 [[C:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.umin.i8(i8 [[B]], i8 [[C]])
+; CHECK-NEXT: [[MIN:%.*]] = add nuw i8 [[TMP1]], [[A]]
+; CHECK-NEXT: ret i8 [[MIN]]
+;
+ %add1 = add nuw i8 %b, %a
+ %add2 = add nuw i8 %c, %a
+ %min = call i8 @llvm.umin.i8(i8 %add1, i8 %add2)
+ ret i8 %min
+}
+
+define i8 @umin_of_add_nuw_comm(i8 %a, i8 %b, i8 %c) {
+; CHECK-LABEL: define i8 @umin_of_add_nuw_comm(
+; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]], i8 [[C:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.umin.i8(i8 [[B]], i8 [[C]])
+; CHECK-NEXT: [[MIN:%.*]] = add nuw i8 [[TMP1]], [[A]]
+; CHECK-NEXT: ret i8 [[MIN]]
+;
+ %add1 = add nuw i8 %a, %b
+ %add2 = add nuw i8 %a, %c
+ %min = call i8 @llvm.umin.i8(i8 %add1, i8 %add2)
+ ret i8 %min
+}
+
+; negative test
+define i8 @umin_of_add_nsw(i8 %a, i8 %b, i8 %c) {
+; CHECK-LABEL: define i8 @umin_of_add_nsw(
+; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]], i8 [[C:%.*]]) {
+; CHECK-NEXT: [[ADD1:%.*]] = add nsw i8 [[B]], [[A]]
+; CHECK-NEXT: [[ADD2:%.*]] = add nsw i8 [[C]], [[A]]
+; CHECK-NEXT: [[MIN:%.*]] = call i8 @llvm.umin.i8(i8 [[ADD1]], i8 [[ADD2]])
+; CHECK-NEXT: ret i8 [[MIN]]
+;
+ %add1 = add nsw i8 %b, %a
+ %add2 = add nsw i8 %c, %a
+ %min = call i8 @llvm.umin.i8(i8 %add1, i8 %add2)
+ ret i8 %min
+}
+
+; negative test
+define i8 @umin_of_add(i8 %a, i8 %b, i8 %c) {
+; CHECK-LABEL: define i8 @umin_of_add(
+; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]], i8 [[C:%.*]]) {
+; CHECK-NEXT: [[ADD1:%.*]] = add i8 [[B]], [[A]]
+; CHECK-NEXT: [[ADD2:%.*]] = add i8 [[C]], [[A]]
+; CHECK-NEXT: [[MIN:%.*]] = call i8 @llvm.umin.i8(i8 [[ADD1]], i8 [[ADD2]])
+; CHECK-NEXT: ret i8 [[MIN]]
+;
+ %add1 = add i8 %b, %a
+ %add2 = add i8 %c, %a
+ %min = call i8 @llvm.umin.i8(i8 %add1, i8 %add2)
+ ret i8 %min
+}
+
+; negative test
+define i8 @smax_of_add_nuw(i8 %a, i8 %b, i8 %c) {
+; CHECK-LABEL: define i8 @smax_of_add_nuw(
+; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]], i8 [[C:%.*]]) {
+; CHECK-NEXT: [[ADD1:%.*]] = add nuw i8 [[B]], [[A]]
+; CHECK-NEXT: [[ADD2:%.*]] = add nuw i8 [[C]], [[A]]
+; CHECK-NEXT: [[MAX:%.*]] = call i8 @llvm.smax.i8(i8 [[ADD1]], i8 [[ADD2]])
+; CHECK-NEXT: ret i8 [[MAX]]
+;
+ %add1 = add nuw i8 %b, %a
+ %add2 = add nuw i8 %c, %a
+ %max = call i8 @llvm.smax.i8(i8 %add1, i8 %add2)
+ ret i8 %max
+}
+
+define i8 @smax_of_add_nsw(i8 %a, i8 %b, i8 %c) {
+; CHECK-LABEL: define i8 @smax_of_add_nsw(
+; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]], i8 [[C:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.smax.i8(i8 [[B]], i8 [[C]])
+; CHECK-NEXT: [[MAX:%.*]] = add nsw i8 [[TMP1]], [[A]]
+; CHECK-NEXT: ret i8 [[MAX]]
+;
+ %add1 = add nsw i8 %b, %a
+ %add2 = add nsw i8 %c, %a
+ %max = call i8 @llvm.smax.i8(i8 %add1, i8 %add2)
+ ret i8 %max
+}
+
+define i8 @smax_of_add_nsw_comm(i8 %a, i8 %b, i8 %c) {
+; CHECK-LABEL: define i8 @smax_of_add_nsw_comm(
+; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]], i8 [[C:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.smax.i8(i8 [[B]], i8 [[C]])
+; CHECK-NEXT: [[MAX:%.*]] = add nsw i8 [[TMP1]], [[A]]
+; CHECK-NEXT: ret i8 [[MAX]]
+;
+ %add1 = add nsw i8 %a, %b
+ %add2 = add nsw i8 %a, %c
+ %max = call i8 @llvm.smax.i8(i8 %add1, i8 %add2)
+ ret i8 %max
+}
+
+; negative test
+define i8 @smax_of_add(i8 %a, i8 %b, i8 %c) {
+; CHECK-LABEL: define i8 @smax_of_add(
+; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]], i8 [[C:%.*]]) {
+; CHECK-NEXT: [[ADD1:%.*]] = add i8 [[B]], [[A]]
+; CHECK-NEXT: [[ADD2:%.*]] = add i8 [[C]], [[A]]
+; CHECK-NEXT: [[MAX:%.*]] = call i8 @llvm.smax.i8(i8 [[ADD1]], i8 [[ADD2]])
+; CHECK-NEXT: ret i8 [[MAX]]
+;
+ %add1 = add i8 %b, %a
+ %add2 = add i8 %c, %a
+ %max = call i8 @llvm.smax.i8(i8 %add1, i8 %add2)
+ ret i8 %max
+}
+
+; negative test
+define i8 @smin_of_add_nuw(i8 %a, i8 %b, i8 %c) {
+; CHECK-LABEL: define i8 @smin_of_add_nuw(
+; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]], i8 [[C:%.*]]) {
+; CHECK-NEXT: [[ADD1:%.*]] = add nuw i8 [[B]], [[A]]
+; CHECK-NEXT: [[ADD2:%.*]] = add nuw i8 [[C]], [[A]]
+; CHECK-NEXT: [[MIN:%.*]] = call i8 @llvm.smin.i8(i8 [[ADD1]], i8 [[ADD2]])
+; CHECK-NEXT: ret i8 [[MIN]]
+;
+ %add1 = add nuw i8 %b, %a
+ %add2 = add nuw i8 %c, %a
+ %min = call i8 @llvm.smin.i8(i8 %add1, i8 %add2)
+ ret i8 %min
+}
+
+define i8 @smin_of_add_nsw(i8 %a, i8 %b, i8 %c) {
+; CHECK-LABEL: define i8 @smin_of_add_nsw(
+; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]], i8 [[C:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.smin.i8(i8 [[B]], i8 [[C]])
+; CHECK-NEXT: [[MIN:%.*]] = add nsw i8 [[TMP1]], [[A]]
+; CHECK-NEXT: ret i8 [[MIN]]
+;
+ %add1 = add nsw i8 %b, %a
+ %add2 = add nsw i8 %c, %a
+ %min = call i8 @llvm.smin.i8(i8 %add1, i8 %add2)
+ ret i8 %min
+}
+
+define i8 @smin_of_add_nsw_comm(i8 %a, i8 %b, i8 %c) {
+; CHECK-LABEL: define i8 @smin_of_add_nsw_comm(
+; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]], i8 [[C:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.smin.i8(i8 [[B]], i8 [[C]])
+; CHECK-NEXT: [[MIN:%.*]] = add nsw i8 [[TMP1]], [[A]]
+; CHECK-NEXT: ret i8 [[MIN]]
+;
+ %add1 = add nsw i8 %a, %b
+ %add2 = add nsw i8 %a, %c
+ %min = call i8 @llvm.smin.i8(i8 %add1, i8 %add2)
+ ret i8 %min
+}
+
+; negative test
+define i8 @smin_of_add(i8 %a, i8 %b, i8 %c) {
+; CHECK-LABEL: define i8 @smin_of_add(
+; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]], i8 [[C:%.*]]) {
+; CHECK-NEXT: [[ADD1:%.*]] = add i8 [[B]], [[A]]
+; CHECK-NEXT: [[ADD2:%.*]] = add i8 [[C]], [[A]]
+; CHECK-NEXT: [[MIN:%.*]] = call i8 @llvm.smin.i8(i8 [[ADD1]], i8 [[ADD2]])
+; CHECK-NEXT: ret i8 [[MIN]]
+;
+ %add1 = add i8 %b, %a
+ %add2 = add i8 %c, %a
+ %min = call i8 @llvm.smin.i8(i8 %add1, i8 %add2)
+ ret i8 %min
+}
|
@dtcxzyw Here's the PR with a fix for the 4 cases you mention in that comment. |
cast<BinaryOperator>(Builder.CreateBinOp(InnerOpcode, NewIntrinsic, B)); | ||
NewBinop->setHasNoSignedWrap(HasNSW); | ||
NewBinop->setHasNoUnsignedWrap(HasNUW); | ||
return NewBinop; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think the return code has become complex enough to warrant updating the detection logic to:
if(A != C && A != D)
std::swap(A, B);
if (B == D || B == C)
....
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, totally get it. I've simplified the boolean logic into fewer/simpler arguments.
if (Op0->hasNoUnsignedWrap() != Op1->hasNoUnsignedWrap() || | ||
Op0->hasNoSignedWrap() != Op1->hasNoSignedWrap()) | ||
return nullptr; | ||
|
||
if (!Op0->hasOneUse() || !Op1->hasOneUse()) | ||
return nullptr; | ||
|
||
Instruction::BinaryOps InnerOpcode = | ||
static_cast<Instruction::BinaryOps>(Op0->getOpcode()); | ||
bool HasNUW = Op0->hasNoUnsignedWrap(); | ||
bool HasNSW = Op0->hasNoSignedWrap(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (Op0->hasNoUnsignedWrap() != Op1->hasNoUnsignedWrap() || | |
Op0->hasNoSignedWrap() != Op1->hasNoSignedWrap()) | |
return nullptr; | |
if (!Op0->hasOneUse() || !Op1->hasOneUse()) | |
return nullptr; | |
Instruction::BinaryOps InnerOpcode = | |
static_cast<Instruction::BinaryOps>(Op0->getOpcode()); | |
bool HasNUW = Op0->hasNoUnsignedWrap(); | |
bool HasNSW = Op0->hasNoSignedWrap(); | |
if (!Op0->hasOneUse() || !Op1->hasOneUse()) | |
return nullptr; | |
Instruction::BinaryOps InnerOpcode = | |
static_cast<Instruction::BinaryOps>(Op0->getOpcode()); | |
bool HasNUW = Op0->hasNoUnsignedWrap() && Op1->hasNoUnsignedWrap(); | |
bool HasNSW = Op0->hasNoSignedWrap() && Op1->hasNoUnsignedWrap(); |
It is too strict. Please add a test for smin((add nuw nsw X, Y), (add nsw X, Z))
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you. Added the test for smin
and the other 3 intrinsics with various flag combinations.
if (A != C) { | ||
std::swap(C, D); | ||
|
||
if (A != D) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is confusing. A != D
here always evaluates to true.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've rewritten the logic in a clearer way.
; CHECK-NEXT: [[MAX:%.*]] = add nuw i8 [[TMP1]], [[A]] | ||
; CHECK-NEXT: ret i8 [[MAX]] | ||
; | ||
%add1 = add nuw i8 %a, %b |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need a trick to avoid complexity-based canonicalization :)
https://llvm.org/docs/InstCombineContributorGuide.html#add-commuted-tests
4f9908a
to
656f78f
Compare
} | ||
|
||
if (B == D || B == C) | ||
std::swap(A, B); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think you need to still swap this and the above A != C
check.
I would just rewrite all the matching as:
if(A != C && A != D)
std::swap(A, B);
if (A == C || A == D) {
if (A != C)
std::swap(C, D);
// Return NewBinop
}
return nullptr;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah! Now I know what you meant before. Thanks.
if (!InnerOpcode) | ||
return nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (!InnerOpcode) | |
return nullptr; |
It is just a noop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you.
ping |
// Attempts to swap variables such that A always equals C | ||
if (A != C && A != D) | ||
std::swap(A, B); | ||
if (A == C || A == D) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question: I'm just curious but does this work for constant A
s/C
s or splat vectors? For example,
define i8 @f(i8 %x, i8 %y) {
%add1 = add nuw i8 %x, 42
%add2 = add nuw i8 %y, 42
%umin = call i8 @llvm.umin.i8(i8 %add1, i8 %add2)
ret i8 %umin
}
and
define <4 x i8> @src(<4 x i8> %x, <4 x i8> %y) {
%add1 = add nuw <4 x i8> %x, <i8 42, i8 42, i8 42, i8 42>
%add2 = add nuw <4 x i8> %y, <i8 42, i8 42, i8 42, i8 42>
%umin = call <4 x i8> @llvm.umin.v4i8(<4 x i8> %add1, <4 x i8> %add2)
ret <4 x i8> %umin
}
It might be a good idea to add such a test to the precommitted tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would work for constants/splats because the optimisation itself doesn't distinguish between different types of operands. Sure. Will add a test.
case Intrinsic::umax: | ||
return hasNUW && LOp == Instruction::Add; | ||
case Intrinsic::umin: | ||
return hasNUW && LOp == Instruction::Add; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can combine these cases since there are no functions that distribute over umax
but not umin
(or vice versa).
case Intrinsic::umax:
case Intrinsic::umin:
return hasNUW && LOp == Instruction::Add;
Proof sketch: Let f be an arbitrary binary function and x, y, z be arbitrary bit vectors. Suppose that (f u (umax v w)) = (umax (f u v) (f u w))
for all u, v, w. Observe that (umin u v) = (xor u v (umax u v))
for all u, v. Then (umin (f x y) (f x z)) = (xor (f x y) (f x z) (umax (f x y) (f x z))) = (xor (f x y) (f x z) (f x (umax y z)))
. The case (f x y) = (f x z)
is trivial, hence suppose they are not equal. Then (f x (umax y z))
is equal to either (f x y)
or (f x z)
, leaving the other as the result of the xor
, which equals (f x (umin y z))
, as required.
(Similar for smin
/smax
.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Completely understand. Fixed.
@@ -1505,6 +1505,80 @@ foldMinimumOverTrailingOrLeadingZeroCount(Value *I0, Value *I1, | |||
ConstantInt::getTrue(ZeroUndef->getType())); | |||
} | |||
|
|||
/// Return whether "X LOp (Y ROp Z)" is always equal to | |||
/// "(X LOp Y) ROp (X LOp Z)". | |||
static bool foldIntrinsicUsingDistributiveLaws(Instruction::BinaryOps LOp, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my opinion, the name of the function is a bit misleading since it doesn't fold anything but rather checks whether we can apply the transformation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. 100% agree. Fixed.
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with a test nit.
; CHECK-NEXT: [[MAX:%.*]] = add nuw i8 [[TMP1]], [[A]] | ||
; CHECK-NEXT: ret i8 [[MAX]] | ||
; | ||
%add1 = add nuw i8 %b, %a ; thwart complexity-based canonicalization |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These "thwart" comments don't make sense to me -- the add here is part of the folded pattern.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've removed them.
Thank you @nikic. Much appreciated it. Would you mind merging it as I don't have write access? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thank you!
Co-authored-by: Yingwei Zheng <[email protected]>
@jf-botto Congratulations on having your first Pull Request (PR) merged into the LLVM Project! Your changes will be combined with recent changes from other authors, then tested by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR. Please check whether problems have been caused by your change specifically, as the builds can include changes from many authors. It is not uncommon for your change to be included in a build that fails due to someone else's changes, or infrastructure issues. How to do this, and the rest of the post-merge process, is covered in detail here. If your change does cause a problem, it may be reverted, or you can revert it yourself. This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again. If you don't get any reports, no action is required from you. Your changes are working as expected, well done! |
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/169/builds/4907 Here is the relevant piece of the build log for the reference
|
…1717) This PR fixes part of llvm#92433. It specifically adds the 4 cases mentioned in llvm#92433 (comment). I've added 8 positive tests, 4 of which are mentioned in the comment above and 4 which are their commutative equivalents. Alive proof: https://alive2.llvm.org/ce/z/z6eFTb I've also added 8 negative tests, because we want to make sure we do not optimise if the relevant flags are not relevant because the optimisation wouldn't be sound. Alive proof that the optimisation is invalid: https://alive2.llvm.org/ce/z/NvNjTD I did have to make the integer types `i4` to make Alive not timeout and to fit them all on one page.
…1717) This PR fixes part of llvm#92433. It specifically adds the 4 cases mentioned in llvm#92433 (comment). I've added 8 positive tests, 4 of which are mentioned in the comment above and 4 which are their commutative equivalents. Alive proof: https://alive2.llvm.org/ce/z/z6eFTb I've also added 8 negative tests, because we want to make sure we do not optimise if the relevant flags are not relevant because the optimisation wouldn't be sound. Alive proof that the optimisation is invalid: https://alive2.llvm.org/ce/z/NvNjTD I did have to make the integer types `i4` to make Alive not timeout and to fit them all on one page.
This PR fixes part of #92433.
It specifically adds the 4 cases mentioned in #92433 (comment).
I've added 8 positive tests, 4 of which are mentioned in the comment above and 4 which are their commutative equivalents. Alive proof: https://alive2.llvm.org/ce/z/z6eFTb
I've also added 8 negative tests, because we want to make sure we do not optimise if the relevant flags are not relevant because the optimisation wouldn't be sound. Alive proof that the optimisation is invalid: https://alive2.llvm.org/ce/z/NvNjTD
I did have to make the integer types
i4
to make Alive not timeout and to fit them all on one page.