Skip to content

Commit c912f88

Browse files
committed
[InstCombine] Remove false commutativity from processUMulZExtIdiom() (NFCI)
This fold requires a fold against a constant, which will always be on the RHS. If the swapped fold actually did trigger, it would result in a miscompile, because it did not work with the swapped predicate when swapping operands.
1 parent 5640d28 commit c912f88

File tree

1 file changed

+17
-41
lines changed

1 file changed

+17
-41
lines changed

llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp

Lines changed: 17 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -5807,15 +5807,13 @@ bool InstCombinerImpl::OptimizeOverflowCheck(Instruction::BinaryOps BinaryOp,
58075807
/// \returns Instruction which must replace the compare instruction, NULL if no
58085808
/// replacement required.
58095809
static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
5810-
Value *OtherVal,
5810+
const APInt *OtherVal,
58115811
InstCombinerImpl &IC) {
58125812
// Don't bother doing this transformation for pointers, don't do it for
58135813
// vectors.
58145814
if (!isa<IntegerType>(MulVal->getType()))
58155815
return nullptr;
58165816

5817-
assert(I.getOperand(0) == MulVal || I.getOperand(1) == MulVal);
5818-
assert(I.getOperand(0) == OtherVal || I.getOperand(1) == OtherVal);
58195817
auto *MulInstr = dyn_cast<Instruction>(MulVal);
58205818
if (!MulInstr)
58215819
return nullptr;
@@ -5875,28 +5873,26 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
58755873

58765874
// Recognize patterns
58775875
switch (I.getPredicate()) {
5878-
case ICmpInst::ICMP_UGT:
5876+
case ICmpInst::ICMP_UGT: {
58795877
// Recognize pattern:
58805878
// mulval = mul(zext A, zext B)
58815879
// cmp ugt mulval, max
5882-
if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) {
5883-
APInt MaxVal = APInt::getMaxValue(MulWidth);
5884-
MaxVal = MaxVal.zext(CI->getBitWidth());
5885-
if (MaxVal.eq(CI->getValue()))
5886-
break; // Recognized
5887-
}
5880+
APInt MaxVal = APInt::getMaxValue(MulWidth);
5881+
MaxVal = MaxVal.zext(OtherVal->getBitWidth());
5882+
if (MaxVal.eq(*OtherVal))
5883+
break; // Recognized
58885884
return nullptr;
5885+
}
58895886

5890-
case ICmpInst::ICMP_ULT:
5887+
case ICmpInst::ICMP_ULT: {
58915888
// Recognize pattern:
58925889
// mulval = mul(zext A, zext B)
58935890
// cmp ule mulval, max + 1
5894-
if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) {
5895-
APInt MaxVal = APInt::getOneBitSet(CI->getBitWidth(), MulWidth);
5896-
if (MaxVal.eq(CI->getValue()))
5897-
break; // Recognized
5898-
}
5891+
APInt MaxVal = APInt::getOneBitSet(OtherVal->getBitWidth(), MulWidth);
5892+
if (MaxVal.eq(*OtherVal))
5893+
break; // Recognized
58995894
return nullptr;
5895+
}
59005896

59015897
default:
59025898
return nullptr;
@@ -5922,7 +5918,7 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
59225918
if (MulVal->hasNUsesOrMore(2)) {
59235919
Value *Mul = Builder.CreateExtractValue(Call, 0, "umul.value");
59245920
for (User *U : make_early_inc_range(MulVal->users())) {
5925-
if (U == &I || U == OtherVal)
5921+
if (U == &I)
59265922
continue;
59275923
if (TruncInst *TI = dyn_cast<TruncInst>(U)) {
59285924
if (TI->getType()->getPrimitiveSizeInBits() == MulWidth)
@@ -5943,27 +5939,10 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
59435939
IC.addToWorklist(cast<Instruction>(U));
59445940
}
59455941
}
5946-
if (isa<Instruction>(OtherVal))
5947-
IC.addToWorklist(cast<Instruction>(OtherVal));
59485942

59495943
// The original icmp gets replaced with the overflow value, maybe inverted
59505944
// depending on predicate.
5951-
bool Inverse = false;
5952-
switch (I.getPredicate()) {
5953-
case ICmpInst::ICMP_UGT:
5954-
if (I.getOperand(0) == MulVal)
5955-
break;
5956-
Inverse = true;
5957-
break;
5958-
case ICmpInst::ICMP_ULT:
5959-
if (I.getOperand(1) == MulVal)
5960-
break;
5961-
Inverse = true;
5962-
break;
5963-
default:
5964-
llvm_unreachable("Unexpected predicate");
5965-
}
5966-
if (Inverse) {
5945+
if (I.getPredicate() == ICmpInst::ICMP_ULT) {
59675946
Value *Res = Builder.CreateExtractValue(Call, 1);
59685947
return BinaryOperator::CreateNot(Res);
59695948
}
@@ -7083,12 +7062,9 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
70837062
}
70847063

70857064
// (zext a) * (zext b) --> llvm.umul.with.overflow.
7086-
if (match(Op0, m_NUWMul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) {
7087-
if (Instruction *R = processUMulZExtIdiom(I, Op0, Op1, *this))
7088-
return R;
7089-
}
7090-
if (match(Op1, m_NUWMul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) {
7091-
if (Instruction *R = processUMulZExtIdiom(I, Op1, Op0, *this))
7065+
if (match(Op0, m_NUWMul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B)))) &&
7066+
match(Op1, m_APInt(C))) {
7067+
if (Instruction *R = processUMulZExtIdiom(I, Op0, C, *this))
70927068
return R;
70937069
}
70947070

0 commit comments

Comments
 (0)