Skip to content

Commit 1eec81a

Browse files
authored
[CVP][LVI] Add support for vectors (#97428)
The core change here is to add support for converting vector constants into constant ranges. The rest is just relaxing isIntegerTy() checks and making sure we don't use APIs that assume vectors. There are a couple of places that don't support vectors yet, most notably the "simplest" fold (comparisons to a constant) isn't supported yet. I'll leave these to a followup.
1 parent b76dd4e commit 1eec81a

File tree

4 files changed

+83
-63
lines changed

4 files changed

+83
-63
lines changed

llvm/lib/Analysis/LazyValueInfo.cpp

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -650,7 +650,7 @@ LazyValueInfoImpl::solveBlockValueImpl(Value *Val, BasicBlock *BB) {
650650
if (PT && isKnownNonZero(BBI, DL))
651651
return ValueLatticeElement::getNot(ConstantPointerNull::get(PT));
652652

653-
if (BBI->getType()->isIntegerTy()) {
653+
if (BBI->getType()->isIntOrIntVectorTy()) {
654654
if (auto *CI = dyn_cast<CastInst>(BBI))
655655
return solveBlockValueCast(CI, BB);
656656

@@ -836,6 +836,24 @@ void LazyValueInfoImpl::intersectAssumeOrGuardBlockValueConstantRange(
836836
}
837837
}
838838

839+
static ConstantRange getConstantRangeFromFixedVector(Constant *C,
840+
FixedVectorType *Ty) {
841+
unsigned BW = Ty->getScalarSizeInBits();
842+
ConstantRange CR = ConstantRange::getEmpty(BW);
843+
for (unsigned I = 0; I < Ty->getNumElements(); ++I) {
844+
Constant *Elem = C->getAggregateElement(I);
845+
if (!Elem)
846+
return ConstantRange::getFull(BW);
847+
if (isa<PoisonValue>(Elem))
848+
continue;
849+
auto *CI = dyn_cast<ConstantInt>(Elem);
850+
if (!CI)
851+
return ConstantRange::getFull(BW);
852+
CR = CR.unionWith(CI->getValue());
853+
}
854+
return CR;
855+
}
856+
839857
static ConstantRange toConstantRange(const ValueLatticeElement &Val,
840858
Type *Ty, bool UndefAllowed = false) {
841859
assert(Ty->isIntOrIntVectorTy() && "Must be integer type");
@@ -844,6 +862,13 @@ static ConstantRange toConstantRange(const ValueLatticeElement &Val,
844862
unsigned BW = Ty->getScalarSizeInBits();
845863
if (Val.isUnknown())
846864
return ConstantRange::getEmpty(BW);
865+
if (Val.isConstant() && Ty->isVectorTy()) {
866+
if (auto *CI = dyn_cast_or_null<ConstantInt>(
867+
Val.getConstant()->getSplatValue(/*AllowPoison=*/true)))
868+
return ConstantRange(CI->getValue());
869+
if (auto *VTy = dyn_cast<FixedVectorType>(Ty))
870+
return getConstantRangeFromFixedVector(Val.getConstant(), VTy);
871+
}
847872
return ConstantRange::getFull(BW);
848873
}
849874

@@ -968,7 +993,7 @@ LazyValueInfoImpl::solveBlockValueCast(CastInst *CI, BasicBlock *BB) {
968993
return std::nullopt;
969994
const ConstantRange &LHSRange = *LHSRes;
970995

971-
const unsigned ResultBitWidth = CI->getType()->getIntegerBitWidth();
996+
const unsigned ResultBitWidth = CI->getType()->getScalarSizeInBits();
972997

973998
// NOTE: We're currently limited by the set of operations that ConstantRange
974999
// can evaluate symbolically. Enhancing that set will allows us to analyze
@@ -1108,7 +1133,7 @@ LazyValueInfoImpl::getValueFromSimpleICmpCondition(CmpInst::Predicate Pred,
11081133
const APInt &Offset,
11091134
Instruction *CxtI,
11101135
bool UseBlockValue) {
1111-
ConstantRange RHSRange(RHS->getType()->getIntegerBitWidth(),
1136+
ConstantRange RHSRange(RHS->getType()->getScalarSizeInBits(),
11121137
/*isFullSet=*/true);
11131138
if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) {
11141139
RHSRange = ConstantRange(CI->getValue());
@@ -1728,7 +1753,6 @@ Constant *LazyValueInfo::getConstant(Value *V, Instruction *CxtI) {
17281753

17291754
ConstantRange LazyValueInfo::getConstantRange(Value *V, Instruction *CxtI,
17301755
bool UndefAllowed) {
1731-
assert(V->getType()->isIntegerTy());
17321756
BasicBlock *BB = CxtI->getParent();
17331757
ValueLatticeElement Result =
17341758
getOrCreateImpl(BB->getModule()).getValueInBlock(V, BB, CxtI);

llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp

Lines changed: 10 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -288,9 +288,8 @@ static bool processPHI(PHINode *P, LazyValueInfo *LVI, DominatorTree *DT,
288288
}
289289

290290
static bool processICmp(ICmpInst *Cmp, LazyValueInfo *LVI) {
291-
// Only for signed relational comparisons of scalar integers.
292-
if (Cmp->getType()->isVectorTy() ||
293-
!Cmp->getOperand(0)->getType()->isIntegerTy())
291+
// Only for signed relational comparisons of integers.
292+
if (!Cmp->getOperand(0)->getType()->isIntOrIntVectorTy())
294293
return false;
295294

296295
if (!Cmp->isSigned())
@@ -505,12 +504,8 @@ static bool processBinOp(BinaryOperator *BinOp, LazyValueInfo *LVI);
505504
// because it is negation-invariant.
506505
static bool processAbsIntrinsic(IntrinsicInst *II, LazyValueInfo *LVI) {
507506
Value *X = II->getArgOperand(0);
508-
Type *Ty = X->getType();
509-
if (!Ty->isIntegerTy())
510-
return false;
511-
512507
bool IsIntMinPoison = cast<ConstantInt>(II->getArgOperand(1))->isOne();
513-
APInt IntMin = APInt::getSignedMinValue(Ty->getScalarSizeInBits());
508+
APInt IntMin = APInt::getSignedMinValue(X->getType()->getScalarSizeInBits());
514509
ConstantRange Range = LVI->getConstantRangeAtUse(
515510
II->getOperandUse(0), /*UndefAllowed*/ IsIntMinPoison);
516511

@@ -679,15 +674,13 @@ static bool processCallSite(CallBase &CB, LazyValueInfo *LVI) {
679674
}
680675

681676
if (auto *WO = dyn_cast<WithOverflowInst>(&CB)) {
682-
if (WO->getLHS()->getType()->isIntegerTy() && willNotOverflow(WO, LVI)) {
677+
if (willNotOverflow(WO, LVI))
683678
return processOverflowIntrinsic(WO, LVI);
684-
}
685679
}
686680

687681
if (auto *SI = dyn_cast<SaturatingInst>(&CB)) {
688-
if (SI->getType()->isIntegerTy() && willNotOverflow(SI, LVI)) {
682+
if (willNotOverflow(SI, LVI))
689683
return processSaturatingInst(SI, LVI);
690-
}
691684
}
692685

693686
bool Changed = false;
@@ -761,11 +754,10 @@ static bool narrowSDivOrSRem(BinaryOperator *Instr, const ConstantRange &LCR,
761754
const ConstantRange &RCR) {
762755
assert(Instr->getOpcode() == Instruction::SDiv ||
763756
Instr->getOpcode() == Instruction::SRem);
764-
assert(!Instr->getType()->isVectorTy());
765757

766758
// Find the smallest power of two bitwidth that's sufficient to hold Instr's
767759
// operands.
768-
unsigned OrigWidth = Instr->getType()->getIntegerBitWidth();
760+
unsigned OrigWidth = Instr->getType()->getScalarSizeInBits();
769761

770762
// What is the smallest bit width that can accommodate the entire value ranges
771763
// of both of the operands?
@@ -788,7 +780,7 @@ static bool narrowSDivOrSRem(BinaryOperator *Instr, const ConstantRange &LCR,
788780

789781
++NumSDivSRemsNarrowed;
790782
IRBuilder<> B{Instr};
791-
auto *TruncTy = Type::getIntNTy(Instr->getContext(), NewWidth);
783+
auto *TruncTy = Instr->getType()->getWithNewBitWidth(NewWidth);
792784
auto *LHS = B.CreateTruncOrBitCast(Instr->getOperand(0), TruncTy,
793785
Instr->getName() + ".lhs.trunc");
794786
auto *RHS = B.CreateTruncOrBitCast(Instr->getOperand(1), TruncTy,
@@ -809,7 +801,6 @@ static bool expandUDivOrURem(BinaryOperator *Instr, const ConstantRange &XCR,
809801
Type *Ty = Instr->getType();
810802
assert(Instr->getOpcode() == Instruction::UDiv ||
811803
Instr->getOpcode() == Instruction::URem);
812-
assert(!Ty->isVectorTy());
813804
bool IsRem = Instr->getOpcode() == Instruction::URem;
814805

815806
Value *X = Instr->getOperand(0);
@@ -892,7 +883,6 @@ static bool narrowUDivOrURem(BinaryOperator *Instr, const ConstantRange &XCR,
892883
const ConstantRange &YCR) {
893884
assert(Instr->getOpcode() == Instruction::UDiv ||
894885
Instr->getOpcode() == Instruction::URem);
895-
assert(!Instr->getType()->isVectorTy());
896886

897887
// Find the smallest power of two bitwidth that's sufficient to hold Instr's
898888
// operands.
@@ -905,12 +895,12 @@ static bool narrowUDivOrURem(BinaryOperator *Instr, const ConstantRange &XCR,
905895

906896
// NewWidth might be greater than OrigWidth if OrigWidth is not a power of
907897
// two.
908-
if (NewWidth >= Instr->getType()->getIntegerBitWidth())
898+
if (NewWidth >= Instr->getType()->getScalarSizeInBits())
909899
return false;
910900

911901
++NumUDivURemsNarrowed;
912902
IRBuilder<> B{Instr};
913-
auto *TruncTy = Type::getIntNTy(Instr->getContext(), NewWidth);
903+
auto *TruncTy = Instr->getType()->getWithNewBitWidth(NewWidth);
914904
auto *LHS = B.CreateTruncOrBitCast(Instr->getOperand(0), TruncTy,
915905
Instr->getName() + ".lhs.trunc");
916906
auto *RHS = B.CreateTruncOrBitCast(Instr->getOperand(1), TruncTy,
@@ -929,9 +919,6 @@ static bool narrowUDivOrURem(BinaryOperator *Instr, const ConstantRange &XCR,
929919
static bool processUDivOrURem(BinaryOperator *Instr, LazyValueInfo *LVI) {
930920
assert(Instr->getOpcode() == Instruction::UDiv ||
931921
Instr->getOpcode() == Instruction::URem);
932-
if (Instr->getType()->isVectorTy())
933-
return false;
934-
935922
ConstantRange XCR = LVI->getConstantRangeAtUse(Instr->getOperandUse(0),
936923
/*UndefAllowed*/ false);
937924
// Allow undef for RHS, as we can assume it is division by zero UB.
@@ -946,7 +933,6 @@ static bool processUDivOrURem(BinaryOperator *Instr, LazyValueInfo *LVI) {
946933
static bool processSRem(BinaryOperator *SDI, const ConstantRange &LCR,
947934
const ConstantRange &RCR, LazyValueInfo *LVI) {
948935
assert(SDI->getOpcode() == Instruction::SRem);
949-
assert(!SDI->getType()->isVectorTy());
950936

951937
if (LCR.abs().icmp(CmpInst::ICMP_ULT, RCR.abs())) {
952938
SDI->replaceAllUsesWith(SDI->getOperand(0));
@@ -1006,7 +992,6 @@ static bool processSRem(BinaryOperator *SDI, const ConstantRange &LCR,
1006992
static bool processSDiv(BinaryOperator *SDI, const ConstantRange &LCR,
1007993
const ConstantRange &RCR, LazyValueInfo *LVI) {
1008994
assert(SDI->getOpcode() == Instruction::SDiv);
1009-
assert(!SDI->getType()->isVectorTy());
1010995

1011996
// Check whether the division folds to a constant.
1012997
ConstantRange DivCR = LCR.sdiv(RCR);
@@ -1064,9 +1049,6 @@ static bool processSDiv(BinaryOperator *SDI, const ConstantRange &LCR,
10641049
static bool processSDivOrSRem(BinaryOperator *Instr, LazyValueInfo *LVI) {
10651050
assert(Instr->getOpcode() == Instruction::SDiv ||
10661051
Instr->getOpcode() == Instruction::SRem);
1067-
if (Instr->getType()->isVectorTy())
1068-
return false;
1069-
10701052
ConstantRange LCR =
10711053
LVI->getConstantRangeAtUse(Instr->getOperandUse(0), /*AllowUndef*/ false);
10721054
// Allow undef for RHS, as we can assume it is division by zero UB.
@@ -1085,12 +1067,9 @@ static bool processSDivOrSRem(BinaryOperator *Instr, LazyValueInfo *LVI) {
10851067
}
10861068

10871069
static bool processAShr(BinaryOperator *SDI, LazyValueInfo *LVI) {
1088-
if (SDI->getType()->isVectorTy())
1089-
return false;
1090-
10911070
ConstantRange LRange =
10921071
LVI->getConstantRangeAtUse(SDI->getOperandUse(0), /*UndefAllowed*/ false);
1093-
unsigned OrigWidth = SDI->getType()->getIntegerBitWidth();
1072+
unsigned OrigWidth = SDI->getType()->getScalarSizeInBits();
10941073
ConstantRange NegOneOrZero =
10951074
ConstantRange(APInt(OrigWidth, (uint64_t)-1, true), APInt(OrigWidth, 1));
10961075
if (NegOneOrZero.contains(LRange)) {
@@ -1117,9 +1096,6 @@ static bool processAShr(BinaryOperator *SDI, LazyValueInfo *LVI) {
11171096
}
11181097

11191098
static bool processSExt(SExtInst *SDI, LazyValueInfo *LVI) {
1120-
if (SDI->getType()->isVectorTy())
1121-
return false;
1122-
11231099
const Use &Base = SDI->getOperandUse(0);
11241100
if (!LVI->getConstantRangeAtUse(Base, /*UndefAllowed*/ false)
11251101
.isAllNonNegative())
@@ -1138,9 +1114,6 @@ static bool processSExt(SExtInst *SDI, LazyValueInfo *LVI) {
11381114
}
11391115

11401116
static bool processPossibleNonNeg(PossiblyNonNegInst *I, LazyValueInfo *LVI) {
1141-
if (I->getType()->isVectorTy())
1142-
return false;
1143-
11441117
if (I->hasNonNeg())
11451118
return false;
11461119

@@ -1164,9 +1137,6 @@ static bool processUIToFP(UIToFPInst *UIToFP, LazyValueInfo *LVI) {
11641137
}
11651138

11661139
static bool processSIToFP(SIToFPInst *SIToFP, LazyValueInfo *LVI) {
1167-
if (SIToFP->getType()->isVectorTy())
1168-
return false;
1169-
11701140
const Use &Base = SIToFP->getOperandUse(0);
11711141
if (!LVI->getConstantRangeAtUse(Base, /*UndefAllowed*/ false)
11721142
.isAllNonNegative())
@@ -1187,9 +1157,6 @@ static bool processSIToFP(SIToFPInst *SIToFP, LazyValueInfo *LVI) {
11871157
static bool processBinOp(BinaryOperator *BinOp, LazyValueInfo *LVI) {
11881158
using OBO = OverflowingBinaryOperator;
11891159

1190-
if (BinOp->getType()->isVectorTy())
1191-
return false;
1192-
11931160
bool NSW = BinOp->hasNoSignedWrap();
11941161
bool NUW = BinOp->hasNoUnsignedWrap();
11951162
if (NSW && NUW)

llvm/test/Transforms/CorrelatedValuePropagation/icmp.ll

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1246,13 +1246,11 @@ define i1 @non_const_range_minmax(i8 %a, i8 %b) {
12461246
ret i1 %cmp1
12471247
}
12481248

1249-
; FIXME: Also support vectors.
12501249
define <2 x i1> @non_const_range_minmax_vec(<2 x i8> %a, <2 x i8> %b) {
12511250
; CHECK-LABEL: @non_const_range_minmax_vec(
12521251
; CHECK-NEXT: [[A2:%.*]] = call <2 x i8> @llvm.umin.v2i8(<2 x i8> [[A:%.*]], <2 x i8> <i8 10, i8 10>)
12531252
; CHECK-NEXT: [[B2:%.*]] = call <2 x i8> @llvm.umax.v2i8(<2 x i8> [[B:%.*]], <2 x i8> <i8 11, i8 11>)
1254-
; CHECK-NEXT: [[CMP1:%.*]] = icmp ult <2 x i8> [[A2]], [[B2]]
1255-
; CHECK-NEXT: ret <2 x i1> [[CMP1]]
1253+
; CHECK-NEXT: ret <2 x i1> <i1 true, i1 true>
12561254
;
12571255
%a2 = call <2 x i8> @llvm.umin.v2i8(<2 x i8> %a, <2 x i8> <i8 10, i8 10>)
12581256
%b2 = call <2 x i8> @llvm.umax.v2i8(<2 x i8> %b, <2 x i8> <i8 11, i8 11>)

0 commit comments

Comments
 (0)