Skip to content

Commit c41b4b6

Browse files
committed
[InstCombine] Make flag drop during select equiv fold more generic
Instead of unsetting flags on the instruction, attempting the fold, and the resetting the flags if it failed, add support to simplifyWithOpReplaced() to ignore poison-generating flags/metadata and collect all instructions where they may need to be dropped. This allows us to perform the fold a) with poison-generating metadata, which was previously not handled and b) poison-generating flags/metadata that are not on the root instruction. Proof for the ctpop case: https://alive2.llvm.org/ce/z/3H3HFs Fixes #62450.
1 parent f502ab7 commit c41b4b6

File tree

7 files changed

+41
-51
lines changed

7 files changed

+41
-51
lines changed

llvm/include/llvm/Analysis/InstructionSimplify.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,8 +339,14 @@ simplifyInstructionWithOperands(Instruction *I, ArrayRef<Value *> NewOps,
339339
/// AllowRefinement specifies whether the simplification can be a refinement
340340
/// (e.g. 0 instead of poison), or whether it needs to be strictly identical.
341341
/// Op and RepOp can be assumed to not be poison when determining refinement.
342-
Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
343-
const SimplifyQuery &Q, bool AllowRefinement);
342+
///
343+
/// If DropFlags is passed, then the replacement result is only valid if
344+
/// poison-generating flags/metadata on those instructions are dropped. This
345+
/// is only useful in conjunction with AllowRefinement=false.
346+
Value *
347+
simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
348+
const SimplifyQuery &Q, bool AllowRefinement,
349+
SmallVectorImpl<Instruction *> *DropFlags = nullptr);
344350

345351
/// Replace all uses of 'I' with 'SimpleV' and simplify the uses recursively.
346352
///

llvm/lib/Analysis/InstructionSimplify.cpp

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4299,6 +4299,7 @@ Value *llvm::simplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
42994299
static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
43004300
const SimplifyQuery &Q,
43014301
bool AllowRefinement,
4302+
SmallVectorImpl<Instruction *> *DropFlags,
43024303
unsigned MaxRecurse) {
43034304
// Trivial replacement.
43044305
if (V == Op)
@@ -4333,7 +4334,7 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
43334334
bool AnyReplaced = false;
43344335
for (Value *InstOp : I->operands()) {
43354336
if (Value *NewInstOp = simplifyWithOpReplaced(
4336-
InstOp, Op, RepOp, Q, AllowRefinement, MaxRecurse)) {
4337+
InstOp, Op, RepOp, Q, AllowRefinement, DropFlags, MaxRecurse)) {
43374338
NewOps.push_back(NewInstOp);
43384339
AnyReplaced = InstOp != NewInstOp;
43394340
} else {
@@ -4427,16 +4428,23 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
44274428
// will be done in InstCombine).
44284429
// TODO: This may be unsound, because it only catches some forms of
44294430
// refinement.
4430-
if (!AllowRefinement && canCreatePoison(cast<Operator>(I)))
4431-
return nullptr;
4431+
if (!AllowRefinement) {
4432+
if (canCreatePoison(cast<Operator>(I), !DropFlags))
4433+
return nullptr;
4434+
Constant *Res = ConstantFoldInstOperands(I, ConstOps, Q.DL, Q.TLI);
4435+
if (DropFlags && Res && I->hasPoisonGeneratingFlagsOrMetadata())
4436+
DropFlags->push_back(I);
4437+
return Res;
4438+
}
44324439

44334440
return ConstantFoldInstOperands(I, ConstOps, Q.DL, Q.TLI);
44344441
}
44354442

44364443
Value *llvm::simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
44374444
const SimplifyQuery &Q,
4438-
bool AllowRefinement) {
4439-
return ::simplifyWithOpReplaced(V, Op, RepOp, Q, AllowRefinement,
4445+
bool AllowRefinement,
4446+
SmallVectorImpl<Instruction *> *DropFlags) {
4447+
return ::simplifyWithOpReplaced(V, Op, RepOp, Q, AllowRefinement, DropFlags,
44404448
RecursionLimit);
44414449
}
44424450

@@ -4569,11 +4577,11 @@ static Value *simplifySelectWithICmpEq(Value *CmpLHS, Value *CmpRHS,
45694577
unsigned MaxRecurse) {
45704578
if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q,
45714579
/* AllowRefinement */ false,
4572-
MaxRecurse) == TrueVal)
4580+
/* DropFlags */ nullptr, MaxRecurse) == TrueVal)
45734581
return FalseVal;
45744582
if (simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q,
45754583
/* AllowRefinement */ true,
4576-
MaxRecurse) == FalseVal)
4584+
/* DropFlags */ nullptr, MaxRecurse) == FalseVal)
45774585
return FalseVal;
45784586

45794587
return nullptr;

llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1309,45 +1309,28 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel,
13091309
return nullptr;
13101310

13111311
// InstSimplify already performed this fold if it was possible subject to
1312-
// current poison-generating flags. Try the transform again with
1313-
// poison-generating flags temporarily dropped.
1314-
bool WasNUW = false, WasNSW = false, WasExact = false, WasInBounds = false;
1315-
if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(FalseVal)) {
1316-
WasNUW = OBO->hasNoUnsignedWrap();
1317-
WasNSW = OBO->hasNoSignedWrap();
1318-
FalseInst->setHasNoUnsignedWrap(false);
1319-
FalseInst->setHasNoSignedWrap(false);
1320-
}
1321-
if (auto *PEO = dyn_cast<PossiblyExactOperator>(FalseVal)) {
1322-
WasExact = PEO->isExact();
1323-
FalseInst->setIsExact(false);
1324-
}
1325-
if (auto *GEP = dyn_cast<GetElementPtrInst>(FalseVal)) {
1326-
WasInBounds = GEP->isInBounds();
1327-
GEP->setIsInBounds(false);
1328-
}
1312+
// current poison-generating flags. Check whether dropping poison-generating
1313+
// flags enables the transform.
13291314

13301315
// Try each equivalence substitution possibility.
13311316
// We have an 'EQ' comparison, so the select's false value will propagate.
13321317
// Example:
13331318
// (X == 42) ? 43 : (X + 1) --> (X == 42) ? (X + 1) : (X + 1) --> X + 1
1319+
SmallVector<Instruction *> DropFlags;
13341320
if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, SQ,
1335-
/* AllowRefinement */ false) == TrueVal ||
1321+
/* AllowRefinement */ false,
1322+
&DropFlags) == TrueVal ||
13361323
simplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, SQ,
1337-
/* AllowRefinement */ false) == TrueVal) {
1324+
/* AllowRefinement */ false,
1325+
&DropFlags) == TrueVal) {
1326+
for (Instruction *I : DropFlags) {
1327+
I->dropPoisonGeneratingFlagsAndMetadata();
1328+
Worklist.add(I);
1329+
}
1330+
13381331
return replaceInstUsesWith(Sel, FalseVal);
13391332
}
13401333

1341-
// Restore poison-generating flags if the transform did not apply.
1342-
if (WasNUW)
1343-
FalseInst->setHasNoUnsignedWrap();
1344-
if (WasNSW)
1345-
FalseInst->setHasNoSignedWrap();
1346-
if (WasExact)
1347-
FalseInst->setIsExact();
1348-
if (WasInBounds)
1349-
cast<GetElementPtrInst>(FalseInst)->setIsInBounds();
1350-
13511334
return nullptr;
13521335
}
13531336

llvm/test/Transforms/InstCombine/bit_ceil.ll

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,9 @@ define i32 @bit_ceil_commuted_operands(i32 %x) {
148148
; CHECK-LABEL: @bit_ceil_commuted_operands(
149149
; CHECK-NEXT: [[DEC:%.*]] = add i32 [[X:%.*]], -1
150150
; CHECK-NEXT: [[CTLZ:%.*]] = tail call i32 @llvm.ctlz.i32(i32 [[DEC]], i1 false), !range [[RNG0]]
151-
; CHECK-NEXT: [[TMP1:%.*]] = sub nsw i32 0, [[CTLZ]]
152-
; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[TMP1]], 31
153-
; CHECK-NEXT: [[SEL:%.*]] = shl nuw i32 1, [[TMP2]]
154-
; CHECK-NEXT: ret i32 [[SEL]]
151+
; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw i32 32, [[CTLZ]]
152+
; CHECK-NEXT: [[SHL:%.*]] = shl nuw i32 1, [[SUB]]
153+
; CHECK-NEXT: ret i32 [[SHL]]
155154
;
156155
%dec = add i32 %x, -1
157156
%ctlz = tail call i32 @llvm.ctlz.i32(i32 %dec, i1 false)

llvm/test/Transforms/InstCombine/ctpop.ll

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -479,9 +479,7 @@ define i32 @parity_xor_extra_use2(i32 %arg, i32 %arg1) {
479479
define i32 @select_ctpop_zero(i32 %x) {
480480
; CHECK-LABEL: @select_ctpop_zero(
481481
; CHECK-NEXT: [[CTPOP:%.*]] = call i32 @llvm.ctpop.i32(i32 [[X:%.*]]), !range [[RNG1]]
482-
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[X]], 0
483-
; CHECK-NEXT: [[RES:%.*]] = select i1 [[CMP]], i32 0, i32 [[CTPOP]]
484-
; CHECK-NEXT: ret i32 [[RES]]
482+
; CHECK-NEXT: ret i32 [[CTPOP]]
485483
;
486484
%ctpop = call i32 @llvm.ctpop.i32(i32 %x)
487485
%cmp = icmp eq i32 %x, 0

llvm/test/Transforms/InstCombine/ispow2.ll

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -345,9 +345,7 @@ define i1 @is_pow2_ctpop_wrong_pred1_logical(i32 %x) {
345345
; CHECK-LABEL: @is_pow2_ctpop_wrong_pred1_logical(
346346
; CHECK-NEXT: [[T0:%.*]] = tail call i32 @llvm.ctpop.i32(i32 [[X:%.*]]), !range [[RNG0]]
347347
; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i32 [[T0]], 2
348-
; CHECK-NEXT: [[NOTZERO:%.*]] = icmp ne i32 [[X]], 0
349-
; CHECK-NEXT: [[R:%.*]] = select i1 [[NOTZERO]], i1 [[CMP]], i1 false
350-
; CHECK-NEXT: ret i1 [[R]]
348+
; CHECK-NEXT: ret i1 [[CMP]]
351349
;
352350
%t0 = tail call i32 @llvm.ctpop.i32(i32 %x)
353351
%cmp = icmp ugt i32 %t0, 2

llvm/test/Transforms/LoopVectorize/reduction-inloop.ll

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,12 +1122,10 @@ define i32 @predicated_not_dominates_reduction(ptr nocapture noundef readonly %h
11221122
; CHECK-NEXT: [[TMP0:%.*]] = sext i32 [[INDEX]] to i64
11231123
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i8, ptr [[H:%.*]], i64 [[TMP0]]
11241124
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i8>, ptr [[TMP1]], align 1
1125-
; CHECK-NEXT: [[DOTNOT:%.*]] = icmp eq <4 x i8> [[WIDE_LOAD]], zeroinitializer
11261125
; CHECK-NEXT: [[TMP2:%.*]] = udiv <4 x i8> [[WIDE_LOAD]], <i8 31, i8 31, i8 31, i8 31>
11271126
; CHECK-NEXT: [[TMP3:%.*]] = shl nuw nsw <4 x i8> [[TMP2]], <i8 3, i8 3, i8 3, i8 3>
11281127
; CHECK-NEXT: [[TMP4:%.*]] = udiv <4 x i8> [[TMP3]], <i8 31, i8 31, i8 31, i8 31>
1129-
; CHECK-NEXT: [[NARROW:%.*]] = select <4 x i1> [[DOTNOT]], <4 x i8> zeroinitializer, <4 x i8> [[TMP4]]
1130-
; CHECK-NEXT: [[TMP5:%.*]] = zext <4 x i8> [[NARROW]] to <4 x i32>
1128+
; CHECK-NEXT: [[TMP5:%.*]] = zext <4 x i8> [[TMP4]] to <4 x i32>
11311129
; CHECK-NEXT: [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP5]])
11321130
; CHECK-NEXT: [[TMP7]] = add i32 [[TMP6]], [[VEC_PHI]]
11331131
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4

0 commit comments

Comments
 (0)