Skip to content

Commit 77eb056

Browse files
authored
[InstCombine] Simplify select using KnownBits of condition (#95923)
Simplify the arms of a select based on the KnownBits implied by its condition. For now this only handles the case where the select arm folds to a constant, but this can be generalized to handle other patterns by using SimplifyDemandedBits instead (in that case we would also have to limit to non-undef conditions). This is implemented by adding a new member to SimplifyQuery that can be used to inject an additional condition. The affected values are pre-computed and we don't call computeKnownBits() if the select arms don't contain affected values. This reduces the cost in some pathological cases.
1 parent 6859e5a commit 77eb056

File tree

7 files changed

+165
-110
lines changed

7 files changed

+165
-110
lines changed

llvm/include/llvm/Analysis/SimplifyQuery.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#ifndef LLVM_ANALYSIS_SIMPLIFYQUERY_H
1010
#define LLVM_ANALYSIS_SIMPLIFYQUERY_H
1111

12+
#include "llvm/ADT/SmallPtrSet.h"
1213
#include "llvm/IR/Operator.h"
1314

1415
namespace llvm {
@@ -57,13 +58,23 @@ struct InstrInfoQuery {
5758
}
5859
};
5960

61+
/// Evaluate query assuming this condition holds.
62+
struct CondContext {
63+
Value *Cond;
64+
bool Invert = false;
65+
SmallPtrSet<Value *, 4> AffectedValues;
66+
67+
CondContext(Value *Cond) : Cond(Cond) {}
68+
};
69+
6070
struct SimplifyQuery {
6171
const DataLayout &DL;
6272
const TargetLibraryInfo *TLI = nullptr;
6373
const DominatorTree *DT = nullptr;
6474
AssumptionCache *AC = nullptr;
6575
const Instruction *CxtI = nullptr;
6676
const DomConditionCache *DC = nullptr;
77+
const CondContext *CC = nullptr;
6778

6879
// Wrapper to query additional information for instructions like metadata or
6980
// keywords like nsw, which provides conservative results if those cannot
@@ -113,6 +124,12 @@ struct SimplifyQuery {
113124
Copy.DC = nullptr;
114125
return Copy;
115126
}
127+
128+
SimplifyQuery getWithCondContext(const CondContext &CC) const {
129+
SimplifyQuery Copy(*this);
130+
Copy.CC = &CC;
131+
return Copy;
132+
}
116133
};
117134

118135
} // end namespace llvm

llvm/lib/Analysis/ValueTracking.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -771,6 +771,10 @@ static void computeKnownBitsFromCond(const Value *V, Value *Cond,
771771

772772
void llvm::computeKnownBitsFromContext(const Value *V, KnownBits &Known,
773773
unsigned Depth, const SimplifyQuery &Q) {
774+
// Handle injected condition.
775+
if (Q.CC && Q.CC->AffectedValues.contains(V))
776+
computeKnownBitsFromCond(V, Q.CC->Cond, Known, Depth, Q, Q.CC->Invert);
777+
774778
if (!Q.CxtI)
775779
return;
776780

llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3519,6 +3519,33 @@ static bool matchFMulByZeroIfResultEqZero(InstCombinerImpl &IC, Value *Cmp0,
35193519
return false;
35203520
}
35213521

3522+
/// Check whether the KnownBits of a select arm may be affected by the
3523+
/// select condition.
3524+
static bool hasAffectedValue(Value *V, SmallPtrSetImpl<Value *> &Affected,
3525+
unsigned Depth) {
3526+
if (Depth == MaxAnalysisRecursionDepth)
3527+
return false;
3528+
3529+
// Ignore the case where the select arm itself is affected. These cases
3530+
// are handled more efficiently by other optimizations.
3531+
if (Depth != 0 && Affected.contains(V))
3532+
return true;
3533+
3534+
if (auto *I = dyn_cast<Instruction>(V)) {
3535+
if (isa<PHINode>(I)) {
3536+
if (Depth == MaxAnalysisRecursionDepth - 1)
3537+
return false;
3538+
Depth = MaxAnalysisRecursionDepth - 2;
3539+
}
3540+
return any_of(I->operands(), [&](Value *Op) {
3541+
return Op->getType()->isIntOrIntVectorTy() &&
3542+
hasAffectedValue(Op, Affected, Depth + 1);
3543+
});
3544+
}
3545+
3546+
return false;
3547+
}
3548+
35223549
Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
35233550
Value *CondVal = SI.getCondition();
35243551
Value *TrueVal = SI.getTrueValue();
@@ -4016,5 +4043,33 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
40164043
if (CondVal->getType() == SI.getType() && isKnownInversion(FalseVal, TrueVal))
40174044
return BinaryOperator::CreateXor(CondVal, FalseVal);
40184045

4046+
if (SelType->isIntOrIntVectorTy() &&
4047+
(!isa<Constant>(TrueVal) || !isa<Constant>(FalseVal))) {
4048+
// Try to simplify select arms based on KnownBits implied by the condition.
4049+
CondContext CC(CondVal);
4050+
findValuesAffectedByCondition(CondVal, /*IsAssume=*/false, [&](Value *V) {
4051+
CC.AffectedValues.insert(V);
4052+
});
4053+
SimplifyQuery Q = SQ.getWithInstruction(&SI).getWithCondContext(CC);
4054+
if (!CC.AffectedValues.empty()) {
4055+
if (!isa<Constant>(TrueVal) &&
4056+
hasAffectedValue(TrueVal, CC.AffectedValues, /*Depth=*/0)) {
4057+
KnownBits Known = llvm::computeKnownBits(TrueVal, /*Depth=*/0, Q);
4058+
if (Known.isConstant())
4059+
return replaceOperand(SI, 1,
4060+
ConstantInt::get(SelType, Known.getConstant()));
4061+
}
4062+
4063+
CC.Invert = true;
4064+
if (!isa<Constant>(FalseVal) &&
4065+
hasAffectedValue(FalseVal, CC.AffectedValues, /*Depth=*/0)) {
4066+
KnownBits Known = llvm::computeKnownBits(FalseVal, /*Depth=*/0, Q);
4067+
if (Known.isConstant())
4068+
return replaceOperand(SI, 2,
4069+
ConstantInt::get(SelType, Known.getConstant()));
4070+
}
4071+
}
4072+
}
4073+
40194074
return nullptr;
40204075
}

llvm/test/Transforms/InstCombine/select-binop-cmp.ll

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -571,10 +571,7 @@ define <2 x i8> @select_xor_icmp_vec_bad(<2 x i8> %x, <2 x i8> %y, <2 x i8> %z)
571571

572572
define <2 x i32> @vec_select_no_equivalence(<2 x i32> %x) {
573573
; CHECK-LABEL: @vec_select_no_equivalence(
574-
; CHECK-NEXT: [[X10:%.*]] = shufflevector <2 x i32> [[X:%.*]], <2 x i32> poison, <2 x i32> <i32 1, i32 0>
575-
; CHECK-NEXT: [[COND:%.*]] = icmp eq <2 x i32> [[X]], zeroinitializer
576-
; CHECK-NEXT: [[S:%.*]] = select <2 x i1> [[COND]], <2 x i32> [[X10]], <2 x i32> [[X]]
577-
; CHECK-NEXT: ret <2 x i32> [[S]]
574+
; CHECK-NEXT: ret <2 x i32> [[X:%.*]]
578575
;
579576
%x10 = shufflevector <2 x i32> %x, <2 x i32> undef, <2 x i32> <i32 1, i32 0>
580577
%cond = icmp eq <2 x i32> %x, zeroinitializer

llvm/test/Transforms/InstCombine/select-of-bittest.ll

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -588,11 +588,9 @@ define i32 @n4(i32 %arg) {
588588

589589
define i32 @n5(i32 %arg) {
590590
; CHECK-LABEL: @n5(
591-
; CHECK-NEXT: [[T:%.*]] = and i32 [[ARG:%.*]], 2
592-
; CHECK-NEXT: [[T1:%.*]] = icmp eq i32 [[T]], 0
593-
; CHECK-NEXT: [[T2:%.*]] = and i32 [[ARG]], 2
594-
; CHECK-NEXT: [[T3:%.*]] = select i1 [[T1]], i32 [[T2]], i32 1
595-
; CHECK-NEXT: ret i32 [[T3]]
591+
; CHECK-NEXT: [[T:%.*]] = lshr i32 [[ARG:%.*]], 1
592+
; CHECK-NEXT: [[T_LOBIT:%.*]] = and i32 [[T]], 1
593+
; CHECK-NEXT: ret i32 [[T_LOBIT]]
596594
;
597595
%t = and i32 %arg, 2
598596
%t1 = icmp eq i32 %t, 0

llvm/test/Transforms/InstCombine/select.ll

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3807,9 +3807,8 @@ define i32 @src_and_eq_neg1_or_xor(i32 %x, i32 %y) {
38073807
; CHECK-NEXT: entry:
38083808
; CHECK-NEXT: [[AND:%.*]] = and i32 [[Y:%.*]], [[X:%.*]]
38093809
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[AND]], -1
3810-
; CHECK-NEXT: [[OR:%.*]] = or i32 [[Y]], [[X]]
38113810
; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[Y]], [[X]]
3812-
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 [[OR]], i32 [[XOR]]
3811+
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 -1, i32 [[XOR]]
38133812
; CHECK-NEXT: ret i32 [[COND]]
38143813
;
38153814
entry:
@@ -3827,9 +3826,8 @@ define i32 @src_and_eq_neg1_xor_or(i32 %x, i32 %y) {
38273826
; CHECK-NEXT: entry:
38283827
; CHECK-NEXT: [[AND:%.*]] = and i32 [[Y:%.*]], [[X:%.*]]
38293828
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[AND]], -1
3830-
; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[Y]], [[X]]
38313829
; CHECK-NEXT: [[OR:%.*]] = or i32 [[Y]], [[X]]
3832-
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 [[XOR]], i32 [[OR]]
3830+
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 0, i32 [[OR]]
38333831
; CHECK-NEXT: ret i32 [[COND]]
38343832
;
38353833
entry:
@@ -3942,9 +3940,8 @@ define i32 @src_or_eq_0_and_xor(i32 %x, i32 %y) {
39423940
; CHECK-NEXT: entry:
39433941
; CHECK-NEXT: [[OR:%.*]] = or i32 [[Y:%.*]], [[X:%.*]]
39443942
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[OR]], 0
3945-
; CHECK-NEXT: [[AND:%.*]] = and i32 [[Y]], [[X]]
39463943
; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[Y]], [[X]]
3947-
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 [[AND]], i32 [[XOR]]
3944+
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 0, i32 [[XOR]]
39483945
; CHECK-NEXT: ret i32 [[COND]]
39493946
;
39503947
entry:
@@ -3962,9 +3959,8 @@ define i32 @src_or_eq_0_xor_and(i32 %x, i32 %y) {
39623959
; CHECK-NEXT: entry:
39633960
; CHECK-NEXT: [[OR:%.*]] = or i32 [[Y:%.*]], [[X:%.*]]
39643961
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[OR]], 0
3965-
; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[Y]], [[X]]
39663962
; CHECK-NEXT: [[AND:%.*]] = and i32 [[Y]], [[X]]
3967-
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 [[XOR]], i32 [[AND]]
3963+
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 0, i32 [[AND]]
39683964
; CHECK-NEXT: ret i32 [[COND]]
39693965
;
39703966
entry:
@@ -4474,10 +4470,7 @@ define i32 @src_no_trans_select_or_eq0_or_xor(i32 %x, i32 %y) {
44744470
define i32 @src_no_trans_select_or_eq0_and_or(i32 %x, i32 %y) {
44754471
; CHECK-LABEL: @src_no_trans_select_or_eq0_and_or(
44764472
; CHECK-NEXT: [[OR:%.*]] = or i32 [[X:%.*]], [[Y:%.*]]
4477-
; CHECK-NEXT: [[OR0:%.*]] = icmp eq i32 [[OR]], 0
4478-
; CHECK-NEXT: [[AND:%.*]] = and i32 [[X]], [[Y]]
4479-
; CHECK-NEXT: [[COND:%.*]] = select i1 [[OR0]], i32 [[AND]], i32 [[OR]]
4480-
; CHECK-NEXT: ret i32 [[COND]]
4473+
; CHECK-NEXT: ret i32 [[OR]]
44814474
;
44824475
%or = or i32 %x, %y
44834476
%or0 = icmp eq i32 %or, 0
@@ -4489,10 +4482,7 @@ define i32 @src_no_trans_select_or_eq0_and_or(i32 %x, i32 %y) {
44894482
define i32 @src_no_trans_select_or_eq0_xor_or(i32 %x, i32 %y) {
44904483
; CHECK-LABEL: @src_no_trans_select_or_eq0_xor_or(
44914484
; CHECK-NEXT: [[OR:%.*]] = or i32 [[X:%.*]], [[Y:%.*]]
4492-
; CHECK-NEXT: [[OR0:%.*]] = icmp eq i32 [[OR]], 0
4493-
; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[X]], [[Y]]
4494-
; CHECK-NEXT: [[COND:%.*]] = select i1 [[OR0]], i32 [[XOR]], i32 [[OR]]
4495-
; CHECK-NEXT: ret i32 [[COND]]
4485+
; CHECK-NEXT: ret i32 [[OR]]
44964486
;
44974487
%or = or i32 %x, %y
44984488
%or0 = icmp eq i32 %or, 0

0 commit comments

Comments
 (0)