diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h index 179a2c38d9d3c..388fd5bd7326f 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -65,6 +65,117 @@ enum SCEVTypes : unsigned short; extern bool VerifySCEV; +class SCEV; + +class SCEVUse : public PointerIntPair { + bool computeIsCanonical() const; + const SCEV *computeCanonical(ScalarEvolution &SE) const; + +public: + SCEVUse() : PointerIntPair(nullptr, 0) {} + SCEVUse(const SCEV *S) : PointerIntPair(S, 0) {} + SCEVUse(const SCEV *S, int Flags) : PointerIntPair(S, Flags) { + if (Flags > 0) + setInt(Flags | 1); + } + + operator const SCEV *() const { return getPointer(); } + const SCEV *operator->() const { return getPointer(); } + const SCEV *operator->() { return getPointer(); } + + void *getRawPointer() const { return getOpaqueValue(); } + + bool isCanonical() const { + assert(((getFlags() & 1) != 0 || computeIsCanonical()) && + "Canonical bit set incorrectly"); + return (getFlags() & 1) == 0; + } + + const SCEV *getCanonical(ScalarEvolution &SE) { + if (isCanonical()) + return getPointer(); + return computeCanonical(SE); + } + + unsigned getFlags() const { return getInt(); } + + bool operator==(const SCEVUse &RHS) const; + bool operator==(const SCEV *RHS) const; + /// Print out the internal representation of this scalar to the specified + /// stream. This should really only be used for debugging purposes. + void print(raw_ostream &OS) const; + + /// This method is used for debugging. + void dump() const; +}; + +/// Provide PointerLikeTypeTraits for SCEVUse, so it can be used with +/// SmallPtrSet, among others. +template <> struct PointerLikeTypeTraits { + static inline void *getAsVoidPointer(SCEVUse U) { return U.getOpaqueValue(); } + static inline SCEVUse getFromVoidPointer(void *P) { + SCEVUse U; + U.setFromOpaqueValue(P); + return U; + } + + /// The Low bits are used by the PointerIntPair. + static constexpr int NumLowBitsAvailable = 0; +}; + +template <> struct DenseMapInfo { + // The following should hold, but it would require T to be complete: + // static_assert(alignof(T) <= (1 << Log2MaxAlign), + // "DenseMap does not support pointer keys requiring more than " + // "Log2MaxAlign bits of alignment"); + static constexpr uintptr_t Log2MaxAlign = 12; + + static inline SCEVUse getEmptyKey() { + uintptr_t Val = static_cast(-1); + Val <<= Log2MaxAlign; + return PointerLikeTypeTraits::getFromVoidPointer((void *)Val); + } + + static inline SCEVUse getTombstoneKey() { + uintptr_t Val = static_cast(-2); + Val <<= Log2MaxAlign; + return PointerLikeTypeTraits::getFromVoidPointer((void *)Val); + } + + static unsigned getHashValue(SCEVUse U) { + void *PtrVal = PointerLikeTypeTraits::getAsVoidPointer(U); + return (unsigned((uintptr_t)PtrVal) >> 4) ^ + (unsigned((uintptr_t)PtrVal) >> 9); + } + + static bool isEqual(const SCEVUse LHS, const SCEVUse RHS) { + return LHS.getRawPointer() == RHS.getRawPointer(); + } +}; + +template [[nodiscard]] inline decltype(auto) dyn_cast(SCEVUse U) { + assert(detail::isPresent(U.getPointer()) && + "dyn_cast on a non-existent value"); + return CastInfo::doCastIfPossible(U.getPointer()); +} + +template [[nodiscard]] inline decltype(auto) cast(SCEVUse U) { + assert(detail::isPresent(U.getPointer()) && + "dyn_cast on a non-existent value"); + return CastInfo::doCast(U.getPointer()); +} + +template [[nodiscard]] inline bool isa(SCEVUse U) { + return CastInfo::isPossible(U.getPointer()); +} + +template auto dyn_cast_or_null(SCEVUse U) { + const SCEV *Val = U.getPointer(); + if (!detail::isPresent(Val)) + return CastInfo::castFailed(); + return CastInfo::doCastIfPossible(detail::unwrapValue(Val)); +} + /// This class represents an analyzed expression in the program. These are /// opaque objects that the client is not allowed to do much with directly. /// @@ -143,7 +254,7 @@ class SCEV : public FoldingSetNode { Type *getType() const; /// Return operands of this SCEV expression. - ArrayRef operands() const; + ArrayRef operands() const; /// Return true if the expression is a constant zero. bool isZero() const; @@ -198,6 +309,11 @@ inline raw_ostream &operator<<(raw_ostream &OS, const SCEV &S) { return OS; } +inline raw_ostream &operator<<(raw_ostream &OS, const SCEVUse &S) { + S.print(OS); + return OS; +} + /// An object of this class is returned by queries that could not be answered. /// For example, if you ask for the number of iterations of a linked-list /// traversal loop, you will get one of these. None of the standard SCEV @@ -207,6 +323,7 @@ struct SCEVCouldNotCompute : public SCEV { /// Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const SCEV *S); + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } }; /// This class represents an assumption made using SCEV expressions which can @@ -277,13 +394,13 @@ struct FoldingSetTrait : DefaultFoldingSetTrait { class SCEVComparePredicate final : public SCEVPredicate { /// We assume that LHS Pred RHS is true. const ICmpInst::Predicate Pred; - const SCEV *LHS; - const SCEV *RHS; + SCEVUse LHS; + SCEVUse RHS; public: SCEVComparePredicate(const FoldingSetNodeIDRef ID, - const ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS); + const ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS); /// Implementation of the SCEVPredicate interface bool implies(const SCEVPredicate *N) const override; @@ -293,10 +410,10 @@ class SCEVComparePredicate final : public SCEVPredicate { ICmpInst::Predicate getPredicate() const { return Pred; } /// Returns the left hand side of the predicate. - const SCEV *getLHS() const { return LHS; } + SCEVUse getLHS() const { return LHS; } /// Returns the right hand side of the predicate. - const SCEV *getRHS() const { return RHS; } + SCEVUse getRHS() const { return RHS; } /// Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const SCEVPredicate *P) { @@ -411,8 +528,7 @@ class SCEVWrapPredicate final : public SCEVPredicate { /// ScalarEvolution::Preds folding set. This is why the \c add function is sound. class SCEVUnionPredicate final : public SCEVPredicate { private: - using PredicateMap = - DenseMap>; + using PredicateMap = DenseMap>; /// Vector with references to all predicates in this union. SmallVector Preds; @@ -519,18 +635,17 @@ class ScalarEvolution { /// loop { v2 = load @global2; } /// } /// No SCEV with operand V1, and v2 can exist in this program. - bool instructionCouldExistWithOperands(const SCEV *A, const SCEV *B); + bool instructionCouldExistWithOperands(SCEVUse A, SCEVUse B); /// Return true if the SCEV is a scAddRecExpr or it contains /// scAddRecExpr. The result will be cached in HasRecMap. - bool containsAddRecurrence(const SCEV *S); + bool containsAddRecurrence(SCEVUse S); /// Is operation \p BinOp between \p LHS and \p RHS provably does not have /// a signed/unsigned overflow (\p Signed)? If \p CtxI is specified, the /// no-overflow fact should be true in the context of this instruction. - bool willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, - const SCEV *LHS, const SCEV *RHS, - const Instruction *CtxI = nullptr); + bool willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, SCEVUse LHS, + SCEVUse RHS, const Instruction *CtxI = nullptr); /// Parse NSW/NUW flags from add/sub/mul IR binary operation \p Op into /// SCEV no-wrap flags, and deduce flag[s] that aren't known yet. @@ -541,78 +656,84 @@ class ScalarEvolution { getStrengthenedNoWrapFlagsFromBinOp(const OverflowingBinaryOperator *OBO); /// Notify this ScalarEvolution that \p User directly uses SCEVs in \p Ops. - void registerUser(const SCEV *User, ArrayRef Ops); + void registerUser(SCEVUse User, ArrayRef Ops); /// Return true if the SCEV expression contains an undef value. - bool containsUndefs(const SCEV *S) const; + bool containsUndefs(SCEVUse S) const; /// Return true if the SCEV expression contains a Value that has been /// optimised out and is now a nullptr. - bool containsErasedValue(const SCEV *S) const; + bool containsErasedValue(SCEVUse S) const; /// Return a SCEV expression for the full generality of the specified /// expression. - const SCEV *getSCEV(Value *V); + SCEVUse getSCEV(Value *V, bool UseCtx = false); /// Return an existing SCEV for V if there is one, otherwise return nullptr. - const SCEV *getExistingSCEV(Value *V); - - const SCEV *getConstant(ConstantInt *V); - const SCEV *getConstant(const APInt &Val); - const SCEV *getConstant(Type *Ty, uint64_t V, bool isSigned = false); - const SCEV *getLosslessPtrToIntExpr(const SCEV *Op, unsigned Depth = 0); - const SCEV *getPtrToIntExpr(const SCEV *Op, Type *Ty); - const SCEV *getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth = 0); - const SCEV *getVScale(Type *Ty); - const SCEV *getElementCount(Type *Ty, ElementCount EC); - const SCEV *getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth = 0); - const SCEV *getZeroExtendExprImpl(const SCEV *Op, Type *Ty, - unsigned Depth = 0); - const SCEV *getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth = 0); - const SCEV *getSignExtendExprImpl(const SCEV *Op, Type *Ty, - unsigned Depth = 0); - const SCEV *getCastExpr(SCEVTypes Kind, const SCEV *Op, Type *Ty); - const SCEV *getAnyExtendExpr(const SCEV *Op, Type *Ty); - const SCEV *getAddExpr(SmallVectorImpl &Ops, - SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap, - unsigned Depth = 0); - const SCEV *getAddExpr(const SCEV *LHS, const SCEV *RHS, - SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap, - unsigned Depth = 0) { - SmallVector Ops = {LHS, RHS}; + SCEVUse getExistingSCEV(Value *V); + + SCEVUse getConstant(ConstantInt *V); + SCEVUse getConstant(const APInt &Val); + SCEVUse getConstant(Type *Ty, uint64_t V, bool isSigned = false); + SCEVUse getLosslessPtrToIntExpr(SCEVUse Op, unsigned Depth = 0); + SCEVUse getPtrToIntExpr(SCEVUse Op, Type *Ty); + SCEVUse getTruncateExpr(SCEVUse Op, Type *Ty, unsigned Depth = 0); + SCEVUse getVScale(Type *Ty); + SCEVUse getElementCount(Type *Ty, ElementCount EC); + SCEVUse getZeroExtendExpr(SCEVUse Op, Type *Ty, unsigned Depth = 0); + SCEVUse getZeroExtendExprImpl(SCEVUse Op, Type *Ty, unsigned Depth = 0); + SCEVUse getSignExtendExpr(SCEVUse Op, Type *Ty, unsigned Depth = 0); + SCEVUse getSignExtendExprImpl(SCEVUse Op, Type *Ty, unsigned Depth = 0); + SCEVUse getCastExpr(SCEVTypes Kind, SCEVUse Op, Type *Ty); + SCEVUse getAnyExtendExpr(SCEVUse Op, Type *Ty); + SCEVUse getAddExpr(ArrayRef Ops, + SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap, + unsigned Depth = 0); + SCEVUse getAddExpr(SmallVectorImpl &Ops, + SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap, + unsigned Depth = 0); + SCEVUse getAddExpr(SCEVUse LHS, SCEVUse RHS, + SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap, + unsigned Depth = 0) { + SmallVector Ops = {LHS, RHS}; return getAddExpr(Ops, Flags, Depth); } - const SCEV *getAddExpr(const SCEV *Op0, const SCEV *Op1, const SCEV *Op2, - SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap, - unsigned Depth = 0) { - SmallVector Ops = {Op0, Op1, Op2}; + SCEVUse getAddExpr(SCEVUse Op0, SCEVUse Op1, SCEVUse Op2, + SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap, + unsigned Depth = 0) { + SmallVector Ops = {Op0, Op1, Op2}; return getAddExpr(Ops, Flags, Depth); } - const SCEV *getMulExpr(SmallVectorImpl &Ops, - SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap, - unsigned Depth = 0); - const SCEV *getMulExpr(const SCEV *LHS, const SCEV *RHS, - SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap, - unsigned Depth = 0) { - SmallVector Ops = {LHS, RHS}; + SCEVUse getMulExpr(ArrayRef Ops, + SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap, + unsigned Depth = 0); + SCEVUse getMulExpr(SmallVectorImpl &Ops, + SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap, + unsigned Depth = 0); + SCEVUse getMulExpr(SCEVUse LHS, SCEVUse RHS, + SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap, + unsigned Depth = 0) { + SmallVector Ops = {LHS, RHS}; return getMulExpr(Ops, Flags, Depth); } - const SCEV *getMulExpr(const SCEV *Op0, const SCEV *Op1, const SCEV *Op2, - SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap, - unsigned Depth = 0) { - SmallVector Ops = {Op0, Op1, Op2}; + SCEVUse getMulExpr(SCEVUse Op0, SCEVUse Op1, SCEVUse Op2, + SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap, + unsigned Depth = 0) { + SmallVector Ops = {Op0, Op1, Op2}; return getMulExpr(Ops, Flags, Depth); } - const SCEV *getUDivExpr(const SCEV *LHS, const SCEV *RHS); - const SCEV *getUDivExactExpr(const SCEV *LHS, const SCEV *RHS); - const SCEV *getURemExpr(const SCEV *LHS, const SCEV *RHS); - const SCEV *getAddRecExpr(const SCEV *Start, const SCEV *Step, const Loop *L, - SCEV::NoWrapFlags Flags); - const SCEV *getAddRecExpr(SmallVectorImpl &Operands, - const Loop *L, SCEV::NoWrapFlags Flags); - const SCEV *getAddRecExpr(const SmallVectorImpl &Operands, - const Loop *L, SCEV::NoWrapFlags Flags) { - SmallVector NewOp(Operands.begin(), Operands.end()); + SCEVUse getUDivExpr(SCEVUse LHS, SCEVUse RHS); + SCEVUse getUDivExactExpr(SCEVUse LHS, SCEVUse RHS); + SCEVUse getURemExpr(SCEVUse LHS, SCEVUse RHS); + SCEVUse getAddRecExpr(SCEVUse Start, SCEVUse Step, const Loop *L, + SCEV::NoWrapFlags Flags); + SCEVUse getAddRecExpr(ArrayRef Operands, const Loop *L, + SCEV::NoWrapFlags Flags); + SCEVUse getAddRecExpr(SmallVectorImpl &Operands, const Loop *L, + SCEV::NoWrapFlags Flags); + SCEVUse getAddRecExpr(const SmallVectorImpl &Operands, const Loop *L, + SCEV::NoWrapFlags Flags) { + SmallVector NewOp(Operands.begin(), Operands.end()); return getAddRecExpr(NewOp, L, Flags); } @@ -620,7 +741,7 @@ class ScalarEvolution { /// Predicates. If successful return these ; /// The function is intended to be called from PSCEV (the caller will decide /// whether to actually add the predicates and carry out the rewrites). - std::optional>> + std::optional>> createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI); /// Returns an expression for a GEP @@ -628,61 +749,63 @@ class ScalarEvolution { /// \p GEP The GEP. The indices contained in the GEP itself are ignored, /// instead we use IndexExprs. /// \p IndexExprs The expressions for the indices. - const SCEV *getGEPExpr(GEPOperator *GEP, - const SmallVectorImpl &IndexExprs); - const SCEV *getAbsExpr(const SCEV *Op, bool IsNSW); - const SCEV *getMinMaxExpr(SCEVTypes Kind, - SmallVectorImpl &Operands); - const SCEV *getSequentialMinMaxExpr(SCEVTypes Kind, - SmallVectorImpl &Operands); - const SCEV *getSMaxExpr(const SCEV *LHS, const SCEV *RHS); - const SCEV *getSMaxExpr(SmallVectorImpl &Operands); - const SCEV *getUMaxExpr(const SCEV *LHS, const SCEV *RHS); - const SCEV *getUMaxExpr(SmallVectorImpl &Operands); - const SCEV *getSMinExpr(const SCEV *LHS, const SCEV *RHS); - const SCEV *getSMinExpr(SmallVectorImpl &Operands); - const SCEV *getUMinExpr(const SCEV *LHS, const SCEV *RHS, - bool Sequential = false); - const SCEV *getUMinExpr(SmallVectorImpl &Operands, - bool Sequential = false); - const SCEV *getUnknown(Value *V); - const SCEV *getCouldNotCompute(); + SCEVUse getGEPExpr(GEPOperator *GEP, ArrayRef IndexExprs, + bool UseCtx = false); + SCEVUse getGEPExpr(GEPOperator *GEP, + const SmallVectorImpl &IndexExprs, + bool UseCtx = false); + SCEVUse getAbsExpr(SCEVUse Op, bool IsNSW); + SCEVUse getMinMaxExpr(SCEVTypes Kind, ArrayRef Operands); + SCEVUse getMinMaxExpr(SCEVTypes Kind, SmallVectorImpl &Operands); + SCEVUse getSequentialMinMaxExpr(SCEVTypes Kind, + SmallVectorImpl &Operands); + SCEVUse getSMaxExpr(SCEVUse LHS, SCEVUse RHS); + SCEVUse getSMaxExpr(SmallVectorImpl &Operands); + SCEVUse getUMaxExpr(SCEVUse LHS, SCEVUse RHS); + SCEVUse getUMaxExpr(SmallVectorImpl &Operands); + SCEVUse getSMinExpr(SCEVUse LHS, SCEVUse RHS); + SCEVUse getSMinExpr(SmallVectorImpl &Operands); + SCEVUse getUMinExpr(SCEVUse LHS, SCEVUse RHS, bool Sequential = false); + SCEVUse getUMinExpr(SmallVectorImpl &Operands, + bool Sequential = false); + SCEVUse getUnknown(Value *V); + SCEVUse getCouldNotCompute(); /// Return a SCEV for the constant 0 of a specific type. - const SCEV *getZero(Type *Ty) { return getConstant(Ty, 0); } + SCEVUse getZero(Type *Ty) { return getConstant(Ty, 0); } /// Return a SCEV for the constant 1 of a specific type. - const SCEV *getOne(Type *Ty) { return getConstant(Ty, 1); } + SCEVUse getOne(Type *Ty) { return getConstant(Ty, 1); } /// Return a SCEV for the constant \p Power of two. - const SCEV *getPowerOfTwo(Type *Ty, unsigned Power) { + SCEVUse getPowerOfTwo(Type *Ty, unsigned Power) { assert(Power < getTypeSizeInBits(Ty) && "Power out of range"); return getConstant(APInt::getOneBitSet(getTypeSizeInBits(Ty), Power)); } /// Return a SCEV for the constant -1 of a specific type. - const SCEV *getMinusOne(Type *Ty) { + SCEVUse getMinusOne(Type *Ty) { return getConstant(Ty, -1, /*isSigned=*/true); } /// Return an expression for a TypeSize. - const SCEV *getSizeOfExpr(Type *IntTy, TypeSize Size); + SCEVUse getSizeOfExpr(Type *IntTy, TypeSize Size); /// Return an expression for the alloc size of AllocTy that is type IntTy - const SCEV *getSizeOfExpr(Type *IntTy, Type *AllocTy); + SCEVUse getSizeOfExpr(Type *IntTy, Type *AllocTy); /// Return an expression for the store size of StoreTy that is type IntTy - const SCEV *getStoreSizeOfExpr(Type *IntTy, Type *StoreTy); + SCEVUse getStoreSizeOfExpr(Type *IntTy, Type *StoreTy); /// Return an expression for offsetof on the given field with type IntTy - const SCEV *getOffsetOfExpr(Type *IntTy, StructType *STy, unsigned FieldNo); + SCEVUse getOffsetOfExpr(Type *IntTy, StructType *STy, unsigned FieldNo); /// Return the SCEV object corresponding to -V. - const SCEV *getNegativeSCEV(const SCEV *V, - SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap); + SCEVUse getNegativeSCEV(SCEVUse V, + SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap); /// Return the SCEV object corresponding to ~V. - const SCEV *getNotSCEV(const SCEV *V); + SCEVUse getNotSCEV(SCEVUse V); /// Return LHS-RHS. Minus is represented in SCEV as A+B*-1. /// @@ -691,9 +814,9 @@ class ScalarEvolution { /// To compute the difference between two unrelated pointers, you can /// explicitly convert the arguments using getPtrToIntExpr(), for pointer /// types that support it. - const SCEV *getMinusSCEV(const SCEV *LHS, const SCEV *RHS, - SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap, - unsigned Depth = 0); + SCEVUse getMinusSCEV(SCEVUse LHS, SCEVUse RHS, + SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap, + unsigned Depth = 0); /// Compute ceil(N / D). N and D are treated as unsigned values. /// @@ -703,59 +826,59 @@ class ScalarEvolution { /// umin(N, 1) + floor((N - umin(N, 1)) / D) /// /// A denominator of zero or poison is handled the same way as getUDivExpr(). - const SCEV *getUDivCeilSCEV(const SCEV *N, const SCEV *D); + SCEVUse getUDivCeilSCEV(SCEVUse N, SCEVUse D); /// Return a SCEV corresponding to a conversion of the input value to the /// specified type. If the type must be extended, it is zero extended. - const SCEV *getTruncateOrZeroExtend(const SCEV *V, Type *Ty, - unsigned Depth = 0); + SCEVUse getTruncateOrZeroExtend(SCEVUse V, Type *Ty, unsigned Depth = 0); /// Return a SCEV corresponding to a conversion of the input value to the /// specified type. If the type must be extended, it is sign extended. - const SCEV *getTruncateOrSignExtend(const SCEV *V, Type *Ty, - unsigned Depth = 0); + SCEVUse getTruncateOrSignExtend(SCEVUse V, Type *Ty, unsigned Depth = 0); /// Return a SCEV corresponding to a conversion of the input value to the /// specified type. If the type must be extended, it is zero extended. The /// conversion must not be narrowing. - const SCEV *getNoopOrZeroExtend(const SCEV *V, Type *Ty); + SCEVUse getNoopOrZeroExtend(SCEVUse V, Type *Ty); /// Return a SCEV corresponding to a conversion of the input value to the /// specified type. If the type must be extended, it is sign extended. The /// conversion must not be narrowing. - const SCEV *getNoopOrSignExtend(const SCEV *V, Type *Ty); + SCEVUse getNoopOrSignExtend(SCEVUse V, Type *Ty); /// Return a SCEV corresponding to a conversion of the input value to the /// specified type. If the type must be extended, it is extended with /// unspecified bits. The conversion must not be narrowing. - const SCEV *getNoopOrAnyExtend(const SCEV *V, Type *Ty); + SCEVUse getNoopOrAnyExtend(SCEVUse V, Type *Ty); /// Return a SCEV corresponding to a conversion of the input value to the /// specified type. The conversion must not be widening. - const SCEV *getTruncateOrNoop(const SCEV *V, Type *Ty); + SCEVUse getTruncateOrNoop(SCEVUse V, Type *Ty); /// Promote the operands to the wider of the types using zero-extension, and /// then perform a umax operation with them. - const SCEV *getUMaxFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS); + SCEVUse getUMaxFromMismatchedTypes(SCEVUse LHS, SCEVUse RHS); /// Promote the operands to the wider of the types using zero-extension, and /// then perform a umin operation with them. - const SCEV *getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS, - bool Sequential = false); + SCEVUse getUMinFromMismatchedTypes(SCEVUse LHS, SCEVUse RHS, + bool Sequential = false); /// Promote the operands to the wider of the types using zero-extension, and /// then perform a umin operation with them. N-ary function. - const SCEV *getUMinFromMismatchedTypes(SmallVectorImpl &Ops, - bool Sequential = false); + SCEVUse getUMinFromMismatchedTypes(ArrayRef Ops, + bool Sequential = false); + SCEVUse getUMinFromMismatchedTypes(SmallVectorImpl &Ops, + bool Sequential = false); /// Transitively follow the chain of pointer-type operands until reaching a /// SCEV that does not have a single pointer operand. This returns a /// SCEVUnknown pointer for well-formed pointer-type expressions, but corner /// cases do exist. - const SCEV *getPointerBase(const SCEV *V); + SCEVUse getPointerBase(SCEVUse V); /// Compute an expression equivalent to S - getPointerBase(S). - const SCEV *removePointerBase(const SCEV *S); + SCEVUse removePointerBase(SCEVUse S); /// Return a SCEV expression for the specified value at the specified scope /// in the program. The L value specifies a loop nest to evaluate the @@ -767,31 +890,31 @@ class ScalarEvolution { /// /// In the case that a relevant loop exit value cannot be computed, the /// original value V is returned. - const SCEV *getSCEVAtScope(const SCEV *S, const Loop *L); + SCEVUse getSCEVAtScope(SCEVUse S, const Loop *L); /// This is a convenience function which does getSCEVAtScope(getSCEV(V), L). - const SCEV *getSCEVAtScope(Value *V, const Loop *L); + SCEVUse getSCEVAtScope(Value *V, const Loop *L); /// Test whether entry to the loop is protected by a conditional between LHS /// and RHS. This is used to help avoid max expressions in loop trip /// counts, and to eliminate casts. bool isLoopEntryGuardedByCond(const Loop *L, ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS); + SCEVUse LHS, SCEVUse RHS); /// Test whether entry to the basic block is protected by a conditional /// between LHS and RHS. bool isBasicBlockEntryGuardedByCond(const BasicBlock *BB, - ICmpInst::Predicate Pred, const SCEV *LHS, - const SCEV *RHS); + ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS); /// Test whether the backedge of the loop is protected by a conditional /// between LHS and RHS. This is used to eliminate casts. bool isLoopBackedgeGuardedByCond(const Loop *L, ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS); + SCEVUse LHS, SCEVUse RHS); /// A version of getTripCountFromExitCount below which always picks an /// evaluation type which can not result in overflow. - const SCEV *getTripCountFromExitCount(const SCEV *ExitCount); + SCEVUse getTripCountFromExitCount(SCEVUse ExitCount); /// Convert from an "exit count" (i.e. "backedge taken count") to a "trip /// count". A "trip count" is the number of times the header of the loop @@ -800,8 +923,8 @@ class ScalarEvolution { /// expression can overflow if ExitCount = UINT_MAX. If EvalTy is not wide /// enough to hold the result without overflow, result unsigned wraps with /// 2s-complement semantics. ex: EC = 255 (i8), TC = 0 (i8) - const SCEV *getTripCountFromExitCount(const SCEV *ExitCount, Type *EvalTy, - const Loop *L); + SCEVUse getTripCountFromExitCount(SCEVUse ExitCount, Type *EvalTy, + const Loop *L); /// Returns the exact trip count of the loop if we can compute it, and /// the result is a small constant. '0' is used to represent an unknown @@ -835,8 +958,7 @@ class ScalarEvolution { /// unknown or not guaranteed to be the multiple of a constant., Will also /// return 1 if the trip count is very large (>= 2^32). /// Note that the argument is an exit count for loop L, NOT a trip count. - unsigned getSmallConstantTripMultiple(const Loop *L, - const SCEV *ExitCount); + unsigned getSmallConstantTripMultiple(const Loop *L, SCEVUse ExitCount); /// Returns the largest constant divisor of the trip count of the /// loop. Will return 1 if no trip count could be computed, or if a @@ -871,12 +993,12 @@ class ScalarEvolution { /// getBackedgeTakenCount. The loop is guaranteed to exit (via *some* exit) /// before the backedge is executed (ExitCount + 1) times. Note that there /// is no guarantee about *which* exit is taken on the exiting iteration. - const SCEV *getExitCount(const Loop *L, const BasicBlock *ExitingBlock, - ExitCountKind Kind = Exact); + SCEVUse getExitCount(const Loop *L, const BasicBlock *ExitingBlock, + ExitCountKind Kind = Exact); /// Same as above except this uses the predicated backedge taken info and /// may require predicates. - const SCEV * + SCEVUse getPredicatedExitCount(const Loop *L, const BasicBlock *ExitingBlock, SmallVectorImpl *Predicates, ExitCountKind Kind = Exact); @@ -891,20 +1013,20 @@ class ScalarEvolution { /// Note that it is not valid to call this method on a loop without a /// loop-invariant backedge-taken count (see /// hasLoopInvariantBackedgeTakenCount). - const SCEV *getBackedgeTakenCount(const Loop *L, ExitCountKind Kind = Exact); + SCEVUse getBackedgeTakenCount(const Loop *L, ExitCountKind Kind = Exact); /// Similar to getBackedgeTakenCount, except it will add a set of /// SCEV predicates to Predicates that are required to be true in order for /// the answer to be correct. Predicates can be checked with run-time /// checks and can be used to perform loop versioning. - const SCEV *getPredicatedBackedgeTakenCount( + SCEVUse getPredicatedBackedgeTakenCount( const Loop *L, SmallVectorImpl &Predicates); /// When successful, this returns a SCEVConstant that is greater than or equal /// to (i.e. a "conservative over-approximation") of the value returend by /// getBackedgeTakenCount. If such a value cannot be computed, it returns the /// SCEVCouldNotCompute object. - const SCEV *getConstantMaxBackedgeTakenCount(const Loop *L) { + SCEVUse getConstantMaxBackedgeTakenCount(const Loop *L) { return getBackedgeTakenCount(L, ConstantMaximum); } @@ -912,14 +1034,14 @@ class ScalarEvolution { /// SCEV predicates to Predicates that are required to be true in order for /// the answer to be correct. Predicates can be checked with run-time /// checks and can be used to perform loop versioning. - const SCEV *getPredicatedConstantMaxBackedgeTakenCount( + SCEVUse getPredicatedConstantMaxBackedgeTakenCount( const Loop *L, SmallVectorImpl &Predicates); /// When successful, this returns a SCEV that is greater than or equal /// to (i.e. a "conservative over-approximation") of the value returend by /// getBackedgeTakenCount. If such a value cannot be computed, it returns the /// SCEVCouldNotCompute object. - const SCEV *getSymbolicMaxBackedgeTakenCount(const Loop *L) { + SCEVUse getSymbolicMaxBackedgeTakenCount(const Loop *L) { return getBackedgeTakenCount(L, SymbolicMaximum); } @@ -927,7 +1049,7 @@ class ScalarEvolution { /// SCEV predicates to Predicates that are required to be true in order for /// the answer to be correct. Predicates can be checked with run-time /// checks and can be used to perform loop versioning. - const SCEV *getPredicatedSymbolicMaxBackedgeTakenCount( + SCEVUse getPredicatedSymbolicMaxBackedgeTakenCount( const Loop *L, SmallVectorImpl &Predicates); /// Return true if the backedge taken count is either the value returned by @@ -984,60 +1106,60 @@ class ScalarEvolution { /// (at every loop iteration). It is, at the same time, the minimum number /// of times S is divisible by 2. For example, given {4,+,8} it returns 2. /// If S is guaranteed to be 0, it returns the bitwidth of S. - uint32_t getMinTrailingZeros(const SCEV *S); + uint32_t getMinTrailingZeros(SCEVUse S); /// Returns the max constant multiple of S. - APInt getConstantMultiple(const SCEV *S); + APInt getConstantMultiple(SCEVUse S); // Returns the max constant multiple of S. If S is exactly 0, return 1. - APInt getNonZeroConstantMultiple(const SCEV *S); + APInt getNonZeroConstantMultiple(SCEVUse S); /// Determine the unsigned range for a particular SCEV. /// NOTE: This returns a copy of the reference returned by getRangeRef. - ConstantRange getUnsignedRange(const SCEV *S) { + ConstantRange getUnsignedRange(SCEVUse S) { return getRangeRef(S, HINT_RANGE_UNSIGNED); } /// Determine the min of the unsigned range for a particular SCEV. - APInt getUnsignedRangeMin(const SCEV *S) { + APInt getUnsignedRangeMin(SCEVUse S) { return getRangeRef(S, HINT_RANGE_UNSIGNED).getUnsignedMin(); } /// Determine the max of the unsigned range for a particular SCEV. - APInt getUnsignedRangeMax(const SCEV *S) { + APInt getUnsignedRangeMax(SCEVUse S) { return getRangeRef(S, HINT_RANGE_UNSIGNED).getUnsignedMax(); } /// Determine the signed range for a particular SCEV. /// NOTE: This returns a copy of the reference returned by getRangeRef. - ConstantRange getSignedRange(const SCEV *S) { + ConstantRange getSignedRange(SCEVUse S) { return getRangeRef(S, HINT_RANGE_SIGNED); } /// Determine the min of the signed range for a particular SCEV. - APInt getSignedRangeMin(const SCEV *S) { + APInt getSignedRangeMin(SCEVUse S) { return getRangeRef(S, HINT_RANGE_SIGNED).getSignedMin(); } /// Determine the max of the signed range for a particular SCEV. - APInt getSignedRangeMax(const SCEV *S) { + APInt getSignedRangeMax(SCEVUse S) { return getRangeRef(S, HINT_RANGE_SIGNED).getSignedMax(); } /// Test if the given expression is known to be negative. - bool isKnownNegative(const SCEV *S); + bool isKnownNegative(SCEVUse S); /// Test if the given expression is known to be positive. - bool isKnownPositive(const SCEV *S); + bool isKnownPositive(SCEVUse S); /// Test if the given expression is known to be non-negative. - bool isKnownNonNegative(const SCEV *S); + bool isKnownNonNegative(SCEVUse S); /// Test if the given expression is known to be non-positive. - bool isKnownNonPositive(const SCEV *S); + bool isKnownNonPositive(SCEVUse S); /// Test if the given expression is known to be non-zero. - bool isKnownNonZero(const SCEV *S); + bool isKnownNonZero(SCEVUse S); /// Test if the given expression is known to be a power of 2. OrNegative /// allows matching negative power of 2s, and OrZero allows matching 0. @@ -1060,8 +1182,7 @@ class ScalarEvolution { /// 0 (initial value) for the first element and to {1, +, 1} (post /// increment value) for the second one. In both cases AddRec expression /// related to L2 remains the same. - std::pair SplitIntoInitAndPostInc(const Loop *L, - const SCEV *S); + std::pair SplitIntoInitAndPostInc(const Loop *L, SCEVUse S); /// We'd like to check the predicate on every iteration of the most dominated /// loop between loops used in LHS and RHS. @@ -1081,46 +1202,43 @@ class ScalarEvolution { /// so we can assert on that. /// e. Return true if isLoopEntryGuardedByCond(Pred, E(LHS), E(RHS)) && /// isLoopBackedgeGuardedByCond(Pred, B(LHS), B(RHS)) - bool isKnownViaInduction(ICmpInst::Predicate Pred, const SCEV *LHS, - const SCEV *RHS); + bool isKnownViaInduction(ICmpInst::Predicate Pred, SCEVUse LHS, SCEVUse RHS); /// Test if the given expression is known to satisfy the condition described /// by Pred, LHS, and RHS. - bool isKnownPredicate(ICmpInst::Predicate Pred, const SCEV *LHS, - const SCEV *RHS); + bool isKnownPredicate(ICmpInst::Predicate Pred, SCEVUse LHS, SCEVUse RHS); /// Check whether the condition described by Pred, LHS, and RHS is true or /// false. If we know it, return the evaluation of this condition. If neither /// is proved, return std::nullopt. - std::optional evaluatePredicate(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS); + std::optional evaluatePredicate(ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS); /// Test if the given expression is known to satisfy the condition described /// by Pred, LHS, and RHS in the given Context. - bool isKnownPredicateAt(ICmpInst::Predicate Pred, const SCEV *LHS, - const SCEV *RHS, const Instruction *CtxI); + bool isKnownPredicateAt(ICmpInst::Predicate Pred, SCEVUse LHS, SCEVUse RHS, + const Instruction *CtxI); /// Check whether the condition described by Pred, LHS, and RHS is true or /// false in the given \p Context. If we know it, return the evaluation of /// this condition. If neither is proved, return std::nullopt. - std::optional evaluatePredicateAt(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS, - const Instruction *CtxI); + std::optional evaluatePredicateAt(ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS, const Instruction *CtxI); /// Test if the condition described by Pred, LHS, RHS is known to be true on /// every iteration of the loop of the recurrency LHS. bool isKnownOnEveryIteration(ICmpInst::Predicate Pred, - const SCEVAddRecExpr *LHS, const SCEV *RHS); + const SCEVAddRecExpr *LHS, SCEVUse RHS); /// Information about the number of loop iterations for which a loop exit's /// branch condition evaluates to the not-taken path. This is a temporary /// pair of exact and max expressions that are eventually summarized in /// ExitNotTakenInfo and BackedgeTakenInfo. struct ExitLimit { - const SCEV *ExactNotTaken; // The exit is not taken exactly this many times - const SCEV *ConstantMaxNotTaken; // The exit is not taken at most this many - // times - const SCEV *SymbolicMaxNotTaken; + SCEVUse ExactNotTaken; // The exit is not taken exactly this many times + SCEVUse ConstantMaxNotTaken; // The exit is not taken at most this many + // times + SCEVUse SymbolicMaxNotTaken; // Not taken either exactly ConstantMaxNotTaken or zero times bool MaxOrZero = false; @@ -1133,14 +1251,14 @@ class ScalarEvolution { /// Construct either an exact exit limit from a constant, or an unknown /// one from a SCEVCouldNotCompute. No other types of SCEVs are allowed /// as arguments and asserts enforce that internally. - /*implicit*/ ExitLimit(const SCEV *E); + /*implicit*/ ExitLimit(SCEVUse E); - ExitLimit(const SCEV *E, const SCEV *ConstantMaxNotTaken, - const SCEV *SymbolicMaxNotTaken, bool MaxOrZero, + ExitLimit(SCEVUse, SCEVUse ConstantMaxNotTaken, SCEVUse SymbolicMaxNotTaken, + bool MaxOrZero, ArrayRef> PredLists = {}); - ExitLimit(const SCEV *E, const SCEV *ConstantMaxNotTaken, - const SCEV *SymbolicMaxNotTaken, bool MaxOrZero, + ExitLimit(SCEVUse E, SCEVUse ConstantMaxNotTaken, + SCEVUse SymbolicMaxNotTaken, bool MaxOrZero, ArrayRef PredList); /// Test whether this ExitLimit contains any computed information, or @@ -1191,20 +1309,18 @@ class ScalarEvolution { struct LoopInvariantPredicate { ICmpInst::Predicate Pred; - const SCEV *LHS; - const SCEV *RHS; + SCEVUse LHS; + SCEVUse RHS; - LoopInvariantPredicate(ICmpInst::Predicate Pred, const SCEV *LHS, - const SCEV *RHS) + LoopInvariantPredicate(ICmpInst::Predicate Pred, SCEVUse LHS, SCEVUse RHS) : Pred(Pred), LHS(LHS), RHS(RHS) {} }; /// If the result of the predicate LHS `Pred` RHS is loop invariant with /// respect to L, return a LoopInvariantPredicate with LHS and RHS being /// invariants, available at L's entry. Otherwise, return std::nullopt. std::optional - getLoopInvariantPredicate(ICmpInst::Predicate Pred, const SCEV *LHS, - const SCEV *RHS, const Loop *L, - const Instruction *CtxI = nullptr); + getLoopInvariantPredicate(ICmpInst::Predicate Pred, SCEVUse LHS, SCEVUse RHS, + const Loop *L, const Instruction *CtxI = nullptr); /// If the result of the predicate LHS `Pred` RHS is loop invariant with /// respect to L at given Context during at least first MaxIter iterations, @@ -1213,59 +1329,61 @@ class ScalarEvolution { /// should be the loop's exit condition. std::optional getLoopInvariantExitCondDuringFirstIterations(ICmpInst::Predicate Pred, - const SCEV *LHS, - const SCEV *RHS, const Loop *L, + SCEVUse LHS, SCEVUse RHS, + const Loop *L, const Instruction *CtxI, - const SCEV *MaxIter); + SCEVUse MaxIter); std::optional - getLoopInvariantExitCondDuringFirstIterationsImpl( - ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, - const Instruction *CtxI, const SCEV *MaxIter); + getLoopInvariantExitCondDuringFirstIterationsImpl(ICmpInst::Predicate Pred, + SCEVUse LHS, SCEVUse RHS, + const Loop *L, + const Instruction *CtxI, + SCEVUse MaxIter); /// Simplify LHS and RHS in a comparison with predicate Pred. Return true /// iff any changes were made. If the operands are provably equal or /// unequal, LHS and RHS are set to the same value and Pred is set to either /// ICMP_EQ or ICMP_NE. - bool SimplifyICmpOperands(ICmpInst::Predicate &Pred, const SCEV *&LHS, - const SCEV *&RHS, unsigned Depth = 0); + bool SimplifyICmpOperands(ICmpInst::Predicate &Pred, SCEVUse &LHS, + SCEVUse &RHS, unsigned Depth = 0); /// Return the "disposition" of the given SCEV with respect to the given /// loop. - LoopDisposition getLoopDisposition(const SCEV *S, const Loop *L); + LoopDisposition getLoopDisposition(SCEVUse S, const Loop *L); /// Return true if the value of the given SCEV is unchanging in the /// specified loop. - bool isLoopInvariant(const SCEV *S, const Loop *L); + bool isLoopInvariant(SCEVUse S, const Loop *L); /// Determine if the SCEV can be evaluated at loop's entry. It is true if it /// doesn't depend on a SCEVUnknown of an instruction which is dominated by /// the header of loop L. - bool isAvailableAtLoopEntry(const SCEV *S, const Loop *L); + bool isAvailableAtLoopEntry(SCEVUse S, const Loop *L); /// Return true if the given SCEV changes value in a known way in the /// specified loop. This property being true implies that the value is /// variant in the loop AND that we can emit an expression to compute the /// value of the expression at any particular loop iteration. - bool hasComputableLoopEvolution(const SCEV *S, const Loop *L); + bool hasComputableLoopEvolution(SCEVUse S, const Loop *L); /// Return the "disposition" of the given SCEV with respect to the given /// block. - BlockDisposition getBlockDisposition(const SCEV *S, const BasicBlock *BB); + BlockDisposition getBlockDisposition(SCEVUse S, const BasicBlock *BB); /// Return true if elements that makes up the given SCEV dominate the /// specified basic block. - bool dominates(const SCEV *S, const BasicBlock *BB); + bool dominates(SCEVUse S, const BasicBlock *BB); /// Return true if elements that makes up the given SCEV properly dominate /// the specified basic block. - bool properlyDominates(const SCEV *S, const BasicBlock *BB); + bool properlyDominates(SCEVUse S, const BasicBlock *BB); /// Test whether the given SCEV has Op as a direct or indirect operand. - bool hasOperand(const SCEV *S, const SCEV *Op) const; + bool hasOperand(SCEVUse S, SCEVUse Op) const; /// Return the size of an element read or written by Inst. - const SCEV *getElementSize(Instruction *Inst); + SCEVUse getElementSize(Instruction *Inst); void print(raw_ostream &OS) const; void verify() const; @@ -1276,22 +1394,21 @@ class ScalarEvolution { /// operating on. const DataLayout &getDataLayout() const { return DL; } - const SCEVPredicate *getEqualPredicate(const SCEV *LHS, const SCEV *RHS); + const SCEVPredicate *getEqualPredicate(SCEVUse LHS, SCEVUse RHS); const SCEVPredicate *getComparePredicate(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS); + SCEVUse LHS, SCEVUse RHS); const SCEVPredicate * getWrapPredicate(const SCEVAddRecExpr *AR, SCEVWrapPredicate::IncrementWrapFlags AddedFlags); /// Re-writes the SCEV according to the Predicates in \p A. - const SCEV *rewriteUsingPredicate(const SCEV *S, const Loop *L, - const SCEVPredicate &A); + SCEVUse rewriteUsingPredicate(SCEVUse S, const Loop *L, + const SCEVPredicate &A); /// Tries to convert the \p S expression to an AddRec expression, /// adding additional predicates to \p Preds as required. const SCEVAddRecExpr *convertSCEVToAddRecWithPredicates( - const SCEV *S, const Loop *L, - SmallVectorImpl &Preds); + SCEVUse S, const Loop *L, SmallVectorImpl &Preds); /// Compute \p LHS - \p RHS and returns the result as an APInt if it is a /// constant, and std::nullopt if it isn't. @@ -1300,8 +1417,7 @@ class ScalarEvolution { /// frugal here since we just bail out of actually constructing and /// canonicalizing an expression in the cases where the result isn't going /// to be a constant. - std::optional computeConstantDifference(const SCEV *LHS, - const SCEV *RHS); + std::optional computeConstantDifference(SCEVUse LHS, SCEVUse RHS); /// Update no-wrap flags of an AddRec. This may drop the cached info about /// this AddRec (such as range info) in case if new flags may potentially @@ -1309,7 +1425,7 @@ class ScalarEvolution { void setNoWrapFlags(SCEVAddRecExpr *AddRec, SCEV::NoWrapFlags Flags); class LoopGuards { - DenseMap RewriteMap; + DenseMap RewriteMap; bool PreserveNUW = false; bool PreserveNSW = false; ScalarEvolution &SE; @@ -1326,8 +1442,8 @@ class ScalarEvolution { }; /// Try to apply information from loop guards for \p L to \p Expr. - const SCEV *applyLoopGuards(const SCEV *Expr, const Loop *L); - const SCEV *applyLoopGuards(const SCEV *Expr, const LoopGuards &Guards); + SCEVUse applyLoopGuards(const SCEVUse Expr, const Loop *L); + SCEVUse applyLoopGuards(const SCEVUse Expr, const LoopGuards &Guards); /// Return true if the loop has no abnormal exits. That is, if the loop /// is not infinite, it must exit through an explicit edge in the CFG. @@ -1345,22 +1461,22 @@ class ScalarEvolution { /// being poison as well. The returned set may be incomplete, i.e. there can /// be additional Values that also result in S being poison. void getPoisonGeneratingValues(SmallPtrSetImpl &Result, - const SCEV *S); + SCEVUse S); /// Check whether it is poison-safe to represent the expression S using the /// instruction I. If such a replacement is performed, the poison flags of /// instructions in DropPoisonGeneratingInsts must be dropped. bool canReuseInstruction( - const SCEV *S, Instruction *I, + SCEVUse S, Instruction *I, SmallVectorImpl &DropPoisonGeneratingInsts); class FoldID { - const SCEV *Op = nullptr; + SCEVUse Op = nullptr; const Type *Ty = nullptr; unsigned short C; public: - FoldID(SCEVTypes C, const SCEV *Op, const Type *Ty) : Op(Op), Ty(Ty), C(C) { + FoldID(SCEVTypes C, SCEVUse Op, const Type *Ty) : Op(Op), Ty(Ty), C(C) { assert(Op); assert(Ty); } @@ -1369,8 +1485,9 @@ class ScalarEvolution { unsigned computeHash() const { return detail::combineHashValue( - C, detail::combineHashValue(reinterpret_cast(Op), - reinterpret_cast(Ty))); + C, detail::combineHashValue( + reinterpret_cast(Op.getRawPointer()), + reinterpret_cast(Ty))); } bool operator==(const FoldID &RHS) const { @@ -1422,14 +1539,14 @@ class ScalarEvolution { std::unique_ptr CouldNotCompute; /// The type for HasRecMap. - using HasRecMapType = DenseMap; + using HasRecMapType = DenseMap; /// This is a cache to record whether a SCEV contains any scAddRecExpr. HasRecMapType HasRecMap; /// The type for ExprValueMap. using ValueSetVector = SmallSetVector; - using ExprValueMapType = DenseMap; + using ExprValueMapType = DenseMap; /// ExprValueMap -- This map records the original values from which /// the SCEV expr is generated from. @@ -1437,15 +1554,15 @@ class ScalarEvolution { /// The type for ValueExprMap. using ValueExprMapType = - DenseMap>; + DenseMap>; /// This is a cache of the values we have analyzed so far. ValueExprMapType ValueExprMap; /// This is a cache for expressions that got folded to a different existing /// SCEV. - DenseMap FoldCache; - DenseMap> FoldCacheUser; + DenseMap FoldCache; + DenseMap> FoldCacheUser; /// Mark predicate values currently being processed by isImpliedCond. SmallPtrSet PendingLoopPredicates; @@ -1468,27 +1585,27 @@ class ScalarEvolution { bool ProvingSplitPredicate = false; /// Memoized values for the getConstantMultiple - DenseMap ConstantMultipleCache; + DenseMap ConstantMultipleCache; /// Return the Value set from which the SCEV expr is generated. - ArrayRef getSCEVValues(const SCEV *S); + ArrayRef getSCEVValues(SCEVUse S); /// Private helper method for the getConstantMultiple method. - APInt getConstantMultipleImpl(const SCEV *S); + APInt getConstantMultipleImpl(SCEVUse S); /// Information about the number of times a particular loop exit may be /// reached before exiting the loop. struct ExitNotTakenInfo { PoisoningVH ExitingBlock; - const SCEV *ExactNotTaken; - const SCEV *ConstantMaxNotTaken; - const SCEV *SymbolicMaxNotTaken; + SCEVUse ExactNotTaken; + SCEVUse ConstantMaxNotTaken; + SCEVUse SymbolicMaxNotTaken; SmallVector Predicates; explicit ExitNotTakenInfo(PoisoningVH ExitingBlock, - const SCEV *ExactNotTaken, - const SCEV *ConstantMaxNotTaken, - const SCEV *SymbolicMaxNotTaken, + SCEVUse ExactNotTaken, + SCEVUse ConstantMaxNotTaken, + SCEVUse SymbolicMaxNotTaken, ArrayRef Predicates) : ExitingBlock(ExitingBlock), ExactNotTaken(ExactNotTaken), ConstantMaxNotTaken(ConstantMaxNotTaken), @@ -1512,7 +1629,7 @@ class ScalarEvolution { /// Expression indicating the least constant maximum backedge-taken count of /// the loop that is known, or a SCEVCouldNotCompute. This expression is /// only valid if the predicates associated with all loop exits are true. - const SCEV *ConstantMax = nullptr; + SCEVUse ConstantMax = nullptr; /// Indicating if \c ExitNotTaken has an element for every exiting block in /// the loop. @@ -1520,13 +1637,13 @@ class ScalarEvolution { /// Expression indicating the least maximum backedge-taken count of the loop /// that is known, or a SCEVCouldNotCompute. Lazily computed on first query. - const SCEV *SymbolicMax = nullptr; + SCEVUse SymbolicMax = nullptr; /// True iff the backedge is taken either exactly Max or zero times. bool MaxOrZero = false; bool isComplete() const { return IsComplete; } - const SCEV *getConstantMax() const { return ConstantMax; } + SCEVUse getConstantMax() const { return ConstantMax; } const ExitNotTakenInfo *getExitNotTaken( const BasicBlock *ExitingBlock, @@ -1541,7 +1658,7 @@ class ScalarEvolution { /// Initialize BackedgeTakenInfo from a list of exact exit counts. BackedgeTakenInfo(ArrayRef ExitCounts, bool IsComplete, - const SCEV *ConstantMax, bool MaxOrZero); + SCEVUse ConstantMax, bool MaxOrZero); /// Test whether this BackedgeTakenInfo contains any computed information, /// or whether it's all SCEVCouldNotCompute values. @@ -1571,7 +1688,7 @@ class ScalarEvolution { /// If we allowed SCEV predicates to be generated when populating this /// vector, this information can contain them and therefore a /// SCEVPredicate argument should be added to getExact. - const SCEV *getExact( + SCEVUse getExact( const Loop *L, ScalarEvolution *SE, SmallVectorImpl *Predicates = nullptr) const; @@ -1580,7 +1697,7 @@ class ScalarEvolution { /// this block before this number of iterations, but may exit via another /// block. If \p Predicates is null the function returns CouldNotCompute if /// predicates are required, otherwise it fills in the required predicates. - const SCEV *getExact( + SCEVUse getExact( const BasicBlock *ExitingBlock, ScalarEvolution *SE, SmallVectorImpl *Predicates = nullptr) const { if (auto *ENT = getExitNotTaken(ExitingBlock, Predicates)) @@ -1590,12 +1707,12 @@ class ScalarEvolution { } /// Get the constant max backedge taken count for the loop. - const SCEV *getConstantMax( + SCEVUse getConstantMax( ScalarEvolution *SE, SmallVectorImpl *Predicates = nullptr) const; /// Get the constant max backedge taken count for the particular loop exit. - const SCEV *getConstantMax( + SCEVUse getConstantMax( const BasicBlock *ExitingBlock, ScalarEvolution *SE, SmallVectorImpl *Predicates = nullptr) const { if (auto *ENT = getExitNotTaken(ExitingBlock, Predicates)) @@ -1605,12 +1722,12 @@ class ScalarEvolution { } /// Get the symbolic max backedge taken count for the loop. - const SCEV *getSymbolicMax( + SCEVUse getSymbolicMax( const Loop *L, ScalarEvolution *SE, SmallVectorImpl *Predicates = nullptr); /// Get the symbolic max backedge taken count for the particular loop exit. - const SCEV *getSymbolicMax( + SCEVUse getSymbolicMax( const BasicBlock *ExitingBlock, ScalarEvolution *SE, SmallVectorImpl *Predicates = nullptr) const { if (auto *ENT = getExitNotTaken(ExitingBlock, Predicates)) @@ -1633,7 +1750,7 @@ class ScalarEvolution { DenseMap PredicatedBackedgeTakenCounts; /// Loops whose backedge taken counts directly use this non-constant SCEV. - DenseMap, 4>> + DenseMap, 4>> BECountUsers; /// This map contains entries for all of the PHI instructions that we @@ -1645,16 +1762,16 @@ class ScalarEvolution { /// This map contains entries for all the expressions that we attempt to /// compute getSCEVAtScope information for, which can be expensive in /// extreme cases. - DenseMap, 2>> + DenseMap, 2>> ValuesAtScopes; /// Reverse map for invalidation purposes: Stores of which SCEV and which /// loop this is the value-at-scope of. - DenseMap, 2>> + DenseMap, 2>> ValuesAtScopesUsers; /// Memoized computeLoopDisposition results. - DenseMap, 2>> LoopDispositions; @@ -1682,33 +1799,33 @@ class ScalarEvolution { } /// Compute a LoopDisposition value. - LoopDisposition computeLoopDisposition(const SCEV *S, const Loop *L); + LoopDisposition computeLoopDisposition(SCEVUse S, const Loop *L); /// Memoized computeBlockDisposition results. DenseMap< - const SCEV *, + SCEVUse, SmallVector, 2>> BlockDispositions; /// Compute a BlockDisposition value. - BlockDisposition computeBlockDisposition(const SCEV *S, const BasicBlock *BB); + BlockDisposition computeBlockDisposition(SCEVUse S, const BasicBlock *BB); /// Stores all SCEV that use a given SCEV as its direct operand. - DenseMap > SCEVUsers; + DenseMap> SCEVUsers; /// Memoized results from getRange - DenseMap UnsignedRanges; + DenseMap UnsignedRanges; /// Memoized results from getRange - DenseMap SignedRanges; + DenseMap SignedRanges; /// Used to parameterize getRange enum RangeSignHint { HINT_RANGE_UNSIGNED, HINT_RANGE_SIGNED }; /// Set the memoized range for the given SCEV. - const ConstantRange &setRange(const SCEV *S, RangeSignHint Hint, + const ConstantRange &setRange(SCEVUse S, RangeSignHint Hint, ConstantRange CR) { - DenseMap &Cache = + DenseMap &Cache = Hint == HINT_RANGE_UNSIGNED ? UnsignedRanges : SignedRanges; auto Pair = Cache.insert_or_assign(S, std::move(CR)); @@ -1718,29 +1835,29 @@ class ScalarEvolution { /// Determine the range for a particular SCEV. /// NOTE: This returns a reference to an entry in a cache. It must be /// copied if its needed for longer. - const ConstantRange &getRangeRef(const SCEV *S, RangeSignHint Hint, + const ConstantRange &getRangeRef(SCEVUse S, RangeSignHint Hint, unsigned Depth = 0); /// Determine the range for a particular SCEV, but evaluates ranges for /// operands iteratively first. - const ConstantRange &getRangeRefIter(const SCEV *S, RangeSignHint Hint); + const ConstantRange &getRangeRefIter(SCEVUse S, RangeSignHint Hint); /// Determines the range for the affine SCEVAddRecExpr {\p Start,+,\p Step}. /// Helper for \c getRange. - ConstantRange getRangeForAffineAR(const SCEV *Start, const SCEV *Step, + ConstantRange getRangeForAffineAR(SCEVUse Start, SCEVUse Step, const APInt &MaxBECount); /// Determines the range for the affine non-self-wrapping SCEVAddRecExpr {\p /// Start,+,\p Step}. ConstantRange getRangeForAffineNoSelfWrappingAR(const SCEVAddRecExpr *AddRec, - const SCEV *MaxBECount, + SCEVUse MaxBECount, unsigned BitWidth, RangeSignHint SignHint); /// Try to compute a range for the affine SCEVAddRecExpr {\p Start,+,\p /// Step} by "factoring out" a ternary expression from the add recurrence. /// Helper called by \c getRange. - ConstantRange getRangeViaFactoring(const SCEV *Start, const SCEV *Step, + ConstantRange getRangeViaFactoring(SCEVUse Start, SCEVUse Step, const APInt &MaxBECount); /// If the unknown expression U corresponds to a simple recurrence, return @@ -1751,55 +1868,54 @@ class ScalarEvolution { /// We know that there is no SCEV for the specified value. Analyze the /// expression recursively. - const SCEV *createSCEV(Value *V); + SCEVUse createSCEV(Value *V, bool UseCtx = false); /// We know that there is no SCEV for the specified value. Create a new SCEV /// for \p V iteratively. - const SCEV *createSCEVIter(Value *V); + SCEVUse createSCEVIter(Value *V, bool UseCtx = false); /// Collect operands of \p V for which SCEV expressions should be constructed /// first. Returns a SCEV directly if it can be constructed trivially for \p /// V. - const SCEV *getOperandsToCreate(Value *V, SmallVectorImpl &Ops); + SCEVUse getOperandsToCreate(Value *V, SmallVectorImpl &Ops); /// Provide the special handling we need to analyze PHI SCEVs. - const SCEV *createNodeForPHI(PHINode *PN); + SCEVUse createNodeForPHI(PHINode *PN); /// Helper function called from createNodeForPHI. - const SCEV *createAddRecFromPHI(PHINode *PN); + SCEVUse createAddRecFromPHI(PHINode *PN); /// A helper function for createAddRecFromPHI to handle simple cases. - const SCEV *createSimpleAffineAddRec(PHINode *PN, Value *BEValueV, - Value *StartValueV); + SCEVUse createSimpleAffineAddRec(PHINode *PN, Value *BEValueV, + Value *StartValueV); /// Helper function called from createNodeForPHI. - const SCEV *createNodeFromSelectLikePHI(PHINode *PN); + SCEVUse createNodeFromSelectLikePHI(PHINode *PN); /// Provide special handling for a select-like instruction (currently this /// is either a select instruction or a phi node). \p Ty is the type of the /// instruction being processed, that is assumed equivalent to /// "Cond ? TrueVal : FalseVal". - std::optional + std::optional createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty, ICmpInst *Cond, Value *TrueVal, Value *FalseVal); /// See if we can model this select-like instruction via umin_seq expression. - const SCEV *createNodeForSelectOrPHIViaUMinSeq(Value *I, Value *Cond, - Value *TrueVal, - Value *FalseVal); + SCEVUse createNodeForSelectOrPHIViaUMinSeq(Value *I, Value *Cond, + Value *TrueVal, Value *FalseVal); /// Given a value \p V, which is a select-like instruction (currently this is /// either a select instruction or a phi node), which is assumed equivalent to /// Cond ? TrueVal : FalseVal /// see if we can model it as a SCEV expression. - const SCEV *createNodeForSelectOrPHI(Value *V, Value *Cond, Value *TrueVal, - Value *FalseVal); + SCEVUse createNodeForSelectOrPHI(Value *V, Value *Cond, Value *TrueVal, + Value *FalseVal); /// Provide the special handling we need to analyze GEP SCEVs. - const SCEV *createNodeForGEP(GEPOperator *GEP); + SCEVUse createNodeForGEP(GEPOperator *GEP, bool UseCtx = false); /// Implementation code for getSCEVAtScope; called at most once for each /// SCEV+Loop pair. - const SCEV *computeSCEVAtScope(const SCEV *S, const Loop *L); + SCEVUse computeSCEVAtScope(SCEVUse S, const Loop *L); /// Return the BackedgeTakenInfo for the given loop, lazily computing new /// values if the loop hasn't been analyzed yet. The returned result is @@ -1881,8 +1997,7 @@ class ScalarEvolution { /// return more precise results in some cases and is preferred when caller /// has a materialized ICmp. ExitLimit computeExitLimitFromICmp(const Loop *L, ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS, - bool IsSubExpr, + SCEVUse LHS, SCEVUse RHS, bool IsSubExpr, bool AllowPredicates = false); /// Compute the number of times the backedge of the specified loop will @@ -1908,20 +2023,20 @@ class ScalarEvolution { /// of the loop until we get the exit condition gets a value of ExitWhen /// (true or false). If we cannot evaluate the exit count of the loop, /// return CouldNotCompute. - const SCEV *computeExitCountExhaustively(const Loop *L, Value *Cond, - bool ExitWhen); + SCEVUse computeExitCountExhaustively(const Loop *L, Value *Cond, + bool ExitWhen); /// Return the number of times an exit condition comparing the specified /// value to zero will execute. If not computable, return CouldNotCompute. /// If AllowPredicates is set, this call will try to use a minimal set of /// SCEV predicates in order to return an exact answer. - ExitLimit howFarToZero(const SCEV *V, const Loop *L, bool IsSubExpr, + ExitLimit howFarToZero(SCEVUse V, const Loop *L, bool IsSubExpr, bool AllowPredicates = false); /// Return the number of times an exit condition checking the specified /// value for nonzero will execute. If not computable, return /// CouldNotCompute. - ExitLimit howFarToNonZero(const SCEV *V, const Loop *L); + ExitLimit howFarToNonZero(SCEVUse V, const Loop *L); /// Return the number of times an exit condition containing the specified /// less-than comparison will execute. If not computable, return @@ -1935,11 +2050,11 @@ class ScalarEvolution { /// /// If \p AllowPredicates is set, this call will try to use a minimal set of /// SCEV predicates in order to return an exact answer. - ExitLimit howManyLessThans(const SCEV *LHS, const SCEV *RHS, const Loop *L, + ExitLimit howManyLessThans(SCEVUse LHS, SCEVUse RHS, const Loop *L, bool isSigned, bool ControlsOnlyExit, bool AllowPredicates = false); - ExitLimit howManyGreaterThans(const SCEV *LHS, const SCEV *RHS, const Loop *L, + ExitLimit howManyGreaterThans(SCEVUse LHS, SCEVUse RHS, const Loop *L, bool isSigned, bool IsSubExpr, bool AllowPredicates = false); @@ -1953,7 +2068,7 @@ class ScalarEvolution { /// whenever the given FoundCondValue value evaluates to true in given /// Context. If Context is nullptr, then the found predicate is true /// everywhere. LHS and FoundLHS may have different type width. - bool isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, + bool isImpliedCond(ICmpInst::Predicate Pred, SCEVUse LHS, SCEVUse RHS, const Value *FoundCondValue, bool Inverse, const Instruction *Context = nullptr); @@ -1961,65 +2076,60 @@ class ScalarEvolution { /// whenever the given FoundCondValue value evaluates to true in given /// Context. If Context is nullptr, then the found predicate is true /// everywhere. LHS and FoundLHS must have same type width. - bool isImpliedCondBalancedTypes(ICmpInst::Predicate Pred, const SCEV *LHS, - const SCEV *RHS, - ICmpInst::Predicate FoundPred, - const SCEV *FoundLHS, const SCEV *FoundRHS, + bool isImpliedCondBalancedTypes(ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS, ICmpInst::Predicate FoundPred, + SCEVUse FoundLHS, SCEVUse FoundRHS, const Instruction *CtxI); /// Test whether the condition described by Pred, LHS, and RHS is true /// whenever the condition described by FoundPred, FoundLHS, FoundRHS is /// true in given Context. If Context is nullptr, then the found predicate is /// true everywhere. - bool isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, - ICmpInst::Predicate FoundPred, const SCEV *FoundLHS, - const SCEV *FoundRHS, - const Instruction *Context = nullptr); + bool isImpliedCond(ICmpInst::Predicate Pred, SCEVUse LHS, SCEVUse RHS, + ICmpInst::Predicate FoundPred, SCEVUse FoundLHS, + SCEVUse FoundRHS, const Instruction *Context = nullptr); /// Test whether the condition described by Pred, LHS, and RHS is true /// whenever the condition described by Pred, FoundLHS, and FoundRHS is /// true in given Context. If Context is nullptr, then the found predicate is /// true everywhere. - bool isImpliedCondOperands(ICmpInst::Predicate Pred, const SCEV *LHS, - const SCEV *RHS, const SCEV *FoundLHS, - const SCEV *FoundRHS, + bool isImpliedCondOperands(ICmpInst::Predicate Pred, SCEVUse LHS, SCEVUse RHS, + SCEVUse FoundLHS, SCEVUse FoundRHS, const Instruction *Context = nullptr); /// Test whether the condition described by Pred, LHS, and RHS is true /// whenever the condition described by Pred, FoundLHS, and FoundRHS is /// true. Here LHS is an operation that includes FoundLHS as one of its /// arguments. - bool isImpliedViaOperations(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS, - const SCEV *FoundLHS, const SCEV *FoundRHS, + bool isImpliedViaOperations(ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS, SCEVUse FoundLHS, SCEVUse FoundRHS, unsigned Depth = 0); /// Test whether the condition described by Pred, LHS, and RHS is true. /// Use only simple non-recursive types of checks, such as range analysis etc. - bool isKnownViaNonRecursiveReasoning(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS); + bool isKnownViaNonRecursiveReasoning(ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS); /// Test whether the condition described by Pred, LHS, and RHS is true /// whenever the condition described by Pred, FoundLHS, and FoundRHS is /// true. - bool isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, const SCEV *LHS, - const SCEV *RHS, const SCEV *FoundLHS, - const SCEV *FoundRHS); + bool isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS, SCEVUse FoundLHS, + SCEVUse FoundRHS); /// Test whether the condition described by Pred, LHS, and RHS is true /// whenever the condition described by Pred, FoundLHS, and FoundRHS is /// true. Utility function used by isImpliedCondOperands. Tries to get /// cases like "X `sgt` 0 => X - 1 `sgt` -1". - bool isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred, const SCEV *LHS, - const SCEV *RHS, + bool isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS, ICmpInst::Predicate FoundPred, - const SCEV *FoundLHS, - const SCEV *FoundRHS); + SCEVUse FoundLHS, SCEVUse FoundRHS); /// Return true if the condition denoted by \p LHS \p Pred \p RHS is implied /// by a call to @llvm.experimental.guard in \p BB. bool isImpliedViaGuard(const BasicBlock *BB, ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS); + SCEVUse LHS, SCEVUse RHS); /// Test whether the condition described by Pred, LHS, and RHS is true /// whenever the condition described by Pred, FoundLHS, and FoundRHS is @@ -2027,10 +2137,9 @@ class ScalarEvolution { /// /// This routine tries to rule out certain kinds of integer overflow, and /// then tries to reason about arithmetic properties of the predicates. - bool isImpliedCondOperandsViaNoOverflow(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS, - const SCEV *FoundLHS, - const SCEV *FoundRHS); + bool isImpliedCondOperandsViaNoOverflow(ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS, SCEVUse FoundLHS, + SCEVUse FoundRHS); /// Test whether the condition described by Pred, LHS, and RHS is true /// whenever the condition described by Pred, FoundLHS, and FoundRHS is @@ -2039,9 +2148,8 @@ class ScalarEvolution { /// This routine tries to weaken the known condition basing on fact that /// FoundLHS is an AddRec. bool isImpliedCondOperandsViaAddRecStart(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS, - const SCEV *FoundLHS, - const SCEV *FoundRHS, + SCEVUse LHS, SCEVUse RHS, + SCEVUse FoundLHS, SCEVUse FoundRHS, const Instruction *CtxI); /// Test whether the condition described by Pred, LHS, and RHS is true @@ -2051,19 +2159,17 @@ class ScalarEvolution { /// This routine tries to figure out predicate for Phis which are SCEVUnknown /// if it is true for every possible incoming value from their respective /// basic blocks. - bool isImpliedViaMerge(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS, - const SCEV *FoundLHS, const SCEV *FoundRHS, - unsigned Depth); + bool isImpliedViaMerge(ICmpInst::Predicate Pred, SCEVUse LHS, SCEVUse RHS, + SCEVUse FoundLHS, SCEVUse FoundRHS, unsigned Depth); /// Test whether the condition described by Pred, LHS, and RHS is true /// whenever the condition described by Pred, FoundLHS, and FoundRHS is /// true. /// /// This routine tries to reason about shifts. - bool isImpliedCondOperandsViaShift(ICmpInst::Predicate Pred, const SCEV *LHS, - const SCEV *RHS, const SCEV *FoundLHS, - const SCEV *FoundRHS); + bool isImpliedCondOperandsViaShift(ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS, SCEVUse FoundLHS, + SCEVUse FoundRHS); /// If we know that the specified Phi is in the header of its containing /// loop, we know the loop executes a constant number of times, and the PHI @@ -2073,50 +2179,50 @@ class ScalarEvolution { /// Test if the given expression is known to satisfy the condition described /// by Pred and the known constant ranges of LHS and RHS. - bool isKnownPredicateViaConstantRanges(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS); + bool isKnownPredicateViaConstantRanges(ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS); /// Try to prove the condition described by "LHS Pred RHS" by ruling out /// integer overflow. /// /// For instance, this will return true for "A s< (A + C)" if C is /// positive. - bool isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred, const SCEV *LHS, - const SCEV *RHS); + bool isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS); /// Try to split Pred LHS RHS into logical conjunctions (and's) and try to /// prove them individually. - bool isKnownPredicateViaSplitting(ICmpInst::Predicate Pred, const SCEV *LHS, - const SCEV *RHS); + bool isKnownPredicateViaSplitting(ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS); /// Try to match the Expr as "(L + R)". - bool splitBinaryAdd(const SCEV *Expr, const SCEV *&L, const SCEV *&R, + bool splitBinaryAdd(SCEVUse Expr, SCEVUse &L, SCEVUse &R, SCEV::NoWrapFlags &Flags); /// Forget predicated/non-predicated backedge taken counts for the given loop. void forgetBackedgeTakenCounts(const Loop *L, bool Predicated); /// Drop memoized information for all \p SCEVs. - void forgetMemoizedResults(ArrayRef SCEVs); + void forgetMemoizedResults(ArrayRef SCEVs); /// Helper for forgetMemoizedResults. - void forgetMemoizedResultsImpl(const SCEV *S); + void forgetMemoizedResultsImpl(SCEVUse S); /// Iterate over instructions in \p Worklist and their users. Erase entries /// from ValueExprMap and collect SCEV expressions in \p ToForget void visitAndClearUsers(SmallVectorImpl &Worklist, SmallPtrSetImpl &Visited, - SmallVectorImpl &ToForget); + SmallVectorImpl &ToForget); /// Erase Value from ValueExprMap and ExprValueMap. void eraseValueFromMap(Value *V); /// Insert V to S mapping into ValueExprMap and ExprValueMap. - void insertValueToMap(Value *V, const SCEV *S); + void insertValueToMap(Value *V, SCEVUse S); /// Return false iff given SCEV contains a SCEVUnknown with NULL value- /// pointer. - bool checkValidity(const SCEV *S) const; + bool checkValidity(SCEVUse S) const; /// Return true if `ExtendOpTy`({`Start`,+,`Step`}) can be proved to be /// equal to {`ExtendOpTy`(`Start`),+,`ExtendOpTy`(`Step`)}. This is @@ -2124,8 +2230,7 @@ class ScalarEvolution { /// {`Start`,+,`Step`} if `ExtendOpTy` is `SCEVSignExtendExpr` /// (resp. `SCEVZeroExtendExpr`). template - bool proveNoWrapByVaryingStart(const SCEV *Start, const SCEV *Step, - const Loop *L); + bool proveNoWrapByVaryingStart(SCEVUse Start, SCEVUse Step, const Loop *L); /// Try to prove NSW or NUW on \p AR relying on ConstantRange manipulation. SCEV::NoWrapFlags proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR); @@ -2151,17 +2256,17 @@ class ScalarEvolution { /// 'S'. Specifically, return the first instruction in said bounding scope. /// Return nullptr if the scope is trivial (function entry). /// (See scope definition rules associated with flag discussion above) - const Instruction *getNonTrivialDefiningScopeBound(const SCEV *S); + const Instruction *getNonTrivialDefiningScopeBound(SCEVUse S); /// Return a scope which provides an upper bound on the defining scope for /// a SCEV with the operands in Ops. The outparam Precise is set if the /// bound found is a precise bound (i.e. must be the defining scope.) - const Instruction *getDefiningScopeBound(ArrayRef Ops, + const Instruction *getDefiningScopeBound(ArrayRef Ops, bool &Precise); /// Wrapper around the above for cases which don't care if the bound /// is precise. - const Instruction *getDefiningScopeBound(ArrayRef Ops); + const Instruction *getDefiningScopeBound(ArrayRef Ops); /// Given two instructions in the same function, return true if we can /// prove B must execute given A executes. @@ -2202,7 +2307,7 @@ class ScalarEvolution { /// If the analysis is not successful, a mapping from the \p SymbolicPHI to /// itself (with no predicates) is recorded, and a nullptr with an empty /// predicates vector is returned as a pair. - std::optional>> + std::optional>> createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI); /// Compute the maximum backedge count based on the range of values @@ -2214,47 +2319,44 @@ class ScalarEvolution { /// * the induction variable is assumed not to overflow (i.e. either it /// actually doesn't, or we'd have to immediately execute UB) /// We *don't* assert these preconditions so please be careful. - const SCEV *computeMaxBECountForLT(const SCEV *Start, const SCEV *Stride, - const SCEV *End, unsigned BitWidth, - bool IsSigned); + SCEVUse computeMaxBECountForLT(SCEVUse Start, SCEVUse Stride, SCEVUse End, + unsigned BitWidth, bool IsSigned); /// Verify if an linear IV with positive stride can overflow when in a /// less-than comparison, knowing the invariant term of the comparison, /// the stride. - bool canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride, bool IsSigned); + bool canIVOverflowOnLT(SCEVUse RHS, SCEVUse Stride, bool IsSigned); /// Verify if an linear IV with negative stride can overflow when in a /// greater-than comparison, knowing the invariant term of the comparison, /// the stride. - bool canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride, bool IsSigned); + bool canIVOverflowOnGT(SCEVUse RHS, SCEVUse Stride, bool IsSigned); /// Get add expr already created or create a new one. - const SCEV *getOrCreateAddExpr(ArrayRef Ops, - SCEV::NoWrapFlags Flags); + SCEVUse getOrCreateAddExpr(ArrayRef Ops, SCEV::NoWrapFlags Flags); /// Get mul expr already created or create a new one. - const SCEV *getOrCreateMulExpr(ArrayRef Ops, - SCEV::NoWrapFlags Flags); + SCEVUse getOrCreateMulExpr(ArrayRef Ops, SCEV::NoWrapFlags Flags); // Get addrec expr already created or create a new one. - const SCEV *getOrCreateAddRecExpr(ArrayRef Ops, - const Loop *L, SCEV::NoWrapFlags Flags); + SCEVUse getOrCreateAddRecExpr(ArrayRef Ops, const Loop *L, + SCEV::NoWrapFlags Flags); /// Return x if \p Val is f(x) where f is a 1-1 function. - const SCEV *stripInjectiveFunctions(const SCEV *Val) const; + SCEVUse stripInjectiveFunctions(SCEVUse Val) const; /// Find all of the loops transitively used in \p S, and fill \p LoopsUsed. /// A loop is considered "used" by an expression if it contains /// an add rec on said loop. - void getUsedLoops(const SCEV *S, SmallPtrSetImpl &LoopsUsed); + void getUsedLoops(SCEVUse S, SmallPtrSetImpl &LoopsUsed); /// Try to match the pattern generated by getURemExpr(A, B). If successful, /// Assign A and B to LHS and RHS, respectively. - bool matchURem(const SCEV *Expr, const SCEV *&LHS, const SCEV *&RHS); + bool matchURem(SCEVUse Expr, SCEVUse &LHS, SCEVUse &RHS); /// Look for a SCEV expression with type `SCEVType` and operands `Ops` in /// `UniqueSCEVs`. Return if found, else nullptr. - SCEV *findExistingSCEVInCache(SCEVTypes SCEVType, ArrayRef Ops); + SCEV *findExistingSCEVInCache(SCEVTypes SCEVType, ArrayRef Ops); /// Get reachable blocks in this function, making limited use of SCEV /// reasoning about conditions. @@ -2263,8 +2365,7 @@ class ScalarEvolution { /// Return the given SCEV expression with a new set of operands. /// This preserves the origial nowrap flags. - const SCEV *getWithOperands(const SCEV *S, - SmallVectorImpl &NewOps); + SCEVUse getWithOperands(SCEVUse S, SmallVectorImpl &NewOps); FoldingSet UniqueSCEVs; FoldingSet UniquePreds; @@ -2276,7 +2377,7 @@ class ScalarEvolution { /// Cache tentative mappings from UnknownSCEVs in a Loop, to a SCEV expression /// they can be rewritten into under certain predicates. DenseMap, - std::pair>> + std::pair>> PredicatedSCEVRewrites; /// Set of AddRecs for which proving NUW via an induction has already been @@ -2368,10 +2469,10 @@ class PredicatedScalarEvolution { /// predicate. The order of transformations applied on the expression of V /// returned by ScalarEvolution is guaranteed to be preserved, even when /// adding new predicates. - const SCEV *getSCEV(Value *V); + SCEVUse getSCEV(Value *V); /// Get the (predicated) backedge count for the analyzed loop. - const SCEV *getBackedgeTakenCount(); + SCEVUse getBackedgeTakenCount(); /// Get the (predicated) symbolic max backedge count for the analyzed loop. const SCEV *getSymbolicMaxBackedgeTakenCount(); @@ -2414,14 +2515,14 @@ class PredicatedScalarEvolution { /// Holds a SCEV and the version number of the SCEV predicate used to /// perform the rewrite of the expression. - using RewriteEntry = std::pair; + using RewriteEntry = std::pair; /// Maps a SCEV to the rewrite result of that SCEV at a certain version /// number. If this number doesn't match the current Generation, we will /// need to do a rewrite. To preserve the transformation order of previous /// rewrites, we will rewrite the previous result instead of the original /// SCEV. - DenseMap RewriteMap; + DenseMap RewriteMap; /// Records what NoWrap flags we've added to a Value *. ValueMap FlagsMap; @@ -2443,10 +2544,10 @@ class PredicatedScalarEvolution { unsigned Generation = 0; /// The backedge taken count. - const SCEV *BackedgeCount = nullptr; + SCEVUse BackedgeCount = nullptr; /// The symbolic backedge taken count. - const SCEV *SymbolicMaxBackedgeCount = nullptr; + SCEVUse SymbolicMaxBackedgeCount = nullptr; }; template <> struct DenseMapInfo { diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h b/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h index fd884f2a2f55b..b72d9fbe64fab 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h +++ b/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h @@ -90,11 +90,12 @@ class SCEVVScale : public SCEV { /// Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const SCEV *S) { return S->getSCEVType() == scVScale; } + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } }; -inline unsigned short computeExpressionSize(ArrayRef Args) { +inline unsigned short computeExpressionSize(ArrayRef Args) { APInt Size(16, 1); - for (const auto *Arg : Args) + for (const auto Arg : Args) Size = Size.uadd_sat(APInt(16, Arg->getExpressionSize())); return (unsigned short)Size.getZExtValue(); } @@ -102,19 +103,19 @@ inline unsigned short computeExpressionSize(ArrayRef Args) { /// This is the base class for unary cast operator classes. class SCEVCastExpr : public SCEV { protected: - const SCEV *Op; + SCEVUse Op; Type *Ty; - SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, + SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, SCEVUse op, Type *ty); public: - const SCEV *getOperand() const { return Op; } - const SCEV *getOperand(unsigned i) const { + SCEVUse getOperand() const { return Op; } + SCEVUse getOperand(unsigned i) const { assert(i == 0 && "Operand index out of range!"); return Op; } - ArrayRef operands() const { return Op; } + ArrayRef operands() const { return Op; } size_t getNumOperands() const { return 1; } Type *getType() const { return Ty; } @@ -123,6 +124,7 @@ class SCEVCastExpr : public SCEV { return S->getSCEVType() == scPtrToInt || S->getSCEVType() == scTruncate || S->getSCEVType() == scZeroExtend || S->getSCEVType() == scSignExtend; } + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } }; /// This class represents a cast from a pointer to a pointer-sized integer @@ -130,18 +132,19 @@ class SCEVCastExpr : public SCEV { class SCEVPtrToIntExpr : public SCEVCastExpr { friend class ScalarEvolution; - SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op, Type *ITy); + SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, SCEVUse Op, Type *ITy); public: /// Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const SCEV *S) { return S->getSCEVType() == scPtrToInt; } + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } }; /// This is the base class for unary integral cast operator classes. class SCEVIntegralCastExpr : public SCEVCastExpr { protected: SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, - const SCEV *op, Type *ty); + SCEVUse op, Type *ty); public: /// Methods for support type inquiry through isa, cast, and dyn_cast: @@ -149,6 +152,7 @@ class SCEVIntegralCastExpr : public SCEVCastExpr { return S->getSCEVType() == scTruncate || S->getSCEVType() == scZeroExtend || S->getSCEVType() == scSignExtend; } + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } }; /// This class represents a truncation of an integer value to a @@ -156,11 +160,12 @@ class SCEVIntegralCastExpr : public SCEVCastExpr { class SCEVTruncateExpr : public SCEVIntegralCastExpr { friend class ScalarEvolution; - SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty); + SCEVTruncateExpr(const FoldingSetNodeIDRef ID, SCEVUse op, Type *ty); public: /// Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const SCEV *S) { return S->getSCEVType() == scTruncate; } + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } }; /// This class represents a zero extension of a small integer value @@ -168,13 +173,14 @@ class SCEVTruncateExpr : public SCEVIntegralCastExpr { class SCEVZeroExtendExpr : public SCEVIntegralCastExpr { friend class ScalarEvolution; - SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty); + SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, SCEVUse op, Type *ty); public: /// Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const SCEV *S) { return S->getSCEVType() == scZeroExtend; } + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } }; /// This class represents a sign extension of a small integer value @@ -182,13 +188,14 @@ class SCEVZeroExtendExpr : public SCEVIntegralCastExpr { class SCEVSignExtendExpr : public SCEVIntegralCastExpr { friend class ScalarEvolution; - SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty); + SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, SCEVUse op, Type *ty); public: /// Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const SCEV *S) { return S->getSCEVType() == scSignExtend; } + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } }; /// This node is a base class providing common functionality for @@ -199,25 +206,23 @@ class SCEVNAryExpr : public SCEV { // arrays with its SCEVAllocator, so this class just needs a simple // pointer rather than a more elaborate vector-like data structure. // This also avoids the need for a non-trivial destructor. - const SCEV *const *Operands; + SCEVUse const *Operands; size_t NumOperands; - SCEVNAryExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T, - const SCEV *const *O, size_t N) + SCEVNAryExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T, SCEVUse const *O, + size_t N) : SCEV(ID, T, computeExpressionSize(ArrayRef(O, N))), Operands(O), NumOperands(N) {} public: size_t getNumOperands() const { return NumOperands; } - const SCEV *getOperand(unsigned i) const { + SCEVUse getOperand(unsigned i) const { assert(i < NumOperands && "Operand index out of range!"); return Operands[i]; } - ArrayRef operands() const { - return ArrayRef(Operands, NumOperands); - } + ArrayRef operands() const { return ArrayRef(Operands, NumOperands); } NoWrapFlags getNoWrapFlags(NoWrapFlags Mask = NoWrapMask) const { return (NoWrapFlags)(SubclassData & Mask); @@ -241,13 +246,14 @@ class SCEVNAryExpr : public SCEV { S->getSCEVType() == scSequentialUMinExpr || S->getSCEVType() == scAddRecExpr; } + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } }; /// This node is the base class for n'ary commutative operators. class SCEVCommutativeExpr : public SCEVNAryExpr { protected: SCEVCommutativeExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T, - const SCEV *const *O, size_t N) + SCEVUse const *O, size_t N) : SCEVNAryExpr(ID, T, O, N) {} public: @@ -257,6 +263,7 @@ class SCEVCommutativeExpr : public SCEVNAryExpr { S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr || S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr; } + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } /// Set flags for a non-recurrence without clearing previously set flags. void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; } @@ -268,11 +275,10 @@ class SCEVAddExpr : public SCEVCommutativeExpr { Type *Ty; - SCEVAddExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) + SCEVAddExpr(const FoldingSetNodeIDRef ID, SCEVUse const *O, size_t N) : SCEVCommutativeExpr(ID, scAddExpr, O, N) { - auto *FirstPointerTypedOp = find_if(operands(), [](const SCEV *Op) { - return Op->getType()->isPointerTy(); - }); + auto *FirstPointerTypedOp = find_if( + operands(), [](SCEVUse Op) { return Op->getType()->isPointerTy(); }); if (FirstPointerTypedOp != operands().end()) Ty = (*FirstPointerTypedOp)->getType(); else @@ -284,13 +290,14 @@ class SCEVAddExpr : public SCEVCommutativeExpr { /// Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const SCEV *S) { return S->getSCEVType() == scAddExpr; } + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } }; /// This node represents multiplication of some number of SCEVs. class SCEVMulExpr : public SCEVCommutativeExpr { friend class ScalarEvolution; - SCEVMulExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) + SCEVMulExpr(const FoldingSetNodeIDRef ID, SCEVUse const *O, size_t N) : SCEVCommutativeExpr(ID, scMulExpr, O, N) {} public: @@ -298,30 +305,31 @@ class SCEVMulExpr : public SCEVCommutativeExpr { /// Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const SCEV *S) { return S->getSCEVType() == scMulExpr; } + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } }; /// This class represents a binary unsigned division operation. class SCEVUDivExpr : public SCEV { friend class ScalarEvolution; - std::array Operands; + std::array Operands; - SCEVUDivExpr(const FoldingSetNodeIDRef ID, const SCEV *lhs, const SCEV *rhs) + SCEVUDivExpr(const FoldingSetNodeIDRef ID, SCEVUse lhs, SCEVUse rhs) : SCEV(ID, scUDivExpr, computeExpressionSize({lhs, rhs})) { Operands[0] = lhs; Operands[1] = rhs; } public: - const SCEV *getLHS() const { return Operands[0]; } - const SCEV *getRHS() const { return Operands[1]; } + SCEVUse getLHS() const { return Operands[0]; } + SCEVUse getRHS() const { return Operands[1]; } size_t getNumOperands() const { return 2; } - const SCEV *getOperand(unsigned i) const { + SCEVUse getOperand(unsigned i) const { assert((i == 0 || i == 1) && "Operand index out of range!"); return i == 0 ? getLHS() : getRHS(); } - ArrayRef operands() const { return Operands; } + ArrayRef operands() const { return Operands; } Type *getType() const { // In most cases the types of LHS and RHS will be the same, but in some @@ -334,6 +342,7 @@ class SCEVUDivExpr : public SCEV { /// Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const SCEV *S) { return S->getSCEVType() == scUDivExpr; } + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } }; /// This node represents a polynomial recurrence on the trip count @@ -349,25 +358,24 @@ class SCEVAddRecExpr : public SCEVNAryExpr { const Loop *L; - SCEVAddRecExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N, + SCEVAddRecExpr(const FoldingSetNodeIDRef ID, SCEVUse const *O, size_t N, const Loop *l) : SCEVNAryExpr(ID, scAddRecExpr, O, N), L(l) {} public: Type *getType() const { return getStart()->getType(); } - const SCEV *getStart() const { return Operands[0]; } + SCEVUse getStart() const { return Operands[0]; } const Loop *getLoop() const { return L; } /// Constructs and returns the recurrence indicating how much this /// expression steps by. If this is a polynomial of degree N, it /// returns a chrec of degree N-1. We cannot determine whether /// the step recurrence has self-wraparound. - const SCEV *getStepRecurrence(ScalarEvolution &SE) const { + SCEVUse getStepRecurrence(ScalarEvolution &SE) const { if (isAffine()) return getOperand(1); - return SE.getAddRecExpr( - SmallVector(operands().drop_front()), getLoop(), - FlagAnyWrap); + return SE.getAddRecExpr(SmallVector(operands().drop_front()), + getLoop(), FlagAnyWrap); } /// Return true if this represents an expression A + B*x where A @@ -394,12 +402,12 @@ class SCEVAddRecExpr : public SCEVNAryExpr { /// Return the value of this chain of recurrences at the specified /// iteration number. - const SCEV *evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const; + SCEVUse evaluateAtIteration(SCEVUse It, ScalarEvolution &SE) const; /// Return the value of this chain of recurrences at the specified iteration /// number. Takes an explicit list of operands to represent an AddRec. - static const SCEV *evaluateAtIteration(ArrayRef Operands, - const SCEV *It, ScalarEvolution &SE); + static SCEVUse evaluateAtIteration(ArrayRef Operands, SCEVUse It, + ScalarEvolution &SE); /// Return the number of iterations of this loop that produce /// values in the specified constant range. Another way of @@ -407,8 +415,8 @@ class SCEVAddRecExpr : public SCEVNAryExpr { /// where the value is not in the condition, thus computing the /// exit count. If the iteration count can't be computed, an /// instance of SCEVCouldNotCompute is returned. - const SCEV *getNumIterationsInRange(const ConstantRange &Range, - ScalarEvolution &SE) const; + SCEVUse getNumIterationsInRange(const ConstantRange &Range, + ScalarEvolution &SE) const; /// Return an expression representing the value of this expression /// one iteration of the loop ahead. @@ -418,6 +426,7 @@ class SCEVAddRecExpr : public SCEVNAryExpr { static bool classof(const SCEV *S) { return S->getSCEVType() == scAddRecExpr; } + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } }; /// This node is the base class min/max selections. @@ -432,7 +441,7 @@ class SCEVMinMaxExpr : public SCEVCommutativeExpr { protected: /// Note: Constructing subclasses via this constructor is allowed SCEVMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T, - const SCEV *const *O, size_t N) + SCEVUse const *O, size_t N) : SCEVCommutativeExpr(ID, T, O, N) { assert(isMinMaxType(T)); // Min and max never overflow @@ -443,6 +452,7 @@ class SCEVMinMaxExpr : public SCEVCommutativeExpr { Type *getType() const { return getOperand(0)->getType(); } static bool classof(const SCEV *S) { return isMinMaxType(S->getSCEVType()); } + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } static enum SCEVTypes negate(enum SCEVTypes T) { switch (T) { @@ -464,48 +474,52 @@ class SCEVMinMaxExpr : public SCEVCommutativeExpr { class SCEVSMaxExpr : public SCEVMinMaxExpr { friend class ScalarEvolution; - SCEVSMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) + SCEVSMaxExpr(const FoldingSetNodeIDRef ID, SCEVUse const *O, size_t N) : SCEVMinMaxExpr(ID, scSMaxExpr, O, N) {} public: /// Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const SCEV *S) { return S->getSCEVType() == scSMaxExpr; } + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } }; /// This class represents an unsigned maximum selection. class SCEVUMaxExpr : public SCEVMinMaxExpr { friend class ScalarEvolution; - SCEVUMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) + SCEVUMaxExpr(const FoldingSetNodeIDRef ID, SCEVUse const *O, size_t N) : SCEVMinMaxExpr(ID, scUMaxExpr, O, N) {} public: /// Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const SCEV *S) { return S->getSCEVType() == scUMaxExpr; } + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } }; /// This class represents a signed minimum selection. class SCEVSMinExpr : public SCEVMinMaxExpr { friend class ScalarEvolution; - SCEVSMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) + SCEVSMinExpr(const FoldingSetNodeIDRef ID, SCEVUse const *O, size_t N) : SCEVMinMaxExpr(ID, scSMinExpr, O, N) {} public: /// Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const SCEV *S) { return S->getSCEVType() == scSMinExpr; } + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } }; /// This class represents an unsigned minimum selection. class SCEVUMinExpr : public SCEVMinMaxExpr { friend class ScalarEvolution; - SCEVUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) + SCEVUMinExpr(const FoldingSetNodeIDRef ID, SCEVUse const *O, size_t N) : SCEVMinMaxExpr(ID, scUMinExpr, O, N) {} public: /// Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const SCEV *S) { return S->getSCEVType() == scUMinExpr; } + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } }; /// This node is the base class for sequential/in-order min/max selections. @@ -526,7 +540,7 @@ class SCEVSequentialMinMaxExpr : public SCEVNAryExpr { protected: /// Note: Constructing subclasses via this constructor is allowed SCEVSequentialMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T, - const SCEV *const *O, size_t N) + SCEVUse const *O, size_t N) : SCEVNAryExpr(ID, T, O, N) { assert(isSequentialMinMaxType(T)); // Min and max never overflow @@ -553,13 +567,14 @@ class SCEVSequentialMinMaxExpr : public SCEVNAryExpr { static bool classof(const SCEV *S) { return isSequentialMinMaxType(S->getSCEVType()); } + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } }; /// This class represents a sequential/in-order unsigned minimum selection. class SCEVSequentialUMinExpr : public SCEVSequentialMinMaxExpr { friend class ScalarEvolution; - SCEVSequentialUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, + SCEVSequentialUMinExpr(const FoldingSetNodeIDRef ID, SCEVUse const *O, size_t N) : SCEVSequentialMinMaxExpr(ID, scSequentialUMinExpr, O, N) {} @@ -568,6 +583,7 @@ class SCEVSequentialUMinExpr : public SCEVSequentialMinMaxExpr { static bool classof(const SCEV *S) { return S->getSCEVType() == scSequentialUMinExpr; } + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } }; /// This means that we are dealing with an entirely unknown SCEV @@ -600,48 +616,56 @@ class SCEVUnknown final : public SCEV, private CallbackVH { /// Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const SCEV *S) { return S->getSCEVType() == scUnknown; } + static bool classof(const SCEVUse *U) { return classof(U->getPointer()); } }; /// This class defines a simple visitor class that may be used for /// various SCEV analysis purposes. template struct SCEVVisitor { - RetVal visit(const SCEV *S) { + RetVal visit(SCEVUse S) { switch (S->getSCEVType()) { case scConstant: - return ((SC *)this)->visitConstant((const SCEVConstant *)S); + return ((SC *)this)->visitConstant((const SCEVConstant *)S.getPointer()); case scVScale: - return ((SC *)this)->visitVScale((const SCEVVScale *)S); + return ((SC *)this)->visitVScale((const SCEVVScale *)S.getPointer()); case scPtrToInt: - return ((SC *)this)->visitPtrToIntExpr((const SCEVPtrToIntExpr *)S); + return ((SC *)this) + ->visitPtrToIntExpr((const SCEVPtrToIntExpr *)S.getPointer()); case scTruncate: - return ((SC *)this)->visitTruncateExpr((const SCEVTruncateExpr *)S); + return ((SC *)this) + ->visitTruncateExpr((const SCEVTruncateExpr *)S.getPointer()); case scZeroExtend: - return ((SC *)this)->visitZeroExtendExpr((const SCEVZeroExtendExpr *)S); + return ((SC *)this) + ->visitZeroExtendExpr((const SCEVZeroExtendExpr *)S.getPointer()); case scSignExtend: - return ((SC *)this)->visitSignExtendExpr((const SCEVSignExtendExpr *)S); + return ((SC *)this) + ->visitSignExtendExpr((const SCEVSignExtendExpr *)S.getPointer()); case scAddExpr: - return ((SC *)this)->visitAddExpr((const SCEVAddExpr *)S); + return ((SC *)this)->visitAddExpr((const SCEVAddExpr *)S.getPointer()); case scMulExpr: - return ((SC *)this)->visitMulExpr((const SCEVMulExpr *)S); + return ((SC *)this)->visitMulExpr((const SCEVMulExpr *)S.getPointer()); case scUDivExpr: - return ((SC *)this)->visitUDivExpr((const SCEVUDivExpr *)S); + return ((SC *)this)->visitUDivExpr((const SCEVUDivExpr *)S.getPointer()); case scAddRecExpr: - return ((SC *)this)->visitAddRecExpr((const SCEVAddRecExpr *)S); + return ((SC *)this) + ->visitAddRecExpr((const SCEVAddRecExpr *)S.getPointer()); case scSMaxExpr: - return ((SC *)this)->visitSMaxExpr((const SCEVSMaxExpr *)S); + return ((SC *)this)->visitSMaxExpr((const SCEVSMaxExpr *)S.getPointer()); case scUMaxExpr: - return ((SC *)this)->visitUMaxExpr((const SCEVUMaxExpr *)S); + return ((SC *)this)->visitUMaxExpr((const SCEVUMaxExpr *)S.getPointer()); case scSMinExpr: - return ((SC *)this)->visitSMinExpr((const SCEVSMinExpr *)S); + return ((SC *)this)->visitSMinExpr((const SCEVSMinExpr *)S.getPointer()); case scUMinExpr: - return ((SC *)this)->visitUMinExpr((const SCEVUMinExpr *)S); + return ((SC *)this)->visitUMinExpr((const SCEVUMinExpr *)S.getPointer()); case scSequentialUMinExpr: return ((SC *)this) - ->visitSequentialUMinExpr((const SCEVSequentialUMinExpr *)S); + ->visitSequentialUMinExpr( + (const SCEVSequentialUMinExpr *)S.getPointer()); case scUnknown: - return ((SC *)this)->visitUnknown((const SCEVUnknown *)S); + return ((SC *)this)->visitUnknown((const SCEVUnknown *)S.getPointer()); case scCouldNotCompute: - return ((SC *)this)->visitCouldNotCompute((const SCEVCouldNotCompute *)S); + return ((SC *)this) + ->visitCouldNotCompute((const SCEVCouldNotCompute *)S.getPointer()); } llvm_unreachable("Unknown SCEV kind!"); } @@ -655,15 +679,15 @@ template struct SCEVVisitor { /// /// Visitor implements: /// // return true to follow this node. -/// bool follow(const SCEV *S); +/// bool follow(SCEVUse S); /// // return true to terminate the search. /// bool isDone(); template class SCEVTraversal { SV &Visitor; - SmallVector Worklist; - SmallPtrSet Visited; + SmallVector Worklist; + SmallPtrSet Visited; - void push(const SCEV *S) { + void push(SCEVUse S) { if (Visited.insert(S).second && Visitor.follow(S)) Worklist.push_back(S); } @@ -671,10 +695,10 @@ template class SCEVTraversal { public: SCEVTraversal(SV &V) : Visitor(V) {} - void visitAll(const SCEV *Root) { + void visitAll(SCEVUse Root) { push(Root); while (!Worklist.empty() && !Visitor.isDone()) { - const SCEV *S = Worklist.pop_back_val(); + SCEVUse S = Worklist.pop_back_val(); switch (S->getSCEVType()) { case scConstant: @@ -694,7 +718,7 @@ template class SCEVTraversal { case scUMinExpr: case scSequentialUMinExpr: case scAddRecExpr: - for (const auto *Op : S->operands()) { + for (const auto Op : S->operands()) { push(Op); if (Visitor.isDone()) break; @@ -709,21 +733,20 @@ template class SCEVTraversal { }; /// Use SCEVTraversal to visit all nodes in the given expression tree. -template void visitAll(const SCEV *Root, SV &Visitor) { +template void visitAll(SCEVUse Root, SV &Visitor) { SCEVTraversal T(Visitor); T.visitAll(Root); } /// Return true if any node in \p Root satisfies the predicate \p Pred. -template -bool SCEVExprContains(const SCEV *Root, PredTy Pred) { +template bool SCEVExprContains(SCEVUse Root, PredTy Pred) { struct FindClosure { bool Found = false; PredTy Pred; FindClosure(PredTy Pred) : Pred(Pred) {} - bool follow(const SCEV *S) { + bool follow(SCEVUse S) { if (!Pred(S)) return true; @@ -743,7 +766,7 @@ bool SCEVExprContains(const SCEV *Root, PredTy Pred) { /// The result from each visit is cached, so it will return the same /// SCEV for the same input. template -class SCEVRewriteVisitor : public SCEVVisitor { +class SCEVRewriteVisitor : public SCEVVisitor { protected: ScalarEvolution &SE; // Memoize the result of each visit so that we only compute once for @@ -751,84 +774,84 @@ class SCEVRewriteVisitor : public SCEVVisitor { // a SCEV is referenced by multiple SCEVs. Without memoization, this // visit algorithm would have exponential time complexity in the worst // case, causing the compiler to hang on certain tests. - SmallDenseMap RewriteResults; + SmallDenseMap RewriteResults; public: SCEVRewriteVisitor(ScalarEvolution &SE) : SE(SE) {} - const SCEV *visit(const SCEV *S) { + SCEVUse visit(SCEVUse S) { auto It = RewriteResults.find(S); if (It != RewriteResults.end()) return It->second; - auto *Visited = SCEVVisitor::visit(S); + auto Visited = SCEVVisitor::visit(S); auto Result = RewriteResults.try_emplace(S, Visited); assert(Result.second && "Should insert a new entry"); return Result.first->second; } - const SCEV *visitConstant(const SCEVConstant *Constant) { return Constant; } + SCEVUse visitConstant(const SCEVConstant *Constant) { return Constant; } - const SCEV *visitVScale(const SCEVVScale *VScale) { return VScale; } + SCEVUse visitVScale(const SCEVVScale *VScale) { return VScale; } - const SCEV *visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { - const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand()); + SCEVUse visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { + SCEVUse Operand = ((SC *)this)->visit(Expr->getOperand()); return Operand == Expr->getOperand() ? Expr : SE.getPtrToIntExpr(Operand, Expr->getType()); } - const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) { - const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand()); + SCEVUse visitTruncateExpr(const SCEVTruncateExpr *Expr) { + SCEVUse Operand = ((SC *)this)->visit(Expr->getOperand()); return Operand == Expr->getOperand() ? Expr : SE.getTruncateExpr(Operand, Expr->getType()); } - const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { - const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand()); + SCEVUse visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { + SCEVUse Operand = ((SC *)this)->visit(Expr->getOperand()); return Operand == Expr->getOperand() ? Expr : SE.getZeroExtendExpr(Operand, Expr->getType()); } - const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { - const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand()); + SCEVUse visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { + SCEVUse Operand = ((SC *)this)->visit(Expr->getOperand()); return Operand == Expr->getOperand() ? Expr : SE.getSignExtendExpr(Operand, Expr->getType()); } - const SCEV *visitAddExpr(const SCEVAddExpr *Expr) { - SmallVector Operands; + SCEVUse visitAddExpr(const SCEVAddExpr *Expr) { + SmallVector Operands; bool Changed = false; - for (const auto *Op : Expr->operands()) { + for (const auto Op : Expr->operands()) { Operands.push_back(((SC *)this)->visit(Op)); Changed |= Op != Operands.back(); } return !Changed ? Expr : SE.getAddExpr(Operands); } - const SCEV *visitMulExpr(const SCEVMulExpr *Expr) { - SmallVector Operands; + SCEVUse visitMulExpr(const SCEVMulExpr *Expr) { + SmallVector Operands; bool Changed = false; - for (const auto *Op : Expr->operands()) { + for (const auto Op : Expr->operands()) { Operands.push_back(((SC *)this)->visit(Op)); Changed |= Op != Operands.back(); } return !Changed ? Expr : SE.getMulExpr(Operands); } - const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) { - auto *LHS = ((SC *)this)->visit(Expr->getLHS()); - auto *RHS = ((SC *)this)->visit(Expr->getRHS()); + SCEVUse visitUDivExpr(const SCEVUDivExpr *Expr) { + auto LHS = ((SC *)this)->visit(Expr->getLHS()); + auto RHS = ((SC *)this)->visit(Expr->getRHS()); bool Changed = LHS != Expr->getLHS() || RHS != Expr->getRHS(); return !Changed ? Expr : SE.getUDivExpr(LHS, RHS); } - const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { - SmallVector Operands; + SCEVUse visitAddRecExpr(const SCEVAddRecExpr *Expr) { + SmallVector Operands; bool Changed = false; - for (const auto *Op : Expr->operands()) { + for (const auto Op : Expr->operands()) { Operands.push_back(((SC *)this)->visit(Op)); Changed |= Op != Operands.back(); } @@ -837,72 +860,70 @@ class SCEVRewriteVisitor : public SCEVVisitor { Expr->getNoWrapFlags()); } - const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) { - SmallVector Operands; + SCEVUse visitSMaxExpr(const SCEVSMaxExpr *Expr) { + SmallVector Operands; bool Changed = false; - for (const auto *Op : Expr->operands()) { + for (const auto Op : Expr->operands()) { Operands.push_back(((SC *)this)->visit(Op)); Changed |= Op != Operands.back(); } return !Changed ? Expr : SE.getSMaxExpr(Operands); } - const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) { - SmallVector Operands; + SCEVUse visitUMaxExpr(const SCEVUMaxExpr *Expr) { + SmallVector Operands; bool Changed = false; - for (const auto *Op : Expr->operands()) { + for (const auto Op : Expr->operands()) { Operands.push_back(((SC *)this)->visit(Op)); Changed |= Op != Operands.back(); } return !Changed ? Expr : SE.getUMaxExpr(Operands); } - const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) { - SmallVector Operands; + SCEVUse visitSMinExpr(const SCEVSMinExpr *Expr) { + SmallVector Operands; bool Changed = false; - for (const auto *Op : Expr->operands()) { + for (const auto Op : Expr->operands()) { Operands.push_back(((SC *)this)->visit(Op)); Changed |= Op != Operands.back(); } return !Changed ? Expr : SE.getSMinExpr(Operands); } - const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) { - SmallVector Operands; + SCEVUse visitUMinExpr(const SCEVUMinExpr *Expr) { + SmallVector Operands; bool Changed = false; - for (const auto *Op : Expr->operands()) { + for (const auto Op : Expr->operands()) { Operands.push_back(((SC *)this)->visit(Op)); Changed |= Op != Operands.back(); } return !Changed ? Expr : SE.getUMinExpr(Operands); } - const SCEV *visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) { - SmallVector Operands; + SCEVUse visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) { + SmallVector Operands; bool Changed = false; - for (const auto *Op : Expr->operands()) { + for (const auto Op : Expr->operands()) { Operands.push_back(((SC *)this)->visit(Op)); Changed |= Op != Operands.back(); } return !Changed ? Expr : SE.getUMinExpr(Operands, /*Sequential=*/true); } - const SCEV *visitUnknown(const SCEVUnknown *Expr) { return Expr; } + SCEVUse visitUnknown(const SCEVUnknown *Expr) { return Expr; } - const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { - return Expr; - } + SCEVUse visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; } }; using ValueToValueMap = DenseMap; -using ValueToSCEVMapTy = DenseMap; +using ValueToSCEVMapTy = DenseMap; /// The SCEVParameterRewriter takes a scalar evolution expression and updates /// the SCEVUnknown components following the Map (Value -> SCEV). class SCEVParameterRewriter : public SCEVRewriteVisitor { public: - static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE, - ValueToSCEVMapTy &Map) { + static SCEVUse rewrite(SCEVUse Scev, ScalarEvolution &SE, + ValueToSCEVMapTy &Map) { SCEVParameterRewriter Rewriter(SE, Map); return Rewriter.visit(Scev); } @@ -910,7 +931,7 @@ class SCEVParameterRewriter : public SCEVRewriteVisitor { SCEVParameterRewriter(ScalarEvolution &SE, ValueToSCEVMapTy &M) : SCEVRewriteVisitor(SE), Map(M) {} - const SCEV *visitUnknown(const SCEVUnknown *Expr) { + SCEVUse visitUnknown(const SCEVUnknown *Expr) { auto I = Map.find(Expr->getValue()); if (I == Map.end()) return Expr; @@ -921,7 +942,7 @@ class SCEVParameterRewriter : public SCEVRewriteVisitor { ValueToSCEVMapTy ⤅ }; -using LoopToScevMapT = DenseMap; +using LoopToScevMapT = DenseMap; /// The SCEVLoopAddRecRewriter takes a scalar evolution expression and applies /// the Map (Loop -> SCEV) to all AddRecExprs. @@ -931,15 +952,15 @@ class SCEVLoopAddRecRewriter SCEVLoopAddRecRewriter(ScalarEvolution &SE, LoopToScevMapT &M) : SCEVRewriteVisitor(SE), Map(M) {} - static const SCEV *rewrite(const SCEV *Scev, LoopToScevMapT &Map, - ScalarEvolution &SE) { + static SCEVUse rewrite(SCEVUse Scev, LoopToScevMapT &Map, + ScalarEvolution &SE) { SCEVLoopAddRecRewriter Rewriter(SE, Map); return Rewriter.visit(Scev); } - const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { - SmallVector Operands; - for (const SCEV *Op : Expr->operands()) + SCEVUse visitAddRecExpr(const SCEVAddRecExpr *Expr) { + SmallVector Operands; + for (SCEVUse Op : Expr->operands()) Operands.push_back(visit(Op)); const Loop *L = Expr->getLoop(); diff --git a/llvm/lib/Analysis/DependenceAnalysis.cpp b/llvm/lib/Analysis/DependenceAnalysis.cpp index a4a98ea0bae14..ae2573f341d86 100644 --- a/llvm/lib/Analysis/DependenceAnalysis.cpp +++ b/llvm/lib/Analysis/DependenceAnalysis.cpp @@ -1250,10 +1250,12 @@ bool DependenceInfo::strongSIVtest(const SCEV *Coeff, const SCEV *SrcConst, if (const SCEV *UpperBound = collectUpperBound(CurLoop, Delta->getType())) { LLVM_DEBUG(dbgs() << "\t UpperBound = " << *UpperBound); LLVM_DEBUG(dbgs() << ", " << *UpperBound->getType() << "\n"); - const SCEV *AbsDelta = - SE->isKnownNonNegative(Delta) ? Delta : SE->getNegativeSCEV(Delta); - const SCEV *AbsCoeff = - SE->isKnownNonNegative(Coeff) ? Coeff : SE->getNegativeSCEV(Coeff); + const SCEV *AbsDelta = SE->isKnownNonNegative(Delta) + ? Delta + : SE->getNegativeSCEV(Delta).getPointer(); + const SCEV *AbsCoeff = SE->isKnownNonNegative(Coeff) + ? Coeff + : SE->getNegativeSCEV(Coeff).getPointer(); const SCEV *Product = SE->getMulExpr(UpperBound, AbsCoeff); if (isKnownPredicate(CmpInst::ICMP_SGT, AbsDelta, Product)) { // Distance greater than trip count - no dependence @@ -1791,8 +1793,9 @@ bool DependenceInfo::weakZeroSrcSIVtest(const SCEV *DstCoeff, const SCEV *AbsCoeff = SE->isKnownNegative(ConstCoeff) ? SE->getNegativeSCEV(ConstCoeff) : ConstCoeff; - const SCEV *NewDelta = - SE->isKnownNegative(ConstCoeff) ? SE->getNegativeSCEV(Delta) : Delta; + const SCEV *NewDelta = SE->isKnownNegative(ConstCoeff) + ? SE->getNegativeSCEV(Delta).getPointer() + : Delta; // check that Delta/SrcCoeff < iteration count // really check NewDelta < count*AbsCoeff @@ -1900,8 +1903,9 @@ bool DependenceInfo::weakZeroDstSIVtest(const SCEV *SrcCoeff, const SCEV *AbsCoeff = SE->isKnownNegative(ConstCoeff) ? SE->getNegativeSCEV(ConstCoeff) : ConstCoeff; - const SCEV *NewDelta = - SE->isKnownNegative(ConstCoeff) ? SE->getNegativeSCEV(Delta) : Delta; + const SCEV *NewDelta = SE->isKnownNegative(ConstCoeff) + ? SE->getNegativeSCEV(Delta).getPointer() + : Delta; // check that Delta/SrcCoeff < iteration count // really check NewDelta < count*AbsCoeff diff --git a/llvm/lib/Analysis/IVDescriptors.cpp b/llvm/lib/Analysis/IVDescriptors.cpp index 53001421ce6f7..7e708699a81f8 100644 --- a/llvm/lib/Analysis/IVDescriptors.cpp +++ b/llvm/lib/Analysis/IVDescriptors.cpp @@ -1422,7 +1422,7 @@ bool InductionDescriptor::isInductionPHI( return false; // Check that the PHI is consecutive. - const SCEV *PhiScev = Expr ? Expr : SE->getSCEV(Phi); + const SCEV *PhiScev = Expr ? Expr : SE->getSCEV(Phi).getPointer(); const SCEVAddRecExpr *AR = dyn_cast(PhiScev); if (!AR) { diff --git a/llvm/lib/Analysis/LoopCacheAnalysis.cpp b/llvm/lib/Analysis/LoopCacheAnalysis.cpp index 7ca9f15ad5fca..3a5ebd316322a 100644 --- a/llvm/lib/Analysis/LoopCacheAnalysis.cpp +++ b/llvm/lib/Analysis/LoopCacheAnalysis.cpp @@ -494,7 +494,8 @@ bool IndexedReference::isConsecutive(const Loop &L, const SCEV *&Stride, SE.getNoopOrSignExtend(ElemSize, WiderType)); const SCEV *CacheLineSize = SE.getConstant(Stride->getType(), CLS); - Stride = SE.isKnownNegative(Stride) ? SE.getNegativeSCEV(Stride) : Stride; + Stride = SE.isKnownNegative(Stride) ? SE.getNegativeSCEV(Stride).getPointer() + : Stride; return SE.isKnownPredicate(ICmpInst::ICMP_ULT, Stride, CacheLineSize); } diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 3d890f05c8ca2..cce648390faa5 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -253,6 +253,75 @@ static cl::opt UseContextForNoWrapFlagInference( // SCEV class definitions //===----------------------------------------------------------------------===// +class SCEVDropFlags : public SCEVRewriteVisitor { + using Base = SCEVRewriteVisitor; + +public: + SCEVDropFlags(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {} + + static SCEVUse rewrite(SCEVUse Scev, ScalarEvolution &SE) { + SCEVDropFlags Rewriter(SE); + return Rewriter.visit(Scev); + } + + SCEVUse visitAddExpr(const SCEVAddExpr *Expr) { + SmallVector Operands; + bool Changed = false; + for (const auto Op : Expr->operands()) { + Operands.push_back(visit(Op)); + Changed |= Op != Operands.back(); + } + return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags()); + } + + SCEVUse visitMulExpr(const SCEVMulExpr *Expr) { + SmallVector Operands; + bool Changed = false; + for (const auto Op : Expr->operands()) { + Operands.push_back(visit(Op)); + Changed |= Op != Operands.back(); + } + return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags()); + } +}; + +const SCEV *SCEVUse::computeCanonical(ScalarEvolution &SE) const { + return SCEVDropFlags::rewrite(*this, SE); +} + +bool SCEVUse::computeIsCanonical() const { + if (!getRawPointer() || + DenseMapInfo::getEmptyKey().getRawPointer() == getRawPointer() || + DenseMapInfo::getTombstoneKey().getRawPointer() == + getRawPointer() || + isa(this)) + return true; + return !SCEVExprContains(*this, [](SCEVUse U) { return U.getFlags() != 0; }); +} + +bool SCEVUse::operator==(const SCEVUse &RHS) const { + assert(isCanonical() && RHS.isCanonical()); + return getPointer() == RHS.getPointer(); +} + +bool SCEVUse::operator==(const SCEV *RHS) const { return getPointer() == RHS; } + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void SCEVUse::dump() const { + print(dbgs()); + dbgs() << '\n'; +} +#endif + +void SCEVUse::print(raw_ostream &OS) const { + getPointer()->print(OS); + SCEV::NoWrapFlags Flags = static_cast(getInt()); + if (Flags & SCEV::FlagNUW) + OS << "(u nuw)"; + if (Flags & SCEV::FlagNSW) + OS << "(u nsw)"; +} + //===----------------------------------------------------------------------===// // Implementation of the SCEV class. // @@ -274,28 +343,28 @@ void SCEV::print(raw_ostream &OS) const { return; case scPtrToInt: { const SCEVPtrToIntExpr *PtrToInt = cast(this); - const SCEV *Op = PtrToInt->getOperand(); - OS << "(ptrtoint " << *Op->getType() << " " << *Op << " to " + SCEVUse Op = PtrToInt->getOperand(); + OS << "(ptrtoint " << *Op->getType() << " " << Op << " to " << *PtrToInt->getType() << ")"; return; } case scTruncate: { const SCEVTruncateExpr *Trunc = cast(this); - const SCEV *Op = Trunc->getOperand(); - OS << "(trunc " << *Op->getType() << " " << *Op << " to " + SCEVUse Op = Trunc->getOperand(); + OS << "(trunc " << *Op->getType() << " " << Op << " to " << *Trunc->getType() << ")"; return; } case scZeroExtend: { const SCEVZeroExtendExpr *ZExt = cast(this); - const SCEV *Op = ZExt->getOperand(); - OS << "(zext " << *Op->getType() << " " << *Op << " to " - << *ZExt->getType() << ")"; + SCEVUse Op = ZExt->getOperand(); + OS << "(zext " << *Op->getType() << " " << Op << " to " << *ZExt->getType() + << ")"; return; } case scSignExtend: { const SCEVSignExtendExpr *SExt = cast(this); - const SCEV *Op = SExt->getOperand(); + SCEVUse Op = SExt->getOperand(); OS << "(sext " << *Op->getType() << " " << *Op << " to " << *SExt->getType() << ")"; return; @@ -345,8 +414,8 @@ void SCEV::print(raw_ostream &OS) const { } OS << "("; ListSeparator LS(OpStr); - for (const SCEV *Op : NAry->operands()) - OS << LS << *Op; + for (SCEVUse Op : NAry->operands()) + OS << LS << Op; OS << ")"; switch (NAry->getSCEVType()) { case scAddExpr: @@ -411,7 +480,7 @@ Type *SCEV::getType() const { llvm_unreachable("Unknown SCEV kind!"); } -ArrayRef SCEV::operands() const { +ArrayRef SCEV::operands() const { switch (getSCEVType()) { case scConstant: case scVScale: @@ -476,51 +545,51 @@ bool SCEVCouldNotCompute::classof(const SCEV *S) { return S->getSCEVType() == scCouldNotCompute; } -const SCEV *ScalarEvolution::getConstant(ConstantInt *V) { +SCEVUse ScalarEvolution::getConstant(ConstantInt *V) { FoldingSetNodeID ID; ID.AddInteger(scConstant); ID.AddPointer(V); void *IP = nullptr; - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + if (SCEVUse S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) + return S; SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V); UniqueSCEVs.InsertNode(S, IP); return S; } -const SCEV *ScalarEvolution::getConstant(const APInt &Val) { +SCEVUse ScalarEvolution::getConstant(const APInt &Val) { return getConstant(ConstantInt::get(getContext(), Val)); } -const SCEV * -ScalarEvolution::getConstant(Type *Ty, uint64_t V, bool isSigned) { +SCEVUse ScalarEvolution::getConstant(Type *Ty, uint64_t V, bool isSigned) { IntegerType *ITy = cast(getEffectiveSCEVType(Ty)); return getConstant(ConstantInt::get(ITy, V, isSigned)); } -const SCEV *ScalarEvolution::getVScale(Type *Ty) { +SCEVUse ScalarEvolution::getVScale(Type *Ty) { FoldingSetNodeID ID; ID.AddInteger(scVScale); ID.AddPointer(Ty); void *IP = nullptr; - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) + if (SCEVUse S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty); UniqueSCEVs.InsertNode(S, IP); return S; } -const SCEV *ScalarEvolution::getElementCount(Type *Ty, ElementCount EC) { - const SCEV *Res = getConstant(Ty, EC.getKnownMinValue()); +SCEVUse ScalarEvolution::getElementCount(Type *Ty, ElementCount EC) { + SCEVUse Res = getConstant(Ty, EC.getKnownMinValue()); if (EC.isScalable()) Res = getMulExpr(Res, getVScale(Ty)); return Res; } SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, - const SCEV *op, Type *ty) + SCEVUse op, Type *ty) : SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {} -SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op, +SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, SCEVUse Op, Type *ITy) : SCEVCastExpr(ID, scPtrToInt, Op, ITy) { assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() && @@ -528,26 +597,26 @@ SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op, } SCEVIntegralCastExpr::SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, - SCEVTypes SCEVTy, const SCEV *op, + SCEVTypes SCEVTy, SCEVUse op, Type *ty) : SCEVCastExpr(ID, SCEVTy, op, ty) {} -SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op, +SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, SCEVUse op, Type *ty) : SCEVIntegralCastExpr(ID, scTruncate, op, ty) { assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() && "Cannot truncate non-integer value!"); } -SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, - const SCEV *op, Type *ty) +SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, SCEVUse op, + Type *ty) : SCEVIntegralCastExpr(ID, scZeroExtend, op, ty) { assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() && "Cannot zero extend non-integer value!"); } -SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, - const SCEV *op, Type *ty) +SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, SCEVUse op, + Type *ty) : SCEVIntegralCastExpr(ID, scSignExtend, op, ty) { assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() && "Cannot sign extend non-integer value!"); @@ -555,7 +624,7 @@ SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, void SCEVUnknown::deleted() { // Clear this SCEVUnknown from various maps. - SE->forgetMemoizedResults(this); + SE->forgetMemoizedResults(SCEVUse(this)); // Remove this SCEVUnknown from the uniquing map. SE->UniqueSCEVs.RemoveNode(this); @@ -566,7 +635,7 @@ void SCEVUnknown::deleted() { void SCEVUnknown::allUsesReplacedWith(Value *New) { // Clear this SCEVUnknown from various maps. - SE->forgetMemoizedResults(this); + SE->forgetMemoizedResults(SCEVUse(this)); // Remove this SCEVUnknown from the uniquing map. SE->UniqueSCEVs.RemoveNode(this); @@ -659,11 +728,12 @@ static int CompareValueComplexity(const LoopInfo *const LI, Value *LV, // If the max analysis depth was reached, return std::nullopt, assuming we do // not know if they are equivalent for sure. static std::optional -CompareSCEVComplexity(EquivalenceClasses &EqCacheSCEV, - const LoopInfo *const LI, const SCEV *LHS, - const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) { +CompareSCEVComplexity(EquivalenceClasses &EqCacheSCEV, + const LoopInfo *const LI, SCEVUse LHS, SCEVUse RHS, + DominatorTree &DT, ScalarEvolution &SE, + unsigned Depth = 0) { // Fast-path: SCEVs are uniqued so we can do a quick equality check. - if (LHS == RHS) + if (LHS.getCanonical(SE) == RHS.getCanonical(SE)) return 0; // Primarily, sort the SCEVs by their getSCEVType(). @@ -744,8 +814,8 @@ CompareSCEVComplexity(EquivalenceClasses &EqCacheSCEV, case scSMinExpr: case scUMinExpr: case scSequentialUMinExpr: { - ArrayRef LOps = LHS->operands(); - ArrayRef ROps = RHS->operands(); + ArrayRef LOps = LHS->operands(); + ArrayRef ROps = RHS->operands(); // Lexicographically compare n-ary-like expressions. unsigned LNumOps = LOps.size(), RNumOps = ROps.size(); @@ -753,7 +823,7 @@ CompareSCEVComplexity(EquivalenceClasses &EqCacheSCEV, return (int)LNumOps - (int)RNumOps; for (unsigned i = 0; i != LNumOps; ++i) { - auto X = CompareSCEVComplexity(EqCacheSCEV, LI, LOps[i], ROps[i], DT, + auto X = CompareSCEVComplexity(EqCacheSCEV, LI, LOps[i], ROps[i], DT, SE, Depth + 1); if (X != 0) return X; @@ -777,37 +847,36 @@ CompareSCEVComplexity(EquivalenceClasses &EqCacheSCEV, /// results from this routine. In other words, we don't want the results of /// this to depend on where the addresses of various SCEV objects happened to /// land in memory. -static void GroupByComplexity(SmallVectorImpl &Ops, - LoopInfo *LI, DominatorTree &DT) { +static void GroupByComplexity(SmallVectorImpl &Ops, LoopInfo *LI, + DominatorTree &DT, ScalarEvolution &SE) { if (Ops.size() < 2) return; // Noop - EquivalenceClasses EqCacheSCEV; + EquivalenceClasses EqCacheSCEV; // Whether LHS has provably less complexity than RHS. - auto IsLessComplex = [&](const SCEV *LHS, const SCEV *RHS) { - auto Complexity = CompareSCEVComplexity(EqCacheSCEV, LI, LHS, RHS, DT); + auto IsLessComplex = [&](SCEVUse LHS, SCEVUse RHS) { + auto Complexity = CompareSCEVComplexity(EqCacheSCEV, LI, LHS, RHS, DT, SE); return Complexity && *Complexity < 0; }; if (Ops.size() == 2) { // This is the common case, which also happens to be trivially simple. // Special case it. - const SCEV *&LHS = Ops[0], *&RHS = Ops[1]; + SCEVUse &LHS = Ops[0], &RHS = Ops[1]; if (IsLessComplex(RHS, LHS)) std::swap(LHS, RHS); return; } // Do the rough sort by complexity. - llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) { - return IsLessComplex(LHS, RHS); - }); + llvm::stable_sort( + Ops, [&](SCEVUse LHS, SCEVUse RHS) { return IsLessComplex(LHS, RHS); }); // Now that we are sorted by complexity, group elements of the same // complexity. Note that this is, at worst, N^2, but the vector is likely to // be extremely short in practice. Note that we take this approach because we // do not want to depend on the addresses of the objects we are grouping. for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) { - const SCEV *S = Ops[i]; + SCEVUse S = Ops[i]; unsigned Complexity = S->getSCEVType(); // If there are any objects of the same complexity and same value as this @@ -825,8 +894,8 @@ static void GroupByComplexity(SmallVectorImpl &Ops, /// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at /// least HugeExprThreshold nodes). -static bool hasHugeExpression(ArrayRef Ops) { - return any_of(Ops, [](const SCEV *S) { +static bool hasHugeExpression(ArrayRef Ops) { + return any_of(Ops, [](SCEVUse S) { return S->getExpressionSize() >= HugeExprThreshold; }); } @@ -842,7 +911,7 @@ static bool hasHugeExpression(ArrayRef Ops) { template static const SCEV * constantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT, - SmallVectorImpl &Ops, FoldT Fold, + SmallVectorImpl &Ops, FoldT Fold, IsIdentityT IsIdentity, IsAbsorberT IsAbsorber) { const SCEVConstant *Folded = nullptr; for (unsigned Idx = 0; Idx < Ops.size();) { @@ -867,7 +936,7 @@ constantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT, if (Folded && IsAbsorber(Folded->getAPInt())) return Folded; - GroupByComplexity(Ops, &LI, DT); + GroupByComplexity(Ops, &LI, DT, SE); if (Folded && !IsIdentity(Folded->getAPInt())) Ops.insert(Ops.begin(), Folded); @@ -879,9 +948,8 @@ constantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT, //===----------------------------------------------------------------------===// /// Compute BC(It, K). The result has width W. Assume, K > 0. -static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K, - ScalarEvolution &SE, - Type *ResultTy) { +static SCEVUse BinomialCoefficient(SCEVUse It, unsigned K, ScalarEvolution &SE, + Type *ResultTy) { // Handle the simplest case efficiently. if (K == 1) return SE.getTruncateOrZeroExtend(It, ResultTy); @@ -968,15 +1036,15 @@ static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K, // Calculate the product, at width T+W IntegerType *CalculationTy = IntegerType::get(SE.getContext(), CalculationBits); - const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy); + SCEVUse Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy); for (unsigned i = 1; i != K; ++i) { - const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i)); + SCEVUse S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i)); Dividend = SE.getMulExpr(Dividend, SE.getTruncateOrZeroExtend(S, CalculationTy)); } // Divide by 2^T - const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor)); + SCEVUse DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor)); // Truncate the result, and divide by K! / 2^T. @@ -992,21 +1060,20 @@ static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K, /// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3) /// /// where BC(It, k) stands for binomial coefficient. -const SCEV *SCEVAddRecExpr::evaluateAtIteration(const SCEV *It, - ScalarEvolution &SE) const { +SCEVUse SCEVAddRecExpr::evaluateAtIteration(SCEVUse It, + ScalarEvolution &SE) const { return evaluateAtIteration(operands(), It, SE); } -const SCEV * -SCEVAddRecExpr::evaluateAtIteration(ArrayRef Operands, - const SCEV *It, ScalarEvolution &SE) { +SCEVUse SCEVAddRecExpr::evaluateAtIteration(ArrayRef Operands, + SCEVUse It, ScalarEvolution &SE) { assert(Operands.size() > 0); - const SCEV *Result = Operands[0]; + SCEVUse Result = Operands[0]; for (unsigned i = 1, e = Operands.size(); i != e; ++i) { // The computation is correct in the face of overflow provided that the // multiplication is performed _after_ the evaluation of the binomial // coefficient. - const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType()); + SCEVUse Coeff = BinomialCoefficient(It, i, SE, Result->getType()); if (isa(Coeff)) return Coeff; @@ -1019,8 +1086,7 @@ SCEVAddRecExpr::evaluateAtIteration(ArrayRef Operands, // SCEV Expression folder implementations //===----------------------------------------------------------------------===// -const SCEV *ScalarEvolution::getLosslessPtrToIntExpr(const SCEV *Op, - unsigned Depth) { +SCEVUse ScalarEvolution::getLosslessPtrToIntExpr(SCEVUse Op, unsigned Depth) { assert(Depth <= 1 && "getLosslessPtrToIntExpr() should self-recurse at most once."); @@ -1032,12 +1098,12 @@ const SCEV *ScalarEvolution::getLosslessPtrToIntExpr(const SCEV *Op, // What would be an ID for such a SCEV cast expression? FoldingSetNodeID ID; ID.AddInteger(scPtrToInt); - ID.AddPointer(Op); + ID.AddPointer(Op.getRawPointer()); void *IP = nullptr; // Is there already an expression for such a cast? - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) + if (SCEVUse S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; // It isn't legal for optimizations to construct new ptrtoint expressions @@ -1094,12 +1160,12 @@ const SCEV *ScalarEvolution::getLosslessPtrToIntExpr(const SCEV *Op, public: SCEVPtrToIntSinkingRewriter(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {} - static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE) { + static SCEVUse rewrite(SCEVUse Scev, ScalarEvolution &SE) { SCEVPtrToIntSinkingRewriter Rewriter(SE); return Rewriter.visit(Scev); } - const SCEV *visit(const SCEV *S) { + SCEVUse visit(SCEVUse S) { Type *STy = S->getType(); // If the expression is not pointer-typed, just keep it as-is. if (!STy->isPointerTy()) @@ -1108,27 +1174,27 @@ const SCEV *ScalarEvolution::getLosslessPtrToIntExpr(const SCEV *Op, return Base::visit(S); } - const SCEV *visitAddExpr(const SCEVAddExpr *Expr) { - SmallVector Operands; + SCEVUse visitAddExpr(const SCEVAddExpr *Expr) { + SmallVector Operands; bool Changed = false; - for (const auto *Op : Expr->operands()) { + for (const auto Op : Expr->operands()) { Operands.push_back(visit(Op)); Changed |= Op != Operands.back(); } return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags()); } - const SCEV *visitMulExpr(const SCEVMulExpr *Expr) { - SmallVector Operands; + SCEVUse visitMulExpr(const SCEVMulExpr *Expr) { + SmallVector Operands; bool Changed = false; - for (const auto *Op : Expr->operands()) { + for (const auto Op : Expr->operands()) { Operands.push_back(visit(Op)); Changed |= Op != Operands.back(); } return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags()); } - const SCEV *visitUnknown(const SCEVUnknown *Expr) { + SCEVUse visitUnknown(const SCEVUnknown *Expr) { assert(Expr->getType()->isPointerTy() && "Should only reach pointer-typed SCEVUnknown's."); return SE.getLosslessPtrToIntExpr(Expr, /*Depth=*/1); @@ -1136,25 +1202,24 @@ const SCEV *ScalarEvolution::getLosslessPtrToIntExpr(const SCEV *Op, }; // And actually perform the cast sinking. - const SCEV *IntOp = SCEVPtrToIntSinkingRewriter::rewrite(Op, *this); + SCEVUse IntOp = SCEVPtrToIntSinkingRewriter::rewrite(Op, *this); assert(IntOp->getType()->isIntegerTy() && "We must have succeeded in sinking the cast, " "and ending up with an integer-typed expression!"); return IntOp; } -const SCEV *ScalarEvolution::getPtrToIntExpr(const SCEV *Op, Type *Ty) { +SCEVUse ScalarEvolution::getPtrToIntExpr(SCEVUse Op, Type *Ty) { assert(Ty->isIntegerTy() && "Target type must be an integer type!"); - const SCEV *IntOp = getLosslessPtrToIntExpr(Op); + SCEVUse IntOp = getLosslessPtrToIntExpr(Op); if (isa(IntOp)) return IntOp; return getTruncateOrZeroExtend(IntOp, Ty); } -const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, Type *Ty, - unsigned Depth) { +SCEVUse ScalarEvolution::getTruncateExpr(SCEVUse Op, Type *Ty, unsigned Depth) { assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) && "This is not a truncating conversion!"); assert(isSCEVable(Ty) && @@ -1164,10 +1229,11 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, Type *Ty, FoldingSetNodeID ID; ID.AddInteger(scTruncate); - ID.AddPointer(Op); + ID.AddPointer(Op.getRawPointer()); ID.AddPointer(Ty); void *IP = nullptr; - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + if (SCEVUse S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) + return S; // Fold if the operand is constant. if (const SCEVConstant *SC = dyn_cast(Op)) @@ -1200,11 +1266,11 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, Type *Ty, // that replace other casts. if (isa(Op) || isa(Op)) { auto *CommOp = cast(Op); - SmallVector Operands; + SmallVector Operands; unsigned numTruncs = 0; for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2; ++i) { - const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1); + SCEVUse S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1); if (!isa(CommOp->getOperand(i)) && isa(S)) numTruncs++; @@ -1220,14 +1286,14 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, Type *Ty, // Although we checked in the beginning that ID is not in the cache, it is // possible that during recursion and different modification ID was inserted // into the cache. So if we find it, just return it. - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) + if (SCEVUse S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; } // If the input value is a chrec scev, truncate the chrec's operands. if (const SCEVAddRecExpr *AddRec = dyn_cast(Op)) { - SmallVector Operands; - for (const SCEV *Op : AddRec->operands()) + SmallVector Operands; + for (SCEVUse Op : AddRec->operands()) Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1)); return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap); } @@ -1250,9 +1316,9 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, Type *Ty, // Get the limit of a recurrence such that incrementing by Step cannot cause // signed overflow as long as the value of the recurrence within the // loop does not exceed this limit before incrementing. -static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step, - ICmpInst::Predicate *Pred, - ScalarEvolution *SE) { +static SCEVUse getSignedOverflowLimitForStep(SCEVUse Step, + ICmpInst::Predicate *Pred, + ScalarEvolution *SE) { unsigned BitWidth = SE->getTypeSizeInBits(Step->getType()); if (SE->isKnownPositive(Step)) { *Pred = ICmpInst::ICMP_SLT; @@ -1270,9 +1336,9 @@ static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step, // Get the limit of a recurrence such that incrementing by Step cannot cause // unsigned overflow as long as the value of the recurrence within the loop does // not exceed this limit before incrementing. -static const SCEV *getUnsignedOverflowLimitForStep(const SCEV *Step, - ICmpInst::Predicate *Pred, - ScalarEvolution *SE) { +static SCEVUse getUnsignedOverflowLimitForStep(SCEVUse Step, + ICmpInst::Predicate *Pred, + ScalarEvolution *SE) { unsigned BitWidth = SE->getTypeSizeInBits(Step->getType()); *Pred = ICmpInst::ICMP_ULT; @@ -1283,8 +1349,8 @@ static const SCEV *getUnsignedOverflowLimitForStep(const SCEV *Step, namespace { struct ExtendOpTraitsBase { - typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *, - unsigned); + typedef SCEVUse (ScalarEvolution::*GetExtendExprTy)(SCEVUse, Type *, + unsigned); }; // Used to make code generic over signed and unsigned overflow. @@ -1295,7 +1361,7 @@ template struct ExtendOpTraits { // // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr; // - // static const SCEV *getOverflowLimitForStep(const SCEV *Step, + // static SCEVUse getOverflowLimitForStep(SCEVUse Step, // ICmpInst::Predicate *Pred, // ScalarEvolution *SE); }; @@ -1306,9 +1372,9 @@ struct ExtendOpTraits : public ExtendOpTraitsBase { static const GetExtendExprTy GetExtendExpr; - static const SCEV *getOverflowLimitForStep(const SCEV *Step, - ICmpInst::Predicate *Pred, - ScalarEvolution *SE) { + static SCEVUse getOverflowLimitForStep(SCEVUse Step, + ICmpInst::Predicate *Pred, + ScalarEvolution *SE) { return getSignedOverflowLimitForStep(Step, Pred, SE); } }; @@ -1322,9 +1388,9 @@ struct ExtendOpTraits : public ExtendOpTraitsBase { static const GetExtendExprTy GetExtendExpr; - static const SCEV *getOverflowLimitForStep(const SCEV *Step, - ICmpInst::Predicate *Pred, - ScalarEvolution *SE) { + static SCEVUse getOverflowLimitForStep(SCEVUse Step, + ICmpInst::Predicate *Pred, + ScalarEvolution *SE) { return getUnsignedOverflowLimitForStep(Step, Pred, SE); } }; @@ -1342,14 +1408,14 @@ const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits< // expression "Step + sext/zext(PreIncAR)" is congruent with // "sext/zext(PostIncAR)" template -static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, - ScalarEvolution *SE, unsigned Depth) { +static SCEVUse getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, + ScalarEvolution *SE, unsigned Depth) { auto WrapType = ExtendOpTraits::WrapType; auto GetExtendExpr = ExtendOpTraits::GetExtendExpr; const Loop *L = AR->getLoop(); - const SCEV *Start = AR->getStart(); - const SCEV *Step = AR->getStepRecurrence(*SE); + SCEVUse Start = AR->getStart(); + SCEVUse Step = AR->getStepRecurrence(*SE); // Check for a simple looking step prior to loop entry. const SCEVAddExpr *SA = dyn_cast(Start); @@ -1360,7 +1426,7 @@ static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, // subtraction is expensive. For this purpose, perform a quick and dirty // difference, by checking for Step in the operand list. Note, that // SA might have repeated ops, like %a + %a + ..., so only remove one. - SmallVector DiffOps(SA->operands()); + SmallVector DiffOps(SA->operands()); for (auto It = DiffOps.begin(); It != DiffOps.end(); ++It) if (*It == Step) { DiffOps.erase(It); @@ -1376,7 +1442,7 @@ static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, // 1. NSW/NUW flags on the step increment. auto PreStartFlags = ScalarEvolution::maskFlags(SA->getNoWrapFlags(), SCEV::FlagNUW); - const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags); + SCEVUse PreStart = SE->getAddExpr(DiffOps, PreStartFlags); const SCEVAddRecExpr *PreAR = dyn_cast( SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap)); @@ -1384,7 +1450,7 @@ static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, // "S+X does not sign/unsign-overflow". // - const SCEV *BECount = SE->getBackedgeTakenCount(L); + SCEVUse BECount = SE->getBackedgeTakenCount(L); if (PreAR && PreAR->getNoWrapFlags(WrapType) && !isa(BECount) && SE->isKnownPositive(BECount)) return PreStart; @@ -1392,7 +1458,7 @@ static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, // 2. Direct overflow check on the step operation's expression. unsigned BitWidth = SE->getTypeSizeInBits(AR->getType()); Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2); - const SCEV *OperandExtendedStart = + SCEVUse OperandExtendedStart = SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth), (SE->*GetExtendExpr)(Step, WideTy, Depth)); if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) { @@ -1407,7 +1473,7 @@ static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, // 3. Loop precondition. ICmpInst::Predicate Pred; - const SCEV *OverflowLimit = + SCEVUse OverflowLimit = ExtendOpTraits::getOverflowLimitForStep(Step, &Pred, SE); if (OverflowLimit && @@ -1419,12 +1485,11 @@ static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, // Get the normalized zero or sign extended expression for this AddRec's Start. template -static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, - ScalarEvolution *SE, - unsigned Depth) { +static SCEVUse getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, + ScalarEvolution *SE, unsigned Depth) { auto GetExtendExpr = ExtendOpTraits::GetExtendExpr; - const SCEV *PreStart = getPreStartForExtend(AR, Ty, SE, Depth); + SCEVUse PreStart = getPreStartForExtend(AR, Ty, SE, Depth); if (!PreStart) return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth); @@ -1466,8 +1531,7 @@ static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, // In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T // is `Delta` (defined below). template -bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start, - const SCEV *Step, +bool ScalarEvolution::proveNoWrapByVaryingStart(SCEVUse Start, SCEVUse Step, const Loop *L) { auto WrapType = ExtendOpTraits::WrapType; @@ -1482,12 +1546,12 @@ bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start, APInt StartAI = StartC->getAPInt(); for (unsigned Delta : {-2, -1, 1, 2}) { - const SCEV *PreStart = getConstant(StartAI - Delta); + SCEVUse PreStart = getConstant(StartAI - Delta); FoldingSetNodeID ID; ID.AddInteger(scAddRecExpr); - ID.AddPointer(PreStart); - ID.AddPointer(Step); + ID.AddPointer(PreStart.getRawPointer()); + ID.AddPointer(Step.getRawPointer()); ID.AddPointer(L); void *IP = nullptr; const auto *PreAR = @@ -1496,9 +1560,9 @@ bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start, // Give up if we don't already have the add recurrence we need because // actually constructing an add recurrence is relatively expensive. if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2) - const SCEV *DeltaS = getConstant(StartC->getType(), Delta); + SCEVUse DeltaS = getConstant(StartC->getType(), Delta); ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE; - const SCEV *Limit = ExtendOpTraits::getOverflowLimitForStep( + SCEVUse Limit = ExtendOpTraits::getOverflowLimitForStep( DeltaS, &Pred, this); if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1) return true; @@ -1536,7 +1600,7 @@ static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, // ConstantStart, x is an arbitrary \p Step, and n is the loop trip count. static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, const APInt &ConstantStart, - const SCEV *Step) { + SCEVUse Step) { const unsigned BitWidth = ConstantStart.getBitWidth(); const uint32_t TZ = SE.getMinTrailingZeros(Step); if (TZ) @@ -1546,10 +1610,9 @@ static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, } static void insertFoldCacheEntry( - const ScalarEvolution::FoldID &ID, const SCEV *S, - DenseMap &FoldCache, - DenseMap> - &FoldCacheUser) { + const ScalarEvolution::FoldID &ID, SCEVUse S, + DenseMap &FoldCache, + DenseMap> &FoldCacheUser) { auto I = FoldCache.insert({ID, S}); if (!I.second) { // Remove FoldCacheUser entry for ID when replacing an existing FoldCache @@ -1567,8 +1630,8 @@ static void insertFoldCacheEntry( FoldCacheUser[S].push_back(ID); } -const SCEV * -ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { +SCEVUse ScalarEvolution::getZeroExtendExpr(SCEVUse Op, Type *Ty, + unsigned Depth) { assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && "This is not an extending conversion!"); assert(isSCEVable(Ty) && @@ -1581,14 +1644,14 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { if (Iter != FoldCache.end()) return Iter->second; - const SCEV *S = getZeroExtendExprImpl(Op, Ty, Depth); + SCEVUse S = getZeroExtendExprImpl(Op, Ty, Depth); if (!isa(S)) insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser); return S; } -const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty, - unsigned Depth) { +SCEVUse ScalarEvolution::getZeroExtendExprImpl(SCEVUse Op, Type *Ty, + unsigned Depth) { assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && "This is not an extending conversion!"); assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!"); @@ -1606,10 +1669,11 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty, // computed a SCEV for this Op and Ty. FoldingSetNodeID ID; ID.AddInteger(scZeroExtend); - ID.AddPointer(Op); + ID.AddPointer(Op.getRawPointer()); ID.AddPointer(Ty); void *IP = nullptr; - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + if (SCEVUse S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) + return S; if (Depth > MaxCastDepth) { SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator), Op, Ty); @@ -1622,7 +1686,7 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty, if (const SCEVTruncateExpr *ST = dyn_cast(Op)) { // It's possible the bits taken off by the truncate were all zero bits. If // so, we should be able to simplify this further. - const SCEV *X = ST->getOperand(); + SCEVUse X = ST->getOperand(); ConstantRange CR = getUnsignedRange(X); unsigned TruncBits = getTypeSizeInBits(ST->getType()); unsigned NewBits = getTypeSizeInBits(Ty); @@ -1637,8 +1701,8 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty, // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; } if (const SCEVAddRecExpr *AR = dyn_cast(Op)) if (AR->isAffine()) { - const SCEV *Start = AR->getStart(); - const SCEV *Step = AR->getStepRecurrence(*this); + SCEVUse Start = AR->getStart(); + SCEVUse Step = AR->getStepRecurrence(*this); unsigned BitWidth = getTypeSizeInBits(AR->getType()); const Loop *L = AR->getLoop(); @@ -1659,34 +1723,33 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty, // in infinite recursion. In the later case, the analysis code will // cope with a conservative value, and it will take care to purge // that value once it has finished. - const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L); + SCEVUse MaxBECount = getConstantMaxBackedgeTakenCount(L); if (!isa(MaxBECount)) { // Manually compute the final value for AR, checking for overflow. // Check whether the backedge-taken count can be losslessly casted to // the addrec's type. The count is always unsigned. - const SCEV *CastedMaxBECount = + SCEVUse CastedMaxBECount = getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth); - const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend( + SCEVUse RecastedMaxBECount = getTruncateOrZeroExtend( CastedMaxBECount, MaxBECount->getType(), Depth); if (MaxBECount == RecastedMaxBECount) { Type *WideTy = IntegerType::get(getContext(), BitWidth * 2); // Check whether Start+Step*MaxBECount has no unsigned overflow. - const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step, - SCEV::FlagAnyWrap, Depth + 1); - const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul, - SCEV::FlagAnyWrap, - Depth + 1), - WideTy, Depth + 1); - const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1); - const SCEV *WideMaxBECount = - getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1); - const SCEV *OperandExtendedAdd = - getAddExpr(WideStart, - getMulExpr(WideMaxBECount, - getZeroExtendExpr(Step, WideTy, Depth + 1), - SCEV::FlagAnyWrap, Depth + 1), - SCEV::FlagAnyWrap, Depth + 1); + SCEVUse ZMul = + getMulExpr(CastedMaxBECount, Step, SCEV::FlagAnyWrap, Depth + 1); + SCEVUse ZAdd = getZeroExtendExpr( + getAddExpr(Start, ZMul, SCEV::FlagAnyWrap, Depth + 1), WideTy, + Depth + 1); + SCEVUse WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1); + SCEVUse WideMaxBECount = + getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1); + SCEVUse OperandExtendedAdd = + getAddExpr(WideStart, + getMulExpr(WideMaxBECount, + getZeroExtendExpr(Step, WideTy, Depth + 1), + SCEV::FlagAnyWrap, Depth + 1), + SCEV::FlagAnyWrap, Depth + 1); if (ZAdd == OperandExtendedAdd) { // Cache knowledge of AR NUW, which is propagated to this AddRec. setNoWrapFlags(const_cast(AR), SCEV::FlagNUW); @@ -1743,8 +1806,8 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty, // For a negative step, we can extend the operands iff doing so only // traverses values in the range zext([0,UINT_MAX]). if (isKnownNegative(Step)) { - const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) - - getSignedRangeMin(Step)); + SCEVUse N = getConstant(APInt::getMaxValue(BitWidth) - + getSignedRangeMin(Step)); if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, AR, N) || isKnownOnEveryIteration(ICmpInst::ICMP_UGT, AR, N)) { // Cache knowledge of AR NW, which is propagated to this @@ -1767,10 +1830,10 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty, const APInt &C = SC->getAPInt(); const APInt &D = extractConstantWithoutWrapping(*this, C, Step); if (D != 0) { - const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth); - const SCEV *SResidual = + SCEVUse SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth); + SCEVUse SResidual = getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags()); - const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1); + SCEVUse SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1); return getAddExpr(SZExtD, SZExtR, (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW), Depth + 1); @@ -1788,8 +1851,8 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty, // zext(A % B) --> zext(A) % zext(B) { - const SCEV *LHS; - const SCEV *RHS; + SCEVUse LHS; + SCEVUse RHS; if (matchURem(Op, LHS, RHS)) return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1), getZeroExtendExpr(RHS, Ty, Depth + 1)); @@ -1805,8 +1868,8 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty, if (SA->hasNoUnsignedWrap()) { // If the addition does not unsign overflow then we can, by definition, // commute the zero extension with the addition operation. - SmallVector Ops; - for (const auto *Op : SA->operands()) + SmallVector Ops; + for (const auto Op : SA->operands()) Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1)); return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1); } @@ -1822,10 +1885,10 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty, if (const auto *SC = dyn_cast(SA->getOperand(0))) { const APInt &D = extractConstantWithoutWrapping(*this, SC, SA); if (D != 0) { - const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth); - const SCEV *SResidual = + SCEVUse SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth); + SCEVUse SResidual = getAddExpr(getConstant(-D), SA, SCEV::FlagAnyWrap, Depth); - const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1); + SCEVUse SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1); return getAddExpr(SZExtD, SZExtR, (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW), Depth + 1); @@ -1838,8 +1901,8 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty, if (SM->hasNoUnsignedWrap()) { // If the multiply does not unsign overflow then we can, by definition, // commute the zero extension with the multiply operation. - SmallVector Ops; - for (const auto *Op : SM->operands()) + SmallVector Ops; + for (const auto Op : SM->operands()) Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1)); return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1); } @@ -1875,8 +1938,8 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty, // zext(umax(x, y)) -> umax(zext(x), zext(y)) if (isa(Op) || isa(Op)) { auto *MinMax = cast(Op); - SmallVector Operands; - for (auto *Operand : MinMax->operands()) + SmallVector Operands; + for (auto Operand : MinMax->operands()) Operands.push_back(getZeroExtendExpr(Operand, Ty)); if (isa(MinMax)) return getUMinExpr(Operands); @@ -1886,15 +1949,16 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty, // zext(umin_seq(x, y)) -> umin_seq(zext(x), zext(y)) if (auto *MinMax = dyn_cast(Op)) { assert(isa(MinMax) && "Not supported!"); - SmallVector Operands; - for (auto *Operand : MinMax->operands()) + SmallVector Operands; + for (auto Operand : MinMax->operands()) Operands.push_back(getZeroExtendExpr(Operand, Ty)); return getUMinExpr(Operands, /*Sequential*/ true); } // The cast wasn't folded; create an explicit cast node. // Recompute the insert position, as it may have been invalidated. - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + if (SCEVUse S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) + return S; SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator), Op, Ty); UniqueSCEVs.InsertNode(S, IP); @@ -1902,8 +1966,8 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty, return S; } -const SCEV * -ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { +SCEVUse ScalarEvolution::getSignExtendExpr(SCEVUse Op, Type *Ty, + unsigned Depth) { assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && "This is not an extending conversion!"); assert(isSCEVable(Ty) && @@ -1916,14 +1980,14 @@ ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { if (Iter != FoldCache.end()) return Iter->second; - const SCEV *S = getSignExtendExprImpl(Op, Ty, Depth); + SCEVUse S = getSignExtendExprImpl(Op, Ty, Depth); if (!isa(S)) insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser); return S; } -const SCEV *ScalarEvolution::getSignExtendExprImpl(const SCEV *Op, Type *Ty, - unsigned Depth) { +SCEVUse ScalarEvolution::getSignExtendExprImpl(SCEVUse Op, Type *Ty, + unsigned Depth) { assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && "This is not an extending conversion!"); assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!"); @@ -1946,10 +2010,11 @@ const SCEV *ScalarEvolution::getSignExtendExprImpl(const SCEV *Op, Type *Ty, // computed a SCEV for this Op and Ty. FoldingSetNodeID ID; ID.AddInteger(scSignExtend); - ID.AddPointer(Op); + ID.AddPointer(Op.getRawPointer()); ID.AddPointer(Ty); void *IP = nullptr; - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + if (SCEVUse S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) + return S; // Limit recursion depth. if (Depth > MaxCastDepth) { SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator), @@ -1963,7 +2028,7 @@ const SCEV *ScalarEvolution::getSignExtendExprImpl(const SCEV *Op, Type *Ty, if (const SCEVTruncateExpr *ST = dyn_cast(Op)) { // It's possible the bits taken off by the truncate were all sign bits. If // so, we should be able to simplify this further. - const SCEV *X = ST->getOperand(); + SCEVUse X = ST->getOperand(); ConstantRange CR = getSignedRange(X); unsigned TruncBits = getTypeSizeInBits(ST->getType()); unsigned NewBits = getTypeSizeInBits(Ty); @@ -1977,8 +2042,8 @@ const SCEV *ScalarEvolution::getSignExtendExprImpl(const SCEV *Op, Type *Ty, if (SA->hasNoSignedWrap()) { // If the addition does not sign overflow then we can, by definition, // commute the sign extension with the addition operation. - SmallVector Ops; - for (const auto *Op : SA->operands()) + SmallVector Ops; + for (const auto Op : SA->operands()) Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1)); return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1); } @@ -1995,10 +2060,10 @@ const SCEV *ScalarEvolution::getSignExtendExprImpl(const SCEV *Op, Type *Ty, if (const auto *SC = dyn_cast(SA->getOperand(0))) { const APInt &D = extractConstantWithoutWrapping(*this, SC, SA); if (D != 0) { - const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth); - const SCEV *SResidual = + SCEVUse SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth); + SCEVUse SResidual = getAddExpr(getConstant(-D), SA, SCEV::FlagAnyWrap, Depth); - const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1); + SCEVUse SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1); return getAddExpr(SSExtD, SSExtR, (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW), Depth + 1); @@ -2011,8 +2076,8 @@ const SCEV *ScalarEvolution::getSignExtendExprImpl(const SCEV *Op, Type *Ty, // this: for (signed char X = 0; X < 100; ++X) { int Y = X; } if (const SCEVAddRecExpr *AR = dyn_cast(Op)) if (AR->isAffine()) { - const SCEV *Start = AR->getStart(); - const SCEV *Step = AR->getStepRecurrence(*this); + SCEVUse Start = AR->getStart(); + SCEVUse Step = AR->getStepRecurrence(*this); unsigned BitWidth = getTypeSizeInBits(AR->getType()); const Loop *L = AR->getLoop(); @@ -2033,35 +2098,34 @@ const SCEV *ScalarEvolution::getSignExtendExprImpl(const SCEV *Op, Type *Ty, // in infinite recursion. In the later case, the analysis code will // cope with a conservative value, and it will take care to purge // that value once it has finished. - const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L); + SCEVUse MaxBECount = getConstantMaxBackedgeTakenCount(L); if (!isa(MaxBECount)) { // Manually compute the final value for AR, checking for // overflow. // Check whether the backedge-taken count can be losslessly casted to // the addrec's type. The count is always unsigned. - const SCEV *CastedMaxBECount = + SCEVUse CastedMaxBECount = getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth); - const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend( + SCEVUse RecastedMaxBECount = getTruncateOrZeroExtend( CastedMaxBECount, MaxBECount->getType(), Depth); if (MaxBECount == RecastedMaxBECount) { Type *WideTy = IntegerType::get(getContext(), BitWidth * 2); // Check whether Start+Step*MaxBECount has no signed overflow. - const SCEV *SMul = getMulExpr(CastedMaxBECount, Step, - SCEV::FlagAnyWrap, Depth + 1); - const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul, - SCEV::FlagAnyWrap, - Depth + 1), - WideTy, Depth + 1); - const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1); - const SCEV *WideMaxBECount = - getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1); - const SCEV *OperandExtendedAdd = - getAddExpr(WideStart, - getMulExpr(WideMaxBECount, - getSignExtendExpr(Step, WideTy, Depth + 1), - SCEV::FlagAnyWrap, Depth + 1), - SCEV::FlagAnyWrap, Depth + 1); + SCEVUse SMul = + getMulExpr(CastedMaxBECount, Step, SCEV::FlagAnyWrap, Depth + 1); + SCEVUse SAdd = getSignExtendExpr( + getAddExpr(Start, SMul, SCEV::FlagAnyWrap, Depth + 1), WideTy, + Depth + 1); + SCEVUse WideStart = getSignExtendExpr(Start, WideTy, Depth + 1); + SCEVUse WideMaxBECount = + getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1); + SCEVUse OperandExtendedAdd = + getAddExpr(WideStart, + getMulExpr(WideMaxBECount, + getSignExtendExpr(Step, WideTy, Depth + 1), + SCEV::FlagAnyWrap, Depth + 1), + SCEV::FlagAnyWrap, Depth + 1); if (SAdd == OperandExtendedAdd) { // Cache knowledge of AR NSW, which is propagated to this AddRec. setNoWrapFlags(const_cast(AR), SCEV::FlagNSW); @@ -2119,10 +2183,10 @@ const SCEV *ScalarEvolution::getSignExtendExprImpl(const SCEV *Op, Type *Ty, const APInt &C = SC->getAPInt(); const APInt &D = extractConstantWithoutWrapping(*this, C, Step); if (D != 0) { - const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth); - const SCEV *SResidual = + SCEVUse SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth); + SCEVUse SResidual = getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags()); - const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1); + SCEVUse SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1); return getAddExpr(SSExtD, SSExtR, (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW), Depth + 1); @@ -2147,8 +2211,8 @@ const SCEV *ScalarEvolution::getSignExtendExprImpl(const SCEV *Op, Type *Ty, // sext(smax(x, y)) -> smax(sext(x), sext(y)) if (isa(Op) || isa(Op)) { auto *MinMax = cast(Op); - SmallVector Operands; - for (auto *Operand : MinMax->operands()) + SmallVector Operands; + for (auto Operand : MinMax->operands()) Operands.push_back(getSignExtendExpr(Operand, Ty)); if (isa(MinMax)) return getSMinExpr(Operands); @@ -2157,7 +2221,8 @@ const SCEV *ScalarEvolution::getSignExtendExprImpl(const SCEV *Op, Type *Ty, // The cast wasn't folded; create an explicit cast node. // Recompute the insert position, as it may have been invalidated. - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + if (SCEVUse S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) + return S; SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator), Op, Ty); UniqueSCEVs.InsertNode(S, IP); @@ -2165,8 +2230,7 @@ const SCEV *ScalarEvolution::getSignExtendExprImpl(const SCEV *Op, Type *Ty, return S; } -const SCEV *ScalarEvolution::getCastExpr(SCEVTypes Kind, const SCEV *Op, - Type *Ty) { +SCEVUse ScalarEvolution::getCastExpr(SCEVTypes Kind, SCEVUse Op, Type *Ty) { switch (Kind) { case scTruncate: return getTruncateExpr(Op, Ty); @@ -2183,8 +2247,7 @@ const SCEV *ScalarEvolution::getCastExpr(SCEVTypes Kind, const SCEV *Op, /// getAnyExtendExpr - Return a SCEV for the given operand extended with /// unspecified bits out to the given type. -const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op, - Type *Ty) { +SCEVUse ScalarEvolution::getAnyExtendExpr(SCEVUse Op, Type *Ty) { assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && "This is not an extending conversion!"); assert(isSCEVable(Ty) && @@ -2198,26 +2261,26 @@ const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op, // Peel off a truncate cast. if (const SCEVTruncateExpr *T = dyn_cast(Op)) { - const SCEV *NewOp = T->getOperand(); + SCEVUse NewOp = T->getOperand(); if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty)) return getAnyExtendExpr(NewOp, Ty); return getTruncateOrNoop(NewOp, Ty); } // Next try a zext cast. If the cast is folded, use it. - const SCEV *ZExt = getZeroExtendExpr(Op, Ty); + SCEVUse ZExt = getZeroExtendExpr(Op, Ty); if (!isa(ZExt)) return ZExt; // Next try a sext cast. If the cast is folded, use it. - const SCEV *SExt = getSignExtendExpr(Op, Ty); + SCEVUse SExt = getSignExtendExpr(Op, Ty); if (!isa(SExt)) return SExt; // Force the cast to be folded into the operands of an addrec. if (const SCEVAddRecExpr *AR = dyn_cast(Op)) { - SmallVector Ops; - for (const SCEV *Op : AR->operands()) + SmallVector Ops; + for (SCEVUse Op : AR->operands()) Ops.push_back(getAnyExtendExpr(Op, Ty)); return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW); } @@ -2253,12 +2316,12 @@ const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op, /// may be exposed. This helps getAddRecExpr short-circuit extra work in /// the common case where no interesting opportunities are present, and /// is also used as a check to avoid infinite recursion. -static bool -CollectAddOperandsWithScales(SmallDenseMap &M, - SmallVectorImpl &NewOps, - APInt &AccumulatedConstant, - ArrayRef Ops, const APInt &Scale, - ScalarEvolution &SE) { +static bool CollectAddOperandsWithScales(SmallDenseMap &M, + SmallVectorImpl &NewOps, + APInt &AccumulatedConstant, + ArrayRef Ops, + const APInt &Scale, + ScalarEvolution &SE) { bool Interesting = false; // Iterate over the add operands. They are sorted, with constants first. @@ -2287,8 +2350,8 @@ CollectAddOperandsWithScales(SmallDenseMap &M, } else { // A multiplication of a constant with some other value. Update // the map. - SmallVector MulOps(drop_begin(Mul->operands())); - const SCEV *Key = SE.getMulExpr(MulOps); + SmallVector MulOps(drop_begin(Mul->operands())); + SCEVUse Key = SE.getMulExpr(MulOps); auto Pair = M.insert({Key, NewScale}); if (Pair.second) { NewOps.push_back(Pair.first->first); @@ -2301,7 +2364,7 @@ CollectAddOperandsWithScales(SmallDenseMap &M, } } else { // An ordinary operand. Update the map. - std::pair::iterator, bool> Pair = + std::pair::iterator, bool> Pair = M.insert({Ops[i], Scale}); if (Pair.second) { NewOps.push_back(Pair.first->first); @@ -2318,10 +2381,10 @@ CollectAddOperandsWithScales(SmallDenseMap &M, } bool ScalarEvolution::willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, - const SCEV *LHS, const SCEV *RHS, + SCEVUse LHS, SCEVUse RHS, const Instruction *CtxI) { - const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *, - SCEV::NoWrapFlags, unsigned); + SCEVUse (ScalarEvolution::*Operation)(SCEVUse, SCEVUse, SCEV::NoWrapFlags, + unsigned); switch (BinOp) { default: llvm_unreachable("Unsupported binary op"); @@ -2336,7 +2399,7 @@ bool ScalarEvolution::willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, break; } - const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) = + SCEVUse (ScalarEvolution::*Extension)(SCEVUse, Type *, unsigned) = Signed ? &ScalarEvolution::getSignExtendExpr : &ScalarEvolution::getZeroExtendExpr; @@ -2345,11 +2408,11 @@ bool ScalarEvolution::willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, auto *WideTy = IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2); - const SCEV *A = (this->*Extension)( + SCEVUse A = (this->*Extension)( (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0); - const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0); - const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0); - const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0); + SCEVUse LHSB = (this->*Extension)(LHS, WideTy, 0); + SCEVUse RHSB = (this->*Extension)(RHS, WideTy, 0); + SCEVUse B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0); if (A == B) return true; // Can we use context to prove the fact we need? @@ -2414,8 +2477,8 @@ ScalarEvolution::getStrengthenedNoWrapFlagsFromBinOp( OBO->getOpcode() != Instruction::Mul) return std::nullopt; - const SCEV *LHS = getSCEV(OBO->getOperand(0)); - const SCEV *RHS = getSCEV(OBO->getOperand(1)); + SCEVUse LHS = getSCEV(OBO->getOperand(0)); + SCEVUse RHS = getSCEV(OBO->getOperand(1)); const Instruction *CtxI = UseContextForNoWrapFlagInference ? dyn_cast(OBO) : nullptr; @@ -2441,10 +2504,10 @@ ScalarEvolution::getStrengthenedNoWrapFlagsFromBinOp( // We're trying to construct a SCEV of type `Type' with `Ops' as operands and // `OldFlags' as can't-wrap behavior. Infer a more aggressive set of // can't-overflow flags for the operation if possible. -static SCEV::NoWrapFlags -StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, - const ArrayRef Ops, - SCEV::NoWrapFlags Flags) { +static SCEV::NoWrapFlags StrengthenNoWrapFlags(ScalarEvolution *SE, + SCEVTypes Type, + const ArrayRef Ops, + SCEV::NoWrapFlags Flags) { using namespace std::placeholders; using OBO = OverflowingBinaryOperator; @@ -2459,7 +2522,7 @@ StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, ScalarEvolution::maskFlags(Flags, SignOrUnsignMask); // If FlagNSW is true and all the operands are non-negative, infer FlagNUW. - auto IsKnownNonNegative = [&](const SCEV *S) { + auto IsKnownNonNegative = [&](SCEVUse S) { return SE->isKnownNonNegative(S); }; @@ -2524,14 +2587,20 @@ StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, return Flags; } -bool ScalarEvolution::isAvailableAtLoopEntry(const SCEV *S, const Loop *L) { +bool ScalarEvolution::isAvailableAtLoopEntry(SCEVUse S, const Loop *L) { return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader()); } +SCEVUse ScalarEvolution::getAddExpr(ArrayRef Ops, + SCEV::NoWrapFlags Flags, unsigned Depth) { + SmallVector Ops2(Ops.begin(), Ops.end()); + return getAddExpr(Ops2, Flags, Depth); +} + /// Get a canonical add expression, or something simpler if possible. -const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, - SCEV::NoWrapFlags OrigFlags, - unsigned Depth) { +SCEVUse ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, + SCEV::NoWrapFlags OrigFlags, + unsigned Depth) { assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) && "only nuw or nsw allowed"); assert(!Ops.empty() && "Cannot get empty add!"); @@ -2541,8 +2610,8 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, for (unsigned i = 1, e = Ops.size(); i != e; ++i) assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy && "SCEVAddExpr operand types don't match!"); - unsigned NumPtrs = count_if( - Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); }); + unsigned NumPtrs = + count_if(Ops, [](SCEVUse Op) { return Op->getType()->isPointerTy(); }); assert(NumPtrs <= 1 && "add has at most one pointer operand"); #endif @@ -2557,7 +2626,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, unsigned Idx = isa(Ops[0]) ? 1 : 0; // Delay expensive flag strengthening until necessary. - auto ComputeFlags = [this, OrigFlags](const ArrayRef Ops) { + auto ComputeFlags = [this, OrigFlags](const ArrayRef Ops) { return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags); }; @@ -2570,7 +2639,9 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, SCEVAddExpr *Add = static_cast(S); if (Add->getNoWrapFlags(OrigFlags) != OrigFlags) Add->setNoWrapFlags(ComputeFlags(Ops)); - return S; + bool IsCanonical = all_of(Ops, [](SCEVUse U) { return U.getFlags() == 0; }); + int UseFlags = IsCanonical ? 0 : 1; + return {S, UseFlags}; } // Okay, check to see if the same value occurs in the operand list more than @@ -2579,14 +2650,15 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, Type *Ty = Ops[0]->getType(); bool FoundMatch = false; for (unsigned i = 0, e = Ops.size(); i != e-1; ++i) - if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2 + if (Ops[i].getCanonical(*this) == + Ops[i + 1].getCanonical(*this)) { // X + Y + Y --> X + Y*2 // Scan ahead to count how many equal operands there are. unsigned Count = 2; while (i+Count != e && Ops[i+Count] == Ops[i]) ++Count; // Merge the values into a multiply. - const SCEV *Scale = getConstant(Ty, Count); - const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1); + SCEVUse Scale = getConstant(Ty, Count); + SCEVUse Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1); if (Ops.size() == Count) return Mul; Ops[i] = Mul; @@ -2609,14 +2681,14 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, if (auto *T = dyn_cast(Ops[Idx])) return T->getOperand()->getType(); if (const auto *Mul = dyn_cast(Ops[Idx])) { - const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1); + const auto LastOp = Mul->getOperand(Mul->getNumOperands() - 1); if (const auto *T = dyn_cast(LastOp)) return T->getOperand()->getType(); } return nullptr; }; if (auto *SrcType = FindTruncSrcType()) { - SmallVector LargeOps; + SmallVector LargeOps; bool Ok = true; // Check all the operands to see if they can be represented in the // source type of the truncate. @@ -2630,7 +2702,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, } else if (const SCEVConstant *C = dyn_cast(Op)) { LargeOps.push_back(getAnyExtendExpr(C, SrcType)); } else if (const SCEVMulExpr *M = dyn_cast(Op)) { - SmallVector LargeMulOps; + SmallVector LargeMulOps; for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) { if (const SCEVTruncateExpr *T = dyn_cast(M->getOperand(j))) { @@ -2655,7 +2727,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, } if (Ok) { // Evaluate the expression in the larger type. - const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1); + SCEVUse Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1); // If it folds to something simple, use it. Otherwise, don't. if (isa(Fold) || isa(Fold)) return getTruncateExpr(Fold, Ty); @@ -2666,8 +2738,8 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, // Check if we have an expression of the form ((X + C1) - C2), where C1 and // C2 can be folded in a way that allows retaining wrapping flags of (X + // C1). - const SCEV *A = Ops[0]; - const SCEV *B = Ops[1]; + SCEVUse A = Ops[0]; + SCEVUse B = Ops[1]; auto *AddExpr = dyn_cast(B); auto *C = dyn_cast(A); if (AddExpr && C && isa(AddExpr->getOperand(0))) { @@ -2694,7 +2766,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, } if (PreservedFlags != SCEV::FlagAnyWrap) { - SmallVector NewOps(AddExpr->operands()); + SmallVector NewOps(AddExpr->operands()); NewOps[0] = getConstant(ConstAdd); return getAddExpr(NewOps, PreservedFlags); } @@ -2706,8 +2778,8 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, const SCEVMulExpr *Mul = dyn_cast(Ops[0]); if (Mul && Mul->getNumOperands() == 2 && Mul->getOperand(0)->isAllOnesValue()) { - const SCEV *X; - const SCEV *Y; + SCEVUse X; + SCEVUse Y; if (matchURem(Mul->getOperand(1), X, Y) && X == Ops[1]) { return getMulExpr(Y, getUDivExpr(X, Y)); } @@ -2731,6 +2803,8 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, break; // If we have an add, expand the add operands onto the end of the operands // list. + // CommonFlags = maskFlags(CommonFlags, setFlags(Add->getNoWrapFlags(), + // static_cast(Ops[Idx].getInt()))); Ops.erase(Ops.begin()+Idx); append_range(Ops, Add->operands()); DeletedAdd = true; @@ -2752,8 +2826,8 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, // operands multiplied by constant values. if (Idx < Ops.size() && isa(Ops[Idx])) { uint64_t BitWidth = getTypeSizeInBits(Ty); - SmallDenseMap M; - SmallVector NewOps; + SmallDenseMap M; + SmallVector NewOps; APInt AccumulatedConstant(BitWidth, 0); if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant, Ops, APInt(BitWidth, 1), *this)) { @@ -2766,8 +2840,8 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, // Some interesting folding opportunity is present, so its worthwhile to // re-generate the operands list. Group the operands by constant scale, // to avoid multiplying by the same constant scale multiple times. - std::map, APIntCompare> MulOpLists; - for (const SCEV *NewOp : NewOps) + std::map, APIntCompare> MulOpLists; + for (SCEVUse NewOp : NewOps) MulOpLists[M.find(NewOp)->second].push_back(NewOp); // Re-generate the operands list. Ops.clear(); @@ -2797,25 +2871,24 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, for (; Idx < Ops.size() && isa(Ops[Idx]); ++Idx) { const SCEVMulExpr *Mul = cast(Ops[Idx]); for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) { - const SCEV *MulOpSCEV = Mul->getOperand(MulOp); + SCEVUse MulOpSCEV = Mul->getOperand(MulOp); if (isa(MulOpSCEV)) continue; for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp) - if (MulOpSCEV == Ops[AddOp]) { + if (MulOpSCEV.getCanonical(*this) == Ops[AddOp].getCanonical(*this)) { // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1)) - const SCEV *InnerMul = Mul->getOperand(MulOp == 0); + SCEVUse InnerMul = Mul->getOperand(MulOp == 0); if (Mul->getNumOperands() != 2) { // If the multiply has more than two operands, we must get the // Y*Z term. - SmallVector MulOps( - Mul->operands().take_front(MulOp)); + SmallVector MulOps(Mul->operands().take_front(MulOp)); append_range(MulOps, Mul->operands().drop_front(MulOp + 1)); InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1); } - SmallVector TwoOps = {getOne(Ty), InnerMul}; - const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1); - const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV, - SCEV::FlagAnyWrap, Depth + 1); + SmallVector TwoOps = {getOne(Ty), InnerMul}; + SCEVUse AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1); + SCEVUse OuterMul = + getMulExpr(AddOne, MulOpSCEV, SCEV::FlagAnyWrap, Depth + 1); if (Ops.size() == 2) return OuterMul; if (AddOp < Idx) { Ops.erase(Ops.begin()+AddOp); @@ -2839,25 +2912,24 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, OMulOp != e; ++OMulOp) if (OtherMul->getOperand(OMulOp) == MulOpSCEV) { // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E)) - const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0); + SCEVUse InnerMul1 = Mul->getOperand(MulOp == 0); if (Mul->getNumOperands() != 2) { - SmallVector MulOps( - Mul->operands().take_front(MulOp)); + SmallVector MulOps(Mul->operands().take_front(MulOp)); append_range(MulOps, Mul->operands().drop_front(MulOp+1)); InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1); } - const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0); + SCEVUse InnerMul2 = OtherMul->getOperand(OMulOp == 0); if (OtherMul->getNumOperands() != 2) { - SmallVector MulOps( + SmallVector MulOps( OtherMul->operands().take_front(OMulOp)); append_range(MulOps, OtherMul->operands().drop_front(OMulOp+1)); InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1); } - SmallVector TwoOps = {InnerMul1, InnerMul2}; - const SCEV *InnerMulSum = + SmallVector TwoOps = {InnerMul1, InnerMul2}; + SCEVUse InnerMulSum = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1); - const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum, - SCEV::FlagAnyWrap, Depth + 1); + SCEVUse OuterMul = getMulExpr(MulOpSCEV, InnerMulSum, + SCEV::FlagAnyWrap, Depth + 1); if (Ops.size() == 2) return OuterMul; Ops.erase(Ops.begin()+Idx); Ops.erase(Ops.begin()+OtherMulIdx-1); @@ -2878,7 +2950,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, for (; Idx < Ops.size() && isa(Ops[Idx]); ++Idx) { // Scan all of the other operands to this add and add them to the vector if // they are loop invariant w.r.t. the recurrence. - SmallVector LIOps; + SmallVector LIOps; const SCEVAddRecExpr *AddRec = cast(Ops[Idx]); const Loop *AddRecLoop = AddRec->getLoop(); for (unsigned i = 0, e = Ops.size(); i != e; ++i) @@ -2900,7 +2972,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step} LIOps.push_back(AddRec->getStart()); - SmallVector AddRecOps(AddRec->operands()); + SmallVector AddRecOps(AddRec->operands()); // It is not in general safe to propagate flags valid on an add within // the addrec scope to one outside it. We must prove that the inner @@ -2925,7 +2997,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, // outer add and the inner addrec are guaranteed to have no overflow. // Always propagate NW. Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW)); - const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags); + SCEVUse NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags); // If all of the other operands were loop invariant, we are done. if (Ops.size() == 1) return NewRec; @@ -2953,7 +3025,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, "AddRecExprs are not sorted in reverse dominance order?"); if (AddRecLoop == cast(Ops[OtherIdx])->getLoop()) { // Other + {A,+,B} + {C,+,D} --> Other + {A+C,+,B+D} - SmallVector AddRecOps(AddRec->operands()); + SmallVector AddRecOps(AddRec->operands()); for (; OtherIdx != Ops.size() && isa(Ops[OtherIdx]); ++OtherIdx) { const auto *OtherAddRec = cast(Ops[OtherIdx]); @@ -2964,8 +3036,8 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, append_range(AddRecOps, OtherAddRec->operands().drop_front(i)); break; } - SmallVector TwoOps = { - AddRecOps[i], OtherAddRec->getOperand(i)}; + SmallVector TwoOps = {AddRecOps[i], + OtherAddRec->getOperand(i)}; AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1); } Ops.erase(Ops.begin() + OtherIdx); --OtherIdx; @@ -2986,18 +3058,17 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, return getOrCreateAddExpr(Ops, ComputeFlags(Ops)); } -const SCEV * -ScalarEvolution::getOrCreateAddExpr(ArrayRef Ops, - SCEV::NoWrapFlags Flags) { +SCEVUse ScalarEvolution::getOrCreateAddExpr(ArrayRef Ops, + SCEV::NoWrapFlags Flags) { FoldingSetNodeID ID; ID.AddInteger(scAddExpr); - for (const SCEV *Op : Ops) - ID.AddPointer(Op); + for (SCEVUse Op : Ops) + ID.AddPointer(Op.getRawPointer()); void *IP = nullptr; SCEVAddExpr *S = static_cast(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); if (!S) { - const SCEV **O = SCEVAllocator.Allocate(Ops.size()); + SCEVUse *O = SCEVAllocator.Allocate(Ops.size()); std::uninitialized_copy(Ops.begin(), Ops.end(), O); S = new (SCEVAllocator) SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size()); @@ -3005,22 +3076,24 @@ ScalarEvolution::getOrCreateAddExpr(ArrayRef Ops, registerUser(S, Ops); } S->setNoWrapFlags(Flags); - return S; + bool IsCanonical = all_of(Ops, [](SCEVUse U) { return U.getFlags() == 0; }); + int UseFlags = IsCanonical ? 0 : 1; + return {S, UseFlags}; } -const SCEV * -ScalarEvolution::getOrCreateAddRecExpr(ArrayRef Ops, - const Loop *L, SCEV::NoWrapFlags Flags) { +SCEVUse ScalarEvolution::getOrCreateAddRecExpr(ArrayRef Ops, + const Loop *L, + SCEV::NoWrapFlags Flags) { FoldingSetNodeID ID; ID.AddInteger(scAddRecExpr); - for (const SCEV *Op : Ops) - ID.AddPointer(Op); + for (SCEVUse Op : Ops) + ID.AddPointer(Op.getRawPointer()); ID.AddPointer(L); void *IP = nullptr; SCEVAddRecExpr *S = static_cast(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); if (!S) { - const SCEV **O = SCEVAllocator.Allocate(Ops.size()); + SCEVUse *O = SCEVAllocator.Allocate(Ops.size()); std::uninitialized_copy(Ops.begin(), Ops.end(), O); S = new (SCEVAllocator) SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L); @@ -3032,18 +3105,17 @@ ScalarEvolution::getOrCreateAddRecExpr(ArrayRef Ops, return S; } -const SCEV * -ScalarEvolution::getOrCreateMulExpr(ArrayRef Ops, - SCEV::NoWrapFlags Flags) { +SCEVUse ScalarEvolution::getOrCreateMulExpr(ArrayRef Ops, + SCEV::NoWrapFlags Flags) { FoldingSetNodeID ID; ID.AddInteger(scMulExpr); - for (const SCEV *Op : Ops) - ID.AddPointer(Op); + for (SCEVUse Op : Ops) + ID.AddPointer(Op.getRawPointer()); void *IP = nullptr; SCEVMulExpr *S = static_cast(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); if (!S) { - const SCEV **O = SCEVAllocator.Allocate(Ops.size()); + SCEVUse *O = SCEVAllocator.Allocate(Ops.size()); std::uninitialized_copy(Ops.begin(), Ops.end(), O); S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator), O, Ops.size()); @@ -3051,7 +3123,9 @@ ScalarEvolution::getOrCreateMulExpr(ArrayRef Ops, registerUser(S, Ops); } S->setNoWrapFlags(Flags); - return S; + bool IsCanonical = all_of(Ops, [](SCEVUse U) { return U.getFlags() == 0; }); + int UseFlags = IsCanonical ? 0 : 1; + return {S, UseFlags}; } static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) { @@ -3088,11 +3162,11 @@ static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) { /// Determine if any of the operands in this SCEV are a constant or if /// any of the add or multiply expressions in this SCEV contain a constant. -static bool containsConstantInAddMulChain(const SCEV *StartExpr) { +static bool containsConstantInAddMulChain(SCEVUse StartExpr) { struct FindConstantInAddMulChain { bool FoundConstant = false; - bool follow(const SCEV *S) { + bool follow(SCEVUse S) { FoundConstant |= isa(S); return isa(S) || isa(S); } @@ -3109,9 +3183,16 @@ static bool containsConstantInAddMulChain(const SCEV *StartExpr) { } /// Get a canonical multiply expression, or something simpler if possible. -const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, - SCEV::NoWrapFlags OrigFlags, - unsigned Depth) { +SCEVUse ScalarEvolution::getMulExpr(ArrayRef Ops, + SCEV::NoWrapFlags OrigFlags, + unsigned Depth) { + SmallVector Ops2(Ops); + return getMulExpr(Ops2, OrigFlags, Depth); +} + +SCEVUse ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, + SCEV::NoWrapFlags OrigFlags, + unsigned Depth) { assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) && "only nuw or nsw allowed"); assert(!Ops.empty() && "Cannot get empty mul!"); @@ -3133,7 +3214,7 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, return Folded; // Delay expensive flag strengthening until necessary. - auto ComputeFlags = [this, OrigFlags](const ArrayRef Ops) { + auto ComputeFlags = [this, OrigFlags](const ArrayRef Ops) { return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags); }; @@ -3146,7 +3227,9 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, SCEVMulExpr *Mul = static_cast(S); if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags) Mul->setNoWrapFlags(ComputeFlags(Ops)); - return S; + bool IsCanonical = all_of(Ops, [](SCEVUse U) { return U.getFlags() == 0; }); + int UseFlags = IsCanonical ? 0 : 1; + return {S, UseFlags}; } if (const SCEVConstant *LHSC = dyn_cast(Ops[0])) { @@ -3160,10 +3243,10 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of // this transformation should be narrowed down. if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add)) { - const SCEV *LHS = getMulExpr(LHSC, Add->getOperand(0), - SCEV::FlagAnyWrap, Depth + 1); - const SCEV *RHS = getMulExpr(LHSC, Add->getOperand(1), - SCEV::FlagAnyWrap, Depth + 1); + SCEVUse LHS = getMulExpr(LHSC, Add->getOperand(0), SCEV::FlagAnyWrap, + Depth + 1); + SCEVUse RHS = getMulExpr(LHSC, Add->getOperand(1), SCEV::FlagAnyWrap, + Depth + 1); return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1); } @@ -3171,11 +3254,11 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, // If we have a mul by -1 of an add, try distributing the -1 among the // add operands. if (const SCEVAddExpr *Add = dyn_cast(Ops[1])) { - SmallVector NewOps; + SmallVector NewOps; bool AnyFolded = false; - for (const SCEV *AddOp : Add->operands()) { - const SCEV *Mul = getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap, - Depth + 1); + for (SCEVUse AddOp : Add->operands()) { + SCEVUse Mul = + getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap, Depth + 1); if (!isa(Mul)) AnyFolded = true; NewOps.push_back(Mul); } @@ -3183,8 +3266,8 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1); } else if (const auto *AddRec = dyn_cast(Ops[1])) { // Negation preserves a recurrence's no self-wrap property. - SmallVector Operands; - for (const SCEV *AddRecOp : AddRec->operands()) + SmallVector Operands; + for (SCEVUse AddRecOp : AddRec->operands()) Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap, Depth + 1)); // Let M be the minimum representable signed value. AddRec with nsw @@ -3241,7 +3324,7 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, for (; Idx < Ops.size() && isa(Ops[Idx]); ++Idx) { // Scan all of the other operands to this mul and add them to the vector // if they are loop invariant w.r.t. the recurrence. - SmallVector LIOps; + SmallVector LIOps; const SCEVAddRecExpr *AddRec = cast(Ops[Idx]); for (unsigned i = 0, e = Ops.size(); i != e; ++i) if (isAvailableAtLoopEntry(Ops[i], AddRec->getLoop())) { @@ -3253,9 +3336,9 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, // If we found some loop invariants, fold them into the recurrence. if (!LIOps.empty()) { // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step} - SmallVector NewOps; + SmallVector NewOps; NewOps.reserve(AddRec->getNumOperands()); - const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1); + SCEVUse Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1); // If both the mul and addrec are nuw, we can preserve nuw. // If both the mul and addrec are nsw, we can only preserve nsw if either @@ -3277,7 +3360,7 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, } } - const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop(), Flags); + SCEVUse NewRec = getAddRecExpr(NewOps, AddRec->getLoop(), Flags); // If all of the other operands were loop invariant, we are done. if (Ops.size() == 1) return NewRec; @@ -3323,10 +3406,10 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, bool Overflow = false; Type *Ty = AddRec->getType(); bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64; - SmallVector AddRecOps; + SmallVector AddRecOps; for (int x = 0, xe = AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) { - SmallVector SumOps; + SmallVector SumOps; for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) { uint64_t Coeff1 = Choose(x, 2*x - y, Overflow); for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1), @@ -3338,9 +3421,9 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, Coeff = umul_ov(Coeff1, Coeff2, Overflow); else Coeff = Coeff1*Coeff2; - const SCEV *CoeffTerm = getConstant(Ty, Coeff); - const SCEV *Term1 = AddRec->getOperand(y-z); - const SCEV *Term2 = OtherAddRec->getOperand(z); + SCEVUse CoeffTerm = getConstant(Ty, Coeff); + SCEVUse Term1 = AddRec->getOperand(y - z); + SCEVUse Term2 = OtherAddRec->getOperand(z); SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2, SCEV::FlagAnyWrap, Depth + 1)); } @@ -3350,8 +3433,8 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1)); } if (!Overflow) { - const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(), - SCEV::FlagAnyWrap); + SCEVUse NewAddRec = + getAddRecExpr(AddRecOps, AddRec->getLoop(), SCEV::FlagAnyWrap); if (Ops.size() == 2) return NewAddRec; Ops[Idx] = NewAddRec; Ops.erase(Ops.begin() + OtherIdx); --OtherIdx; @@ -3374,8 +3457,7 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, } /// Represents an unsigned remainder expression based on unsigned division. -const SCEV *ScalarEvolution::getURemExpr(const SCEV *LHS, - const SCEV *RHS) { +SCEVUse ScalarEvolution::getURemExpr(SCEVUse LHS, SCEVUse RHS) { assert(getEffectiveSCEVType(LHS->getType()) == getEffectiveSCEVType(RHS->getType()) && "SCEVURemExpr operand types don't match!"); @@ -3396,15 +3478,14 @@ const SCEV *ScalarEvolution::getURemExpr(const SCEV *LHS, } // Fallback to %a == %x urem %y == %x - ((%x udiv %y) * %y) - const SCEV *UDiv = getUDivExpr(LHS, RHS); - const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW); + SCEVUse UDiv = getUDivExpr(LHS, RHS); + SCEVUse Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW); return getMinusSCEV(LHS, Mult, SCEV::FlagNUW); } /// Get a canonical unsigned division expression, or something simpler if /// possible. -const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, - const SCEV *RHS) { +SCEVUse ScalarEvolution::getUDivExpr(SCEVUse LHS, SCEVUse RHS) { assert(!LHS->getType()->isPointerTy() && "SCEVUDivExpr operand can't be pointer!"); assert(LHS->getType() == RHS->getType() && @@ -3412,10 +3493,10 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, FoldingSetNodeID ID; ID.AddInteger(scUDivExpr); - ID.AddPointer(LHS); - ID.AddPointer(RHS); + ID.AddPointer(LHS.getRawPointer()); + ID.AddPointer(RHS.getRawPointer()); void *IP = nullptr; - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) + if (SCEVUse S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; // 0 udiv Y == 0 @@ -3453,8 +3534,8 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy), getZeroExtendExpr(Step, ExtTy), AR->getLoop(), SCEV::FlagAnyWrap)) { - SmallVector Operands; - for (const SCEV *Op : AR->operands()) + SmallVector Operands; + for (SCEVUse Op : AR->operands()) Operands.push_back(getUDivExpr(Op, RHS)); return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW); } @@ -3470,9 +3551,8 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, const APInt &StartInt = StartC->getAPInt(); const APInt &StartRem = StartInt.urem(StepInt); if (StartRem != 0) { - const SCEV *NewLHS = - getAddRecExpr(getConstant(StartInt - StartRem), Step, - AR->getLoop(), SCEV::FlagNW); + SCEVUse NewLHS = getAddRecExpr(getConstant(StartInt - StartRem), + Step, AR->getLoop(), SCEV::FlagNW); if (LHS != NewLHS) { LHS = NewLHS; @@ -3480,10 +3560,10 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, // already cached. ID.clear(); ID.AddInteger(scUDivExpr); - ID.AddPointer(LHS); - ID.AddPointer(RHS); + ID.AddPointer(LHS.getRawPointer()); + ID.AddPointer(RHS.getRawPointer()); IP = nullptr; - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) + if (SCEVUse S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; } } @@ -3491,16 +3571,16 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, } // (A*B)/C --> A*(B/C) if safe and B/C can be folded. if (const SCEVMulExpr *M = dyn_cast(LHS)) { - SmallVector Operands; - for (const SCEV *Op : M->operands()) + SmallVector Operands; + for (SCEVUse Op : M->operands()) Operands.push_back(getZeroExtendExpr(Op, ExtTy)); if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands)) // Find an operand that's safely divisible. for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) { - const SCEV *Op = M->getOperand(i); - const SCEV *Div = getUDivExpr(Op, RHSC); + SCEVUse Op = M->getOperand(i); + SCEVUse Div = getUDivExpr(Op, RHSC); if (!isa(Div) && getMulExpr(Div, RHSC) == Op) { - Operands = SmallVector(M->operands()); + Operands = SmallVector(M->operands()); Operands[i] = Div; return getMulExpr(Operands); } @@ -3523,13 +3603,13 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded. if (const SCEVAddExpr *A = dyn_cast(LHS)) { - SmallVector Operands; - for (const SCEV *Op : A->operands()) + SmallVector Operands; + for (SCEVUse Op : A->operands()) Operands.push_back(getZeroExtendExpr(Op, ExtTy)); if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) { Operands.clear(); for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) { - const SCEV *Op = getUDivExpr(A->getOperand(i), RHS); + SCEVUse Op = getUDivExpr(A->getOperand(i), RHS); if (isa(Op) || getMulExpr(Op, RHS) != A->getOperand(i)) break; @@ -3565,7 +3645,8 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs // changes). Make sure we get a new one. IP = nullptr; - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + if (SCEVUse S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) + return S; SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator), LHS, RHS); UniqueSCEVs.InsertNode(S, IP); @@ -3591,8 +3672,7 @@ APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) { /// possible. There is no representation for an exact udiv in SCEV IR, but we /// can attempt to remove factors from the LHS and RHS. We can't do this when /// it's not exact because the udiv may be clearing bits. -const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS, - const SCEV *RHS) { +SCEVUse ScalarEvolution::getUDivExactExpr(SCEVUse LHS, SCEVUse RHS) { // TODO: we could try to find factors in all sorts of things, but for now we // just deal with u/exact (multiply, constant). See SCEVDivision towards the // end of this file for inspiration. @@ -3606,7 +3686,7 @@ const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS, // first element of the mulexpr. if (const auto *LHSCst = dyn_cast(Mul->getOperand(0))) { if (LHSCst == RHSCst) { - SmallVector Operands(drop_begin(Mul->operands())); + SmallVector Operands(drop_begin(Mul->operands())); return getMulExpr(Operands); } @@ -3619,7 +3699,7 @@ const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS, cast(getConstant(LHSCst->getAPInt().udiv(Factor))); RHSCst = cast(getConstant(RHSCst->getAPInt().udiv(Factor))); - SmallVector Operands; + SmallVector Operands; Operands.push_back(LHSCst); append_range(Operands, Mul->operands().drop_front()); LHS = getMulExpr(Operands); @@ -3633,7 +3713,7 @@ const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS, for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) { if (Mul->getOperand(i) == RHS) { - SmallVector Operands; + SmallVector Operands; append_range(Operands, Mul->operands().take_front(i)); append_range(Operands, Mul->operands().drop_front(i + 1)); return getMulExpr(Operands); @@ -3645,10 +3725,9 @@ const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS, /// Get an add recurrence expression for the specified loop. Simplify the /// expression as much as possible. -const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step, - const Loop *L, - SCEV::NoWrapFlags Flags) { - SmallVector Operands; +SCEVUse ScalarEvolution::getAddRecExpr(SCEVUse Start, SCEVUse Step, + const Loop *L, SCEV::NoWrapFlags Flags) { + SmallVector Operands; Operands.push_back(Start); if (const SCEVAddRecExpr *StepChrec = dyn_cast(Step)) if (StepChrec->getLoop() == L) { @@ -3660,11 +3739,16 @@ const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step, return getAddRecExpr(Operands, L, Flags); } +SCEVUse ScalarEvolution::getAddRecExpr(ArrayRef Operands, + const Loop *L, SCEV::NoWrapFlags Flags) { + SmallVector Ops2(Operands); + return getAddRecExpr(Ops2, L, Flags); +} + /// Get an add recurrence expression for the specified loop. Simplify the /// expression as much as possible. -const SCEV * -ScalarEvolution::getAddRecExpr(SmallVectorImpl &Operands, - const Loop *L, SCEV::NoWrapFlags Flags) { +SCEVUse ScalarEvolution::getAddRecExpr(SmallVectorImpl &Operands, + const Loop *L, SCEV::NoWrapFlags Flags) { if (Operands.size() == 1) return Operands[0]; #ifndef NDEBUG Type *ETy = getEffectiveSCEVType(Operands[0]->getType()); @@ -3698,13 +3782,13 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl &Operands, ? (L->getLoopDepth() < NestedLoop->getLoopDepth()) : (!NestedLoop->contains(L) && DT.dominates(L->getHeader(), NestedLoop->getHeader()))) { - SmallVector NestedOperands(NestedAR->operands()); + SmallVector NestedOperands(NestedAR->operands()); Operands[0] = NestedAR->getStart(); // AddRecs require their operands be loop-invariant with respect to their // loops. Don't perform this transformation if it would break this // requirement. - bool AllInvariant = all_of( - Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); }); + bool AllInvariant = + all_of(Operands, [&](SCEVUse Op) { return isLoopInvariant(Op, L); }); if (AllInvariant) { // Create a recurrence for the outer loop with the same step size. @@ -3715,7 +3799,7 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl &Operands, maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags()); NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags); - AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) { + AllInvariant = all_of(NestedOperands, [&](SCEVUse Op) { return isLoopInvariant(Op, NestedLoop); }); @@ -3739,10 +3823,16 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl &Operands, return getOrCreateAddRecExpr(Operands, L, Flags); } -const SCEV * -ScalarEvolution::getGEPExpr(GEPOperator *GEP, - const SmallVectorImpl &IndexExprs) { - const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand()); +SCEVUse ScalarEvolution::getGEPExpr(GEPOperator *GEP, + ArrayRef IndexExprs, + bool UseCtx) { + return getGEPExpr(GEP, SmallVector(IndexExprs), UseCtx); +} + +SCEVUse ScalarEvolution::getGEPExpr(GEPOperator *GEP, + const SmallVectorImpl &IndexExprs, + bool UseCtx) { + SCEVUse BaseExpr = getSCEV(GEP->getPointerOperand()); // getSCEV(Base)->getType() has the same address space as Base->getType() // because SCEV::getType() preserves the address space. Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType()); @@ -3766,14 +3856,14 @@ ScalarEvolution::getGEPExpr(GEPOperator *GEP, Type *CurTy = GEP->getType(); bool FirstIter = true; - SmallVector Offsets; - for (const SCEV *IndexExpr : IndexExprs) { + SmallVector Offsets; + for (SCEVUse IndexExpr : IndexExprs) { // Compute the (potentially symbolic) offset in bytes for this index. if (StructType *STy = dyn_cast(CurTy)) { // For a struct, add the member offset. ConstantInt *Index = cast(IndexExpr)->getValue(); unsigned FieldNo = Index->getZExtValue(); - const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo); + SCEVUse FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo); Offsets.push_back(FieldOffset); // Update CurTy to the type of the field at Index. @@ -3789,12 +3879,12 @@ ScalarEvolution::getGEPExpr(GEPOperator *GEP, CurTy = GetElementPtrInst::getTypeAtIndex(CurTy, (uint64_t)0); } // For an array, add the element offset, explicitly scaled. - const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy); + SCEVUse ElementSize = getSizeOfExpr(IntIdxTy, CurTy); // Getelementptr indices are signed. IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy); // Multiply the index by the element size to compute the element offset. - const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap); + SCEVUse LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap); Offsets.push_back(LocalOffset); } } @@ -3804,36 +3894,45 @@ ScalarEvolution::getGEPExpr(GEPOperator *GEP, return BaseExpr; // Add the offsets together, assuming nsw if inbounds. - const SCEV *Offset = getAddExpr(Offsets, OffsetWrap); + SCEVUse Offset = getAddExpr(Offsets, OffsetWrap); // Add the base address and the offset. We cannot use the nsw flag, as the // base address is unsigned. However, if we know that the offset is // non-negative, we can use nuw. bool NUW = NW.hasNoUnsignedWrap() || (NW.hasNoUnsignedSignedWrap() && isKnownNonNegative(Offset)); SCEV::NoWrapFlags BaseWrap = NUW ? SCEV::FlagNUW : SCEV::FlagAnyWrap; - auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap); + auto GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap); assert(BaseExpr->getType() == GEPExpr->getType() && "GEP should not change type mid-flight."); + if (UseCtx && BaseWrap != SCEV::FlagNUW && GEP->isInBounds() && + isKnownNonNegative(Offset)) + GEPExpr = SCEVUse(&*GEPExpr, 2); return GEPExpr; } SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType, - ArrayRef Ops) { + ArrayRef Ops) { FoldingSetNodeID ID; ID.AddInteger(SCEVType); - for (const SCEV *Op : Ops) - ID.AddPointer(Op); + for (SCEVUse Op : Ops) + ID.AddPointer(Op.getRawPointer()); void *IP = nullptr; return UniqueSCEVs.FindNodeOrInsertPos(ID, IP); } -const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) { +SCEVUse ScalarEvolution::getAbsExpr(SCEVUse Op, bool IsNSW) { SCEV::NoWrapFlags Flags = IsNSW ? SCEV::FlagNSW : SCEV::FlagAnyWrap; return getSMaxExpr(Op, getNegativeSCEV(Op, Flags)); } -const SCEV *ScalarEvolution::getMinMaxExpr(SCEVTypes Kind, - SmallVectorImpl &Ops) { +SCEVUse ScalarEvolution::getMinMaxExpr(SCEVTypes Kind, + ArrayRef Ops) { + SmallVector Ops2(Ops); + return getMinMaxExpr(Kind, Ops2); +} + +SCEVUse ScalarEvolution::getMinMaxExpr(SCEVTypes Kind, + SmallVectorImpl &Ops) { assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!"); assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!"); if (Ops.size() == 1) return Ops[0]; @@ -3885,7 +3984,7 @@ const SCEV *ScalarEvolution::getMinMaxExpr(SCEVTypes Kind, return Folded; // Check if we have created the same expression before. - if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) { + if (SCEVUse S = findExistingSCEVInCache(Kind, Ops)) { return S; } @@ -3943,13 +4042,13 @@ const SCEV *ScalarEvolution::getMinMaxExpr(SCEVTypes Kind, // already have one, otherwise create a new one. FoldingSetNodeID ID; ID.AddInteger(Kind); - for (const SCEV *Op : Ops) - ID.AddPointer(Op); + for (SCEVUse Op : Ops) + ID.AddPointer(Op.getRawPointer()); void *IP = nullptr; - const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP); + SCEVUse ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP); if (ExistingSCEV) return ExistingSCEV; - const SCEV **O = SCEVAllocator.Allocate(Ops.size()); + SCEVUse *O = SCEVAllocator.Allocate(Ops.size()); std::uninitialized_copy(Ops.begin(), Ops.end(), O); SCEV *S = new (SCEVAllocator) SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size()); @@ -3963,14 +4062,14 @@ namespace { class SCEVSequentialMinMaxDeduplicatingVisitor final : public SCEVVisitor> { - using RetVal = std::optional; + std::optional> { + using RetVal = std::optional; using Base = SCEVVisitor; ScalarEvolution &SE; const SCEVTypes RootKind; // Must be a sequential min/max expression. const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind. - SmallPtrSet SeenOps; + SmallPtrSet SeenOps; bool canRecurseInto(SCEVTypes Kind) const { // We can only recurse into the SCEV expression of the same effective type @@ -3978,7 +4077,7 @@ class SCEVSequentialMinMaxDeduplicatingVisitor final return RootKind == Kind || NonSequentialRootKind == Kind; }; - RetVal visitAnyMinMaxExpr(const SCEV *S) { + RetVal visitAnyMinMaxExpr(SCEVUse S) { assert((isa(S) || isa(S)) && "Only for min/max expressions."); SCEVTypes Kind = S->getSCEVType(); @@ -3987,7 +4086,7 @@ class SCEVSequentialMinMaxDeduplicatingVisitor final return S; auto *NAry = cast(S); - SmallVector NewOps; + SmallVector NewOps; bool Changed = visit(Kind, NAry->operands(), NewOps); if (!Changed) @@ -4000,7 +4099,7 @@ class SCEVSequentialMinMaxDeduplicatingVisitor final : SE.getMinMaxExpr(Kind, NewOps); } - RetVal visit(const SCEV *S) { + RetVal visit(SCEVUse S) { // Has the whole operand been seen already? if (!SeenOps.insert(S).second) return std::nullopt; @@ -4015,13 +4114,13 @@ class SCEVSequentialMinMaxDeduplicatingVisitor final SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType( RootKind)) {} - bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef OrigOps, - SmallVectorImpl &NewOps) { + bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef OrigOps, + SmallVectorImpl &NewOps) { bool Changed = false; - SmallVector Ops; + SmallVector Ops; Ops.reserve(OrigOps.size()); - for (const SCEV *Op : OrigOps) { + for (SCEVUse Op : OrigOps) { RetVal NewOp = visit(Op); if (NewOp != Op) Changed = true; @@ -4124,7 +4223,7 @@ struct SCEVPoisonCollector { SCEVPoisonCollector(bool LookThroughMaybePoisonBlocking) : LookThroughMaybePoisonBlocking(LookThroughMaybePoisonBlocking) {} - bool follow(const SCEV *S) { + bool follow(SCEVUse S) { if (!LookThroughMaybePoisonBlocking && !scevUnconditionallyPropagatesPoisonFromOperands(S->getSCEVType())) return false; @@ -4140,7 +4239,7 @@ struct SCEVPoisonCollector { } // namespace /// Return true if V is poison given that AssumedPoison is already poison. -static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) { +static bool impliesPoison(SCEVUse AssumedPoison, SCEVUse S) { // First collect all SCEVs that might result in AssumedPoison to be poison. // We need to look through potentially poison-blocking operations here, // because we want to find all SCEVs that *might* result in poison, not only @@ -4165,7 +4264,7 @@ static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) { } void ScalarEvolution::getPoisonGeneratingValues( - SmallPtrSetImpl &Result, const SCEV *S) { + SmallPtrSetImpl &Result, SCEVUse S) { SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ false); visitAll(S, PC); for (const SCEVUnknown *SU : PC.MaybePoison) @@ -4173,7 +4272,7 @@ void ScalarEvolution::getPoisonGeneratingValues( } bool ScalarEvolution::canReuseInstruction( - const SCEV *S, Instruction *I, + SCEVUse S, Instruction *I, SmallVectorImpl &DropPoisonGeneratingInsts) { // If the instruction cannot be poison, it's always safe to reuse. if (programUndefinedIfPoison(I)) @@ -4234,9 +4333,9 @@ bool ScalarEvolution::canReuseInstruction( return true; } -const SCEV * +SCEVUse ScalarEvolution::getSequentialMinMaxExpr(SCEVTypes Kind, - SmallVectorImpl &Ops) { + SmallVectorImpl &Ops) { assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) && "Not a SCEVSequentialMinMaxExpr!"); assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!"); @@ -4257,7 +4356,7 @@ ScalarEvolution::getSequentialMinMaxExpr(SCEVTypes Kind, // so we can *NOT* do any kind of sorting of the expressions! // Check if we have created the same expression before. - if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) + if (SCEVUse S = findExistingSCEVInCache(Kind, Ops)) return S; // FIXME: there are *some* simplifications that we can do here. @@ -4291,7 +4390,7 @@ ScalarEvolution::getSequentialMinMaxExpr(SCEVTypes Kind, return getSequentialMinMaxExpr(Kind, Ops); } - const SCEV *SaturationPoint; + SCEVUse SaturationPoint; ICmpInst::Predicate Pred; switch (Kind) { case scSequentialUMinExpr: @@ -4309,7 +4408,7 @@ ScalarEvolution::getSequentialMinMaxExpr(SCEVTypes Kind, if (::impliesPoison(Ops[i], Ops[i - 1]) || isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1], SaturationPoint)) { - SmallVector SeqOps = {Ops[i - 1], Ops[i]}; + SmallVector SeqOps = {Ops[i - 1], Ops[i]}; Ops[i - 1] = getMinMaxExpr( SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(Kind), SeqOps); @@ -4328,14 +4427,14 @@ ScalarEvolution::getSequentialMinMaxExpr(SCEVTypes Kind, // already have one, otherwise create a new one. FoldingSetNodeID ID; ID.AddInteger(Kind); - for (const SCEV *Op : Ops) - ID.AddPointer(Op); + for (SCEVUse Op : Ops) + ID.AddPointer(Op.getRawPointer()); void *IP = nullptr; - const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP); + SCEVUse ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP); if (ExistingSCEV) return ExistingSCEV; - const SCEV **O = SCEVAllocator.Allocate(Ops.size()); + SCEVUse *O = SCEVAllocator.Allocate(Ops.size()); std::uninitialized_copy(Ops.begin(), Ops.end(), O); SCEV *S = new (SCEVAllocator) SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size()); @@ -4345,65 +4444,62 @@ ScalarEvolution::getSequentialMinMaxExpr(SCEVTypes Kind, return S; } -const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) { - SmallVector Ops = {LHS, RHS}; +SCEVUse ScalarEvolution::getSMaxExpr(SCEVUse LHS, SCEVUse RHS) { + SmallVector Ops = {LHS, RHS}; return getSMaxExpr(Ops); } -const SCEV *ScalarEvolution::getSMaxExpr(SmallVectorImpl &Ops) { +SCEVUse ScalarEvolution::getSMaxExpr(SmallVectorImpl &Ops) { return getMinMaxExpr(scSMaxExpr, Ops); } -const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, const SCEV *RHS) { - SmallVector Ops = {LHS, RHS}; +SCEVUse ScalarEvolution::getUMaxExpr(SCEVUse LHS, SCEVUse RHS) { + SmallVector Ops = {LHS, RHS}; return getUMaxExpr(Ops); } -const SCEV *ScalarEvolution::getUMaxExpr(SmallVectorImpl &Ops) { +SCEVUse ScalarEvolution::getUMaxExpr(SmallVectorImpl &Ops) { return getMinMaxExpr(scUMaxExpr, Ops); } -const SCEV *ScalarEvolution::getSMinExpr(const SCEV *LHS, - const SCEV *RHS) { - SmallVector Ops = { LHS, RHS }; +SCEVUse ScalarEvolution::getSMinExpr(SCEVUse LHS, SCEVUse RHS) { + SmallVector Ops = {LHS, RHS}; return getSMinExpr(Ops); } -const SCEV *ScalarEvolution::getSMinExpr(SmallVectorImpl &Ops) { +SCEVUse ScalarEvolution::getSMinExpr(SmallVectorImpl &Ops) { return getMinMaxExpr(scSMinExpr, Ops); } -const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, const SCEV *RHS, - bool Sequential) { - SmallVector Ops = { LHS, RHS }; +SCEVUse ScalarEvolution::getUMinExpr(SCEVUse LHS, SCEVUse RHS, + bool Sequential) { + SmallVector Ops = {LHS, RHS}; return getUMinExpr(Ops, Sequential); } -const SCEV *ScalarEvolution::getUMinExpr(SmallVectorImpl &Ops, - bool Sequential) { +SCEVUse ScalarEvolution::getUMinExpr(SmallVectorImpl &Ops, + bool Sequential) { return Sequential ? getSequentialMinMaxExpr(scSequentialUMinExpr, Ops) : getMinMaxExpr(scUMinExpr, Ops); } -const SCEV * -ScalarEvolution::getSizeOfExpr(Type *IntTy, TypeSize Size) { - const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue()); +SCEVUse ScalarEvolution::getSizeOfExpr(Type *IntTy, TypeSize Size) { + SCEVUse Res = getConstant(IntTy, Size.getKnownMinValue()); if (Size.isScalable()) Res = getMulExpr(Res, getVScale(IntTy)); return Res; } -const SCEV *ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) { +SCEVUse ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) { return getSizeOfExpr(IntTy, getDataLayout().getTypeAllocSize(AllocTy)); } -const SCEV *ScalarEvolution::getStoreSizeOfExpr(Type *IntTy, Type *StoreTy) { +SCEVUse ScalarEvolution::getStoreSizeOfExpr(Type *IntTy, Type *StoreTy) { return getSizeOfExpr(IntTy, getDataLayout().getTypeStoreSize(StoreTy)); } -const SCEV *ScalarEvolution::getOffsetOfExpr(Type *IntTy, - StructType *STy, - unsigned FieldNo) { +SCEVUse ScalarEvolution::getOffsetOfExpr(Type *IntTy, StructType *STy, + unsigned FieldNo) { // We can bypass creating a target-independent constant expression and then // folding it back into a ConstantInt. This is just a compile-time // optimization. @@ -4413,7 +4509,7 @@ const SCEV *ScalarEvolution::getOffsetOfExpr(Type *IntTy, return getConstant(IntTy, SL->getElementOffset(FieldNo)); } -const SCEV *ScalarEvolution::getUnknown(Value *V) { +SCEVUse ScalarEvolution::getUnknown(Value *V) { // Don't attempt to do anything other than create a SCEVUnknown object // here. createSCEV only calls getUnknown after checking for all other // interesting possibilities, and any other code that calls getUnknown @@ -4475,8 +4571,7 @@ Type *ScalarEvolution::getWiderType(Type *T1, Type *T2) const { return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2; } -bool ScalarEvolution::instructionCouldExistWithOperands(const SCEV *A, - const SCEV *B) { +bool ScalarEvolution::instructionCouldExistWithOperands(SCEVUse A, SCEVUse B) { /// For a valid use point to exist, the defining scope of one operand /// must dominate the other. bool PreciseA, PreciseB; @@ -4489,12 +4584,10 @@ bool ScalarEvolution::instructionCouldExistWithOperands(const SCEV *A, DT.dominates(ScopeB, ScopeA); } -const SCEV *ScalarEvolution::getCouldNotCompute() { - return CouldNotCompute.get(); -} +SCEVUse ScalarEvolution::getCouldNotCompute() { return CouldNotCompute.get(); } -bool ScalarEvolution::checkValidity(const SCEV *S) const { - bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) { +bool ScalarEvolution::checkValidity(SCEVUse S) const { + bool ContainsNulls = SCEVExprContains(S, [](SCEVUse S) { auto *SU = dyn_cast(S); return SU && SU->getValue() == nullptr; }); @@ -4502,20 +4595,20 @@ bool ScalarEvolution::checkValidity(const SCEV *S) const { return !ContainsNulls; } -bool ScalarEvolution::containsAddRecurrence(const SCEV *S) { +bool ScalarEvolution::containsAddRecurrence(SCEVUse S) { HasRecMapType::iterator I = HasRecMap.find(S); if (I != HasRecMap.end()) return I->second; bool FoundAddRec = - SCEVExprContains(S, [](const SCEV *S) { return isa(S); }); + SCEVExprContains(S, [](SCEVUse S) { return isa(S); }); HasRecMap.insert({S, FoundAddRec}); return FoundAddRec; } /// Return the ValueOffsetPair set for \p S. \p S can be represented /// by the value and offset from any ValueOffsetPair in the set. -ArrayRef ScalarEvolution::getSCEVValues(const SCEV *S) { +ArrayRef ScalarEvolution::getSCEVValues(SCEVUse S) { ExprValueMapType::iterator SI = ExprValueMap.find_as(S); if (SI == ExprValueMap.end()) return {}; @@ -4536,7 +4629,7 @@ void ScalarEvolution::eraseValueFromMap(Value *V) { } } -void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) { +void ScalarEvolution::insertValueToMap(Value *V, SCEVUse S) { // A recursive query may have already computed the SCEV. It should be // equivalent, but may not necessarily be exactly the same, e.g. due to lazily // inferred nowrap flags. @@ -4549,20 +4642,20 @@ void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) { /// Return an existing SCEV if it exists, otherwise analyze the expression and /// create a new one. -const SCEV *ScalarEvolution::getSCEV(Value *V) { +SCEVUse ScalarEvolution::getSCEV(Value *V, bool UseCtx) { assert(isSCEVable(V->getType()) && "Value is not SCEVable!"); - if (const SCEV *S = getExistingSCEV(V)) + if (SCEVUse S = getExistingSCEV(V)) return S; - return createSCEVIter(V); + return createSCEVIter(V, UseCtx); } -const SCEV *ScalarEvolution::getExistingSCEV(Value *V) { +SCEVUse ScalarEvolution::getExistingSCEV(Value *V) { assert(isSCEVable(V->getType()) && "Value is not SCEVable!"); ValueExprMapType::iterator I = ValueExprMap.find_as(V); if (I != ValueExprMap.end()) { - const SCEV *S = I->second; + SCEVUse S = I->second; assert(checkValidity(S) && "existing SCEV has not been properly invalidated"); return S; @@ -4571,8 +4664,7 @@ const SCEV *ScalarEvolution::getExistingSCEV(Value *V) { } /// Return a SCEV corresponding to -V = -1*V -const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V, - SCEV::NoWrapFlags Flags) { +SCEVUse ScalarEvolution::getNegativeSCEV(SCEVUse V, SCEV::NoWrapFlags Flags) { if (const SCEVConstant *VC = dyn_cast(V)) return getConstant( cast(ConstantExpr::getNeg(VC->getValue()))); @@ -4583,7 +4675,7 @@ const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V, } /// If Expr computes ~A, return A else return nullptr -static const SCEV *MatchNotExpr(const SCEV *Expr) { +static SCEVUse MatchNotExpr(SCEVUse Expr) { const SCEVAddExpr *Add = dyn_cast(Expr); if (!Add || Add->getNumOperands() != 2 || !Add->getOperand(0)->isAllOnesValue()) @@ -4598,7 +4690,7 @@ static const SCEV *MatchNotExpr(const SCEV *Expr) { } /// Return a SCEV corresponding to ~V = -1-V -const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) { +SCEVUse ScalarEvolution::getNotSCEV(SCEVUse V) { assert(!V->getType()->isPointerTy() && "Can't negate pointer"); if (const SCEVConstant *VC = dyn_cast(V)) @@ -4608,17 +4700,17 @@ const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) { // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y) if (const SCEVMinMaxExpr *MME = dyn_cast(V)) { auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) { - SmallVector MatchedOperands; - for (const SCEV *Operand : MME->operands()) { - const SCEV *Matched = MatchNotExpr(Operand); + SmallVector MatchedOperands; + for (SCEVUse Operand : MME->operands()) { + SCEVUse Matched = MatchNotExpr(Operand); if (!Matched) - return (const SCEV *)nullptr; + return (SCEVUse) nullptr; MatchedOperands.push_back(Matched); } return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()), MatchedOperands); }; - if (const SCEV *Replaced = MatchMinMaxNegation(MME)) + if (SCEVUse Replaced = MatchMinMaxNegation(MME)) return Replaced; } @@ -4627,12 +4719,12 @@ const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) { return getMinusSCEV(getMinusOne(Ty), V); } -const SCEV *ScalarEvolution::removePointerBase(const SCEV *P) { +SCEVUse ScalarEvolution::removePointerBase(SCEVUse P) { assert(P->getType()->isPointerTy()); if (auto *AddRec = dyn_cast(P)) { // The base of an AddRec is the first operand. - SmallVector Ops{AddRec->operands()}; + SmallVector Ops{AddRec->operands()}; Ops[0] = removePointerBase(Ops[0]); // Don't try to transfer nowrap flags for now. We could in some cases // (for example, if pointer operand of the AddRec is a SCEVUnknown). @@ -4640,9 +4732,9 @@ const SCEV *ScalarEvolution::removePointerBase(const SCEV *P) { } if (auto *Add = dyn_cast(P)) { // The base of an Add is the pointer operand. - SmallVector Ops{Add->operands()}; - const SCEV **PtrOp = nullptr; - for (const SCEV *&AddOp : Ops) { + SmallVector Ops{Add->operands()}; + SCEVUse *PtrOp = nullptr; + for (SCEVUse &AddOp : Ops) { if (AddOp->getType()->isPointerTy()) { assert(!PtrOp && "Cannot have multiple pointer ops"); PtrOp = &AddOp; @@ -4657,9 +4749,8 @@ const SCEV *ScalarEvolution::removePointerBase(const SCEV *P) { return getZero(P->getType()); } -const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS, - SCEV::NoWrapFlags Flags, - unsigned Depth) { +SCEVUse ScalarEvolution::getMinusSCEV(SCEVUse LHS, SCEVUse RHS, + SCEV::NoWrapFlags Flags, unsigned Depth) { // Fast path: X - X --> 0. if (LHS == RHS) return getZero(LHS->getType()); @@ -4707,8 +4798,8 @@ const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS, return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth); } -const SCEV *ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, Type *Ty, - unsigned Depth) { +SCEVUse ScalarEvolution::getTruncateOrZeroExtend(SCEVUse V, Type *Ty, + unsigned Depth) { Type *SrcTy = V->getType(); assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && "Cannot truncate or zero extend with non-integer arguments!"); @@ -4719,8 +4810,8 @@ const SCEV *ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, Type *Ty, return getZeroExtendExpr(V, Ty, Depth); } -const SCEV *ScalarEvolution::getTruncateOrSignExtend(const SCEV *V, Type *Ty, - unsigned Depth) { +SCEVUse ScalarEvolution::getTruncateOrSignExtend(SCEVUse V, Type *Ty, + unsigned Depth) { Type *SrcTy = V->getType(); assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && "Cannot truncate or zero extend with non-integer arguments!"); @@ -4731,8 +4822,7 @@ const SCEV *ScalarEvolution::getTruncateOrSignExtend(const SCEV *V, Type *Ty, return getSignExtendExpr(V, Ty, Depth); } -const SCEV * -ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, Type *Ty) { +SCEVUse ScalarEvolution::getNoopOrZeroExtend(SCEVUse V, Type *Ty) { Type *SrcTy = V->getType(); assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && "Cannot noop or zero extend with non-integer arguments!"); @@ -4743,8 +4833,7 @@ ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, Type *Ty) { return getZeroExtendExpr(V, Ty); } -const SCEV * -ScalarEvolution::getNoopOrSignExtend(const SCEV *V, Type *Ty) { +SCEVUse ScalarEvolution::getNoopOrSignExtend(SCEVUse V, Type *Ty) { Type *SrcTy = V->getType(); assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && "Cannot noop or sign extend with non-integer arguments!"); @@ -4755,8 +4844,7 @@ ScalarEvolution::getNoopOrSignExtend(const SCEV *V, Type *Ty) { return getSignExtendExpr(V, Ty); } -const SCEV * -ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, Type *Ty) { +SCEVUse ScalarEvolution::getNoopOrAnyExtend(SCEVUse V, Type *Ty) { Type *SrcTy = V->getType(); assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && "Cannot noop or any extend with non-integer arguments!"); @@ -4767,8 +4855,7 @@ ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, Type *Ty) { return getAnyExtendExpr(V, Ty); } -const SCEV * -ScalarEvolution::getTruncateOrNoop(const SCEV *V, Type *Ty) { +SCEVUse ScalarEvolution::getTruncateOrNoop(SCEVUse V, Type *Ty) { Type *SrcTy = V->getType(); assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && "Cannot truncate or noop with non-integer arguments!"); @@ -4779,10 +4866,9 @@ ScalarEvolution::getTruncateOrNoop(const SCEV *V, Type *Ty) { return getTruncateExpr(V, Ty); } -const SCEV *ScalarEvolution::getUMaxFromMismatchedTypes(const SCEV *LHS, - const SCEV *RHS) { - const SCEV *PromotedLHS = LHS; - const SCEV *PromotedRHS = RHS; +SCEVUse ScalarEvolution::getUMaxFromMismatchedTypes(SCEVUse LHS, SCEVUse RHS) { + SCEVUse PromotedLHS = LHS; + SCEVUse PromotedRHS = RHS; if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType())) PromotedRHS = getZeroExtendExpr(RHS, LHS->getType()); @@ -4792,15 +4878,20 @@ const SCEV *ScalarEvolution::getUMaxFromMismatchedTypes(const SCEV *LHS, return getUMaxExpr(PromotedLHS, PromotedRHS); } -const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS, - const SCEV *RHS, - bool Sequential) { - SmallVector Ops = { LHS, RHS }; +SCEVUse ScalarEvolution::getUMinFromMismatchedTypes(SCEVUse LHS, SCEVUse RHS, + bool Sequential) { + SmallVector Ops = {LHS, RHS}; return getUMinFromMismatchedTypes(Ops, Sequential); } -const SCEV * -ScalarEvolution::getUMinFromMismatchedTypes(SmallVectorImpl &Ops, +SCEVUse ScalarEvolution::getUMinFromMismatchedTypes(ArrayRef Ops, + bool Sequential) { + SmallVector Ops2(Ops); + return getUMinFromMismatchedTypes(Ops2, Sequential); +} + +SCEVUse +ScalarEvolution::getUMinFromMismatchedTypes(SmallVectorImpl &Ops, bool Sequential) { assert(!Ops.empty() && "At least one operand must be!"); // Trivial case. @@ -4809,7 +4900,7 @@ ScalarEvolution::getUMinFromMismatchedTypes(SmallVectorImpl &Ops, // Find the max type first. Type *MaxType = nullptr; - for (const auto *S : Ops) + for (const auto S : Ops) if (MaxType) MaxType = getWiderType(MaxType, S->getType()); else @@ -4817,15 +4908,15 @@ ScalarEvolution::getUMinFromMismatchedTypes(SmallVectorImpl &Ops, assert(MaxType && "Failed to find maximum type!"); // Extend all ops to max type. - SmallVector PromotedOps; - for (const auto *S : Ops) + SmallVector PromotedOps; + for (const auto S : Ops) PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType)); // Generate umin. return getUMinExpr(PromotedOps, Sequential); } -const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) { +SCEVUse ScalarEvolution::getPointerBase(SCEVUse V) { // A pointer operand may evaluate to a nonpointer expression, such as null. if (!V->getType()->isPointerTy()) return V; @@ -4834,8 +4925,8 @@ const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) { if (auto *AddRec = dyn_cast(V)) { V = AddRec->getStart(); } else if (auto *Add = dyn_cast(V)) { - const SCEV *PtrOp = nullptr; - for (const SCEV *AddOp : Add->operands()) { + SCEVUse PtrOp = nullptr; + for (SCEVUse AddOp : Add->operands()) { if (AddOp->getType()->isPointerTy()) { assert(!PtrOp && "Cannot have multiple pointer ops"); PtrOp = AddOp; @@ -4869,10 +4960,10 @@ namespace { /// If SCEV contains non-invariant unknown SCEV rewrite cannot be done. class SCEVInitRewriter : public SCEVRewriteVisitor { public: - static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE, - bool IgnoreOtherLoops = true) { + static SCEVUse rewrite(SCEVUse S, const Loop *L, ScalarEvolution &SE, + bool IgnoreOtherLoops = true) { SCEVInitRewriter Rewriter(L, SE); - const SCEV *Result = Rewriter.visit(S); + SCEVUse Result = Rewriter.visit(S); if (Rewriter.hasSeenLoopVariantSCEVUnknown()) return SE.getCouldNotCompute(); return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops @@ -4880,13 +4971,13 @@ class SCEVInitRewriter : public SCEVRewriteVisitor { : Result; } - const SCEV *visitUnknown(const SCEVUnknown *Expr) { + SCEVUse visitUnknown(const SCEVUnknown *Expr) { if (!SE.isLoopInvariant(Expr, L)) SeenLoopVariantSCEVUnknown = true; return Expr; } - const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { + SCEVUse visitAddRecExpr(const SCEVAddRecExpr *Expr) { // Only re-write AddRecExprs for this loop. if (Expr->getLoop() == L) return Expr->getStart(); @@ -4913,21 +5004,21 @@ class SCEVInitRewriter : public SCEVRewriteVisitor { /// If SCEV contains non-invariant unknown SCEV rewrite cannot be done. class SCEVPostIncRewriter : public SCEVRewriteVisitor { public: - static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) { + static SCEVUse rewrite(SCEVUse S, const Loop *L, ScalarEvolution &SE) { SCEVPostIncRewriter Rewriter(L, SE); - const SCEV *Result = Rewriter.visit(S); + SCEVUse Result = Rewriter.visit(S); return Rewriter.hasSeenLoopVariantSCEVUnknown() ? SE.getCouldNotCompute() : Result; } - const SCEV *visitUnknown(const SCEVUnknown *Expr) { + SCEVUse visitUnknown(const SCEVUnknown *Expr) { if (!SE.isLoopInvariant(Expr, L)) SeenLoopVariantSCEVUnknown = true; return Expr; } - const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { + SCEVUse visitAddRecExpr(const SCEVAddRecExpr *Expr) { // Only re-write AddRecExprs for this loop. if (Expr->getLoop() == L) return Expr->getPostIncExpr(SE); @@ -4954,8 +5045,7 @@ class SCEVPostIncRewriter : public SCEVRewriteVisitor { class SCEVBackedgeConditionFolder : public SCEVRewriteVisitor { public: - static const SCEV *rewrite(const SCEV *S, const Loop *L, - ScalarEvolution &SE) { + static SCEVUse rewrite(SCEVUse S, const Loop *L, ScalarEvolution &SE) { bool IsPosBECond = false; Value *BECond = nullptr; if (BasicBlock *Latch = L->getLoopLatch()) { @@ -4973,8 +5063,8 @@ class SCEVBackedgeConditionFolder return Rewriter.visit(S); } - const SCEV *visitUnknown(const SCEVUnknown *Expr) { - const SCEV *Result = Expr; + SCEVUse visitUnknown(const SCEVUnknown *Expr) { + SCEVUse Result = Expr; bool InvariantF = SE.isLoopInvariant(Expr, L); if (!InvariantF) { @@ -4982,7 +5072,7 @@ class SCEVBackedgeConditionFolder switch (I->getOpcode()) { case Instruction::Select: { SelectInst *SI = cast(I); - std::optional Res = + std::optional Res = compareWithBackedgeCondition(SI->getCondition()); if (Res) { bool IsOne = cast(*Res)->getValue()->isOne(); @@ -4991,7 +5081,7 @@ class SCEVBackedgeConditionFolder break; } default: { - std::optional Res = compareWithBackedgeCondition(I); + std::optional Res = compareWithBackedgeCondition(I); if (Res) Result = *Res; break; @@ -5007,7 +5097,7 @@ class SCEVBackedgeConditionFolder : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond), IsPositiveBECond(IsPosBECond) {} - std::optional compareWithBackedgeCondition(Value *IC); + std::optional compareWithBackedgeCondition(Value *IC); const Loop *L; /// Loop back condition. @@ -5016,7 +5106,7 @@ class SCEVBackedgeConditionFolder bool IsPositiveBECond; }; -std::optional +std::optional SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) { // If value matches the backedge condition for loop latch, @@ -5030,21 +5120,20 @@ SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) { class SCEVShiftRewriter : public SCEVRewriteVisitor { public: - static const SCEV *rewrite(const SCEV *S, const Loop *L, - ScalarEvolution &SE) { + static SCEVUse rewrite(SCEVUse S, const Loop *L, ScalarEvolution &SE) { SCEVShiftRewriter Rewriter(L, SE); - const SCEV *Result = Rewriter.visit(S); + SCEVUse Result = Rewriter.visit(S); return Rewriter.isValid() ? Result : SE.getCouldNotCompute(); } - const SCEV *visitUnknown(const SCEVUnknown *Expr) { + SCEVUse visitUnknown(const SCEVUnknown *Expr) { // Only allow AddRecExprs for this loop. if (!SE.isLoopInvariant(Expr, L)) Valid = false; return Expr; } - const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { + SCEVUse visitAddRecExpr(const SCEVAddRecExpr *Expr) { if (Expr->getLoop() == L && Expr->isAffine()) return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE)); Valid = false; @@ -5073,7 +5162,7 @@ ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) { SCEV::NoWrapFlags Result = SCEV::FlagAnyWrap; if (!AR->hasNoSelfWrap()) { - const SCEV *BECount = getConstantMaxBackedgeTakenCount(AR->getLoop()); + SCEVUse BECount = getConstantMaxBackedgeTakenCount(AR->getLoop()); if (const SCEVConstant *BECountMax = dyn_cast(BECount)) { ConstantRange StepCR = getSignedRange(AR->getStepRecurrence(*this)); const APInt &BECountAP = BECountMax->getAPInt(); @@ -5121,7 +5210,7 @@ ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) { if (!SignedWrapViaInductionTried.insert(AR).second) return Result; - const SCEV *Step = AR->getStepRecurrence(*this); + SCEVUse Step = AR->getStepRecurrence(*this); const Loop *L = AR->getLoop(); // Check whether the backedge-taken count is SCEVCouldNotCompute. @@ -5132,7 +5221,7 @@ ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) { // in infinite recursion. In the later case, the analysis code will // cope with a conservative value, and it will take care to purge // that value once it has finished. - const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L); + SCEVUse MaxBECount = getConstantMaxBackedgeTakenCount(L); // Normally, in the cases we can prove no-overflow via a // backedge guarding condition, we can also compute a backedge @@ -5151,8 +5240,7 @@ ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) { // start value and the backedge is guarded by a comparison with the post-inc // value, the addrec is safe. ICmpInst::Predicate Pred; - const SCEV *OverflowLimit = - getSignedOverflowLimitForStep(Step, &Pred, this); + SCEVUse OverflowLimit = getSignedOverflowLimitForStep(Step, &Pred, this); if (OverflowLimit && (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) || isKnownOnEveryIteration(Pred, AR, OverflowLimit))) { @@ -5174,7 +5262,7 @@ ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) { if (!UnsignedWrapViaInductionTried.insert(AR).second) return Result; - const SCEV *Step = AR->getStepRecurrence(*this); + SCEVUse Step = AR->getStepRecurrence(*this); unsigned BitWidth = getTypeSizeInBits(AR->getType()); const Loop *L = AR->getLoop(); @@ -5186,7 +5274,7 @@ ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) { // in infinite recursion. In the later case, the analysis code will // cope with a conservative value, and it will take care to purge // that value once it has finished. - const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L); + SCEVUse MaxBECount = getConstantMaxBackedgeTakenCount(L); // Normally, in the cases we can prove no-overflow via a // backedge guarding condition, we can also compute a backedge @@ -5205,8 +5293,8 @@ ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) { // start value and the backedge is guarded by a comparison with the post-inc // value, the addrec is safe. if (isKnownPositive(Step)) { - const SCEV *N = getConstant(APInt::getMinValue(BitWidth) - - getUnsignedRangeMax(Step)); + SCEVUse N = + getConstant(APInt::getMinValue(BitWidth) - getUnsignedRangeMax(Step)); if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, AR, N) || isKnownOnEveryIteration(ICmpInst::ICMP_ULT, AR, N)) { Result = setFlags(Result, SCEV::FlagNUW); @@ -5355,7 +5443,7 @@ static std::optional MatchBinaryOp(Value *V, const DataLayout &DL, /// we return the type of the truncation operation, and indicate whether the /// truncated type should be treated as signed/unsigned by setting /// \p Signed to true/false, respectively. -static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI, +static Type *isSimpleCastedPHI(SCEVUse Op, const SCEVUnknown *SymbolicPHI, bool &Signed, ScalarEvolution &SE) { // The case where Op == SymbolicPHI (that is, with no type conversions on // the way) is handled by the regular add recurrence creating logic and @@ -5384,7 +5472,7 @@ static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI, : dyn_cast(ZExt->getOperand()); if (!Trunc) return nullptr; - const SCEV *X = Trunc->getOperand(); + SCEVUse X = Trunc->getOperand(); if (X != SymbolicPHI) return nullptr; Signed = SExt != nullptr; @@ -5453,8 +5541,9 @@ static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) { // which correspond to a phi->trunc->add->sext/zext->phi update chain. // // 3) Outline common code with createAddRecFromPHI to avoid duplication. -std::optional>> -ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) { +std::optional>> +ScalarEvolution::createAddRecFromPHIWithCastsImpl( + const SCEVUnknown *SymbolicPHI) { SmallVector Predicates; // *** Part1: Analyze if we have a phi-with-cast pattern for which we can @@ -5487,7 +5576,7 @@ ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI if (!BEValueV || !StartValueV) return std::nullopt; - const SCEV *BEValue = getSCEV(BEValueV); + SCEVUse BEValue = getSCEV(BEValueV); // If the value coming around the backedge is an add with the symbolic // value we just inserted, possibly with casts that we can ignore under @@ -5513,11 +5602,11 @@ ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI return std::nullopt; // Create an add with everything but the specified operand. - SmallVector Ops; + SmallVector Ops; for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) if (i != FoundIndex) Ops.push_back(Add->getOperand(i)); - const SCEV *Accum = getAddExpr(Ops); + SCEVUse Accum = getAddExpr(Ops); // The runtime checks will not be valid if the step amount is // varying inside the loop. @@ -5575,8 +5664,8 @@ ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI // // Create a truncated addrec for which we will add a no overflow check (P1). - const SCEV *StartVal = getSCEV(StartValueV); - const SCEV *PHISCEV = + SCEVUse StartVal = getSCEV(StartValueV); + SCEVUse PHISCEV = getAddRecExpr(getTruncateExpr(StartVal, TruncTy), getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap); @@ -5603,11 +5692,10 @@ ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy) // for each of StartVal and Accum - auto getExtendedExpr = [&](const SCEV *Expr, - bool CreateSignExtend) -> const SCEV * { + auto getExtendedExpr = [&](SCEVUse Expr, bool CreateSignExtend) -> SCEVUse { assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant"); - const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy); - const SCEV *ExtendedExpr = + SCEVUse TruncatedExpr = getTruncateExpr(Expr, TruncTy); + SCEVUse ExtendedExpr = CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType()) : getZeroExtendExpr(TruncatedExpr, Expr->getType()); return ExtendedExpr; @@ -5618,13 +5706,12 @@ ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI // = getExtendedExpr(Expr) // Determine whether the predicate P: Expr == ExtendedExpr // is known to be false at compile time - auto PredIsKnownFalse = [&](const SCEV *Expr, - const SCEV *ExtendedExpr) -> bool { + auto PredIsKnownFalse = [&](SCEVUse Expr, SCEVUse ExtendedExpr) -> bool { return Expr != ExtendedExpr && isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr); }; - const SCEV *StartExtended = getExtendedExpr(StartVal, Signed); + SCEVUse StartExtended = getExtendedExpr(StartVal, Signed); if (PredIsKnownFalse(StartVal, StartExtended)) { LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";); return std::nullopt; @@ -5632,14 +5719,13 @@ ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI // The Step is always Signed (because the overflow checks are either // NSSW or NUSW) - const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true); + SCEVUse AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true); if (PredIsKnownFalse(Accum, AccumExtended)) { LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";); return std::nullopt; } - auto AppendPredicate = [&](const SCEV *Expr, - const SCEV *ExtendedExpr) -> void { + auto AppendPredicate = [&](SCEVUse Expr, SCEVUse ExtendedExpr) -> void { if (Expr != ExtendedExpr && !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) { const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr); @@ -5655,16 +5741,16 @@ ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI // which the casts had been folded away. The caller can rewrite SymbolicPHI // into NewAR if it will also add the runtime overflow checks specified in // Predicates. - auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap); + auto NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap); - std::pair> PredRewrite = + std::pair> PredRewrite = std::make_pair(NewAR, Predicates); // Remember the result of the analysis for this SCEV at this locayyytion. PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite; return PredRewrite; } -std::optional>> +std::optional>> ScalarEvolution::createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI) { auto *PN = cast(SymbolicPHI->getValue()); const Loop *L = isIntegerLoopHeaderPHI(PN, LI); @@ -5674,7 +5760,7 @@ ScalarEvolution::createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI) { // Check to see if we already analyzed this PHI. auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L}); if (I != PredicatedSCEVRewrites.end()) { - std::pair> Rewrite = + std::pair> Rewrite = I->second; // Analysis was done before and failed to create an AddRec: if (Rewrite.first == SymbolicPHI) @@ -5686,8 +5772,8 @@ ScalarEvolution::createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI) { return Rewrite; } - std::optional>> - Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI); + std::optional>> + Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI); // Record in the cache that the analysis failed if (!Rewrite) { @@ -5710,7 +5796,7 @@ bool PredicatedScalarEvolution::areAddRecsEqualWithPreds( if (AR1 == AR2) return true; - auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool { + auto areExprsEqual = [&](SCEVUse Expr1, SCEVUse Expr2) -> bool { if (Expr1 != Expr2 && !Preds->implies(SE.getEqualPredicate(Expr1, Expr2)) && !Preds->implies(SE.getEqualPredicate(Expr2, Expr1))) return false; @@ -5729,9 +5815,8 @@ bool PredicatedScalarEvolution::areAddRecsEqualWithPreds( /// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)). /// If it fails, createAddRecFromPHI will use a more general, but slow, /// technique for finding the AddRec expression. -const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN, - Value *BEValueV, - Value *StartValueV) { +SCEVUse ScalarEvolution::createSimpleAffineAddRec(PHINode *PN, Value *BEValueV, + Value *StartValueV) { const Loop *L = LI.getLoopFor(PN->getParent()); assert(L && L->getHeader() == PN->getParent()); assert(BEValueV && StartValueV); @@ -5743,7 +5828,7 @@ const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN, if (BO->Opcode != Instruction::Add) return nullptr; - const SCEV *Accum = nullptr; + SCEVUse Accum = nullptr; if (BO->LHS == PN && L->isLoopInvariant(BO->RHS)) Accum = getSCEV(BO->RHS); else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS)) @@ -5758,8 +5843,8 @@ const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN, if (BO->IsNSW) Flags = setFlags(Flags, SCEV::FlagNSW); - const SCEV *StartVal = getSCEV(StartValueV); - const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags); + SCEVUse StartVal = getSCEV(StartValueV); + SCEVUse PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags); insertValueToMap(PN, PHISCEV); if (auto *AR = dyn_cast(PHISCEV)) { @@ -5781,7 +5866,7 @@ const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN, return PHISCEV; } -const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { +SCEVUse ScalarEvolution::createAddRecFromPHI(PHINode *PN) { const Loop *L = LI.getLoopFor(PN->getParent()); if (!L || L->getHeader() != PN->getParent()) return nullptr; @@ -5814,16 +5899,16 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { // First, try to find AddRec expression without creating a fictituos symbolic // value for PN. - if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV)) + if (auto S = createSimpleAffineAddRec(PN, BEValueV, StartValueV)) return S; // Handle PHI node value symbolically. - const SCEV *SymbolicName = getUnknown(PN); + SCEVUse SymbolicName = getUnknown(PN); insertValueToMap(PN, SymbolicName); // Using this symbolic name for the PHI, analyze the value coming around // the back-edge. - const SCEV *BEValue = getSCEV(BEValueV); + SCEVUse BEValue = getSCEV(BEValueV); // NOTE: If BEValue is loop invariant, we know that the PHI node just // has a special value for the first iteration of the loop. @@ -5843,12 +5928,12 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { if (FoundIndex != Add->getNumOperands()) { // Create an add with everything but the specified operand. - SmallVector Ops; + SmallVector Ops; for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) if (i != FoundIndex) Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i), L, *this)); - const SCEV *Accum = getAddExpr(Ops); + SCEVUse Accum = getAddExpr(Ops); // This is not a valid addrec if the step amount is varying each // loop iteration, but is not itself an addrec in this loop. @@ -5884,8 +5969,8 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { // for instance. } - const SCEV *StartVal = getSCEV(StartValueV); - const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags); + SCEVUse StartVal = getSCEV(StartValueV); + SCEVUse PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags); // Okay, for the entire analysis of this edge we assumed the PHI // to be symbolic. We now need to go back and purge all of the @@ -5919,11 +6004,11 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { // We can generalize this saying that i is the shifted value of BEValue // by one iteration: // PHI(f(0), f({1,+,1})) --> f({0,+,1}) - const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this); - const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false); + SCEVUse Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this); + SCEVUse Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false); if (Shifted != getCouldNotCompute() && Start != getCouldNotCompute()) { - const SCEV *StartVal = getSCEV(StartValueV); + SCEVUse StartVal = getSCEV(StartValueV); if (Start == StartVal) { // Okay, for the entire analysis of this edge we assumed the PHI // to be symbolic. We now need to go back and purge all of the @@ -5977,7 +6062,7 @@ static bool BrPHIToSelect(DominatorTree &DT, BranchInst *BI, PHINode *Merge, return false; } -const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) { +SCEVUse ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) { auto IsReachable = [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); }; if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) { @@ -6009,24 +6094,24 @@ const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) { return nullptr; } -const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { - if (const SCEV *S = createAddRecFromPHI(PN)) +SCEVUse ScalarEvolution::createNodeForPHI(PHINode *PN) { + if (SCEVUse S = createAddRecFromPHI(PN)) return S; if (Value *V = simplifyInstruction(PN, {getDataLayout(), &TLI, &DT, &AC})) return getSCEV(V); - if (const SCEV *S = createNodeFromSelectLikePHI(PN)) + if (SCEVUse S = createNodeFromSelectLikePHI(PN)) return S; // If it's not a loop phi, we can't handle it yet. return getUnknown(PN); } -bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind, +bool SCEVMinMaxExprContains(SCEVUse Root, SCEVUse OperandToFind, SCEVTypes RootKind) { struct FindClosure { - const SCEV *OperandToFind; + SCEVUse OperandToFind; const SCEVTypes RootKind; // Must be a sequential min/max expression. const SCEVTypes NonSequentialRootKind; // Non-seq variant of RootKind. @@ -6039,13 +6124,13 @@ bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind, scZeroExtend == Kind; }; - FindClosure(const SCEV *OperandToFind, SCEVTypes RootKind) + FindClosure(SCEVUse OperandToFind, SCEVTypes RootKind) : OperandToFind(OperandToFind), RootKind(RootKind), NonSequentialRootKind( SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType( RootKind)) {} - bool follow(const SCEV *S) { + bool follow(SCEVUse S) { Found = S == OperandToFind; return !isDone() && canRecurseInto(S->getSCEVType()); @@ -6059,7 +6144,7 @@ bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind, return FC.Found; } -std::optional +std::optional ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty, ICmpInst *Cond, Value *TrueVal, @@ -6085,10 +6170,10 @@ ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty, // a > b ? b+x : a+x -> min(a, b)+x if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(Ty)) { bool Signed = ICI->isSigned(); - const SCEV *LA = getSCEV(TrueVal); - const SCEV *RA = getSCEV(FalseVal); - const SCEV *LS = getSCEV(LHS); - const SCEV *RS = getSCEV(RHS); + SCEVUse LA = getSCEV(TrueVal); + SCEVUse RA = getSCEV(FalseVal); + SCEVUse LS = getSCEV(LHS); + SCEVUse RS = getSCEV(RHS); if (LA->getType()->isPointerTy()) { // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA. // Need to make sure we can't produce weird expressions involving @@ -6098,7 +6183,7 @@ ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty, if (LA == RS && RA == LS) return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS); } - auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * { + auto CoerceOperand = [&](SCEVUse Op) -> SCEVUse { if (Op->getType()->isPointerTy()) { Op = getLosslessPtrToIntExpr(Op); if (isa(Op)) @@ -6114,8 +6199,8 @@ ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty, RS = CoerceOperand(RS); if (isa(LS) || isa(RS)) break; - const SCEV *LDiff = getMinusSCEV(LA, LS); - const SCEV *RDiff = getMinusSCEV(RA, RS); + SCEVUse LDiff = getMinusSCEV(LA, LS); + SCEVUse RDiff = getMinusSCEV(RA, RS); if (LDiff == RDiff) return getAddExpr(Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS), LDiff); @@ -6134,11 +6219,11 @@ ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty, // x == 0 ? C+y : x+y -> umax(x, C)+y iff C u<= 1 if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(Ty) && isa(RHS) && cast(RHS)->isZero()) { - const SCEV *X = getNoopOrZeroExtend(getSCEV(LHS), Ty); - const SCEV *TrueValExpr = getSCEV(TrueVal); // C+y - const SCEV *FalseValExpr = getSCEV(FalseVal); // x+y - const SCEV *Y = getMinusSCEV(FalseValExpr, X); // y = (x+y)-x - const SCEV *C = getMinusSCEV(TrueValExpr, Y); // C = (C+y)-y + SCEVUse X = getNoopOrZeroExtend(getSCEV(LHS), Ty); + SCEVUse TrueValExpr = getSCEV(TrueVal); // C+y + SCEVUse FalseValExpr = getSCEV(FalseVal); // x+y + SCEVUse Y = getMinusSCEV(FalseValExpr, X); // y = (x+y)-x + SCEVUse C = getMinusSCEV(TrueValExpr, Y); // C = (C+y)-y if (isa(C) && cast(C)->getAPInt().ule(1)) return getAddExpr(getUMaxExpr(X, C), Y); } @@ -6148,11 +6233,11 @@ ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty, // -> umin_seq(x, umin (..., umin_seq(...), ...)) if (isa(RHS) && cast(RHS)->isZero() && isa(TrueVal) && cast(TrueVal)->isZero()) { - const SCEV *X = getSCEV(LHS); + SCEVUse X = getSCEV(LHS); while (auto *ZExt = dyn_cast(X)) X = ZExt->getOperand(); if (getTypeSizeInBits(X->getType()) <= getTypeSizeInBits(Ty)) { - const SCEV *FalseValExpr = getSCEV(FalseVal); + SCEVUse FalseValExpr = getSCEV(FalseVal); if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr)) return getUMinExpr(getNoopOrZeroExtend(X, Ty), FalseValExpr, /*Sequential=*/true); @@ -6166,9 +6251,10 @@ ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty, return std::nullopt; } -static std::optional -createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr, - const SCEV *TrueExpr, const SCEV *FalseExpr) { +static std::optional createNodeForSelectViaUMinSeq(ScalarEvolution *SE, + SCEVUse CondExpr, + SCEVUse TrueExpr, + SCEVUse FalseExpr) { assert(CondExpr->getType()->isIntegerTy(1) && TrueExpr->getType() == FalseExpr->getType() && TrueExpr->getType()->isIntegerTy(1) && @@ -6186,7 +6272,7 @@ createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr, if (!isa(TrueExpr) && !isa(FalseExpr)) return std::nullopt; - const SCEV *X, *C; + SCEVUse X, C; if (isa(TrueExpr)) { CondExpr = SE->getNotSCEV(CondExpr); X = FalseExpr; @@ -6199,20 +6285,23 @@ createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr, /*Sequential=*/true)); } -static std::optional -createNodeForSelectViaUMinSeq(ScalarEvolution *SE, Value *Cond, Value *TrueVal, - Value *FalseVal) { +static std::optional createNodeForSelectViaUMinSeq(ScalarEvolution *SE, + Value *Cond, + Value *TrueVal, + Value *FalseVal) { if (!isa(TrueVal) && !isa(FalseVal)) return std::nullopt; - const auto *SECond = SE->getSCEV(Cond); - const auto *SETrue = SE->getSCEV(TrueVal); - const auto *SEFalse = SE->getSCEV(FalseVal); + const auto SECond = SE->getSCEV(Cond); + const auto SETrue = SE->getSCEV(TrueVal); + const auto SEFalse = SE->getSCEV(FalseVal); return createNodeForSelectViaUMinSeq(SE, SECond, SETrue, SEFalse); } -const SCEV *ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq( - Value *V, Value *Cond, Value *TrueVal, Value *FalseVal) { +SCEVUse ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq(Value *V, + Value *Cond, + Value *TrueVal, + Value *FalseVal) { assert(Cond->getType()->isIntegerTy(1) && "Select condition is not an i1?"); assert(TrueVal->getType() == FalseVal->getType() && V->getType() == TrueVal->getType() && @@ -6222,16 +6311,16 @@ const SCEV *ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq( if (!V->getType()->isIntegerTy(1)) return getUnknown(V); - if (std::optional S = + if (std::optional S = createNodeForSelectViaUMinSeq(this, Cond, TrueVal, FalseVal)) return *S; return getUnknown(V); } -const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond, - Value *TrueVal, - Value *FalseVal) { +SCEVUse ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond, + Value *TrueVal, + Value *FalseVal) { // Handle "constant" branch or select. This can occur for instance when a // loop pass transforms an inner loop and moves on to process the outer loop. if (auto *CI = dyn_cast(Cond)) @@ -6239,7 +6328,7 @@ const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond, if (auto *I = dyn_cast(V)) { if (auto *ICI = dyn_cast(Cond)) { - if (std::optional S = + if (std::optional S = createNodeForSelectOrPHIInstWithICmpInstCond(I->getType(), ICI, TrueVal, FalseVal)) return *S; @@ -6251,17 +6340,17 @@ const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond, /// Expand GEP instructions into add and multiply operations. This allows them /// to be analyzed by regular SCEV code. -const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) { +SCEVUse ScalarEvolution::createNodeForGEP(GEPOperator *GEP, bool UseCtx) { assert(GEP->getSourceElementType()->isSized() && "GEP source element type must be sized"); - SmallVector IndexExprs; + SmallVector IndexExprs; for (Value *Index : GEP->indices()) IndexExprs.push_back(getSCEV(Index)); - return getGEPExpr(GEP, IndexExprs); + return getGEPExpr(GEP, IndexExprs, UseCtx); } -APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) { +APInt ScalarEvolution::getConstantMultipleImpl(SCEVUse S) { uint64_t BitWidth = getTypeSizeInBits(S->getType()); auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) { return TrailingZeros >= BitWidth @@ -6304,7 +6393,7 @@ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) { if (M->hasNoUnsignedWrap()) { // The result is the product of all operand results. APInt Res = getConstantMultiple(M->getOperand(0)); - for (const SCEV *Operand : M->operands().drop_front()) + for (SCEVUse Operand : M->operands().drop_front()) Res = Res * getConstantMultiple(Operand); return Res; } @@ -6312,7 +6401,7 @@ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) { // If there are no wrap guarentees, find the trailing zeros, which is the // sum of trailing zeros for all its operands. uint32_t TZ = 0; - for (const SCEV *Operand : M->operands()) + for (SCEVUse Operand : M->operands()) TZ += getMinTrailingZeros(Operand); return GetShiftedByZeros(TZ); } @@ -6323,7 +6412,7 @@ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) { return GetGCDMultiple(N); // Find the trailing bits, which is the minimum of its operands. uint32_t TZ = getMinTrailingZeros(N->getOperand(0)); - for (const SCEV *Operand : N->operands().drop_front()) + for (SCEVUse Operand : N->operands().drop_front()) TZ = std::min(TZ, getMinTrailingZeros(Operand)); return GetShiftedByZeros(TZ); } @@ -6347,7 +6436,7 @@ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) { llvm_unreachable("Unknown SCEV kind!"); } -APInt ScalarEvolution::getConstantMultiple(const SCEV *S) { +APInt ScalarEvolution::getConstantMultiple(SCEVUse S) { auto I = ConstantMultipleCache.find(S); if (I != ConstantMultipleCache.end()) return I->second; @@ -6358,12 +6447,12 @@ APInt ScalarEvolution::getConstantMultiple(const SCEV *S) { return InsertPair.first->second; } -APInt ScalarEvolution::getNonZeroConstantMultiple(const SCEV *S) { +APInt ScalarEvolution::getNonZeroConstantMultiple(SCEVUse S) { APInt Multiple = getConstantMultiple(S); return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple; } -uint32_t ScalarEvolution::getMinTrailingZeros(const SCEV *S) { +uint32_t ScalarEvolution::getMinTrailingZeros(SCEVUse S) { return std::min(getConstantMultiple(S).countTrailingZeros(), (unsigned)getTypeSizeInBits(S->getType())); } @@ -6514,17 +6603,17 @@ getRangeForUnknownRecurrence(const SCEVUnknown *U) { } const ConstantRange & -ScalarEvolution::getRangeRefIter(const SCEV *S, +ScalarEvolution::getRangeRefIter(SCEVUse S, ScalarEvolution::RangeSignHint SignHint) { - DenseMap &Cache = + DenseMap &Cache = SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges : SignedRanges; - SmallVector WorkList; - SmallPtrSet Seen; + SmallVector WorkList; + SmallPtrSet Seen; // Add Expr to the worklist, if Expr is either an N-ary expression or a // SCEVUnknown PHI node. - auto AddToWorklist = [&WorkList, &Seen, &Cache](const SCEV *Expr) { + auto AddToWorklist = [&WorkList, &Seen, &Cache](SCEVUse Expr) { if (!Seen.insert(Expr).second) return; if (Cache.contains(Expr)) @@ -6559,11 +6648,11 @@ ScalarEvolution::getRangeRefIter(const SCEV *S, // Build worklist by queuing operands of N-ary expressions and phi nodes. for (unsigned I = 0; I != WorkList.size(); ++I) { - const SCEV *P = WorkList[I]; + SCEVUse P = WorkList[I]; auto *UnknownS = dyn_cast(P); // If it is not a `SCEVUnknown`, just recurse into operands. if (!UnknownS) { - for (const SCEV *Op : P->operands()) + for (SCEVUse Op : P->operands()) AddToWorklist(Op); continue; } @@ -6580,7 +6669,7 @@ ScalarEvolution::getRangeRefIter(const SCEV *S, // Use getRangeRef to compute ranges for items in the worklist in reverse // order. This will force ranges for earlier operands to be computed before // their users in most cases. - for (const SCEV *P : reverse(drop_begin(WorkList))) { + for (SCEVUse P : reverse(drop_begin(WorkList))) { getRangeRef(P, SignHint); if (auto *UnknownS = dyn_cast(P)) @@ -6595,9 +6684,10 @@ ScalarEvolution::getRangeRefIter(const SCEV *S, /// Determine the range for a particular SCEV. If SignHint is /// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges /// with a "cleaner" unsigned (resp. signed) representation. -const ConstantRange &ScalarEvolution::getRangeRef( - const SCEV *S, ScalarEvolution::RangeSignHint SignHint, unsigned Depth) { - DenseMap &Cache = +const ConstantRange & +ScalarEvolution::getRangeRef(SCEVUse S, ScalarEvolution::RangeSignHint SignHint, + unsigned Depth) { + DenseMap &Cache = SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges : SignedRanges; ConstantRange::PreferredRangeType RangeType = @@ -6605,7 +6695,7 @@ const ConstantRange &ScalarEvolution::getRangeRef( : ConstantRange::Signed; // See if we've computed this range already. - DenseMap::iterator I = Cache.find(S); + DenseMap::iterator I = Cache.find(S); if (I != Cache.end()) return I->second; @@ -6740,8 +6830,7 @@ const ConstantRange &ScalarEvolution::getRangeRef( // TODO: non-affine addrec if (AddRec->isAffine()) { - const SCEV *MaxBEScev = - getConstantMaxBackedgeTakenCount(AddRec->getLoop()); + SCEVUse MaxBEScev = getConstantMaxBackedgeTakenCount(AddRec->getLoop()); if (!isa(MaxBEScev)) { APInt MaxBECount = cast(MaxBEScev)->getAPInt(); @@ -6768,7 +6857,7 @@ const ConstantRange &ScalarEvolution::getRangeRef( // Now try symbolic BE count and more powerful methods. if (UseExpensiveRangeSharpening) { - const SCEV *SymbolicMaxBECount = + SCEVUse SymbolicMaxBECount = getSymbolicMaxBackedgeTakenCount(AddRec->getLoop()); if (!isa(SymbolicMaxBECount) && getTypeSizeInBits(MaxBEScev->getType()) <= BitWidth && @@ -6997,8 +7086,7 @@ static ConstantRange getRangeForAffineARHelper(APInt Step, return ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper)); } -ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start, - const SCEV *Step, +ConstantRange ScalarEvolution::getRangeForAffineAR(SCEVUse Start, SCEVUse Step, const APInt &MaxBECount) { assert(getTypeSizeInBits(Start->getType()) == getTypeSizeInBits(Step->getType()) && @@ -7027,13 +7115,13 @@ ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start, } ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR( - const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth, + const SCEVAddRecExpr *AddRec, SCEVUse MaxBECount, unsigned BitWidth, ScalarEvolution::RangeSignHint SignHint) { assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n"); assert(AddRec->hasNoSelfWrap() && "This only works for non-self-wrapping AddRecs!"); const bool IsSigned = SignHint == HINT_RANGE_SIGNED; - const SCEV *Step = AddRec->getStepRecurrence(*this); + SCEVUse Step = AddRec->getStepRecurrence(*this); // Only deal with constant step to save compile time. if (!isa(Step)) return ConstantRange::getFull(BitWidth); @@ -7046,9 +7134,9 @@ ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR( getTypeSizeInBits(AddRec->getType())) return ConstantRange::getFull(BitWidth); MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType()); - const SCEV *RangeWidth = getMinusOne(AddRec->getType()); - const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step)); - const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs); + SCEVUse RangeWidth = getMinusOne(AddRec->getType()); + SCEVUse StepAbs = getUMinExpr(Step, getNegativeSCEV(Step)); + SCEVUse MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs); if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount, MaxItersWithoutWrap)) return ConstantRange::getFull(BitWidth); @@ -7057,7 +7145,7 @@ ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR( IsSigned ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE; ICmpInst::Predicate GEPred = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE; - const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this); + SCEVUse End = AddRec->evaluateAtIteration(MaxBECount, *this); // We know that there is no self-wrap. Let's take Start and End values and // look at all intermediate values V1, V2, ..., Vn that IndVar takes during @@ -7071,7 +7159,7 @@ ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR( // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that // knowledge, let's try to prove that we are dealing with Case 1. It is so if // Start <= End and step is positive, or Start >= End and step is negative. - const SCEV *Start = applyLoopGuards(AddRec->getStart(), AddRec->getLoop()); + SCEVUse Start = applyLoopGuards(AddRec->getStart(), AddRec->getLoop()); ConstantRange StartRange = getRangeRef(Start, SignHint); ConstantRange EndRange = getRangeRef(End, SignHint); ConstantRange RangeBetween = StartRange.unionWith(EndRange); @@ -7094,8 +7182,7 @@ ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR( return ConstantRange::getFull(BitWidth); } -ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start, - const SCEV *Step, +ConstantRange ScalarEvolution::getRangeViaFactoring(SCEVUse Start, SCEVUse Step, const APInt &MaxBECount) { // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q}) // == RangeOf({A,+,P}) union RangeOf({B,+,Q}) @@ -7110,8 +7197,7 @@ ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start, APInt TrueValue; APInt FalseValue; - explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth, - const SCEV *S) { + explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth, SCEVUse S) { std::optional CastOp; APInt Offset(BitWidth, 0); @@ -7200,10 +7286,10 @@ ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start, // FIXME: without the explicit `this` receiver below, MSVC errors out with // C2352 and C2512 (otherwise it isn't needed). - const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue); - const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue); - const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue); - const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue); + SCEVUse TrueStart = this->getConstant(StartPattern.TrueValue); + SCEVUse TrueStep = this->getConstant(StepPattern.TrueValue); + SCEVUse FalseStart = this->getConstant(StartPattern.FalseValue); + SCEVUse FalseStep = this->getConstant(StepPattern.FalseValue); ConstantRange TrueRange = this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount); @@ -7229,8 +7315,7 @@ SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) { return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap; } -const Instruction * -ScalarEvolution::getNonTrivialDefiningScopeBound(const SCEV *S) { +const Instruction *ScalarEvolution::getNonTrivialDefiningScopeBound(SCEVUse S) { if (auto *AddRec = dyn_cast(S)) return &*AddRec->getLoop()->getHeader()->begin(); if (auto *U = dyn_cast(S)) @@ -7239,14 +7324,13 @@ ScalarEvolution::getNonTrivialDefiningScopeBound(const SCEV *S) { return nullptr; } -const Instruction * -ScalarEvolution::getDefiningScopeBound(ArrayRef Ops, - bool &Precise) { +const Instruction *ScalarEvolution::getDefiningScopeBound(ArrayRef Ops, + bool &Precise) { Precise = true; // Do a bounded search of the def relation of the requested SCEVs. - SmallSet Visited; - SmallVector Worklist; - auto pushOp = [&](const SCEV *S) { + SmallSet Visited; + SmallVector Worklist; + auto pushOp = [&](SCEVUse S) { if (!Visited.insert(S).second) return; // Threshold of 30 here is arbitrary. @@ -7257,17 +7341,17 @@ ScalarEvolution::getDefiningScopeBound(ArrayRef Ops, Worklist.push_back(S); }; - for (const auto *S : Ops) + for (const auto S : Ops) pushOp(S); const Instruction *Bound = nullptr; while (!Worklist.empty()) { - auto *S = Worklist.pop_back_val(); + auto S = Worklist.pop_back_val(); if (auto *DefI = getNonTrivialDefiningScopeBound(S)) { if (!Bound || DT.dominates(Bound, DefI)) Bound = DefI; } else { - for (const auto *Op : S->operands()) + for (const auto Op : S->operands()) pushOp(Op); } } @@ -7275,7 +7359,7 @@ ScalarEvolution::getDefiningScopeBound(ArrayRef Ops, } const Instruction * -ScalarEvolution::getDefiningScopeBound(ArrayRef Ops) { +ScalarEvolution::getDefiningScopeBound(ArrayRef Ops) { bool Discard; return getDefiningScopeBound(Ops, Discard); } @@ -7315,7 +7399,7 @@ bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) { // executed every time we enter that scope. When the bounding scope is a // loop (the common case), this is equivalent to proving I executes on every // iteration of that loop. - SmallVector SCEVOps; + SmallVector SCEVOps; for (const Use &Op : I->operands()) { // I could be an extractvalue from a call to an overflow intrinsic. // TODO: We can do better here in some cases. @@ -7411,7 +7495,7 @@ bool ScalarEvolution::loopIsFiniteByAssumption(const Loop *L) { return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L)); } -const SCEV *ScalarEvolution::createSCEVIter(Value *V) { +SCEVUse ScalarEvolution::createSCEVIter(Value *V, bool UseCtx) { // Worklist item with a Value and a bool indicating whether all operands have // been visited already. using PointerTy = PointerIntPair; @@ -7427,10 +7511,10 @@ const SCEV *ScalarEvolution::createSCEVIter(Value *V) { continue; SmallVector Ops; - const SCEV *CreatedSCEV = nullptr; + SCEVUse CreatedSCEV = nullptr; // If all operands have been visited already, create the SCEV. if (E.getInt()) { - CreatedSCEV = createSCEV(CurV); + CreatedSCEV = createSCEV(CurV, UseCtx); } else { // Otherwise get the operands we need to create SCEV's for before creating // the SCEV for CurV. If the SCEV for CurV can be constructed trivially, @@ -7452,8 +7536,8 @@ const SCEV *ScalarEvolution::createSCEVIter(Value *V) { return getExistingSCEV(V); } -const SCEV * -ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl &Ops) { +SCEVUse ScalarEvolution::getOperandsToCreate(Value *V, + SmallVectorImpl &Ops) { if (!isSCEVable(V->getType())) return getUnknown(V); @@ -7639,7 +7723,7 @@ ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl &Ops) { return nullptr; } -const SCEV *ScalarEvolution::createSCEV(Value *V) { +SCEVUse ScalarEvolution::createSCEV(Value *V, bool UseCtx) { if (!isSCEVable(V->getType())) return getUnknown(V); @@ -7657,8 +7741,8 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { else if (!isa(V)) return getUnknown(V); - const SCEV *LHS; - const SCEV *RHS; + SCEVUse LHS; + SCEVUse RHS; Operator *U = cast(V); if (auto BO = @@ -7671,10 +7755,10 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { // because it leads to N-1 getAddExpr calls for N ultimate operands. // Instead, gather up all the operands and make a single getAddExpr call. // LLVM IR canonical form means we need only traverse the left operands. - SmallVector AddOps; + SmallVector AddOps; do { if (BO->Op) { - if (auto *OpSCEV = getExistingSCEV(BO->Op)) { + if (auto OpSCEV = getExistingSCEV(BO->Op)) { AddOps.push_back(OpSCEV); break; } @@ -7686,10 +7770,10 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { // since the flags are only known to apply to this particular // addition - they may not apply to other additions that can be // formed with operands from AddOps. - const SCEV *RHS = getSCEV(BO->RHS); + SCEVUse RHS = getSCEV(BO->RHS); SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op); if (Flags != SCEV::FlagAnyWrap) { - const SCEV *LHS = getSCEV(BO->LHS); + SCEVUse LHS = getSCEV(BO->LHS); if (BO->Opcode == Instruction::Sub) AddOps.push_back(getMinusSCEV(LHS, RHS, Flags)); else @@ -7717,10 +7801,10 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { } case Instruction::Mul: { - SmallVector MulOps; + SmallVector MulOps; do { if (BO->Op) { - if (auto *OpSCEV = getExistingSCEV(BO->Op)) { + if (auto OpSCEV = getExistingSCEV(BO->Op)) { MulOps.push_back(OpSCEV); break; } @@ -7786,19 +7870,19 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { APInt EffectiveMask = APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ); if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) { - const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ)); - const SCEV *LHS = getSCEV(BO->LHS); - const SCEV *ShiftedLHS = nullptr; + SCEVUse MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ)); + SCEVUse LHS = getSCEV(BO->LHS); + SCEVUse ShiftedLHS = nullptr; if (auto *LHSMul = dyn_cast(LHS)) { if (auto *OpC = dyn_cast(LHSMul->getOperand(0))) { // For an expression like (x * 8) & 8, simplify the multiply. unsigned MulZeros = OpC->getAPInt().countr_zero(); unsigned GCD = std::min(MulZeros, TZ); APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD); - SmallVector MulOps; + SmallVector MulOps; MulOps.push_back(getConstant(OpC->getAPInt().lshr(GCD))); append_range(MulOps, LHSMul->operands().drop_front()); - auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags()); + auto NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags()); ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt)); } } @@ -7846,7 +7930,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { if (const SCEVZeroExtendExpr *Z = dyn_cast(getSCEV(BO->LHS))) { Type *UTy = BO->LHS->getType(); - const SCEV *Z0 = Z->getOperand(); + SCEVUse Z0 = Z->getOperand(); Type *Z0Ty = Z0->getType(); unsigned Z0TySize = getTypeSizeInBits(Z0Ty); @@ -7922,9 +8006,9 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt); Operator *L = dyn_cast(BO->LHS); - const SCEV *AddTruncateExpr = nullptr; + SCEVUse AddTruncateExpr = nullptr; ConstantInt *ShlAmtCI = nullptr; - const SCEV *AddConstant = nullptr; + SCEVUse AddConstant = nullptr; if (L && L->getOpcode() == Instruction::Add) { // X = Shl A, n @@ -7936,7 +8020,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { ConstantInt *AddOperandCI = dyn_cast(L->getOperand(1)); if (LShift && LShift->getOpcode() == Instruction::Shl) { if (AddOperandCI) { - const SCEV *ShlOp0SCEV = getSCEV(LShift->getOperand(0)); + SCEVUse ShlOp0SCEV = getSCEV(LShift->getOperand(0)); ShlAmtCI = dyn_cast(LShift->getOperand(1)); // since we truncate to TruncTy, the AddConstant should be of the // same type, so create a new Constant with type same as TruncTy. @@ -7954,7 +8038,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { // Y = AShr X, m // Both n and m are constant. - const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0)); + SCEVUse ShlOp0SCEV = getSCEV(L->getOperand(0)); ShlAmtCI = dyn_cast(L->getOperand(1)); AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy); } @@ -7975,8 +8059,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { if (ShlAmt.ult(BitWidth) && ShlAmt.uge(AShrAmt)) { APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt, ShlAmtCI->getZExtValue() - AShrAmt); - const SCEV *CompositeExpr = - getMulExpr(AddTruncateExpr, getConstant(Mul)); + SCEVUse CompositeExpr = getMulExpr(AddTruncateExpr, getConstant(Mul)); if (L->getOpcode() != Instruction::Shl) CompositeExpr = getAddExpr(CompositeExpr, AddConstant); @@ -8006,8 +8089,8 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { // but by that point the NSW information has potentially been lost. if (BO->Opcode == Instruction::Sub && BO->IsNSW) { Type *Ty = U->getType(); - auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty); - auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty); + auto V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty); + auto V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty); return getMinusSCEV(V1, V2, SCEV::FlagNSW); } } @@ -8021,11 +8104,11 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { case Instruction::PtrToInt: { // Pointer to integer cast is straight-forward, so do model it. - const SCEV *Op = getSCEV(U->getOperand(0)); + SCEVUse Op = getSCEV(U->getOperand(0)); Type *DstIntTy = U->getType(); // But only if effective SCEV (integer) type is wide enough to represent // all possible pointer values. - const SCEV *IntOp = getPtrToIntExpr(Op, DstIntTy); + SCEVUse IntOp = getPtrToIntExpr(Op, DstIntTy); if (isa(IntOp)) return getUnknown(V); return IntOp; @@ -8049,7 +8132,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { break; case Instruction::GetElementPtr: - return createNodeForGEP(cast(U)); + return createNodeForGEP(cast(U), UseCtx); case Instruction::PHI: return createNodeForPHI(cast(U)); @@ -8086,15 +8169,15 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { RHS = getSCEV(II->getArgOperand(1)); return getSMinExpr(LHS, RHS); case Intrinsic::usub_sat: { - const SCEV *X = getSCEV(II->getArgOperand(0)); - const SCEV *Y = getSCEV(II->getArgOperand(1)); - const SCEV *ClampedY = getUMinExpr(X, Y); + SCEVUse X = getSCEV(II->getArgOperand(0)); + SCEVUse Y = getSCEV(II->getArgOperand(1)); + SCEVUse ClampedY = getUMinExpr(X, Y); return getMinusSCEV(X, ClampedY, SCEV::FlagNUW); } case Intrinsic::uadd_sat: { - const SCEV *X = getSCEV(II->getArgOperand(0)); - const SCEV *Y = getSCEV(II->getArgOperand(1)); - const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y)); + SCEVUse X = getSCEV(II->getArgOperand(0)); + SCEVUse Y = getSCEV(II->getArgOperand(1)); + SCEVUse ClampedX = getUMinExpr(X, getNotSCEV(Y)); return getAddExpr(ClampedX, Y, SCEV::FlagNUW); } case Intrinsic::start_loop_iterations: @@ -8119,7 +8202,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { // Iteration Count Computation Code // -const SCEV *ScalarEvolution::getTripCountFromExitCount(const SCEV *ExitCount) { +SCEVUse ScalarEvolution::getTripCountFromExitCount(SCEVUse ExitCount) { if (isa(ExitCount)) return getCouldNotCompute(); @@ -8130,9 +8213,9 @@ const SCEV *ScalarEvolution::getTripCountFromExitCount(const SCEV *ExitCount) { return getTripCountFromExitCount(ExitCount, EvalTy, nullptr); } -const SCEV *ScalarEvolution::getTripCountFromExitCount(const SCEV *ExitCount, - Type *EvalTy, - const Loop *L) { +SCEVUse ScalarEvolution::getTripCountFromExitCount(SCEVUse ExitCount, + Type *EvalTy, + const Loop *L) { if (isa(ExitCount)) return getCouldNotCompute(); @@ -8193,7 +8276,7 @@ ScalarEvolution::getSmallConstantTripCount(const Loop *L, unsigned ScalarEvolution::getSmallConstantMaxTripCount( const Loop *L, SmallVectorImpl *Predicates) { - const auto *MaxExitCount = + SCEVUse MaxExitCount = Predicates ? getPredicatedConstantMaxBackedgeTakenCount(L, *Predicates) : getConstantMaxBackedgeTakenCount(L); return getConstantTripCount(dyn_cast(MaxExitCount)); @@ -8214,12 +8297,12 @@ unsigned ScalarEvolution::getSmallConstantTripMultiple(const Loop *L) { } unsigned ScalarEvolution::getSmallConstantTripMultiple(const Loop *L, - const SCEV *ExitCount) { + SCEVUse ExitCount) { if (ExitCount == getCouldNotCompute()) return 1; // Get the trip count - const SCEV *TCExpr = getTripCountFromExitCount(applyLoopGuards(ExitCount, L)); + SCEVUse TCExpr = getTripCountFromExitCount(applyLoopGuards(ExitCount, L)); APInt Multiple = getNonZeroConstantMultiple(TCExpr); // If a trip multiple is huge (>=2^32), the trip count is still divisible by @@ -8247,13 +8330,13 @@ ScalarEvolution::getSmallConstantTripMultiple(const Loop *L, assert(ExitingBlock && "Must pass a non-null exiting block!"); assert(L->isLoopExiting(ExitingBlock) && "Exiting block must actually branch out of the loop!"); - const SCEV *ExitCount = getExitCount(L, ExitingBlock); + SCEVUse ExitCount = getExitCount(L, ExitingBlock); return getSmallConstantTripMultiple(L, ExitCount); } -const SCEV *ScalarEvolution::getExitCount(const Loop *L, - const BasicBlock *ExitingBlock, - ExitCountKind Kind) { +SCEVUse ScalarEvolution::getExitCount(const Loop *L, + const BasicBlock *ExitingBlock, + ExitCountKind Kind) { switch (Kind) { case Exact: return getBackedgeTakenInfo(L).getExact(ExitingBlock, this); @@ -8265,7 +8348,7 @@ const SCEV *ScalarEvolution::getExitCount(const Loop *L, llvm_unreachable("Invalid ExitCountKind!"); } -const SCEV *ScalarEvolution::getPredicatedExitCount( +SCEVUse ScalarEvolution::getPredicatedExitCount( const Loop *L, const BasicBlock *ExitingBlock, SmallVectorImpl *Predicates, ExitCountKind Kind) { switch (Kind) { @@ -8282,13 +8365,13 @@ const SCEV *ScalarEvolution::getPredicatedExitCount( llvm_unreachable("Invalid ExitCountKind!"); } -const SCEV *ScalarEvolution::getPredicatedBackedgeTakenCount( +SCEVUse ScalarEvolution::getPredicatedBackedgeTakenCount( const Loop *L, SmallVectorImpl &Preds) { return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds); } -const SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L, - ExitCountKind Kind) { +SCEVUse ScalarEvolution::getBackedgeTakenCount(const Loop *L, + ExitCountKind Kind) { switch (Kind) { case Exact: return getBackedgeTakenInfo(L).getExact(L, this); @@ -8300,12 +8383,12 @@ const SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L, llvm_unreachable("Invalid ExitCountKind!"); } -const SCEV *ScalarEvolution::getPredicatedSymbolicMaxBackedgeTakenCount( +SCEVUse ScalarEvolution::getPredicatedSymbolicMaxBackedgeTakenCount( const Loop *L, SmallVectorImpl &Preds) { return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(L, this, &Preds); } -const SCEV *ScalarEvolution::getPredicatedConstantMaxBackedgeTakenCount( +SCEVUse ScalarEvolution::getPredicatedConstantMaxBackedgeTakenCount( const Loop *L, SmallVectorImpl &Preds) { return getPredicatedBackedgeTakenInfo(L).getConstantMax(this, &Preds); } @@ -8367,7 +8450,7 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { // only done to produce more precise results. if (Result.hasAnyInfo()) { // Invalidate any expression using an addrec in this loop. - SmallVector ToForget; + SmallVector ToForget; auto LoopUsersIt = LoopUsers.find(L); if (LoopUsersIt != LoopUsers.end()) append_range(ToForget, LoopUsersIt->second); @@ -8414,7 +8497,7 @@ void ScalarEvolution::forgetAllLoops() { void ScalarEvolution::visitAndClearUsers( SmallVectorImpl &Worklist, SmallPtrSetImpl &Visited, - SmallVectorImpl &ToForget) { + SmallVectorImpl &ToForget) { while (!Worklist.empty()) { Instruction *I = Worklist.pop_back_val(); if (!isSCEVable(I->getType()) && !isa(I)) @@ -8437,7 +8520,7 @@ void ScalarEvolution::forgetLoop(const Loop *L) { SmallVector LoopWorklist(1, L); SmallVector Worklist; SmallPtrSet Visited; - SmallVector ToForget; + SmallVector ToForget; // Iterate over all the loops and sub-loops to drop SCEV information. while (!LoopWorklist.empty()) { @@ -8450,7 +8533,7 @@ void ScalarEvolution::forgetLoop(const Loop *L) { // Drop information about predicated SCEV rewrites for this loop. for (auto I = PredicatedSCEVRewrites.begin(); I != PredicatedSCEVRewrites.end();) { - std::pair Entry = I->first; + std::pair Entry = I->first; if (Entry.second == CurrL) PredicatedSCEVRewrites.erase(I++); else @@ -8486,7 +8569,7 @@ void ScalarEvolution::forgetValue(Value *V) { // Drop information about expressions based on loop-header PHIs. SmallVector Worklist; SmallPtrSet Visited; - SmallVector ToForget; + SmallVector ToForget; Worklist.push_back(I); Visited.insert(I); visitAndClearUsers(Worklist, Visited, ToForget); @@ -8502,14 +8585,14 @@ void ScalarEvolution::forgetLcssaPhiWithNewPredecessor(Loop *L, PHINode *V) { // directly using a SCEVUnknown/SCEVAddRec defined in the loop. After an // extra predecessor is added, this is no longer valid. Find all Unknowns and // AddRecs defined in the loop and invalidate any SCEV's making use of them. - if (const SCEV *S = getExistingSCEV(V)) { + if (SCEVUse S = getExistingSCEV(V)) { struct InvalidationRootCollector { Loop *L; - SmallVector Roots; + SmallVector Roots; InvalidationRootCollector(Loop *L) : L(L) {} - bool follow(const SCEV *S) { + bool follow(SCEVUse S) { if (auto *SU = dyn_cast(S)) { if (auto *I = dyn_cast(SU->getValue())) if (L->contains(I)) @@ -8546,7 +8629,7 @@ void ScalarEvolution::forgetBlockAndLoopDispositions(Value *V) { if (!isSCEVable(V->getType())) return; - const SCEV *S = getExistingSCEV(V); + SCEVUse S = getExistingSCEV(V); if (!S) return; @@ -8554,17 +8637,17 @@ void ScalarEvolution::forgetBlockAndLoopDispositions(Value *V) { // S's users may change if S's disposition changes (i.e. a user may change to // loop-invariant, if S changes to loop invariant), so also invalidate // dispositions of S's users recursively. - SmallVector Worklist = {S}; - SmallPtrSet Seen = {S}; + SmallVector Worklist = {S}; + SmallPtrSet Seen = {S}; while (!Worklist.empty()) { - const SCEV *Curr = Worklist.pop_back_val(); + SCEVUse Curr = Worklist.pop_back_val(); bool LoopDispoRemoved = LoopDispositions.erase(Curr); bool BlockDispoRemoved = BlockDispositions.erase(Curr); if (!LoopDispoRemoved && !BlockDispoRemoved) continue; auto Users = SCEVUsers.find(Curr); if (Users != SCEVUsers.end()) - for (const auto *User : Users->second) + for (const auto User : Users->second) if (Seen.insert(User).second) Worklist.push_back(User); } @@ -8576,7 +8659,7 @@ void ScalarEvolution::forgetBlockAndLoopDispositions(Value *V) { /// is never skipped. This is a valid assumption as long as the loop exits via /// that test. For precise results, it is the caller's responsibility to specify /// the relevant loop exiting block using getExact(ExitingBlock, SE). -const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact( +SCEVUse ScalarEvolution::BackedgeTakenInfo::getExact( const Loop *L, ScalarEvolution *SE, SmallVectorImpl *Preds) const { // If any exits were not computable, the loop is not computable. @@ -8590,9 +8673,9 @@ const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact( // All exiting blocks we have gathered dominate loop's latch, so exact trip // count is simply a minimum out of all these calculated exit counts. - SmallVector Ops; + SmallVector Ops; for (const auto &ENT : ExitNotTaken) { - const SCEV *BECount = ENT.ExactNotTaken; + SCEVUse BECount = ENT.ExactNotTaken; assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!"); assert(SE->DT.dominates(ENT.ExitingBlock, Latch) && "We should only have known counts for exiting blocks that dominate " @@ -8631,7 +8714,7 @@ ScalarEvolution::BackedgeTakenInfo::getExitNotTaken( } /// getConstantMax - Get the constant max backedge taken count for the loop. -const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax( +SCEVUse ScalarEvolution::BackedgeTakenInfo::getConstantMax( ScalarEvolution *SE, SmallVectorImpl *Predicates) const { if (!getConstantMax()) @@ -8650,7 +8733,7 @@ const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax( return getConstantMax(); } -const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax( +SCEVUse ScalarEvolution::BackedgeTakenInfo::getSymbolicMax( const Loop *L, ScalarEvolution *SE, SmallVectorImpl *Predicates) { if (!SymbolicMax) { @@ -8691,13 +8774,11 @@ bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero( return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue); } -ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E) - : ExitLimit(E, E, E, false) {} +ScalarEvolution::ExitLimit::ExitLimit(SCEVUse E) : ExitLimit(E, E, E, false) {} ScalarEvolution::ExitLimit::ExitLimit( - const SCEV *E, const SCEV *ConstantMaxNotTaken, - const SCEV *SymbolicMaxNotTaken, bool MaxOrZero, - ArrayRef> PredLists) + SCEVUse E, SCEVUse ConstantMaxNotTaken, SCEVUse SymbolicMaxNotTaken, + bool MaxOrZero, ArrayRef> PredLists) : ExactNotTaken(E), ConstantMaxNotTaken(ConstantMaxNotTaken), SymbolicMaxNotTaken(SymbolicMaxNotTaken), MaxOrZero(MaxOrZero) { // If we prove the max count is zero, so is the symbolic bound. This happens @@ -8736,9 +8817,8 @@ ScalarEvolution::ExitLimit::ExitLimit( "Max backedge count should be int"); } -ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E, - const SCEV *ConstantMaxNotTaken, - const SCEV *SymbolicMaxNotTaken, +ScalarEvolution::ExitLimit::ExitLimit(SCEVUse E, SCEVUse ConstantMaxNotTaken, + SCEVUse SymbolicMaxNotTaken, bool MaxOrZero, ArrayRef PredList) : ExitLimit(E, ConstantMaxNotTaken, SymbolicMaxNotTaken, MaxOrZero, @@ -8748,7 +8828,7 @@ ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E, /// computable exit into a persistent ExitNotTakenInfo array. ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo( ArrayRef ExitCounts, - bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero) + bool IsComplete, SCEVUse ConstantMax, bool MaxOrZero) : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) { using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo; @@ -8779,8 +8859,8 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L, SmallVector ExitCounts; bool CouldComputeBECount = true; BasicBlock *Latch = L->getLoopLatch(); // may be NULL. - const SCEV *MustExitMaxBECount = nullptr; - const SCEV *MayExitMaxBECount = nullptr; + SCEVUse MustExitMaxBECount = nullptr; + SCEVUse MayExitMaxBECount = nullptr; bool MustExitMaxOrZero = false; bool IsOnlyExit = ExitingBlocks.size() == 1; @@ -8850,8 +8930,10 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L, } } } - const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount : - (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute()); + SCEVUse MaxBECount = + MustExitMaxBECount + ? MustExitMaxBECount + : (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute()); // The loop backedge will be taken the maximum or zero times if there's // a single exit that must be taken the maximum or zero times. bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1); @@ -9013,7 +9095,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl( NWR.getEquivalentICmp(Pred, NewRHSC, Offset); if (!ExitIfTrue) Pred = ICmpInst::getInversePredicate(Pred); - auto *LHS = getSCEV(WO->getLHS()); + auto LHS = getSCEV(WO->getLHS()); if (Offset != 0) LHS = getAddExpr(LHS, getConstant(Offset)); auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC), @@ -9058,9 +9140,9 @@ ScalarEvolution::computeExitLimitFromCondFromBinOp( if (isa(Op0)) return Op0 == NeutralElement ? EL1 : EL0; - const SCEV *BECount = getCouldNotCompute(); - const SCEV *ConstantMaxBECount = getCouldNotCompute(); - const SCEV *SymbolicMaxBECount = getCouldNotCompute(); + SCEVUse BECount = getCouldNotCompute(); + SCEVUse ConstantMaxBECount = getCouldNotCompute(); + SCEVUse SymbolicMaxBECount = getCouldNotCompute(); if (EitherMayExit) { bool UseSequentialUMin = !isa(ExitCond); // Both conditions must be same for the loop to continue executing. @@ -9118,16 +9200,15 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( Pred = ExitCond->getInversePredicate(); const ICmpInst::Predicate OriginalPred = Pred; - const SCEV *LHS = getSCEV(ExitCond->getOperand(0)); - const SCEV *RHS = getSCEV(ExitCond->getOperand(1)); + SCEVUse LHS = getSCEV(ExitCond->getOperand(0)); + SCEVUse RHS = getSCEV(ExitCond->getOperand(1)); ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit, AllowPredicates); if (EL.hasAnyInfo()) return EL; - auto *ExhaustiveCount = - computeExitCountExhaustively(L, ExitCond, ExitIfTrue); + auto ExhaustiveCount = computeExitCountExhaustively(L, ExitCond, ExitIfTrue); if (!isa(ExhaustiveCount)) return ExhaustiveCount; @@ -9136,7 +9217,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( ExitCond->getOperand(1), L, OriginalPred); } ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( - const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, + const Loop *L, ICmpInst::Predicate Pred, SCEVUse LHS, SCEVUse RHS, bool ControlsOnlyExit, bool AllowPredicates) { // Try to evaluate any dependencies out of the loop. @@ -9165,7 +9246,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( ConstantRange CompRange = ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt()); - const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this); + SCEVUse Ret = AddRec->getNumIterationsInRange(CompRange, *this); if (!isa(Ret)) return Ret; } @@ -9178,7 +9259,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( // because if it did, we'd have an infinite (undefined) loop. // TODO: We can peel off any functions which are invertible *in L*. Loop // invariant terms are effectively constants for our purposes here. - auto *InnerLHS = LHS; + auto InnerLHS = LHS; if (auto *ZExt = dyn_cast(LHS)) InnerLHS = ZExt->getOperand(); if (const SCEVAddRecExpr *AR = dyn_cast(InnerLHS); @@ -9187,7 +9268,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( /*OrNegative=*/true)) { auto Flags = AR->getNoWrapFlags(); Flags = setFlags(Flags, SCEV::FlagNW); - SmallVector Operands{AR->operands()}; + SmallVector Operands{AR->operands()}; Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags); setNoWrapFlags(const_cast(AR), Flags); } @@ -9205,7 +9286,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( isKnownPositive(AR->getStepRecurrence(*this))) { auto Flags = AR->getNoWrapFlags(); Flags = setFlags(Flags, WrapType); - SmallVector Operands{AR->operands()}; + SmallVector Operands{AR->operands()}; Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags); setNoWrapFlags(const_cast(AR), Flags); } @@ -9320,8 +9401,8 @@ ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L, assert(L->contains(Switch->getDefaultDest()) && "Default case must not exit the loop!"); - const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L); - const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock)); + SCEVUse LHS = getSCEVAtScope(Switch->getCondition(), L); + SCEVUse RHS = getConstant(Switch->findCaseDest(ExitingBlock)); // while (X != Y) --> while (X-Y != 0) ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit); @@ -9334,8 +9415,8 @@ ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L, static ConstantInt * EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, ScalarEvolution &SE) { - const SCEV *InVal = SE.getConstant(C); - const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE); + SCEVUse InVal = SE.getConstant(C); + SCEVUse Val = AddRec->evaluateAtIteration(InVal, SE); assert(isa(Val) && "Evaluation of SCEV at constant didn't fold correctly?"); return cast(Val)->getValue(); @@ -9476,7 +9557,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit( if (Result->isZeroValue()) { unsigned BitWidth = getTypeSizeInBits(RHS->getType()); - const SCEV *UpperBound = + SCEVUse UpperBound = getConstant(getEffectiveSCEVType(RHS->getType()), BitWidth); return ExitLimit(getCouldNotCompute(), UpperBound, UpperBound, false); } @@ -9726,9 +9807,9 @@ ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN, } } -const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L, - Value *Cond, - bool ExitWhen) { +SCEVUse ScalarEvolution::computeExitCountExhaustively(const Loop *L, + Value *Cond, + bool ExitWhen) { PHINode *PN = getConstantEvolvingPHI(Cond, L); if (!PN) return getCouldNotCompute(); @@ -9793,9 +9874,8 @@ const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L, return getCouldNotCompute(); } -const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { - SmallVector, 2> &Values = - ValuesAtScopes[V]; +SCEVUse ScalarEvolution::getSCEVAtScope(SCEVUse V, const Loop *L) { + SmallVector, 2> &Values = ValuesAtScopes[V]; // Check to see if we've folded this expression at this loop before. for (auto &LS : Values) if (LS.first == L) @@ -9804,7 +9884,7 @@ const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { Values.emplace_back(L, nullptr); // Otherwise compute it. - const SCEV *C = computeSCEVAtScope(V, L); + SCEVUse C = computeSCEVAtScope(V, L); for (auto &LS : reverse(ValuesAtScopes[V])) if (LS.first == L) { LS.second = C; @@ -9819,7 +9899,7 @@ const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { /// will return Constants for objects which aren't represented by a /// SCEVConstant, because SCEVConstant is restricted to ConstantInt. /// Returns NULL if the SCEV isn't representable as a Constant. -static Constant *BuildConstantFromSCEV(const SCEV *V) { +static Constant *BuildConstantFromSCEV(SCEVUse V) { switch (V->getSCEVType()) { case scCouldNotCompute: case scAddRecExpr: @@ -9845,7 +9925,7 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) { case scAddExpr: { const SCEVAddExpr *SA = cast(V); Constant *C = nullptr; - for (const SCEV *Op : SA->operands()) { + for (SCEVUse Op : SA->operands()) { Constant *OpC = BuildConstantFromSCEV(Op); if (!OpC) return nullptr; @@ -9880,9 +9960,8 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) { llvm_unreachable("Unknown SCEV kind!"); } -const SCEV * -ScalarEvolution::getWithOperands(const SCEV *S, - SmallVectorImpl &NewOps) { +SCEVUse ScalarEvolution::getWithOperands(SCEVUse S, + SmallVectorImpl &NewOps) { switch (S->getSCEVType()) { case scTruncate: case scZeroExtend: @@ -9916,7 +9995,7 @@ ScalarEvolution::getWithOperands(const SCEV *S, llvm_unreachable("Unknown SCEV kind!"); } -const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { +SCEVUse ScalarEvolution::computeSCEVAtScope(SCEVUse V, const Loop *L) { switch (V->getSCEVType()) { case scConstant: case scVScale: @@ -9929,21 +10008,21 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { // Avoid performing the look-up in the common case where the specified // expression has no loop-variant portions. for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) { - const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L); + SCEVUse OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L); if (OpAtScope == AddRec->getOperand(i)) continue; // Okay, at least one of these operands is loop variant but might be // foldable. Build a new instance of the folded commutative expression. - SmallVector NewOps; + SmallVector NewOps; NewOps.reserve(AddRec->getNumOperands()); append_range(NewOps, AddRec->operands().take_front(i)); NewOps.push_back(OpAtScope); for (++i; i != e; ++i) NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L)); - const SCEV *FoldedRec = getAddRecExpr( - NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW)); + SCEVUse FoldedRec = getAddRecExpr(NewOps, AddRec->getLoop(), + AddRec->getNoWrapFlags(SCEV::FlagNW)); AddRec = dyn_cast(FoldedRec); // The addrec may be folded to a nonrecurrence, for example, if the // induction variable is multiplied by zero after constant folding. Go @@ -9958,7 +10037,7 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { if (!AddRec->getLoop()->contains(L)) { // To evaluate this recurrence, we need to know how many times the AddRec // loop iterates. Compute this now. - const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop()); + SCEVUse BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop()); if (BackedgeTakenCount == getCouldNotCompute()) return AddRec; @@ -9980,15 +10059,15 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { case scUMinExpr: case scSMinExpr: case scSequentialUMinExpr: { - ArrayRef Ops = V->operands(); + ArrayRef Ops = V->operands(); // Avoid performing the look-up in the common case where the specified // expression has no loop-variant portions. for (unsigned i = 0, e = Ops.size(); i != e; ++i) { - const SCEV *OpAtScope = getSCEVAtScope(Ops[i], L); + SCEVUse OpAtScope = getSCEVAtScope(Ops[i], L); if (OpAtScope != Ops[i]) { // Okay, at least one of these operands is loop variant but might be // foldable. Build a new instance of the folded commutative expression. - SmallVector NewOps; + SmallVector NewOps; NewOps.reserve(Ops.size()); append_range(NewOps, Ops.take_front(i)); NewOps.push_back(OpAtScope); @@ -10021,7 +10100,7 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { // to see if the loop that contains it has a known backedge-taken // count. If so, we may be able to force computation of the exit // value. - const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop); + SCEVUse BackedgeTakenCount = getBackedgeTakenCount(CurrLoop); // This trivial case can show up in some degenerate cases where // the incoming IR has not yet been fully simplified. if (BackedgeTakenCount->isZero()) { @@ -10086,8 +10165,8 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { if (!isSCEVable(Op->getType())) return V; - const SCEV *OrigV = getSCEV(Op); - const SCEV *OpV = getSCEVAtScope(OrigV, L); + SCEVUse OrigV = getSCEV(Op); + SCEVUse OpV = getSCEVAtScope(OrigV, L); MadeImprovement |= OrigV != OpV; Constant *C = BuildConstantFromSCEV(OpV); @@ -10115,11 +10194,11 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { llvm_unreachable("Unknown SCEV type!"); } -const SCEV *ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) { +SCEVUse ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) { return getSCEVAtScope(getSCEV(V), L); } -const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const { +SCEVUse ScalarEvolution::stripInjectiveFunctions(SCEVUse S) const { if (const SCEVZeroExtendExpr *ZExt = dyn_cast(S)) return stripInjectiveFunctions(ZExt->getOperand()); if (const SCEVSignExtendExpr *SExt = dyn_cast(S)) @@ -10136,7 +10215,7 @@ const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const { /// /// If the equation does not have a solution, SCEVCouldNotCompute is returned. static const SCEV * -SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, +SolveLinEquationWithOverflow(const APInt &A, SCEVUse B, SmallVectorImpl *Predicates, ScalarEvolution &SE) { @@ -10185,7 +10264,7 @@ SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, // I * (B / D) mod (N / D) // To simplify the computation, we factor out the divide by D: // (I * B mod N) / D - const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2)); + SCEVUse D = SE.getConstant(APInt::getOneBitSet(BW, Mult2)); return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D); } @@ -10462,7 +10541,7 @@ SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth); } -ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V, +ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(SCEVUse V, const Loop *L, bool ControlsOnlyExit, bool AllowPredicates) { @@ -10476,7 +10555,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V, // If the value is a constant if (const SCEVConstant *C = dyn_cast(V)) { // If the value is already zero, the branch will execute zero times. - if (C->getValue()->isZero()) return C; + if (C->getValue()->isZero()) + return SCEVUse(C); return getCouldNotCompute(); // Otherwise it will loop infinitely. } @@ -10521,8 +10601,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V, // where BW is the common bit width of Start and Step. // Get the initial value for the loop. - const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop()); - const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop()); + SCEVUse Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop()); + SCEVUse Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop()); const SCEVConstant *StepC = dyn_cast(Step); if (!isLoopInvariant(Step, L)) @@ -10557,9 +10637,9 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V, // Explicitly handling this here is necessary because getUnsignedRange // isn't context-sensitive; it doesn't know that we only care about the // range inside the loop. - const SCEV *Zero = getZero(Distance->getType()); - const SCEV *One = getOne(Distance->getType()); - const SCEV *DistancePlusOne = getAddExpr(Distance, One); + SCEVUse Zero = getZero(Distance->getType()); + SCEVUse One = getOne(Distance->getType()); + SCEVUse DistancePlusOne = getAddExpr(Distance, One); if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) { // If Distance + 1 doesn't overflow, we can compute the maximum distance // as "unsigned_max(Distance + 1) - 1". @@ -10584,37 +10664,36 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V, if (!loopIsFiniteByAssumption(L) && !isKnownNonZero(StepWLG)) return getCouldNotCompute(); - const SCEV *Exact = + SCEVUse Exact = getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step); - const SCEV *ConstantMax = getCouldNotCompute(); + SCEVUse ConstantMax = getCouldNotCompute(); if (Exact != getCouldNotCompute()) { APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, Guards)); ConstantMax = getConstant(APIntOps::umin(MaxInt, getUnsignedRangeMax(Exact))); } - const SCEV *SymbolicMax = - isa(Exact) ? ConstantMax : Exact; + SCEVUse SymbolicMax = isa(Exact) ? ConstantMax : Exact; return ExitLimit(Exact, ConstantMax, SymbolicMax, false, Predicates); } // Solve the general equation. if (!StepC || StepC->getValue()->isZero()) return getCouldNotCompute(); - const SCEV *E = SolveLinEquationWithOverflow( + SCEVUse E = SolveLinEquationWithOverflow( StepC->getAPInt(), getNegativeSCEV(Start), AllowPredicates ? &Predicates : nullptr, *this); - const SCEV *M = E; + SCEVUse M = E; if (E != getCouldNotCompute()) { APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, Guards)); M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E))); } - auto *S = isa(E) ? M : E; + auto S = isa(E) ? M : E; return ExitLimit(E, M, S, false, Predicates); } -ScalarEvolution::ExitLimit -ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) { +ScalarEvolution::ExitLimit ScalarEvolution::howFarToNonZero(SCEVUse V, + const Loop *L) { // Loops that look like: while (X == 0) are very strange indeed. We don't // handle them yet except for the trivial case. This could be expanded in the // future as needed. @@ -10654,9 +10733,10 @@ ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB) /// expressions are equal, however for the purposes of looking for a condition /// guarding a loop, it can be useful to be a little more general, since a /// front-end may have replicated the controlling expression. -static bool HasSameValue(const SCEV *A, const SCEV *B) { - // Quick check to see if they are the same SCEV. - if (A == B) return true; +static bool HasSameValue(SCEVUse A, SCEVUse B) { + // Quick check to see if they are the same SCEV, ignoring use-specific flags. + if (A.getPointer() == B.getPointer()) + return true; auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) { // Not all instructions that are "identical" compute the same value. For @@ -10678,7 +10758,7 @@ static bool HasSameValue(const SCEV *A, const SCEV *B) { return false; } -static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS) { +static bool MatchBinarySub(SCEVUse S, SCEVUse &LHS, SCEVUse &RHS) { const SCEVAddExpr *Add = dyn_cast(S); if (!Add || Add->getNumOperands() != 2) return false; @@ -10698,7 +10778,7 @@ static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS) { } bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, - const SCEV *&LHS, const SCEV *&RHS, + SCEVUse &LHS, SCEVUse &RHS, unsigned Depth) { bool Changed = false; // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or @@ -10882,23 +10962,23 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, return Changed; } -bool ScalarEvolution::isKnownNegative(const SCEV *S) { +bool ScalarEvolution::isKnownNegative(SCEVUse S) { return getSignedRangeMax(S).isNegative(); } -bool ScalarEvolution::isKnownPositive(const SCEV *S) { +bool ScalarEvolution::isKnownPositive(SCEVUse S) { return getSignedRangeMin(S).isStrictlyPositive(); } -bool ScalarEvolution::isKnownNonNegative(const SCEV *S) { +bool ScalarEvolution::isKnownNonNegative(SCEVUse S) { return !getSignedRangeMin(S).isNegative(); } -bool ScalarEvolution::isKnownNonPositive(const SCEV *S) { +bool ScalarEvolution::isKnownNonPositive(SCEVUse S) { return !getSignedRangeMax(S).isStrictlyPositive(); } -bool ScalarEvolution::isKnownNonZero(const SCEV *S) { +bool ScalarEvolution::isKnownNonZero(SCEVUse S) { // Query push down for cases where the unsigned range is // less than sufficient. if (const auto *SExt = dyn_cast(S)) @@ -10926,20 +11006,20 @@ bool ScalarEvolution::isKnownToBeAPowerOfTwo(const SCEV *S, bool OrZero, return all_of(Mul->operands(), NonRecursive) && (OrZero || isKnownNonZero(S)); } -std::pair -ScalarEvolution::SplitIntoInitAndPostInc(const Loop *L, const SCEV *S) { +std::pair +ScalarEvolution::SplitIntoInitAndPostInc(const Loop *L, SCEVUse S) { // Compute SCEV on entry of loop L. - const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this); + SCEVUse Start = SCEVInitRewriter::rewrite(S, L, *this); if (Start == getCouldNotCompute()) return { Start, Start }; // Compute post increment SCEV for loop L. - const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this); + SCEVUse PostInc = SCEVPostIncRewriter::rewrite(S, L, *this); assert(PostInc != getCouldNotCompute() && "Unexpected could not compute"); return { Start, PostInc }; } -bool ScalarEvolution::isKnownViaInduction(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS) { +bool ScalarEvolution::isKnownViaInduction(ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS) { // First collect all loops. SmallPtrSet LoopsUsed; getUsedLoops(LHS, LoopsUsed); @@ -10988,8 +11068,8 @@ bool ScalarEvolution::isKnownViaInduction(ICmpInst::Predicate Pred, isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first); } -bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS) { +bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS) { // Canonicalize the inputs first. (void)SimplifyICmpOperands(Pred, LHS, RHS); @@ -11004,8 +11084,8 @@ bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred, } std::optional ScalarEvolution::evaluatePredicate(ICmpInst::Predicate Pred, - const SCEV *LHS, - const SCEV *RHS) { + SCEVUse LHS, + SCEVUse RHS) { if (isKnownPredicate(Pred, LHS, RHS)) return true; if (isKnownPredicate(ICmpInst::getInversePredicate(Pred), LHS, RHS)) @@ -11013,17 +11093,16 @@ std::optional ScalarEvolution::evaluatePredicate(ICmpInst::Predicate Pred, return std::nullopt; } -bool ScalarEvolution::isKnownPredicateAt(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS, - const Instruction *CtxI) { +bool ScalarEvolution::isKnownPredicateAt(ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS, const Instruction *CtxI) { // TODO: Analyze guards and assumes from Context's block. return isKnownPredicate(Pred, LHS, RHS) || isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS); } std::optional -ScalarEvolution::evaluatePredicateAt(ICmpInst::Predicate Pred, const SCEV *LHS, - const SCEV *RHS, const Instruction *CtxI) { +ScalarEvolution::evaluatePredicateAt(ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS, const Instruction *CtxI) { std::optional KnownWithoutContext = evaluatePredicate(Pred, LHS, RHS); if (KnownWithoutContext) return KnownWithoutContext; @@ -11039,7 +11118,7 @@ ScalarEvolution::evaluatePredicateAt(ICmpInst::Predicate Pred, const SCEV *LHS, bool ScalarEvolution::isKnownOnEveryIteration(ICmpInst::Predicate Pred, const SCEVAddRecExpr *LHS, - const SCEV *RHS) { + SCEVUse RHS) { const Loop *L = LHS->getLoop(); return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) && isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS); @@ -11097,7 +11176,7 @@ ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS, if (!LHS->hasNoSignedWrap()) return std::nullopt; - const SCEV *Step = LHS->getStepRecurrence(*this); + SCEVUse Step = LHS->getStepRecurrence(*this); if (isKnownNonNegative(Step)) return IsGreater ? MonotonicallyIncreasing : MonotonicallyDecreasing; @@ -11110,7 +11189,7 @@ ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS, std::optional ScalarEvolution::getLoopInvariantPredicate(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS, + SCEVUse LHS, SCEVUse RHS, const Loop *L, const Instruction *CtxI) { // If there is a loop-invariant, force it into the RHS, otherwise bail out. @@ -11196,8 +11275,8 @@ ScalarEvolution::getLoopInvariantPredicate(ICmpInst::Predicate Pred, std::optional ScalarEvolution::getLoopInvariantExitCondDuringFirstIterations( - ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, - const Instruction *CtxI, const SCEV *MaxIter) { + ICmpInst::Predicate Pred, SCEVUse LHS, SCEVUse RHS, const Loop *L, + const Instruction *CtxI, SCEVUse MaxIter) { if (auto LIP = getLoopInvariantExitCondDuringFirstIterationsImpl( Pred, LHS, RHS, L, CtxI, MaxIter)) return LIP; @@ -11207,7 +11286,7 @@ ScalarEvolution::getLoopInvariantExitCondDuringFirstIterations( // work, try the following trick: if the a predicate is invariant for X, it // is also invariant for umin(X, ...). So try to find something that works // among subexpressions of MaxIter expressed as umin. - for (auto *Op : UMin->operands()) + for (auto Op : UMin->operands()) if (auto LIP = getLoopInvariantExitCondDuringFirstIterationsImpl( Pred, LHS, RHS, L, CtxI, Op)) return LIP; @@ -11216,8 +11295,8 @@ ScalarEvolution::getLoopInvariantExitCondDuringFirstIterations( std::optional ScalarEvolution::getLoopInvariantExitCondDuringFirstIterationsImpl( - ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, - const Instruction *CtxI, const SCEV *MaxIter) { + ICmpInst::Predicate Pred, SCEVUse LHS, SCEVUse RHS, const Loop *L, + const Instruction *CtxI, SCEVUse MaxIter) { // Try to prove the following set of facts: // - The predicate is monotonic in the iteration space. // - If the check does not fail on the 1st iteration: @@ -11244,9 +11323,9 @@ ScalarEvolution::getLoopInvariantExitCondDuringFirstIterationsImpl( return std::nullopt; // TODO: Support steps other than +/- 1. - const SCEV *Step = AR->getStepRecurrence(*this); - auto *One = getOne(Step->getType()); - auto *MinusOne = getNegativeSCEV(One); + SCEVUse Step = AR->getStepRecurrence(*this); + auto One = getOne(Step->getType()); + auto MinusOne = getNegativeSCEV(One); if (Step != One && Step != MinusOne) return std::nullopt; @@ -11257,7 +11336,7 @@ ScalarEvolution::getLoopInvariantExitCondDuringFirstIterationsImpl( return std::nullopt; // Value of IV on suggested last iteration. - const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this); + SCEVUse Last = AR->evaluateAtIteration(MaxIter, *this); // Does it still meet the requirement? if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS)) return std::nullopt; @@ -11270,7 +11349,7 @@ ScalarEvolution::getLoopInvariantExitCondDuringFirstIterationsImpl( CmpInst::isSigned(Pred) ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE; if (Step == MinusOne) NoOverflowPred = CmpInst::getSwappedPredicate(NoOverflowPred); - const SCEV *Start = AR->getStart(); + SCEVUse Start = AR->getStart(); if (!isKnownPredicateAt(NoOverflowPred, Start, Last, CtxI)) return std::nullopt; @@ -11279,7 +11358,7 @@ ScalarEvolution::getLoopInvariantExitCondDuringFirstIterationsImpl( } bool ScalarEvolution::isKnownPredicateViaConstantRanges( - ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { + ICmpInst::Predicate Pred, SCEVUse LHS, SCEVUse RHS) { if (HasSameValue(LHS, RHS)) return ICmpInst::isTrueWhenEqual(Pred); @@ -11305,7 +11384,7 @@ bool ScalarEvolution::isKnownPredicateViaConstantRanges( auto UR = getUnsignedRange(RHS); if (CheckRanges(UL, UR)) return true; - auto *Diff = getMinusSCEV(LHS, RHS); + auto Diff = getMinusSCEV(LHS, RHS); return !isa(Diff) && isKnownNonZero(Diff); } @@ -11321,17 +11400,16 @@ bool ScalarEvolution::isKnownPredicateViaConstantRanges( } bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred, - const SCEV *LHS, - const SCEV *RHS) { + SCEVUse LHS, SCEVUse RHS) { // Match X to (A + C1) and Y to (A + C2), where // C1 and C2 are constant integers. If either X or Y are not add expressions, // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via // OutC1 and OutC2. - auto MatchBinaryAddToConst = [this](const SCEV *X, const SCEV *Y, - APInt &OutC1, APInt &OutC2, + auto MatchBinaryAddToConst = [this](SCEVUse X, SCEVUse Y, APInt &OutC1, + APInt &OutC2, SCEV::NoWrapFlags ExpectedFlags) { - const SCEV *XNonConstOp, *XConstOp; - const SCEV *YNonConstOp, *YConstOp; + SCEVUse XNonConstOp, XConstOp; + SCEVUse YNonConstOp, YConstOp; SCEV::NoWrapFlags XFlagsPresent; SCEV::NoWrapFlags YFlagsPresent; @@ -11414,8 +11492,7 @@ bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred, } bool ScalarEvolution::isKnownPredicateViaSplitting(ICmpInst::Predicate Pred, - const SCEV *LHS, - const SCEV *RHS) { + SCEVUse LHS, SCEVUse RHS) { if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate) return false; @@ -11436,8 +11513,8 @@ bool ScalarEvolution::isKnownPredicateViaSplitting(ICmpInst::Predicate Pred, } bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB, - ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS) { + ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS) { // No need to even try if we know the module has no guards. if (!HasGuards) return false; @@ -11455,10 +11532,9 @@ bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB, /// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is /// protected by a conditional between LHS and RHS. This is used to /// to eliminate casts. -bool -ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L, - ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS) { +bool ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L, + ICmpInst::Predicate Pred, + SCEVUse LHS, SCEVUse RHS) { // Interpret a null as meaning no loop, where there is obviously no guard // (interprocedural conditions notwithstanding). Do not bother about // unreachable loops. @@ -11494,15 +11570,15 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L, // See if we can exploit a trip count to prove the predicate. const auto &BETakenInfo = getBackedgeTakenInfo(L); - const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this); + SCEVUse LatchBECount = BETakenInfo.getExact(Latch, this); if (LatchBECount != getCouldNotCompute()) { // We know that Latch branches back to the loop header exactly // LatchBECount times. This means the backdege condition at Latch is // equivalent to "{0,+,1} u< LatchBECount". Type *Ty = LatchBECount->getType(); auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW); - const SCEV *LoopCounter = - getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags); + SCEVUse LoopCounter = + getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags); if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter, LatchBECount)) return true; @@ -11563,8 +11639,7 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L, bool ScalarEvolution::isBasicBlockEntryGuardedByCond(const BasicBlock *BB, ICmpInst::Predicate Pred, - const SCEV *LHS, - const SCEV *RHS) { + SCEVUse LHS, SCEVUse RHS) { // Do not bother proving facts for unreachable code. if (!DT.isReachableFromEntry(BB)) return true; @@ -11663,8 +11738,7 @@ bool ScalarEvolution::isBasicBlockEntryGuardedByCond(const BasicBlock *BB, bool ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, ICmpInst::Predicate Pred, - const SCEV *LHS, - const SCEV *RHS) { + SCEVUse LHS, SCEVUse RHS) { // Interpret a null as meaning no loop, where there is obviously no guard // (interprocedural conditions notwithstanding). if (!L) @@ -11682,10 +11756,9 @@ bool ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS); } -bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, - const SCEV *RHS, - const Value *FoundCondValue, bool Inverse, - const Instruction *CtxI) { +bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS, const Value *FoundCondValue, + bool Inverse, const Instruction *CtxI) { // False conditions implies anything. Do not bother analyzing it further. if (FoundCondValue == ConstantInt::getBool(FoundCondValue->getContext(), Inverse)) @@ -11720,16 +11793,15 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, else FoundPred = ICI->getPredicate(); - const SCEV *FoundLHS = getSCEV(ICI->getOperand(0)); - const SCEV *FoundRHS = getSCEV(ICI->getOperand(1)); + SCEVUse FoundLHS = getSCEV(ICI->getOperand(0)); + SCEVUse FoundRHS = getSCEV(ICI->getOperand(1)); return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, CtxI); } -bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, - const SCEV *RHS, - ICmpInst::Predicate FoundPred, - const SCEV *FoundLHS, const SCEV *FoundRHS, +bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS, ICmpInst::Predicate FoundPred, + SCEVUse FoundLHS, SCEVUse FoundRHS, const Instruction *CtxI) { // Balance the types. if (getTypeSizeInBits(LHS->getType()) < @@ -11742,14 +11814,14 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, auto *NarrowType = LHS->getType(); auto *WideType = FoundLHS->getType(); auto BitWidth = getTypeSizeInBits(NarrowType); - const SCEV *MaxValue = getZeroExtendExpr( + SCEVUse MaxValue = getZeroExtendExpr( getConstant(APInt::getMaxValue(BitWidth)), WideType); if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundLHS, MaxValue) && isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundRHS, MaxValue)) { - const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType); - const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType); + SCEVUse TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType); + SCEVUse TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType); if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, TruncFoundLHS, TruncFoundRHS, CtxI)) return true; @@ -11781,10 +11853,12 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, FoundRHS, CtxI); } -bool ScalarEvolution::isImpliedCondBalancedTypes( - ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, - ICmpInst::Predicate FoundPred, const SCEV *FoundLHS, const SCEV *FoundRHS, - const Instruction *CtxI) { +bool ScalarEvolution::isImpliedCondBalancedTypes(ICmpInst::Predicate Pred, + SCEVUse LHS, SCEVUse RHS, + ICmpInst::Predicate FoundPred, + SCEVUse FoundLHS, + SCEVUse FoundRHS, + const Instruction *CtxI) { assert(getTypeSizeInBits(LHS->getType()) == getTypeSizeInBits(FoundLHS->getType()) && "Types should be balanced!"); @@ -11862,8 +11936,8 @@ bool ScalarEvolution::isImpliedCondBalancedTypes( // Create local copies that we can freely swap and canonicalize our // conditions to "le/lt". ICmpInst::Predicate CanonicalPred = Pred, CanonicalFoundPred = FoundPred; - const SCEV *CanonicalLHS = LHS, *CanonicalRHS = RHS, - *CanonicalFoundLHS = FoundLHS, *CanonicalFoundRHS = FoundRHS; + SCEVUse CanonicalLHS = LHS, CanonicalRHS = RHS, + CanonicalFoundLHS = FoundLHS, CanonicalFoundRHS = FoundRHS; if (ICmpInst::isGT(CanonicalPred) || ICmpInst::isGE(CanonicalPred)) { CanonicalPred = ICmpInst::getSwappedPredicate(CanonicalPred); CanonicalFoundPred = ICmpInst::getSwappedPredicate(CanonicalFoundPred); @@ -11896,7 +11970,7 @@ bool ScalarEvolution::isImpliedCondBalancedTypes( (isa(FoundLHS) || isa(FoundRHS))) { const SCEVConstant *C = nullptr; - const SCEV *V = nullptr; + SCEVUse V = nullptr; if (isa(FoundLHS)) { C = cast(FoundLHS); @@ -11985,8 +12059,7 @@ bool ScalarEvolution::isImpliedCondBalancedTypes( return false; } -bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr, - const SCEV *&L, const SCEV *&R, +bool ScalarEvolution::splitBinaryAdd(SCEVUse Expr, SCEVUse &L, SCEVUse &R, SCEV::NoWrapFlags &Flags) { const auto *AE = dyn_cast(Expr); if (!AE || AE->getNumOperands() != 2) @@ -11998,8 +12071,8 @@ bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr, return true; } -std::optional -ScalarEvolution::computeConstantDifference(const SCEV *More, const SCEV *Less) { +std::optional ScalarEvolution::computeConstantDifference(SCEVUse More, + SCEVUse Less) { // We avoid subtracting expressions here because this function is usually // fairly deep in the call stack (i.e. is called many times). @@ -12116,8 +12189,8 @@ ScalarEvolution::computeConstantDifference(const SCEV *More, const SCEV *Less) { } bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart( - ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, - const SCEV *FoundLHS, const SCEV *FoundRHS, const Instruction *CtxI) { + ICmpInst::Predicate Pred, SCEVUse LHS, SCEVUse RHS, SCEVUse FoundLHS, + SCEVUse FoundRHS, const Instruction *CtxI) { // Try to recognize the following pattern: // // FoundRHS = ... @@ -12161,8 +12234,8 @@ bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart( } bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow( - ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, - const SCEV *FoundLHS, const SCEV *FoundRHS) { + ICmpInst::Predicate Pred, SCEVUse LHS, SCEVUse RHS, SCEVUse FoundLHS, + SCEVUse FoundRHS) { if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT) return false; @@ -12239,10 +12312,9 @@ bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow( getConstant(FoundRHSLimit)); } -bool ScalarEvolution::isImpliedViaMerge(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS, - const SCEV *FoundLHS, - const SCEV *FoundRHS, unsigned Depth) { +bool ScalarEvolution::isImpliedViaMerge(ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS, SCEVUse FoundLHS, + SCEVUse FoundRHS, unsigned Depth) { const PHINode *LPhi = nullptr, *RPhi = nullptr; auto ClearOnExit = make_scope_exit([&]() { @@ -12296,7 +12368,7 @@ bool ScalarEvolution::isImpliedViaMerge(ICmpInst::Predicate Pred, const BasicBlock *LBB = LPhi->getParent(); const SCEVAddRecExpr *RAR = dyn_cast(RHS); - auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) { + auto ProvedEasily = [&](SCEVUse S1, SCEVUse S2) { return isKnownViaNonRecursiveReasoning(Pred, S1, S2) || isImpliedCondOperandsViaRanges(Pred, S1, S2, Pred, FoundLHS, FoundRHS) || isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth); @@ -12308,8 +12380,8 @@ bool ScalarEvolution::isImpliedViaMerge(ICmpInst::Predicate Pred, // the predicate is true for incoming values from this block, then the // predicate is also true for the Phis. for (const BasicBlock *IncBB : predecessors(LBB)) { - const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB)); - const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB)); + SCEVUse L = getSCEV(LPhi->getIncomingValueForBlock(IncBB)); + SCEVUse R = getSCEV(RPhi->getIncomingValueForBlock(IncBB)); if (!ProvedEasily(L, R)) return false; } @@ -12324,12 +12396,12 @@ bool ScalarEvolution::isImpliedViaMerge(ICmpInst::Predicate Pred, auto *RLoop = RAR->getLoop(); auto *Predecessor = RLoop->getLoopPredecessor(); assert(Predecessor && "Loop with AddRec with no predecessor?"); - const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor)); + SCEVUse L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor)); if (!ProvedEasily(L1, RAR->getStart())) return false; auto *Latch = RLoop->getLoopLatch(); assert(Latch && "Loop with AddRec with no latch?"); - const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch)); + SCEVUse L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch)); if (!ProvedEasily(L2, RAR->getPostIncExpr(*this))) return false; } else { @@ -12341,7 +12413,7 @@ bool ScalarEvolution::isImpliedViaMerge(ICmpInst::Predicate Pred, // Check that RHS is available in this block. if (!dominates(RHS, IncBB)) return false; - const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB)); + SCEVUse L = getSCEV(LPhi->getIncomingValueForBlock(IncBB)); // Make sure L does not refer to a value from a potentially previous // iteration of a loop. if (!properlyDominates(L, LBB)) @@ -12354,10 +12426,9 @@ bool ScalarEvolution::isImpliedViaMerge(ICmpInst::Predicate Pred, } bool ScalarEvolution::isImpliedCondOperandsViaShift(ICmpInst::Predicate Pred, - const SCEV *LHS, - const SCEV *RHS, - const SCEV *FoundLHS, - const SCEV *FoundRHS) { + SCEVUse LHS, SCEVUse RHS, + SCEVUse FoundLHS, + SCEVUse FoundRHS) { // We want to imply LHS < RHS from LHS < (RHS >> shiftvalue). First, make // sure that we are dealing with same LHS. if (RHS == FoundRHS) { @@ -12377,7 +12448,7 @@ bool ScalarEvolution::isImpliedCondOperandsViaShift(ICmpInst::Predicate Pred, using namespace PatternMatch; if (match(SUFoundRHS->getValue(), m_LShr(m_Value(Shiftee), m_Value(ShiftValue)))) { - auto *ShifteeS = getSCEV(Shiftee); + auto ShifteeS = getSCEV(Shiftee); // Prove one of the following: // LHS > shiftvalue) && shiftee <=u RHS ---> LHS > shiftvalue) && shiftee <=u RHS ---> LHS <=u RHS @@ -12396,9 +12467,8 @@ bool ScalarEvolution::isImpliedCondOperandsViaShift(ICmpInst::Predicate Pred, } bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS, - const SCEV *FoundLHS, - const SCEV *FoundRHS, + SCEVUse LHS, SCEVUse RHS, + SCEVUse FoundLHS, SCEVUse FoundRHS, const Instruction *CtxI) { if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, Pred, FoundLHS, FoundRHS)) return true; @@ -12419,8 +12489,7 @@ bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred, /// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values? template -static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr, - const SCEV *Candidate) { +static bool IsMinMaxConsistingOf(SCEVUse MaybeMinMaxExpr, SCEVUse Candidate) { const MinMaxExprType *MinMaxExpr = dyn_cast(MaybeMinMaxExpr); if (!MinMaxExpr) return false; @@ -12430,7 +12499,7 @@ static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr, static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS) { + SCEVUse LHS, SCEVUse RHS) { // If both sides are affine addrecs for the same loop, with equal // steps, and we know the recurrences don't wrap, then we only // need to check the predicate on the starting values. @@ -12463,8 +12532,8 @@ static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, /// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max /// expression? static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, - ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS) { + ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS) { switch (Pred) { default: return false; @@ -12495,9 +12564,8 @@ static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, } bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS, - const SCEV *FoundLHS, - const SCEV *FoundRHS, + SCEVUse LHS, SCEVUse RHS, + SCEVUse FoundLHS, SCEVUse FoundRHS, unsigned Depth) { assert(getTypeSizeInBits(LHS->getType()) == getTypeSizeInBits(RHS->getType()) && @@ -12525,7 +12593,7 @@ bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred, // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us // use this fact to prove that LHS and RHS are non-negative. - const SCEV *MinusOne = getMinusOne(LHS->getType()); + SCEVUse MinusOne = getMinusOne(LHS->getType()); if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS, FoundRHS) && isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS, @@ -12536,7 +12604,7 @@ bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred, if (Pred != ICmpInst::ICMP_SGT) return false; - auto GetOpFromSExt = [&](const SCEV *S) { + auto GetOpFromSExt = [&](SCEVUse S) { if (auto *Ext = dyn_cast(S)) return Ext->getOperand(); // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off @@ -12545,13 +12613,13 @@ bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred, }; // Acquire values from extensions. - auto *OrigLHS = LHS; - auto *OrigFoundLHS = FoundLHS; + auto OrigLHS = LHS; + auto OrigFoundLHS = FoundLHS; LHS = GetOpFromSExt(LHS); FoundLHS = GetOpFromSExt(FoundLHS); // Is the SGT predicate can be proved trivially or using the found context. - auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) { + auto IsSGTViaContext = [&](SCEVUse S1, SCEVUse S2) { return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) || isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS, FoundRHS, Depth + 1); @@ -12570,12 +12638,12 @@ bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred, if (!LHSAddExpr->hasNoSignedWrap()) return false; - auto *LL = LHSAddExpr->getOperand(0); - auto *LR = LHSAddExpr->getOperand(1); - auto *MinusOne = getMinusOne(RHS->getType()); + auto LL = LHSAddExpr->getOperand(0); + auto LR = LHSAddExpr->getOperand(1); + auto MinusOne = getMinusOne(RHS->getType()); // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context. - auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) { + auto IsSumGreaterThanRHS = [&](SCEVUse S1, SCEVUse S2) { return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS); }; // Try to prove the following rule: @@ -12605,7 +12673,7 @@ bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred, // We want to make sure that LHS = FoundLHS / Denominator. If it is so, // then a SCEV for the numerator already exists and matches with FoundLHS. - auto *Numerator = getExistingSCEV(LL); + auto Numerator = getExistingSCEV(LL); if (!Numerator || Numerator->getType() != FoundLHS->getType()) return false; @@ -12626,14 +12694,14 @@ bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred, // Given that: // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0. auto *WTy = getWiderType(DTy, FRHSTy); - auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy); - auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy); + auto DenominatorExt = getNoopOrSignExtend(Denominator, WTy); + auto FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy); // Try to prove the following rule: // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS). // For example, given that FoundLHS > 2. It means that FoundLHS is at // least 3. If we divide it by Denominator < 4, we will have at least 1. - auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2)); + auto DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2)); if (isKnownNonPositive(RHS) && IsSGTViaContext(FoundRHSExt, DenomMinusTwo)) return true; @@ -12645,8 +12713,8 @@ bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred, // 1. If FoundLHS is negative, then the result is 0. // 2. If FoundLHS is non-negative, then the result is non-negative. // Anyways, the result is non-negative. - auto *MinusOne = getMinusOne(WTy); - auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt); + auto MinusOne = getMinusOne(WTy); + auto NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt); if (isKnownNegative(RHS) && IsSGTViaContext(FoundRHSExt, NegDenomMinusOne)) return true; @@ -12662,8 +12730,8 @@ bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred, return false; } -static bool isKnownPredicateExtendIdiom(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS) { +static bool isKnownPredicateExtendIdiom(ICmpInst::Predicate Pred, SCEVUse LHS, + SCEVUse RHS) { // zext x u<= sext x, sext x s<= zext x switch (Pred) { case ICmpInst::ICMP_SGE: @@ -12694,9 +12762,9 @@ static bool isKnownPredicateExtendIdiom(ICmpInst::Predicate Pred, return false; } -bool -ScalarEvolution::isKnownViaNonRecursiveReasoning(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS) { +bool ScalarEvolution::isKnownViaNonRecursiveReasoning(ICmpInst::Predicate Pred, + SCEVUse LHS, + SCEVUse RHS) { return isKnownPredicateExtendIdiom(Pred, LHS, RHS) || isKnownPredicateViaConstantRanges(Pred, LHS, RHS) || IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) || @@ -12704,11 +12772,10 @@ ScalarEvolution::isKnownViaNonRecursiveReasoning(ICmpInst::Predicate Pred, isKnownPredicateViaNoOverflow(Pred, LHS, RHS); } -bool -ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS, - const SCEV *FoundLHS, - const SCEV *FoundRHS) { +bool ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, + SCEVUse LHS, SCEVUse RHS, + SCEVUse FoundLHS, + SCEVUse FoundRHS) { switch (Pred) { default: llvm_unreachable("Unexpected ICmpInst::Predicate value!"); case ICmpInst::ICMP_EQ: @@ -12749,12 +12816,9 @@ ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, return false; } -bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred, - const SCEV *LHS, - const SCEV *RHS, - ICmpInst::Predicate FoundPred, - const SCEV *FoundLHS, - const SCEV *FoundRHS) { +bool ScalarEvolution::isImpliedCondOperandsViaRanges( + ICmpInst::Predicate Pred, SCEVUse LHS, SCEVUse RHS, + ICmpInst::Predicate FoundPred, SCEVUse FoundLHS, SCEVUse FoundRHS) { if (!isa(RHS) || !isa(FoundRHS)) // The restriction on `FoundRHS` be lifted easily -- it exists only to // reduce the compile time impact of this optimization. @@ -12782,12 +12846,12 @@ bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred, return LHSRange.icmp(Pred, ConstRHS); } -bool ScalarEvolution::canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride, +bool ScalarEvolution::canIVOverflowOnLT(SCEVUse RHS, SCEVUse Stride, bool IsSigned) { assert(isKnownPositive(Stride) && "Positive stride expected!"); unsigned BitWidth = getTypeSizeInBits(RHS->getType()); - const SCEV *One = getOne(Stride->getType()); + SCEVUse One = getOne(Stride->getType()); if (IsSigned) { APInt MaxRHS = getSignedRangeMax(RHS); @@ -12806,11 +12870,11 @@ bool ScalarEvolution::canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride, return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS); } -bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride, +bool ScalarEvolution::canIVOverflowOnGT(SCEVUse RHS, SCEVUse Stride, bool IsSigned) { unsigned BitWidth = getTypeSizeInBits(RHS->getType()); - const SCEV *One = getOne(Stride->getType()); + SCEVUse One = getOne(Stride->getType()); if (IsSigned) { APInt MinRHS = getSignedRangeMin(RHS); @@ -12829,20 +12893,18 @@ bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride, return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS); } -const SCEV *ScalarEvolution::getUDivCeilSCEV(const SCEV *N, const SCEV *D) { +SCEVUse ScalarEvolution::getUDivCeilSCEV(SCEVUse N, SCEVUse D) { // umin(N, 1) + floor((N - umin(N, 1)) / D) // This is equivalent to "1 + floor((N - 1) / D)" for N != 0. The umin // expression fixes the case of N=0. - const SCEV *MinNOne = getUMinExpr(N, getOne(N->getType())); - const SCEV *NMinusOne = getMinusSCEV(N, MinNOne); + SCEVUse MinNOne = getUMinExpr(N, getOne(N->getType())); + SCEVUse NMinusOne = getMinusSCEV(N, MinNOne); return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D)); } -const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start, - const SCEV *Stride, - const SCEV *End, - unsigned BitWidth, - bool IsSigned) { +SCEVUse ScalarEvolution::computeMaxBECountForLT(SCEVUse Start, SCEVUse Stride, + SCEVUse End, unsigned BitWidth, + bool IsSigned) { // The logic in this function assumes we can represent a positive stride. // If we can't, the backedge-taken count must be zero. if (IsSigned && BitWidth == 1) @@ -12888,9 +12950,9 @@ const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start, } ScalarEvolution::ExitLimit -ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, - const Loop *L, bool IsSigned, - bool ControlsOnlyExit, bool AllowPredicates) { +ScalarEvolution::howManyLessThans(SCEVUse LHS, SCEVUse RHS, const Loop *L, + bool IsSigned, bool ControlsOnlyExit, + bool AllowPredicates) { SmallVector Predicates; const SCEVAddRecExpr *IV = dyn_cast(LHS); @@ -12934,11 +12996,11 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, if (AR->hasNoUnsignedWrap()) { // Emulate what getZeroExtendExpr would have done during construction // if we'd been able to infer the fact just above at that time. - const SCEV *Step = AR->getStepRecurrence(*this); + SCEVUse Step = AR->getStepRecurrence(*this); Type *Ty = ZExt->getType(); - auto *S = getAddRecExpr( - getExtendAddRecStart(AR, Ty, this, 0), - getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags()); + auto S = getAddRecExpr( + getExtendAddRecStart(AR, Ty, this, 0), + getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags()); IV = dyn_cast(S); } } @@ -12972,7 +13034,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType); ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT; - const SCEV *Stride = IV->getStepRecurrence(*this); + SCEVUse Stride = IV->getStepRecurrence(*this); bool PositiveStride = isKnownPositive(Stride); @@ -13038,7 +13100,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, // Note: The (Start - Stride) term is used to get the start' term from // (start' + stride,+,stride). Remember that we only care about the // result of this expression when stride == 0 at runtime. - auto *StartIfZero = getMinusSCEV(IV->getStart(), Stride); + auto StartIfZero = getMinusSCEV(IV->getStart(), Stride); return isLoopEntryGuardedByCond(L, Cond, StartIfZero, RHS); }; if (!wouldZeroStrideBeUB()) { @@ -13061,14 +13123,14 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, // before any possible exit. // Note that we have not yet proved RHS invariant (in general). - const SCEV *Start = IV->getStart(); + SCEVUse Start = IV->getStart(); // Preserve pointer-typed Start/RHS to pass to isLoopEntryGuardedByCond. // If we convert to integers, isLoopEntryGuardedByCond will miss some cases. // Use integer-typed versions for actual computation; we can't subtract // pointers in general. - const SCEV *OrigStart = Start; - const SCEV *OrigRHS = RHS; + SCEVUse OrigStart = Start; + SCEVUse OrigRHS = RHS; if (Start->getType()->isPointerTy()) { Start = getLosslessPtrToIntExpr(Start); if (isa(Start)) @@ -13080,8 +13142,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, return RHS; } - const SCEV *End = nullptr, *BECount = nullptr, - *BECountIfBackedgeTaken = nullptr; + SCEVUse End = nullptr, BECount = nullptr, BECountIfBackedgeTaken = nullptr; if (!isLoopInvariant(RHS, L)) { const auto *RHSAddRec = dyn_cast(RHS); if (PositiveStride && RHSAddRec != nullptr && RHSAddRec->getLoop() == L && @@ -13140,7 +13201,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, // backedge count, as if the backedge is taken at least once // max(End,Start) is End and so the result is as above, and if not // max(End,Start) is Start so we get a backedge count of zero. - auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride); + const auto OrigStartMinusStride = getMinusSCEV(OrigStart, Stride); assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!"); assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!"); assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!"); @@ -13192,7 +13253,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, // // FIXME: Should isLoopEntryGuardedByCond do this for us? auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT; - auto *StartMinusOne = + auto StartMinusOne = getAddExpr(OrigStart, getMinusOne(OrigStart->getType())); return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne); }; @@ -13302,7 +13363,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, } } - const SCEV *ConstantMaxBECount; + SCEVUse ConstantMaxBECount; bool MaxOrZero = false; if (isa(BECount)) { ConstantMaxBECount = BECount; @@ -13322,15 +13383,16 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, !isa(BECount)) ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount)); - const SCEV *SymbolicMaxBECount = + SCEVUse SymbolicMaxBECount = isa(BECount) ? ConstantMaxBECount : BECount; return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, MaxOrZero, Predicates); } -ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans( - const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned, - bool ControlsOnlyExit, bool AllowPredicates) { +ScalarEvolution::ExitLimit +ScalarEvolution::howManyGreaterThans(SCEVUse LHS, SCEVUse RHS, const Loop *L, + bool IsSigned, bool ControlsOnlyExit, + bool AllowPredicates) { SmallVector Predicates; // We handle only IV > Invariant if (!isLoopInvariant(RHS, L)) @@ -13351,7 +13413,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans( bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType); ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT; - const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this)); + SCEVUse Stride = getNegativeSCEV(IV->getStepRecurrence(*this)); // Avoid negative or zero stride values if (!isKnownPositive(Stride)) @@ -13365,8 +13427,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans( if (canIVOverflowOnGT(RHS, Stride, IsSigned)) return getCouldNotCompute(); - const SCEV *Start = IV->getStart(); - const SCEV *End = RHS; + SCEVUse Start = IV->getStart(); + SCEVUse End = RHS; if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) { // If we know that Start >= RHS in the context of loop, then we know that // min(RHS, Start) = RHS at this point. @@ -13391,8 +13453,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans( // Compute ((Start - End) + (Stride - 1)) / Stride. // FIXME: This can overflow. Holding off on fixing this for now; // howManyGreaterThans will hopefully be gone soon. - const SCEV *One = getOne(Stride->getType()); - const SCEV *BECount = getUDivExpr( + SCEVUse One = getOne(Stride->getType()); + SCEVUse BECount = getUDivExpr( getAddExpr(getMinusSCEV(Start, End), getMinusSCEV(Stride, One)), Stride); APInt MaxStart = IsSigned ? getSignedRangeMax(Start) @@ -13412,7 +13474,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans( IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit) : APIntOps::umax(getUnsignedRangeMin(RHS), Limit); - const SCEV *ConstantMaxBECount = + SCEVUse ConstantMaxBECount = isa(BECount) ? BECount : getUDivCeilSCEV(getConstant(MaxStart - MinEnd), @@ -13420,25 +13482,25 @@ ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans( if (isa(ConstantMaxBECount)) ConstantMaxBECount = BECount; - const SCEV *SymbolicMaxBECount = + SCEVUse SymbolicMaxBECount = isa(BECount) ? ConstantMaxBECount : BECount; return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false, Predicates); } -const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range, - ScalarEvolution &SE) const { +SCEVUse SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range, + ScalarEvolution &SE) const { if (Range.isFullSet()) // Infinite loop. return SE.getCouldNotCompute(); // If the start is a non-zero constant, shift the range to simplify things. if (const SCEVConstant *SC = dyn_cast(getStart())) if (!SC->getValue()->isZero()) { - SmallVector Operands(operands()); + SmallVector Operands(operands()); Operands[0] = SE.getZero(SC->getType()); - const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(), - getNoWrapFlags(FlagNW)); + SCEVUse Shifted = + SE.getAddRecExpr(Operands, getLoop(), getNoWrapFlags(FlagNW)); if (const auto *ShiftedAddRec = dyn_cast(Shifted)) return ShiftedAddRec->getNumIterationsInRange( Range.subtract(SC->getAPInt()), SE); @@ -13448,7 +13510,7 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range, // The only time we can solve this is when we have all constant indices. // Otherwise, we cannot determine the overflow conditions. - if (any_of(operands(), [](const SCEV *Op) { return !isa(Op); })) + if (any_of(operands(), [](SCEVUse Op) { return !isa(Op); })) return SE.getCouldNotCompute(); // Okay at this point we know that all elements of the chrec are constants and @@ -13507,7 +13569,7 @@ SCEVAddRecExpr::getPostIncExpr(ScalarEvolution &SE) const { // simplification: it is legal to return ({rec1} + {rec2}). For example, it // may happen if we reach arithmetic depth limit while simplifying. So we // construct the returned value explicitly. - SmallVector Ops; + SmallVector Ops; // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and // (this + Step) is {A+B,+,B+C,+...,+,N}. for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i) @@ -13516,7 +13578,7 @@ SCEVAddRecExpr::getPostIncExpr(ScalarEvolution &SE) const { // have been popped out earlier). This guarantees us that if the result has // the same last operand, then it will also not be popped out, meaning that // the returned value will be an AddRec. - const SCEV *Last = getOperand(getNumOperands() - 1); + SCEVUse Last = getOperand(getNumOperands() - 1); assert(!Last->isZero() && "Recurrency with zero step?"); Ops.push_back(Last); return cast(SE.getAddRecExpr(Ops, getLoop(), @@ -13524,8 +13586,8 @@ SCEVAddRecExpr::getPostIncExpr(ScalarEvolution &SE) const { } // Return true when S contains at least an undef value. -bool ScalarEvolution::containsUndefs(const SCEV *S) const { - return SCEVExprContains(S, [](const SCEV *S) { +bool ScalarEvolution::containsUndefs(SCEVUse S) const { + return SCEVExprContains(S, [](SCEVUse S) { if (const auto *SU = dyn_cast(S)) return isa(SU->getValue()); return false; @@ -13533,8 +13595,8 @@ bool ScalarEvolution::containsUndefs(const SCEV *S) const { } // Return true when S contains a value that is a nullptr. -bool ScalarEvolution::containsErasedValue(const SCEV *S) const { - return SCEVExprContains(S, [](const SCEV *S) { +bool ScalarEvolution::containsErasedValue(SCEVUse S) const { + return SCEVExprContains(S, [](SCEVUse S) { if (const auto *SU = dyn_cast(S)) return SU->getValue() == nullptr; return false; @@ -13542,7 +13604,7 @@ bool ScalarEvolution::containsErasedValue(const SCEV *S) const { } /// Return the size of an element read or written by Inst. -const SCEV *ScalarEvolution::getElementSize(Instruction *Inst) { +SCEVUse ScalarEvolution::getElementSize(Instruction *Inst) { Type *Ty; if (StoreInst *Store = dyn_cast(Inst)) Ty = Store->getValueOperand()->getType(); @@ -13651,6 +13713,19 @@ ScalarEvolution::~ScalarEvolution() { HasRecMap.clear(); BackedgeTakenCounts.clear(); PredicatedBackedgeTakenCounts.clear(); + UnsignedRanges.clear(); + SignedRanges.clear(); + + BECountUsers.clear(); + SCEVUsers.clear(); + FoldCache.clear(); + FoldCacheUser.clear(); + ValuesAtScopes.clear(); + ValuesAtScopesUsers.clear(); + LoopDispositions.clear(); + + BlockDispositions.clear(); + ConstantMultipleCache.clear(); assert(PendingLoopPredicates.empty() && "isImpliedCond garbage"); assert(PendingPhiRanges.empty() && "getRangeRef garbage"); @@ -13686,7 +13761,7 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, if (ExitingBlocks.size() != 1) OS << " "; - auto *BTC = SE->getBackedgeTakenCount(L); + auto BTC = SE->getBackedgeTakenCount(L); if (!isa(BTC)) { OS << "backedge-taken count is "; PrintSCEVWithTypeHint(OS, BTC); @@ -13719,7 +13794,7 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, L->getHeader()->printAsOperand(OS, /*PrintType=*/false); OS << ": "; - auto *ConstantBTC = SE->getConstantMaxBackedgeTakenCount(L); + auto ConstantBTC = SE->getConstantMaxBackedgeTakenCount(L); if (!isa(ConstantBTC)) { OS << "constant max backedge-taken count is "; PrintSCEVWithTypeHint(OS, ConstantBTC); @@ -13734,7 +13809,7 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, L->getHeader()->printAsOperand(OS, /*PrintType=*/false); OS << ": "; - auto *SymbolicBTC = SE->getSymbolicMaxBackedgeTakenCount(L); + auto SymbolicBTC = SE->getSymbolicMaxBackedgeTakenCount(L); if (!isa(SymbolicBTC)) { OS << "symbolic max backedge-taken count is "; PrintSCEVWithTypeHint(OS, SymbolicBTC); @@ -13748,8 +13823,8 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, if (ExitingBlocks.size() > 1) for (BasicBlock *ExitingBlock : ExitingBlocks) { OS << " symbolic max exit count for " << ExitingBlock->getName() << ": "; - auto *ExitBTC = SE->getExitCount(L, ExitingBlock, - ScalarEvolution::SymbolicMaximum); + auto ExitBTC = + SE->getExitCount(L, ExitingBlock, ScalarEvolution::SymbolicMaximum); PrintSCEVWithTypeHint(OS, ExitBTC); if (isa(ExitBTC)) { // Retry with predicates. @@ -13769,7 +13844,7 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, } SmallVector Preds; - auto *PBT = SE->getPredicatedBackedgeTakenCount(L, Preds); + auto PBT = SE->getPredicatedBackedgeTakenCount(L, Preds); if (PBT != BTC) { assert(!Preds.empty() && "Different predicated BTC, but no predicates"); OS << "Loop "; @@ -13787,7 +13862,7 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, } Preds.clear(); - auto *PredConstantMax = + auto PredConstantMax = SE->getPredicatedConstantMaxBackedgeTakenCount(L, Preds); if (PredConstantMax != ConstantBTC) { assert(!Preds.empty() && @@ -13807,7 +13882,7 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, } Preds.clear(); - auto *PredSymbolicMax = + auto PredSymbolicMax = SE->getPredicatedSymbolicMaxBackedgeTakenCount(L, Preds); if (SymbolicBTC != PredSymbolicMax) { assert(!Preds.empty() && @@ -13883,8 +13958,8 @@ void ScalarEvolution::print(raw_ostream &OS) const { if (isSCEVable(I.getType()) && !isa(I)) { OS << I << '\n'; OS << " --> "; - const SCEV *SV = SE.getSCEV(&I); - SV->print(OS); + SCEVUse SV = SE.getSCEV(&I, /*UseCtx=*/true); + SV.print(OS); if (!isa(SV)) { OS << " U: "; SE.getUnsignedRange(SV).print(OS); @@ -13894,7 +13969,7 @@ void ScalarEvolution::print(raw_ostream &OS) const { const Loop *L = LI.getLoopFor(I.getParent()); - const SCEV *AtUse = SE.getSCEVAtScope(SV, L); + SCEVUse AtUse = SE.getSCEVAtScope(SV, L); if (AtUse != SV) { OS << " --> "; AtUse->print(OS); @@ -13908,7 +13983,7 @@ void ScalarEvolution::print(raw_ostream &OS) const { if (L) { OS << "\t\t" "Exits: "; - const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop()); + SCEVUse ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop()); if (!SE.isLoopInvariant(ExitValue, L)) { OS << "<>"; } else { @@ -13957,7 +14032,7 @@ void ScalarEvolution::print(raw_ostream &OS) const { } ScalarEvolution::LoopDisposition -ScalarEvolution::getLoopDisposition(const SCEV *S, const Loop *L) { +ScalarEvolution::getLoopDisposition(SCEVUse S, const Loop *L) { auto &Values = LoopDispositions[S]; for (auto &V : Values) { if (V.getPointer() == L) @@ -13976,7 +14051,7 @@ ScalarEvolution::getLoopDisposition(const SCEV *S, const Loop *L) { } ScalarEvolution::LoopDisposition -ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) { +ScalarEvolution::computeLoopDisposition(SCEVUse S, const Loop *L) { switch (S->getSCEVType()) { case scConstant: case scVScale: @@ -14004,7 +14079,7 @@ ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) { // This recurrence is variant w.r.t. L if any of its operands // are variant. - for (const auto *Op : AR->operands()) + for (const auto Op : AR->operands()) if (!isLoopInvariant(Op, L)) return LoopVariant; @@ -14024,7 +14099,7 @@ ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) { case scSMinExpr: case scSequentialUMinExpr: { bool HasVarying = false; - for (const auto *Op : S->operands()) { + for (const auto Op : S->operands()) { LoopDisposition D = getLoopDisposition(Op, L); if (D == LoopVariant) return LoopVariant; @@ -14047,16 +14122,16 @@ ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) { llvm_unreachable("Unknown SCEV kind!"); } -bool ScalarEvolution::isLoopInvariant(const SCEV *S, const Loop *L) { +bool ScalarEvolution::isLoopInvariant(SCEVUse S, const Loop *L) { return getLoopDisposition(S, L) == LoopInvariant; } -bool ScalarEvolution::hasComputableLoopEvolution(const SCEV *S, const Loop *L) { +bool ScalarEvolution::hasComputableLoopEvolution(SCEVUse S, const Loop *L) { return getLoopDisposition(S, L) == LoopComputable; } ScalarEvolution::BlockDisposition -ScalarEvolution::getBlockDisposition(const SCEV *S, const BasicBlock *BB) { +ScalarEvolution::getBlockDisposition(SCEVUse S, const BasicBlock *BB) { auto &Values = BlockDispositions[S]; for (auto &V : Values) { if (V.getPointer() == BB) @@ -14075,7 +14150,7 @@ ScalarEvolution::getBlockDisposition(const SCEV *S, const BasicBlock *BB) { } ScalarEvolution::BlockDisposition -ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) { +ScalarEvolution::computeBlockDisposition(SCEVUse S, const BasicBlock *BB) { switch (S->getSCEVType()) { case scConstant: case scVScale: @@ -14105,7 +14180,7 @@ ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) { case scSMinExpr: case scSequentialUMinExpr: { bool Proper = true; - for (const SCEV *NAryOp : S->operands()) { + for (SCEVUse NAryOp : S->operands()) { BlockDisposition D = getBlockDisposition(NAryOp, BB); if (D == DoesNotDominateBlock) return DoesNotDominateBlock; @@ -14130,16 +14205,16 @@ ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) { llvm_unreachable("Unknown SCEV kind!"); } -bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) { +bool ScalarEvolution::dominates(SCEVUse S, const BasicBlock *BB) { return getBlockDisposition(S, BB) >= DominatesBlock; } -bool ScalarEvolution::properlyDominates(const SCEV *S, const BasicBlock *BB) { +bool ScalarEvolution::properlyDominates(SCEVUse S, const BasicBlock *BB) { return getBlockDisposition(S, BB) == ProperlyDominatesBlock; } -bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const { - return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; }); +bool ScalarEvolution::hasOperand(SCEVUse S, SCEVUse Op) const { + return SCEVExprContains(S, [&](SCEVUse Expr) { return Expr == Op; }); } void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L, @@ -14149,7 +14224,7 @@ void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L, auto It = BECounts.find(L); if (It != BECounts.end()) { for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) { - for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) { + for (SCEVUse S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) { if (!isa(S)) { auto UserIt = BECountUsers.find(S); assert(UserIt != BECountUsers.end()); @@ -14161,25 +14236,25 @@ void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L, } } -void ScalarEvolution::forgetMemoizedResults(ArrayRef SCEVs) { - SmallPtrSet ToForget(SCEVs.begin(), SCEVs.end()); - SmallVector Worklist(ToForget.begin(), ToForget.end()); +void ScalarEvolution::forgetMemoizedResults(ArrayRef SCEVs) { + SmallPtrSet ToForget(SCEVs.begin(), SCEVs.end()); + SmallVector Worklist(ToForget.begin(), ToForget.end()); while (!Worklist.empty()) { - const SCEV *Curr = Worklist.pop_back_val(); + SCEVUse Curr = Worklist.pop_back_val(); auto Users = SCEVUsers.find(Curr); if (Users != SCEVUsers.end()) - for (const auto *User : Users->second) + for (const auto User : Users->second) if (ToForget.insert(User).second) Worklist.push_back(User); } - for (const auto *S : ToForget) + for (const auto S : ToForget) forgetMemoizedResultsImpl(S); for (auto I = PredicatedSCEVRewrites.begin(); I != PredicatedSCEVRewrites.end();) { - std::pair Entry = I->first; + std::pair Entry = I->first; if (ToForget.count(Entry.first)) PredicatedSCEVRewrites.erase(I++); else @@ -14187,7 +14262,7 @@ void ScalarEvolution::forgetMemoizedResults(ArrayRef SCEVs) { } } -void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) { +void ScalarEvolution::forgetMemoizedResultsImpl(SCEVUse S) { LoopDispositions.erase(S); BlockDispositions.erase(S); UnsignedRanges.erase(S); @@ -14242,14 +14317,13 @@ void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) { FoldCacheUser.erase(S); } -void -ScalarEvolution::getUsedLoops(const SCEV *S, - SmallPtrSetImpl &LoopsUsed) { +void ScalarEvolution::getUsedLoops(SCEVUse S, + SmallPtrSetImpl &LoopsUsed) { struct FindUsedLoops { FindUsedLoops(SmallPtrSetImpl &LoopsUsed) : LoopsUsed(LoopsUsed) {} SmallPtrSetImpl &LoopsUsed; - bool follow(const SCEV *S) { + bool follow(SCEVUse S) { if (auto *AR = dyn_cast(S)) LoopsUsed.insert(AR->getLoop()); return true; @@ -14281,8 +14355,8 @@ void ScalarEvolution::getReachableBlocks( } if (auto *Cmp = dyn_cast(Cond)) { - const SCEV *L = getSCEV(Cmp->getOperand(0)); - const SCEV *R = getSCEV(Cmp->getOperand(1)); + SCEVUse L = getSCEV(Cmp->getOperand(0)); + SCEVUse R = getSCEV(Cmp->getOperand(1)); if (isKnownPredicateViaConstantRanges(Cmp->getPredicate(), L, R)) { Worklist.push_back(TrueBB); continue; @@ -14309,15 +14383,15 @@ void ScalarEvolution::verify() const { struct SCEVMapper : public SCEVRewriteVisitor { SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {} - const SCEV *visitConstant(const SCEVConstant *Constant) { + SCEVUse visitConstant(const SCEVConstant *Constant) { return SE.getConstant(Constant->getAPInt()); } - const SCEV *visitUnknown(const SCEVUnknown *Expr) { + SCEVUse visitUnknown(const SCEVUnknown *Expr) { return SE.getUnknown(Expr->getValue()); } - const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { + SCEVUse visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return SE.getCouldNotCompute(); } }; @@ -14326,7 +14400,7 @@ void ScalarEvolution::verify() const { SmallPtrSet ReachableBlocks; SE2.getReachableBlocks(ReachableBlocks, F); - auto GetDelta = [&](const SCEV *Old, const SCEV *New) -> const SCEV * { + auto GetDelta = [&](SCEVUse Old, SCEVUse New) -> SCEVUse { if (containsUndefs(Old) || containsUndefs(New)) { // SCEV treats "undef" as an unknown but consistent value (i.e. it does // not propagate undef aggressively). This means we can (and do) fail @@ -14337,7 +14411,7 @@ void ScalarEvolution::verify() const { } // Unless VerifySCEVStrict is set, we only compare constant deltas. - const SCEV *Delta = SE2.getMinusSCEV(Old, New); + SCEVUse Delta = SE2.getMinusSCEV(Old, New); if (!VerifySCEVStrict && !isa(Delta)) return nullptr; @@ -14359,9 +14433,9 @@ void ScalarEvolution::verify() const { if (It == BackedgeTakenCounts.end()) continue; - auto *CurBECount = + auto CurBECount = SCM.visit(It->second.getExact(L, const_cast(this))); - auto *NewBECount = SE2.getBackedgeTakenCount(L); + auto NewBECount = SE2.getBackedgeTakenCount(L); if (CurBECount == SE2.getCouldNotCompute() || NewBECount == SE2.getCouldNotCompute()) { @@ -14380,7 +14454,7 @@ void ScalarEvolution::verify() const { SE.getTypeSizeInBits(NewBECount->getType())) CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType()); - const SCEV *Delta = GetDelta(CurBECount, NewBECount); + SCEVUse Delta = GetDelta(CurBECount, NewBECount); if (Delta && !Delta->isZero()) { dbgs() << "Trip Count for " << *L << " Changed!\n"; dbgs() << "Old: " << *CurBECount << "\n"; @@ -14418,9 +14492,9 @@ void ScalarEvolution::verify() const { if (auto *I = dyn_cast(&*KV.first)) { if (!ReachableBlocks.contains(I->getParent())) continue; - const SCEV *OldSCEV = SCM.visit(KV.second); - const SCEV *NewSCEV = SE2.getSCEV(I); - const SCEV *Delta = GetDelta(OldSCEV, NewSCEV); + SCEVUse OldSCEV = SCM.visit(KV.second); + SCEVUse NewSCEV = SE2.getSCEV(I); + SCEVUse Delta = GetDelta(OldSCEV, NewSCEV); if (Delta && !Delta->isZero()) { dbgs() << "SCEV for value " << *I << " changed!\n" << "Old: " << *OldSCEV << "\n" @@ -14449,7 +14523,7 @@ void ScalarEvolution::verify() const { // Verify integrity of SCEV users. for (const auto &S : UniqueSCEVs) { - for (const auto *Op : S.operands()) { + for (const auto Op : S.operands()) { // We do not store dependencies of constants. if (isa(Op)) continue; @@ -14464,10 +14538,10 @@ void ScalarEvolution::verify() const { // Verify integrity of ValuesAtScopes users. for (const auto &ValueAndVec : ValuesAtScopes) { - const SCEV *Value = ValueAndVec.first; + SCEVUse Value = ValueAndVec.first; for (const auto &LoopAndValueAtScope : ValueAndVec.second) { const Loop *L = LoopAndValueAtScope.first; - const SCEV *ValueAtScope = LoopAndValueAtScope.second; + SCEVUse ValueAtScope = LoopAndValueAtScope.second; if (!isa(ValueAtScope)) { auto It = ValuesAtScopesUsers.find(ValueAtScope); if (It != ValuesAtScopesUsers.end() && @@ -14481,10 +14555,10 @@ void ScalarEvolution::verify() const { } for (const auto &ValueAtScopeAndVec : ValuesAtScopesUsers) { - const SCEV *ValueAtScope = ValueAtScopeAndVec.first; + SCEVUse ValueAtScope = ValueAtScopeAndVec.first; for (const auto &LoopAndValue : ValueAtScopeAndVec.second) { const Loop *L = LoopAndValue.first; - const SCEV *Value = LoopAndValue.second; + SCEVUse Value = LoopAndValue.second; assert(!isa(Value)); auto It = ValuesAtScopes.find(Value); if (It != ValuesAtScopes.end() && @@ -14502,7 +14576,7 @@ void ScalarEvolution::verify() const { Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts; for (const auto &LoopAndBEInfo : BECounts) { for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) { - for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) { + for (SCEVUse S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) { if (!isa(S)) { auto UserIt = BECountUsers.find(S); if (UserIt != BECountUsers.end() && @@ -14677,22 +14751,22 @@ void ScalarEvolutionWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.addRequiredTransitive(); } -const SCEVPredicate *ScalarEvolution::getEqualPredicate(const SCEV *LHS, - const SCEV *RHS) { +const SCEVPredicate *ScalarEvolution::getEqualPredicate(SCEVUse LHS, + SCEVUse RHS) { return getComparePredicate(ICmpInst::ICMP_EQ, LHS, RHS); } const SCEVPredicate * ScalarEvolution::getComparePredicate(const ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS) { + SCEVUse LHS, SCEVUse RHS) { FoldingSetNodeID ID; assert(LHS->getType() == RHS->getType() && "Type mismatch between LHS and RHS"); // Unique this node based on the arguments ID.AddInteger(SCEVPredicate::P_Compare); ID.AddInteger(Pred); - ID.AddPointer(LHS); - ID.AddPointer(RHS); + ID.AddPointer(LHS.getRawPointer()); + ID.AddPointer(RHS.getRawPointer()); void *IP = nullptr; if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP)) return S; @@ -14708,6 +14782,7 @@ const SCEVPredicate *ScalarEvolution::getWrapPredicate( FoldingSetNodeID ID; // Unique this node based on the arguments ID.AddInteger(SCEVPredicate::P_Wrap); + // TODO: Use SCEVUse ID.AddPointer(AR); ID.AddInteger(AddedFlags); void *IP = nullptr; @@ -14732,14 +14807,14 @@ class SCEVPredicateRewriter : public SCEVRewriteVisitor { /// /// If \p NewPreds is non-null, rewrite is free to add further predicates to /// \p NewPreds such that the result will be an AddRecExpr. - static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE, - SmallVectorImpl *NewPreds, - const SCEVPredicate *Pred) { + static SCEVUse rewrite(SCEVUse S, const Loop *L, ScalarEvolution &SE, + SmallVectorImpl *NewPreds, + const SCEVPredicate *Pred) { SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred); return Rewriter.visit(S); } - const SCEV *visitUnknown(const SCEVUnknown *Expr) { + SCEVUse visitUnknown(const SCEVUnknown *Expr) { if (Pred) { if (auto *U = dyn_cast(Pred)) { for (const auto *Pred : U->getPredicates()) @@ -14756,13 +14831,13 @@ class SCEVPredicateRewriter : public SCEVRewriteVisitor { return convertToAddRecWithPreds(Expr); } - const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { - const SCEV *Operand = visit(Expr->getOperand()); + SCEVUse visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { + SCEVUse Operand = visit(Expr->getOperand()); const SCEVAddRecExpr *AR = dyn_cast(Operand); if (AR && AR->getLoop() == L && AR->isAffine()) { // This couldn't be folded because the operand didn't have the nuw // flag. Add the nusw flag as an assumption that we could make. - const SCEV *Step = AR->getStepRecurrence(SE); + SCEVUse Step = AR->getStepRecurrence(SE); Type *Ty = Expr->getType(); if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW)) return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty), @@ -14772,13 +14847,13 @@ class SCEVPredicateRewriter : public SCEVRewriteVisitor { return SE.getZeroExtendExpr(Operand, Expr->getType()); } - const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { - const SCEV *Operand = visit(Expr->getOperand()); + SCEVUse visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { + SCEVUse Operand = visit(Expr->getOperand()); const SCEVAddRecExpr *AR = dyn_cast(Operand); if (AR && AR->getLoop() == L && AR->isAffine()) { // This couldn't be folded because the operand didn't have the nsw // flag. Add the nssw flag as an assumption that we could make. - const SCEV *Step = AR->getStepRecurrence(SE); + SCEVUse Step = AR->getStepRecurrence(SE); Type *Ty = Expr->getType(); if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW)) return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty), @@ -14816,11 +14891,10 @@ class SCEVPredicateRewriter : public SCEVRewriteVisitor { // If \p Expr does not meet these conditions (is not a PHI node, or we // couldn't create an AddRec for it, or couldn't add the predicate), we just // return \p Expr. - const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) { + SCEVUse convertToAddRecWithPreds(const SCEVUnknown *Expr) { if (!isa(Expr->getValue())) return Expr; - std::optional< - std::pair>> + std::optional>> PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr); if (!PredicatedRewrite) return Expr; @@ -14843,15 +14917,13 @@ class SCEVPredicateRewriter : public SCEVRewriteVisitor { } // end anonymous namespace -const SCEV * -ScalarEvolution::rewriteUsingPredicate(const SCEV *S, const Loop *L, - const SCEVPredicate &Preds) { +SCEVUse ScalarEvolution::rewriteUsingPredicate(SCEVUse S, const Loop *L, + const SCEVPredicate &Preds) { return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds); } const SCEVAddRecExpr *ScalarEvolution::convertSCEVToAddRecWithPredicates( - const SCEV *S, const Loop *L, - SmallVectorImpl &Preds) { + SCEVUse S, const Loop *L, SmallVectorImpl &Preds) { SmallVector TransformPreds; S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr); auto *AddRec = dyn_cast(S); @@ -14872,9 +14944,9 @@ SCEVPredicate::SCEVPredicate(const FoldingSetNodeIDRef ID, : FastID(ID), Kind(Kind) {} SCEVComparePredicate::SCEVComparePredicate(const FoldingSetNodeIDRef ID, - const ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS) - : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) { + const ICmpInst::Predicate Pred, + SCEVUse LHS, SCEVUse RHS) + : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) { assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match"); assert(LHS != RHS && "LHS and RHS are the same SCEV"); } @@ -15000,9 +15072,8 @@ PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE, Preds = std::make_unique(Empty); } -void ScalarEvolution::registerUser(const SCEV *User, - ArrayRef Ops) { - for (const auto *Op : Ops) +void ScalarEvolution::registerUser(SCEVUse User, ArrayRef Ops) { + for (const auto Op : Ops) // We do not expect that forgetting cached data for SCEVConstants will ever // open any prospects for sharpening or introduce any correctness issues, // so we don't bother storing their dependencies. @@ -15010,8 +15081,8 @@ void ScalarEvolution::registerUser(const SCEV *User, SCEVUsers[Op].insert(User); } -const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) { - const SCEV *Expr = SE.getSCEV(V); +SCEVUse PredicatedScalarEvolution::getSCEV(Value *V) { + SCEVUse Expr = SE.getSCEV(V); RewriteEntry &Entry = RewriteMap[Expr]; // If we already have an entry and the version matches, return it. @@ -15023,13 +15094,13 @@ const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) { if (Entry.second) Expr = Entry.second; - const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds); + SCEVUse NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds); Entry = {Generation, NewSCEV}; return NewSCEV; } -const SCEV *PredicatedScalarEvolution::getBackedgeTakenCount() { +SCEVUse PredicatedScalarEvolution::getBackedgeTakenCount() { if (!BackedgeCount) { SmallVector Preds; BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, Preds); @@ -15068,7 +15139,7 @@ void PredicatedScalarEvolution::updateGeneration() { // If the generation number wrapped recompute everything. if (++Generation == 0) { for (auto &II : RewriteMap) { - const SCEV *Rewritten = II.second.second; + SCEVUse Rewritten = II.second.second; II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, *Preds)}; } } @@ -15076,7 +15147,7 @@ void PredicatedScalarEvolution::updateGeneration() { void PredicatedScalarEvolution::setNoOverflow( Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) { - const SCEV *Expr = getSCEV(V); + SCEVUse Expr = getSCEV(V); const auto *AR = cast(Expr); auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE); @@ -15092,7 +15163,7 @@ void PredicatedScalarEvolution::setNoOverflow( bool PredicatedScalarEvolution::hasNoOverflow( Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) { - const SCEV *Expr = getSCEV(V); + SCEVUse Expr = getSCEV(V); const auto *AR = cast(Expr); Flags = SCEVWrapPredicate::clearFlags( @@ -15107,7 +15178,7 @@ bool PredicatedScalarEvolution::hasNoOverflow( } const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) { - const SCEV *Expr = this->getSCEV(V); + SCEVUse Expr = this->getSCEV(V); SmallVector NewPreds; auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds); @@ -15137,7 +15208,7 @@ void PredicatedScalarEvolution::print(raw_ostream &OS, unsigned Depth) const { if (!SE.isSCEVable(I.getType())) continue; - auto *Expr = SE.getSCEV(&I); + auto Expr = SE.getSCEV(&I); auto II = RewriteMap.find(Expr); if (II == RewriteMap.end()) @@ -15158,8 +15229,7 @@ void PredicatedScalarEvolution::print(raw_ostream &OS, unsigned Depth) const { // for URem with constant power-of-2 second operands. // It's not always easy, as A and B can be folded (imagine A is X / 2, and B is // 4, A / B becomes X / 8). -bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS, - const SCEV *&RHS) { +bool ScalarEvolution::matchURem(SCEVUse Expr, SCEVUse &LHS, SCEVUse &RHS) { if (Expr->getType()->isPointerTy()) return false; @@ -15184,13 +15254,13 @@ bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS, if (Add == nullptr || Add->getNumOperands() != 2) return false; - const SCEV *A = Add->getOperand(1); + SCEVUse A = Add->getOperand(1); const auto *Mul = dyn_cast(Add->getOperand(0)); if (Mul == nullptr) return false; - const auto MatchURemWithDivisor = [&](const SCEV *B) { + const auto MatchURemWithDivisor = [&](SCEVUse B) { // (SomeExpr + (-(SomeExpr / B) * B)). if (Expr == getURemExpr(A, B)) { LHS = A; @@ -15218,10 +15288,9 @@ ScalarEvolution::LoopGuards ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) { LoopGuards Guards(SE); SmallVector ExprsToRewrite; - auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS, - const SCEV *RHS, - DenseMap - &RewriteMap) { + auto CollectCondition = [&](ICmpInst::Predicate Predicate, SCEVUse LHS, + SCEVUse RHS, + DenseMap &RewriteMap) { // WARNING: It is generally unsound to apply any wrap flags to the proposed // replacement SCEV which isn't directly implied by the structure of that // SCEV. In particular, using contextual facts to imply flags is *NOT* @@ -15256,7 +15325,7 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) { if (ExactRegion.isWrappedSet() || ExactRegion.isFullSet()) return false; auto I = RewriteMap.find(LHSUnknown); - const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : LHSUnknown; + SCEVUse RewrittenLHS = I != RewriteMap.end() ? I->second : LHSUnknown; RewriteMap[LHSUnknown] = SE.getUMaxExpr( SE.getConstant(ExactRegion.getUnsignedMin()), SE.getUMinExpr(RewrittenLHS, @@ -15271,8 +15340,7 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) { // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS // the non-constant operand and in \p LHS the constant operand. auto IsMinMaxSCEVWithNonNegativeConstant = - [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS, - const SCEV *&RHS) { + [&](SCEVUse Expr, SCEVTypes &SCTy, SCEVUse &LHS, SCEVUse &RHS) { if (auto *MinMax = dyn_cast(Expr)) { if (MinMax->getNumOperands() != 2) return false; @@ -15290,7 +15358,7 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) { // Checks whether Expr is a non-negative constant, and Divisor is a positive // constant, and returns their APInt in ExprVal and in DivisorVal. - auto GetNonNegExprAndPosDivisor = [&](const SCEV *Expr, const SCEV *Divisor, + auto GetNonNegExprAndPosDivisor = [&](SCEVUse Expr, SCEVUse Divisor, APInt &ExprVal, APInt &DivisorVal) { auto *ConstExpr = dyn_cast(Expr); auto *ConstDivisor = dyn_cast(Divisor); @@ -15304,8 +15372,7 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) { // Return a new SCEV that modifies \p Expr to the closest number divides by // \p Divisor and greater or equal than Expr. // For now, only handle constant Expr and Divisor. - auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr, - const SCEV *Divisor) { + auto GetNextSCEVDividesByDivisor = [&](SCEVUse Expr, SCEVUse Divisor) { APInt ExprVal; APInt DivisorVal; if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal)) @@ -15320,8 +15387,7 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) { // Return a new SCEV that modifies \p Expr to the closest number divides by // \p Divisor and less or equal than Expr. // For now, only handle constant Expr and Divisor. - auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr, - const SCEV *Divisor) { + auto GetPreviousSCEVDividesByDivisor = [&](SCEVUse Expr, SCEVUse Divisor) { APInt ExprVal; APInt DivisorVal; if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal)) @@ -15334,10 +15400,9 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) { // Apply divisibilty by \p Divisor on MinMaxExpr with constant values, // recursively. This is done by aligning up/down the constant value to the // Divisor. - std::function - ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr, - const SCEV *Divisor) { - const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr; + std::function ApplyDivisibiltyOnMinMaxExpr = + [&](SCEVUse MinMaxExpr, SCEVUse Divisor) { + SCEVUse MinMaxLHS = nullptr, MinMaxRHS = nullptr; SCEVTypes SCTy; if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS, MinMaxRHS)) @@ -15346,10 +15411,10 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) { isa(MinMaxExpr) || isa(MinMaxExpr); assert(SE.isKnownNonNegative(MinMaxLHS) && "Expected non-negative operand!"); - auto *DivisibleExpr = + auto DivisibleExpr = IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, Divisor) : GetNextSCEVDividesByDivisor(MinMaxLHS, Divisor); - SmallVector Ops = { + SmallVector Ops = { ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr}; return SE.getMinMaxExpr(SCTy, Ops); }; @@ -15361,15 +15426,14 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) { RHSC->getValue()->isNullValue()) { // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to // explicitly express that. - const SCEV *URemLHS = nullptr; - const SCEV *URemRHS = nullptr; + SCEVUse URemLHS = nullptr; + SCEVUse URemRHS = nullptr; if (SE.matchURem(LHS, URemLHS, URemRHS)) { if (const SCEVUnknown *LHSUnknown = dyn_cast(URemLHS)) { auto I = RewriteMap.find(LHSUnknown); - const SCEV *RewrittenLHS = - I != RewriteMap.end() ? I->second : LHSUnknown; + SCEVUse RewrittenLHS = I != RewriteMap.end() ? I->second : LHSUnknown; RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS); - const auto *Multiple = + const auto Multiple = SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS); RewriteMap[LHSUnknown] = Multiple; ExprsToRewrite.push_back(LHSUnknown); @@ -15392,8 +15456,7 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) { // and \p FromRewritten are the same (i.e. there has been no rewrite // registered for \p From), then puts this value in the list of rewritten // expressions. - auto AddRewrite = [&](const SCEV *From, const SCEV *FromRewritten, - const SCEV *To) { + auto AddRewrite = [&](SCEVUse From, SCEVUse FromRewritten, SCEVUse To) { if (From == FromRewritten) ExprsToRewrite.push_back(From); RewriteMap[From] = To; @@ -15402,7 +15465,7 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) { // Checks whether \p S has already been rewritten. In that case returns the // existing rewrite because we want to chain further rewrites onto the // already rewritten value. Otherwise returns \p S. - auto GetMaybeRewritten = [&](const SCEV *S) { + auto GetMaybeRewritten = [&](SCEVUse S) { auto I = RewriteMap.find(S); return I != RewriteMap.end() ? I->second : S; }; @@ -15414,13 +15477,13 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) { // example, if Expr = umin (umax ((A /u 8) * 8, 16), 64), return true since // (A /u 8) * 8 matched the pattern, and return the constant SCEV 8 in \p // DividesBy. - std::function HasDivisibiltyInfo = - [&](const SCEV *Expr, const SCEV *&DividesBy) { + std::function HasDivisibiltyInfo = + [&](SCEVUse Expr, SCEVUse &DividesBy) { if (auto *Mul = dyn_cast(Expr)) { if (Mul->getNumOperands() != 2) return false; - auto *MulLHS = Mul->getOperand(0); - auto *MulRHS = Mul->getOperand(1); + auto MulLHS = Mul->getOperand(0); + auto MulRHS = Mul->getOperand(1); if (isa(MulLHS)) std::swap(MulLHS, MulRHS); if (auto *Div = dyn_cast(MulLHS)) @@ -15436,8 +15499,8 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) { }; // Return true if Expr known to divide by \p DividesBy. - std::function IsKnownToDivideBy = - [&](const SCEV *Expr, const SCEV *DividesBy) { + std::function IsKnownToDivideBy = + [&](SCEVUse Expr, SCEVUse DividesBy) { if (SE.getURemExpr(Expr, DividesBy)->isZero()) return true; if (auto *MinMax = dyn_cast(Expr)) @@ -15446,8 +15509,8 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) { return false; }; - const SCEV *RewrittenLHS = GetMaybeRewritten(LHS); - const SCEV *DividesBy = nullptr; + SCEVUse RewrittenLHS = GetMaybeRewritten(LHS); + SCEVUse DividesBy = nullptr; if (HasDivisibiltyInfo(RewrittenLHS, DividesBy)) // Check that the whole expression is divided by DividesBy DividesBy = @@ -15464,50 +15527,50 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) { // We cannot express strict predicates in SCEV, so instead we replace them // with non-strict ones against plus or minus one of RHS depending on the // predicate. - const SCEV *One = SE.getOne(RHS->getType()); + SCEVUse One = SE.getOne(RHS->getType()); switch (Predicate) { - case CmpInst::ICMP_ULT: - if (RHS->getType()->isPointerTy()) - return; - RHS = SE.getUMaxExpr(RHS, One); - [[fallthrough]]; - case CmpInst::ICMP_SLT: { - RHS = SE.getMinusSCEV(RHS, One); - RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS; - break; - } - case CmpInst::ICMP_UGT: - case CmpInst::ICMP_SGT: - RHS = SE.getAddExpr(RHS, One); - RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS; - break; - case CmpInst::ICMP_ULE: - case CmpInst::ICMP_SLE: - RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS; - break; - case CmpInst::ICMP_UGE: - case CmpInst::ICMP_SGE: - RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS; - break; - default: - break; + case CmpInst::ICMP_ULT: + if (RHS->getType()->isPointerTy()) + return; + RHS = SE.getUMaxExpr(RHS, One); + [[fallthrough]]; + case CmpInst::ICMP_SLT: { + RHS = SE.getMinusSCEV(RHS, One); + RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS; + break; + } + case CmpInst::ICMP_UGT: + case CmpInst::ICMP_SGT: + RHS = SE.getAddExpr(RHS, One); + RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS; + break; + case CmpInst::ICMP_ULE: + case CmpInst::ICMP_SLE: + RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS; + break; + case CmpInst::ICMP_UGE: + case CmpInst::ICMP_SGE: + RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS; + break; + default: + break; } - SmallVector Worklist(1, LHS); - SmallPtrSet Visited; + SmallVector Worklist(1, LHS); + SmallPtrSet Visited; auto EnqueueOperands = [&Worklist](const SCEVNAryExpr *S) { append_range(Worklist, S->operands()); }; while (!Worklist.empty()) { - const SCEV *From = Worklist.pop_back_val(); + SCEVUse From = Worklist.pop_back_val(); if (isa(From)) continue; if (!Visited.insert(From).second) continue; - const SCEV *FromRewritten = GetMaybeRewritten(From); - const SCEV *To = nullptr; + SCEVUse FromRewritten = GetMaybeRewritten(From); + SCEVUse To = nullptr; switch (Predicate) { case CmpInst::ICMP_ULT: @@ -15541,7 +15604,7 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) { case CmpInst::ICMP_NE: if (isa(RHS) && cast(RHS)->getValue()->isNullValue()) { - const SCEV *OneAlignedUp = + SCEVUse OneAlignedUp = DividesBy ? GetNextSCEVDividesByDivisor(One, DividesBy) : One; To = SE.getUMaxExpr(FromRewritten, OneAlignedUp); } @@ -15612,8 +15675,8 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) { if (auto *Cmp = dyn_cast(Cond)) { auto Predicate = EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate(); - const auto *LHS = SE.getSCEV(Cmp->getOperand(0)); - const auto *RHS = SE.getSCEV(Cmp->getOperand(1)); + const auto LHS = SE.getSCEV(Cmp->getOperand(0)); + const auto RHS = SE.getSCEV(Cmp->getOperand(1)); CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap); continue; } @@ -15645,7 +15708,7 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) { // sub-expressions. if (ExprsToRewrite.size() > 1) { for (const SCEV *Expr : ExprsToRewrite) { - const SCEV *RewriteTo = Guards.RewriteMap[Expr]; + SCEVUse RewriteTo = Guards.RewriteMap[Expr]; Guards.RewriteMap.erase(Expr); Guards.RewriteMap.insert({Expr, Guards.rewrite(RewriteTo)}); } @@ -15659,7 +15722,7 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const { /// replacement is loop invariant in the loop of the AddRec. class SCEVLoopGuardRewriter : public SCEVRewriteVisitor { - const DenseMap ⤅ + const DenseMap ⤅ SCEV::NoWrapFlags FlagMask = SCEV::FlagAnyWrap; @@ -15688,12 +15751,12 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const { // If we didn't find the extact ZExt expr in the map, check if there's // an entry for a smaller ZExt we can use instead. Type *Ty = Expr->getType(); - const SCEV *Op = Expr->getOperand(0); + SCEVUse Op = Expr->getOperand(0); unsigned Bitwidth = Ty->getScalarSizeInBits() / 2; while (Bitwidth % 8 == 0 && Bitwidth >= 8 && Bitwidth > Op->getType()->getScalarSizeInBits()) { Type *NarrowTy = IntegerType::get(SE.getContext(), Bitwidth); - auto *NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy); + auto NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy); auto I = Map.find(NarrowExt); if (I != Map.end()) return SE.getZeroExtendExpr(I->second, Ty); @@ -15731,7 +15794,7 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const { const SCEV *visitAddExpr(const SCEVAddExpr *Expr) { SmallVector Operands; bool Changed = false; - for (const auto *Op : Expr->operands()) { + for (const auto Op : Expr->operands()) { Operands.push_back( SCEVRewriteVisitor::visit(Op)); Changed |= Op != Operands.back(); @@ -15747,7 +15810,7 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const { const SCEV *visitMulExpr(const SCEVMulExpr *Expr) { SmallVector Operands; bool Changed = false; - for (const auto *Op : Expr->operands()) { + for (const auto Op : Expr->operands()) { Operands.push_back( SCEVRewriteVisitor::visit(Op)); Changed |= Op != Operands.back(); @@ -15768,11 +15831,11 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const { return Rewriter.visit(Expr); } -const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) { +SCEVUse ScalarEvolution::applyLoopGuards(SCEVUse Expr, const Loop *L) { return applyLoopGuards(Expr, LoopGuards::collect(L, *this)); } -const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, - const LoopGuards &Guards) { +SCEVUse ScalarEvolution::applyLoopGuards(SCEVUse Expr, + const LoopGuards &Guards) { return Guards.rewrite(Expr); } diff --git a/llvm/lib/Target/PowerPC/PPCLoopInstrFormPrep.cpp b/llvm/lib/Target/PowerPC/PPCLoopInstrFormPrep.cpp index 53a79db5843a9..3004931ad9716 100644 --- a/llvm/lib/Target/PowerPC/PPCLoopInstrFormPrep.cpp +++ b/llvm/lib/Target/PowerPC/PPCLoopInstrFormPrep.cpp @@ -565,6 +565,7 @@ bool PPCLoopInstrFormPrep::rewriteLoadStoresForCommoningChains( const SCEV *BaseSCEV = ChainIdx ? SE->getAddExpr(Bucket.BaseSCEV, Bucket.Elements[BaseElemIdx].Offset) + .getPointer() : Bucket.BaseSCEV; const SCEVAddRecExpr *BasePtrSCEV = cast(BaseSCEV); @@ -598,6 +599,7 @@ bool PPCLoopInstrFormPrep::rewriteLoadStoresForCommoningChains( const SCEV *OffsetSCEV = BaseElemIdx ? SE->getMinusSCEV(Bucket.Elements[Idx].Offset, Bucket.Elements[BaseElemIdx].Offset) + .getPointer() : Bucket.Elements[Idx].Offset; // Make sure offset is able to expand. Only need to check one time as the diff --git a/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp b/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp index 104e8ceb79670..78375d14d204f 100644 --- a/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp +++ b/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp @@ -430,8 +430,8 @@ bool InductiveRangeCheck::reassociateSubLHS( auto getExprScaledIfOverflow = [&](Instruction::BinaryOps BinOp, const SCEV *LHS, const SCEV *RHS) -> const SCEV * { - const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *, - SCEV::NoWrapFlags, unsigned); + SCEVUse (ScalarEvolution::*Operation)(SCEVUse, SCEVUse, SCEV::NoWrapFlags, + unsigned); switch (BinOp) { default: llvm_unreachable("Unsupported binary op"); @@ -750,7 +750,7 @@ InductiveRangeCheck::computeSafeIterationSpace(ScalarEvolution &SE, const SCEV *Zero = SE.getZero(M->getType()); // This function returns SCEV equal to 1 if X is non-negative 0 otherwise. - auto SCEVCheckNonNegative = [&](const SCEV *X) { + auto SCEVCheckNonNegative = [&](const SCEV *X) -> const SCEV * { const Loop *L = IndVar->getLoop(); const SCEV *Zero = SE.getZero(X->getType()); const SCEV *One = SE.getOne(X->getType()); diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp index 740e1e39b9ee7..79f84fa188741 100644 --- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -846,8 +846,8 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI, return false; } - const SCEV *PointerStrideSCEV = Ev->getOperand(1); - const SCEV *MemsetSizeSCEV = SE->getSCEV(MSI->getLength()); + SCEVUse PointerStrideSCEV = Ev->getOperand(1); + SCEVUse MemsetSizeSCEV = SE->getSCEV(MSI->getLength()); if (!PointerStrideSCEV || !MemsetSizeSCEV) return false; @@ -889,9 +889,9 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI, // Compare positive direction PointerStrideSCEV with MemsetSizeSCEV IsNegStride = PointerStrideSCEV->isNonConstantNegative(); - const SCEV *PositiveStrideSCEV = - IsNegStride ? SE->getNegativeSCEV(PointerStrideSCEV) - : PointerStrideSCEV; + SCEVUse PositiveStrideSCEV = IsNegStride + ? SE->getNegativeSCEV(PointerStrideSCEV) + : PointerStrideSCEV; LLVM_DEBUG(dbgs() << " MemsetSizeSCEV: " << *MemsetSizeSCEV << "\n" << " PositiveStrideSCEV: " << *PositiveStrideSCEV << "\n"); diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp index 575395eda1c5b..75dcc6ab922f9 100644 --- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -3442,8 +3442,9 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain, // be signed. const SCEV *IncExpr = SE.getNoopOrSignExtend(Inc.IncExpr, IntTy); Accum = SE.getAddExpr(Accum, IncExpr); - LeftOverExpr = LeftOverExpr ? - SE.getAddExpr(LeftOverExpr, IncExpr) : IncExpr; + LeftOverExpr = LeftOverExpr + ? SE.getAddExpr(LeftOverExpr, IncExpr).getPointer() + : IncExpr; } // Look through each base to see if any can produce a nice addressing mode. @@ -3844,7 +3845,7 @@ static const SCEV *CollectSubexprs(const SCEV *S, const SCEVConstant *C, for (const SCEV *S : Add->operands()) { const SCEV *Remainder = CollectSubexprs(S, C, Ops, L, SE, Depth+1); if (Remainder) - Ops.push_back(C ? SE.getMulExpr(C, Remainder) : Remainder); + Ops.push_back(C ? SE.getMulExpr(C, Remainder).getPointer() : Remainder); } return nullptr; } else if (const SCEVAddRecExpr *AR = dyn_cast(S)) { @@ -3857,7 +3858,7 @@ static const SCEV *CollectSubexprs(const SCEV *S, const SCEVConstant *C, // Split the non-zero AddRec unless it is part of a nested recurrence that // does not pertain to this loop. if (Remainder && (AR->getLoop() == L || !isa(Remainder))) { - Ops.push_back(C ? SE.getMulExpr(C, Remainder) : Remainder); + Ops.push_back(C ? SE.getMulExpr(C, Remainder).getPointer() : Remainder); Remainder = nullptr; } if (Remainder != AR->getStart()) { diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp index 1ff3cd78aa987..90c9e68558c14 100644 --- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp +++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp @@ -504,8 +504,8 @@ class LoopCompare { Value *SCEVExpander::visitAddExpr(const SCEVAddExpr *S) { // Recognize the canonical representation of an unsimplifed urem. - const SCEV *URemLHS = nullptr; - const SCEV *URemRHS = nullptr; + SCEVUse URemLHS = nullptr; + SCEVUse URemRHS = nullptr; if (SE.matchURem(S, URemLHS, URemRHS)) { Value *LHS = expand(URemLHS); Value *RHS = expand(URemRHS); diff --git a/llvm/test/Analysis/ScalarEvolution/min-max-exprs.ll b/llvm/test/Analysis/ScalarEvolution/min-max-exprs.ll index 6ededf2477711..a80c0f4ebc5ea 100644 --- a/llvm/test/Analysis/ScalarEvolution/min-max-exprs.ll +++ b/llvm/test/Analysis/ScalarEvolution/min-max-exprs.ll @@ -42,7 +42,7 @@ define void @f(ptr %A, i32 %N) { ; CHECK-NEXT: %tmp19 = select i1 %tmp14, i64 0, i64 %tmp17 ; CHECK-NEXT: --> (-3 + (3 smax {0,+,1}<%bb1>)) U: [0,2147483645) S: [0,2147483645) Exits: (-3 + (3 smax (zext i32 (0 smax %N) to i64))) LoopDispositions: { %bb1: Computable } ; CHECK-NEXT: %tmp21 = getelementptr inbounds i32, ptr %A, i64 %tmp19 -; CHECK-NEXT: --> (-12 + (4 * (3 smax {0,+,1}<%bb1>)) + %A) U: full-set S: full-set Exits: (-12 + (4 * (3 smax (zext i32 (0 smax %N) to i64))) + %A) LoopDispositions: { %bb1: Computable } +; CHECK-NEXT: --> (-12 + (4 * (3 smax {0,+,1}<%bb1>)) + %A)(u nuw) U: full-set S: full-set Exits: (-12 + (4 * (3 smax (zext i32 (0 smax %N) to i64))) + %A) LoopDispositions: { %bb1: Computable } ; CHECK-NEXT: %tmp23 = add nuw nsw i32 %i.0, 1 ; CHECK-NEXT: --> {1,+,1}<%bb1> U: [1,-2147483647) S: [1,-2147483647) Exits: (1 + (0 smax %N)) LoopDispositions: { %bb1: Computable } ; CHECK-NEXT: Determining loop execution counts for: @f diff --git a/llvm/test/Analysis/ScalarEvolution/no-wrap-add-exprs.ll b/llvm/test/Analysis/ScalarEvolution/no-wrap-add-exprs.ll index b096adc7c5eb4..8a8b3e2c4dbe8 100644 --- a/llvm/test/Analysis/ScalarEvolution/no-wrap-add-exprs.ll +++ b/llvm/test/Analysis/ScalarEvolution/no-wrap-add-exprs.ll @@ -209,11 +209,11 @@ define void @f3(ptr %x_addr, ptr %y_addr, ptr %tmp_addr) { ; CHECK-NEXT: %sunkaddr3 = mul i64 %add4.zext, 4 ; CHECK-NEXT: --> (4 * (zext i32 (4 + (4 * (%tmp /u 4))) to i64)) U: [0,17179869169) S: [0,17179869181) ; CHECK-NEXT: %sunkaddr4 = getelementptr inbounds i8, ptr @tmp_addr, i64 %sunkaddr3 -; CHECK-NEXT: --> ((4 * (zext i32 (4 + (4 * (%tmp /u 4))) to i64)) + @tmp_addr) U: [0,-3) S: [-9223372036854775808,9223372036854775805) +; CHECK-NEXT: --> ((4 * (zext i32 (4 + (4 * (%tmp /u 4))) to i64)) + @tmp_addr)(u nuw) U: [0,-3) S: [-9223372036854775808,9223372036854775805) ; CHECK-NEXT: %sunkaddr5 = getelementptr inbounds i8, ptr %sunkaddr4, i64 4096 -; CHECK-NEXT: --> (4096 + (4 * (zext i32 (4 + (4 * (%tmp /u 4))) to i64)) + @tmp_addr) U: [0,-3) S: [-9223372036854775808,9223372036854775805) +; CHECK-NEXT: --> (4096 + (4 * (zext i32 (4 + (4 * (%tmp /u 4))) to i64)) + @tmp_addr)(u nuw) U: [0,-3) S: [-9223372036854775808,9223372036854775805) ; CHECK-NEXT: %addr4.cast = bitcast ptr %sunkaddr5 to ptr -; CHECK-NEXT: --> (4096 + (4 * (zext i32 (4 + (4 * (%tmp /u 4))) to i64)) + @tmp_addr) U: [0,-3) S: [-9223372036854775808,9223372036854775805) +; CHECK-NEXT: --> (4096 + (4 * (zext i32 (4 + (4 * (%tmp /u 4))) to i64)) + @tmp_addr)(u nuw) U: [0,-3) S: [-9223372036854775808,9223372036854775805) ; CHECK-NEXT: %addr4.incr = getelementptr i32, ptr %addr4.cast, i64 1 ; CHECK-NEXT: --> (4100 + (4 * (zext i32 (4 + (4 * (%tmp /u 4))) to i64)) + @tmp_addr) U: [0,-3) S: [-9223372036854775808,9223372036854775805) ; CHECK-NEXT: %add5 = add i32 %mul, 5 @@ -223,11 +223,11 @@ define void @f3(ptr %x_addr, ptr %y_addr, ptr %tmp_addr) { ; CHECK-NEXT: %sunkaddr0 = mul i64 %add5.zext, 4 ; CHECK-NEXT: --> (4 + (4 * (zext i32 (4 + (4 * (%tmp /u 4))) to i64))) U: [4,17179869173) S: [4,17179869185) ; CHECK-NEXT: %sunkaddr1 = getelementptr inbounds i8, ptr @tmp_addr, i64 %sunkaddr0 -; CHECK-NEXT: --> (4 + (4 * (zext i32 (4 + (4 * (%tmp /u 4))) to i64)) + @tmp_addr) U: [0,-3) S: [-9223372036854775808,9223372036854775805) +; CHECK-NEXT: --> (4 + (4 * (zext i32 (4 + (4 * (%tmp /u 4))) to i64)) + @tmp_addr)(u nuw) U: [0,-3) S: [-9223372036854775808,9223372036854775805) ; CHECK-NEXT: %sunkaddr2 = getelementptr inbounds i8, ptr %sunkaddr1, i64 4096 -; CHECK-NEXT: --> (4100 + (4 * (zext i32 (4 + (4 * (%tmp /u 4))) to i64)) + @tmp_addr) U: [0,-3) S: [-9223372036854775808,9223372036854775805) +; CHECK-NEXT: --> (4100 + (4 * (zext i32 (4 + (4 * (%tmp /u 4))) to i64)) + @tmp_addr)(u nuw) U: [0,-3) S: [-9223372036854775808,9223372036854775805) ; CHECK-NEXT: %addr5.cast = bitcast ptr %sunkaddr2 to ptr -; CHECK-NEXT: --> (4100 + (4 * (zext i32 (4 + (4 * (%tmp /u 4))) to i64)) + @tmp_addr) U: [0,-3) S: [-9223372036854775808,9223372036854775805) +; CHECK-NEXT: --> (4100 + (4 * (zext i32 (4 + (4 * (%tmp /u 4))) to i64)) + @tmp_addr)(u nuw) U: [0,-3) S: [-9223372036854775808,9223372036854775805) ; CHECK-NEXT: Determining loop execution counts for: @f3 ; entry: diff --git a/llvm/test/Analysis/ScalarEvolution/no-wrap-symbolic-becount.ll b/llvm/test/Analysis/ScalarEvolution/no-wrap-symbolic-becount.ll index e13a8976bf5ac..20ebcd6158e98 100644 --- a/llvm/test/Analysis/ScalarEvolution/no-wrap-symbolic-becount.ll +++ b/llvm/test/Analysis/ScalarEvolution/no-wrap-symbolic-becount.ll @@ -96,7 +96,7 @@ define void @pointer_iv_nowrap(ptr %startptr, ptr %endptr) local_unnamed_addr { ; CHECK-LABEL: 'pointer_iv_nowrap' ; CHECK-NEXT: Classifying expressions for: @pointer_iv_nowrap ; CHECK-NEXT: %init = getelementptr inbounds i8, ptr %startptr, i64 2000 -; CHECK-NEXT: --> (2000 + %startptr) U: full-set S: full-set +; CHECK-NEXT: --> (2000 + %startptr)(u nuw) U: full-set S: full-set ; CHECK-NEXT: %iv = phi ptr [ %init, %entry ], [ %iv.next, %loop ] ; CHECK-NEXT: --> {(2000 + %startptr),+,1}<%loop> U: full-set S: full-set Exits: <> LoopDispositions: { %loop: Computable } ; CHECK-NEXT: %iv.next = getelementptr inbounds i8, ptr %iv, i64 1 diff --git a/llvm/test/Analysis/ScalarEvolution/ptrtoint.ll b/llvm/test/Analysis/ScalarEvolution/ptrtoint.ll index e784d25385980..72f59682b023a 100644 --- a/llvm/test/Analysis/ScalarEvolution/ptrtoint.ll +++ b/llvm/test/Analysis/ScalarEvolution/ptrtoint.ll @@ -192,7 +192,7 @@ define void @ptrtoint_of_gep(ptr %in, ptr %out0) { ; X64-LABEL: 'ptrtoint_of_gep' ; X64-NEXT: Classifying expressions for: @ptrtoint_of_gep ; X64-NEXT: %in_adj = getelementptr inbounds i8, ptr %in, i64 42 -; X64-NEXT: --> (42 + %in) U: full-set S: full-set +; X64-NEXT: --> (42 + %in)(u nuw) U: full-set S: full-set ; X64-NEXT: %p0 = ptrtoint ptr %in_adj to i64 ; X64-NEXT: --> (42 + (ptrtoint ptr %in to i64)) U: full-set S: full-set ; X64-NEXT: Determining loop execution counts for: @ptrtoint_of_gep @@ -200,7 +200,7 @@ define void @ptrtoint_of_gep(ptr %in, ptr %out0) { ; X32-LABEL: 'ptrtoint_of_gep' ; X32-NEXT: Classifying expressions for: @ptrtoint_of_gep ; X32-NEXT: %in_adj = getelementptr inbounds i8, ptr %in, i64 42 -; X32-NEXT: --> (42 + %in) U: full-set S: full-set +; X32-NEXT: --> (42 + %in)(u nuw) U: full-set S: full-set ; X32-NEXT: %p0 = ptrtoint ptr %in_adj to i64 ; X32-NEXT: --> (zext i32 (42 + (ptrtoint ptr %in to i32)) to i64) U: [0,4294967296) S: [0,4294967296) ; X32-NEXT: Determining loop execution counts for: @ptrtoint_of_gep diff --git a/llvm/test/Analysis/ScalarEvolution/sdiv.ll b/llvm/test/Analysis/ScalarEvolution/sdiv.ll index e01f84fb2226e..ce531a967f634 100644 --- a/llvm/test/Analysis/ScalarEvolution/sdiv.ll +++ b/llvm/test/Analysis/ScalarEvolution/sdiv.ll @@ -18,7 +18,7 @@ define dso_local void @_Z4loopi(i32 %width) local_unnamed_addr #0 { ; CHECK-NEXT: %idxprom = sext i32 %rem to i64 ; CHECK-NEXT: --> ({0,+,1}<%for.cond> /u 2) U: [0,2147483648) S: [0,2147483648) Exits: ((zext i32 %width to i64) /u 2) LoopDispositions: { %for.cond: Computable } ; CHECK-NEXT: %arrayidx = getelementptr inbounds [2 x i32], ptr %storage, i64 0, i64 %idxprom -; CHECK-NEXT: --> ((4 * ({0,+,1}<%for.cond> /u 2)) + %storage) U: [0,-3) S: [-9223372036854775808,9223372036854775805) Exits: ((4 * ((zext i32 %width to i64) /u 2)) + %storage) LoopDispositions: { %for.cond: Computable } +; CHECK-NEXT: --> ((4 * ({0,+,1}<%for.cond> /u 2)) + %storage)(u nuw) U: [0,-3) S: [-9223372036854775808,9223372036854775805) Exits: ((4 * ((zext i32 %width to i64) /u 2)) + %storage) LoopDispositions: { %for.cond: Computable } ; CHECK-NEXT: %1 = load i32, ptr %arrayidx, align 4 ; CHECK-NEXT: --> %1 U: full-set S: full-set Exits: <> LoopDispositions: { %for.cond: Variant } ; CHECK-NEXT: %call = call i32 @_Z3adji(i32 %1) diff --git a/llvm/test/Analysis/ScalarEvolution/srem.ll b/llvm/test/Analysis/ScalarEvolution/srem.ll index ff898c963d0dc..6ea3921880a2d 100644 --- a/llvm/test/Analysis/ScalarEvolution/srem.ll +++ b/llvm/test/Analysis/ScalarEvolution/srem.ll @@ -18,7 +18,7 @@ define dso_local void @_Z4loopi(i32 %width) local_unnamed_addr #0 { ; CHECK-NEXT: %idxprom = sext i32 %rem to i64 ; CHECK-NEXT: --> (zext i1 {false,+,true}<%for.cond> to i64) U: [0,2) S: [0,2) Exits: (zext i1 (trunc i32 %width to i1) to i64) LoopDispositions: { %for.cond: Computable } ; CHECK-NEXT: %arrayidx = getelementptr inbounds [2 x i32], ptr %storage, i64 0, i64 %idxprom -; CHECK-NEXT: --> ((4 * (zext i1 {false,+,true}<%for.cond> to i64)) + %storage) U: [4,-7) S: [-9223372036854775808,9223372036854775805) Exits: ((4 * (zext i1 (trunc i32 %width to i1) to i64)) + %storage) LoopDispositions: { %for.cond: Computable } +; CHECK-NEXT: --> ((4 * (zext i1 {false,+,true}<%for.cond> to i64)) + %storage)(u nuw) U: [4,-7) S: [-9223372036854775808,9223372036854775805) Exits: ((4 * (zext i1 (trunc i32 %width to i1) to i64)) + %storage) LoopDispositions: { %for.cond: Computable } ; CHECK-NEXT: %1 = load i32, ptr %arrayidx, align 4 ; CHECK-NEXT: --> %1 U: full-set S: full-set Exits: <> LoopDispositions: { %for.cond: Variant } ; CHECK-NEXT: %call = call i32 @_Z3adji(i32 %1) diff --git a/llvm/test/Transforms/IndVarSimplify/turn-to-invariant.ll b/llvm/test/Transforms/IndVarSimplify/turn-to-invariant.ll index 326ee75e135b0..bc9e8004ec5df 100644 --- a/llvm/test/Transforms/IndVarSimplify/turn-to-invariant.ll +++ b/llvm/test/Transforms/IndVarSimplify/turn-to-invariant.ll @@ -846,11 +846,9 @@ failed: define i32 @test_litter_conditions_constant(i32 %start, i32 %len) { ; CHECK-LABEL: @test_litter_conditions_constant( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[TMP0:%.*]] = add i32 [[START:%.*]], -1 -; CHECK-NEXT: [[RANGE_CHECK_FIRST_ITER:%.*]] = icmp ult i32 [[TMP0]], [[LEN:%.*]] ; CHECK-NEXT: br label [[LOOP:%.*]] ; CHECK: loop: -; CHECK-NEXT: [[IV:%.*]] = phi i32 [ [[START]], [[ENTRY:%.*]] ], [ [[IV_NEXT:%.*]], [[BACKEDGE:%.*]] ] +; CHECK-NEXT: [[IV:%.*]] = phi i32 [ [[START:%.*]], [[ENTRY:%.*]] ], [ [[IV_NEXT:%.*]], [[BACKEDGE:%.*]] ] ; CHECK-NEXT: [[CANONICAL_IV:%.*]] = phi i32 [ 0, [[ENTRY]] ], [ [[CANONICAL_IV_NEXT:%.*]], [[BACKEDGE]] ] ; CHECK-NEXT: [[CONSTANT_CHECK:%.*]] = icmp ult i32 [[CANONICAL_IV]], 65635 ; CHECK-NEXT: br i1 [[CONSTANT_CHECK]], label [[CONSTANT_CHECK_PASSED:%.*]], label [[CONSTANT_CHECK_FAILED:%.*]] @@ -860,8 +858,10 @@ define i32 @test_litter_conditions_constant(i32 %start, i32 %len) { ; CHECK-NEXT: [[AND_1:%.*]] = and i1 [[ZERO_CHECK]], [[FAKE_1]] ; CHECK-NEXT: br i1 [[AND_1]], label [[RANGE_CHECK_BLOCK:%.*]], label [[FAILED_1:%.*]] ; CHECK: range_check_block: +; CHECK-NEXT: [[IV_MINUS_1:%.*]] = add i32 [[IV]], -1 +; CHECK-NEXT: [[RANGE_CHECK:%.*]] = icmp ult i32 [[IV_MINUS_1]], [[LEN:%.*]] ; CHECK-NEXT: [[FAKE_2:%.*]] = call i1 @cond() -; CHECK-NEXT: [[AND_2:%.*]] = and i1 [[RANGE_CHECK_FIRST_ITER]], [[FAKE_2]] +; CHECK-NEXT: [[AND_2:%.*]] = and i1 [[RANGE_CHECK]], [[FAKE_2]] ; CHECK-NEXT: br i1 [[AND_2]], label [[BACKEDGE]], label [[FAILED_2:%.*]] ; CHECK: backedge: ; CHECK-NEXT: [[IV_NEXT]] = add i32 [[IV]], -1 diff --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp index 37c61e4e4fa71..a32a66a59c240 100644 --- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp +++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp @@ -63,10 +63,10 @@ static std::optional computeConstantDifference(ScalarEvolution &SE, return SE.computeConstantDifference(LHS, RHS); } - static bool matchURem(ScalarEvolution &SE, const SCEV *Expr, const SCEV *&LHS, - const SCEV *&RHS) { - return SE.matchURem(Expr, LHS, RHS); - } +static bool matchURem(ScalarEvolution &SE, const SCEV *Expr, SCEVUse &LHS, + SCEVUse &RHS) { + return SE.matchURem(Expr, LHS, RHS); +} static bool isImpliedCond( ScalarEvolution &SE, ICmpInst::Predicate Pred, const SCEV *LHS, @@ -1522,8 +1522,8 @@ TEST_F(ScalarEvolutionsTest, MatchURem) { runWithSE(*M, "test", [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) { for (auto *N : {"rem1", "rem2", "rem3", "rem5"}) { auto *URemI = getInstructionByName(F, N); - auto *S = SE.getSCEV(URemI); - const SCEV *LHS, *RHS; + const SCEV *S = SE.getSCEV(URemI); + SCEVUse LHS, RHS; EXPECT_TRUE(matchURem(SE, S, LHS, RHS)); EXPECT_EQ(LHS, SE.getSCEV(URemI->getOperand(0))); EXPECT_EQ(RHS, SE.getSCEV(URemI->getOperand(1))); @@ -1535,8 +1535,8 @@ TEST_F(ScalarEvolutionsTest, MatchURem) { // match results are extended to the size of the input expression. auto *Ext = getInstructionByName(F, "ext"); auto *URem1 = getInstructionByName(F, "rem4"); - auto *S = SE.getSCEV(Ext); - const SCEV *LHS, *RHS; + const SCEV *S = SE.getSCEV(Ext); + SCEVUse LHS, RHS; EXPECT_TRUE(matchURem(SE, S, LHS, RHS)); EXPECT_NE(LHS, SE.getSCEV(URem1->getOperand(0))); // RHS and URem1->getOperand(1) have different widths, so compare the @@ -1662,11 +1662,11 @@ TEST_F(ScalarEvolutionsTest, ForgetValueWithOverflowInst) { auto *ExtractValue = getInstructionByName(F, "extractvalue"); auto *IV = getInstructionByName(F, "iv"); - auto *ExtractValueScev = SE.getSCEV(ExtractValue); + auto ExtractValueScev = SE.getSCEV(ExtractValue); EXPECT_NE(ExtractValueScev, nullptr); SE.forgetValue(IV); - auto *ExtractValueScevForgotten = SE.getExistingSCEV(ExtractValue); + auto ExtractValueScevForgotten = SE.getExistingSCEV(ExtractValue); EXPECT_EQ(ExtractValueScevForgotten, nullptr); }); } @@ -1707,4 +1707,59 @@ TEST_F(ScalarEvolutionsTest, ComplexityComparatorIsStrictWeakOrdering) { }); } +TEST_F(ScalarEvolutionsTest, SCEVUseWithFlags) { + Type *Ty = IntegerType::get(Context, 32); + FunctionType *FTy = + FunctionType::get(Type::getVoidTy(Context), {Ty, Ty, Ty}, false); + Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M); + BasicBlock *BB = BasicBlock::Create(Context, "entry", F); + ReturnInst::Create(Context, nullptr, BB); + + Value *V0 = F->getArg(0); + Value *V1 = F->getArg(1); + Value *V2 = F->getArg(2); + + ScalarEvolution SE = buildSE(*F); + + const SCEV *S0 = SE.getSCEV(V0); + const SCEV *S1 = SE.getSCEV(V1); + const SCEV *S2 = SE.getSCEV(V2); + + SCEVUse AddNoFlags = SE.getAddExpr(S0, SE.getConstant(S0->getType(), 2)); + SCEVUse AddWithFlag2 = {AddNoFlags, 2}; + SCEVUse MulNoFlags = SE.getMulExpr(AddNoFlags, S1); + SCEVUse MulFlags2 = SE.getMulExpr(AddWithFlag2, S1); + EXPECT_EQ(AddNoFlags.getCanonical(SE), AddWithFlag2.getCanonical(SE)); + EXPECT_EQ(MulNoFlags.getCanonical(SE), MulFlags2.getCanonical(SE)); + + SCEVUse AddWithFlag1 = {AddNoFlags, 1}; + SCEVUse MulFlags1 = SE.getMulExpr(AddWithFlag1, S1); + EXPECT_EQ(MulNoFlags.getCanonical(SE), MulFlags1.getCanonical(SE)); + EXPECT_EQ(MulFlags1.getCanonical(SE), MulFlags2.getCanonical(SE)); + + SCEVUse AddNoFlags2 = SE.getAddExpr(S0, SE.getConstant(S0->getType(), 2)); + EXPECT_EQ(AddNoFlags.getCanonical(SE), AddNoFlags2.getCanonical(SE)); + EXPECT_EQ(AddNoFlags2.getCanonical(SE), AddWithFlag2.getCanonical(SE)); + + SCEVUse MulFlags22 = SE.getMulExpr(AddWithFlag2, S1); + EXPECT_EQ(MulFlags22.getCanonical(SE), MulFlags2.getCanonical(SE)); + EXPECT_EQ(MulNoFlags.getCanonical(SE), MulFlags22.getCanonical(SE)); + + SCEVUse MulNoFlags2 = SE.getMulExpr(AddNoFlags, S1); + EXPECT_EQ(MulNoFlags.getCanonical(SE), MulNoFlags2.getCanonical(SE)); + EXPECT_EQ(MulNoFlags2.getCanonical(SE), MulFlags2.getCanonical(SE)); + EXPECT_EQ(MulNoFlags2.getCanonical(SE), MulFlags22.getCanonical(SE)); + + SE.getAddExpr(MulNoFlags, S2); + SE.getAddExpr(MulFlags1, S2); + SE.getAddExpr(MulFlags2, S2); + SCEVUse AddMulNoFlags = SE.getAddExpr(MulNoFlags, S2); + SCEVUse AddMulFlags1 = SE.getAddExpr(MulFlags1, S2); + SCEVUse AddMulFlags2 = SE.getAddExpr(MulFlags2, S2); + + EXPECT_EQ(AddMulNoFlags.getCanonical(SE), AddMulFlags1.getCanonical(SE)); + EXPECT_EQ(AddMulNoFlags.getCanonical(SE), AddMulFlags2.getCanonical(SE)); + EXPECT_EQ(AddMulFlags1.getCanonical(SE), AddMulFlags2.getCanonical(SE)); +} + } // end namespace llvm diff --git a/polly/include/polly/Support/ScopHelper.h b/polly/include/polly/Support/ScopHelper.h index 13852ecb18ee7..9a0319241e62d 100644 --- a/polly/include/polly/Support/ScopHelper.h +++ b/polly/include/polly/Support/ScopHelper.h @@ -14,6 +14,7 @@ #define POLLY_SUPPORT_IRHELPER_H #include "llvm/ADT/SetVector.h" +#include "llvm/Analysis/ScalarEvolution.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/ValueHandle.h" @@ -37,7 +38,7 @@ class Scop; class ScopStmt; /// Same as llvm/Analysis/ScalarEvolutionExpressions.h -using LoopToScevMapT = llvm::DenseMap; +using LoopToScevMapT = llvm::DenseMap; /// Enumeration of assumptions Polly can take. enum AssumptionKind { @@ -401,7 +402,7 @@ void splitEntryBlockForAlloca(llvm::BasicBlock *EntryBlock, llvm::Value *expandCodeFor(Scop &S, llvm::ScalarEvolution &SE, llvm::Function *GenFn, llvm::ScalarEvolution &GenSE, const llvm::DataLayout &DL, const char *Name, - const llvm::SCEV *E, llvm::Type *Ty, + llvm::SCEVUse E, llvm::Type *Ty, llvm::Instruction *IP, ValueMapT *VMap, LoopToScevMapT *LoopMap, llvm::BasicBlock *RTCBB); diff --git a/polly/lib/Analysis/ScopBuilder.cpp b/polly/lib/Analysis/ScopBuilder.cpp index c05fc1a347c25..24ab756a6d041 100644 --- a/polly/lib/Analysis/ScopBuilder.cpp +++ b/polly/lib/Analysis/ScopBuilder.cpp @@ -1567,7 +1567,7 @@ bool ScopBuilder::buildAccessMemIntrinsic(MemAccInst Inst, ScopStmt *Stmt) { return false; auto *L = LI.getLoopFor(Inst->getParent()); - auto *LengthVal = SE.getSCEVAtScope(MemIntr->getLength(), L); + SCEVUse LengthVal = SE.getSCEVAtScope(MemIntr->getLength(), L); assert(LengthVal); // Check if the length val is actually affine or if we overapproximate it @@ -1586,7 +1586,7 @@ bool ScopBuilder::buildAccessMemIntrinsic(MemAccInst Inst, ScopStmt *Stmt) { auto *DestPtrVal = MemIntr->getDest(); assert(DestPtrVal); - auto *DestAccFunc = SE.getSCEVAtScope(DestPtrVal, L); + SCEVUse DestAccFunc = SE.getSCEVAtScope(DestPtrVal, L); assert(DestAccFunc); // Ignore accesses to "NULL". // TODO: We could use this to optimize the region further, e.g., intersect @@ -1616,7 +1616,7 @@ bool ScopBuilder::buildAccessMemIntrinsic(MemAccInst Inst, ScopStmt *Stmt) { auto *SrcPtrVal = MemTrans->getSource(); assert(SrcPtrVal); - auto *SrcAccFunc = SE.getSCEVAtScope(SrcPtrVal, L); + SCEVUse SrcAccFunc = SE.getSCEVAtScope(SrcPtrVal, L); assert(SrcAccFunc); // Ignore accesses to "NULL". // TODO: See above TODO @@ -1643,7 +1643,7 @@ bool ScopBuilder::buildAccessCallInst(MemAccInst Inst, ScopStmt *Stmt) { if (CI->doesNotAccessMemory() || isIgnoredIntrinsic(CI) || isDebugCall(CI)) return true; - auto *AF = SE.getConstant(IntegerType::getInt64Ty(CI->getContext()), 0); + SCEVUse AF = SE.getConstant(IntegerType::getInt64Ty(CI->getContext()), 0); auto *CalledFunction = CI->getCalledFunction(); MemoryEffects ME = AA.getMemoryEffects(CalledFunction); if (ME.doesNotAccessMemory()) @@ -1658,7 +1658,7 @@ bool ScopBuilder::buildAccessCallInst(MemAccInst Inst, ScopStmt *Stmt) { if (!Arg->getType()->isPointerTy()) continue; - auto *ArgSCEV = SE.getSCEVAtScope(Arg, L); + SCEVUse ArgSCEV = SE.getSCEVAtScope(Arg, L); if (ArgSCEV->isZero()) continue; @@ -2169,7 +2169,7 @@ static bool isDivisible(const SCEV *Expr, unsigned Size, ScalarEvolution &SE) { // Only one factor needs to be divisible. if (auto *MulExpr = dyn_cast(Expr)) { - for (auto *FactorExpr : MulExpr->operands()) + for (SCEVUse FactorExpr : MulExpr->operands()) if (isDivisible(FactorExpr, Size, SE)) return true; return false; @@ -2178,15 +2178,15 @@ static bool isDivisible(const SCEV *Expr, unsigned Size, ScalarEvolution &SE) { // For other n-ary expressions (Add, AddRec, Max,...) all operands need // to be divisible. if (auto *NAryExpr = dyn_cast(Expr)) { - for (auto *OpExpr : NAryExpr->operands()) + for (SCEVUse OpExpr : NAryExpr->operands()) if (!isDivisible(OpExpr, Size, SE)) return false; return true; } - auto *SizeSCEV = SE.getConstant(Expr->getType(), Size); - auto *UDivSCEV = SE.getUDivExpr(Expr, SizeSCEV); - auto *MulSCEV = SE.getMulExpr(UDivSCEV, SizeSCEV); + SCEVUse SizeSCEV = SE.getConstant(Expr->getType(), Size); + SCEVUse UDivSCEV = SE.getUDivExpr(Expr, SizeSCEV); + SCEVUse MulSCEV = SE.getMulExpr(UDivSCEV, SizeSCEV); return MulSCEV == Expr; } @@ -3672,7 +3672,7 @@ void ScopBuilder::buildScop(Region &R, AssumptionCache &AC) { } // Create memory accesses for global reads since all arrays are now known. - auto *AF = SE.getConstant(IntegerType::getInt64Ty(SE.getContext()), 0); + SCEVUse AF = SE.getConstant(IntegerType::getInt64Ty(SE.getContext()), 0); for (auto GlobalReadPair : GlobalReads) { ScopStmt *GlobalReadStmt = GlobalReadPair.first; Instruction *GlobalRead = GlobalReadPair.second; diff --git a/polly/lib/Analysis/ScopDetection.cpp b/polly/lib/Analysis/ScopDetection.cpp index 79db3965de023..bfd7ba500494c 100644 --- a/polly/lib/Analysis/ScopDetection.cpp +++ b/polly/lib/Analysis/ScopDetection.cpp @@ -520,7 +520,7 @@ bool ScopDetection::involvesMultiplePtrs(const SCEV *S0, const SCEV *S1, if (!V->getType()->isPointerTy()) continue; - auto *PtrSCEV = SE.getSCEVAtScope(V, Scope); + SCEVUse PtrSCEV = SE.getSCEVAtScope(V, Scope); if (isa(PtrSCEV)) continue; @@ -720,7 +720,7 @@ bool ScopDetection::isValidCallInst(CallInst &CI, // Bail if a pointer argument has a base address not known to // ScalarEvolution. Note that a zero pointer is acceptable. - auto *ArgSCEV = SE.getSCEVAtScope(Arg, LI.getLoopFor(CI.getParent())); + SCEVUse ArgSCEV = SE.getSCEVAtScope(Arg, LI.getLoopFor(CI.getParent())); if (ArgSCEV->isZero()) continue; @@ -891,7 +891,7 @@ ScopDetection::getDelinearizationTerms(DetectionContext &Context, if (auto *AF2 = dyn_cast(Op)) { SmallVector Operands; - for (auto *MulOp : AF2->operands()) { + for (SCEVUse MulOp : AF2->operands()) { if (auto *Const = dyn_cast(MulOp)) Operands.push_back(Const); if (auto *Unknown = dyn_cast(MulOp)) { @@ -1366,7 +1366,7 @@ bool ScopDetection::isValidLoop(Loop *L, DetectionContext &Context) { ScopDetection::LoopStats ScopDetection::countBeneficialSubLoops(Loop *L, ScalarEvolution &SE, unsigned MinProfitableTrips) { - auto *TripCount = SE.getBackedgeTakenCount(L); + SCEVUse TripCount = SE.getBackedgeTakenCount(L); int NumLoops = 1; int MaxLoopDepth = 1; diff --git a/polly/lib/Analysis/ScopInfo.cpp b/polly/lib/Analysis/ScopInfo.cpp index 56ffb990faf1c..38b93b6520258 100644 --- a/polly/lib/Analysis/ScopInfo.cpp +++ b/polly/lib/Analysis/ScopInfo.cpp @@ -215,12 +215,12 @@ static const ScopArrayInfo *identifyBasePtrOriginSAI(Scop *S, Value *BasePtr) { ScalarEvolution &SE = *S->getSE(); - auto *OriginBaseSCEV = + SCEVUse OriginBaseSCEV = SE.getPointerBase(SE.getSCEV(BasePtrLI->getPointerOperand())); if (!OriginBaseSCEV) return nullptr; - auto *OriginBaseSCEVUnknown = dyn_cast(OriginBaseSCEV); + auto OriginBaseSCEVUnknown = dyn_cast(OriginBaseSCEV); if (!OriginBaseSCEVUnknown) return nullptr; @@ -713,11 +713,11 @@ void MemoryAccess::computeBoundsOnAccessRelation(unsigned ElementSize) { if (!Ptr || !SE->isSCEVable(Ptr->getType())) return; - auto *PtrSCEV = SE->getSCEV(Ptr); + SCEVUse PtrSCEV = SE->getSCEV(Ptr); if (isa(PtrSCEV)) return; - auto *BasePtrSCEV = SE->getPointerBase(PtrSCEV); + SCEVUse BasePtrSCEV = SE->getPointerBase(PtrSCEV); if (BasePtrSCEV && !isa(BasePtrSCEV)) PtrSCEV = SE->getMinusSCEV(PtrSCEV, BasePtrSCEV); @@ -1384,10 +1384,10 @@ class SCEVSensitiveParameterRewriter final } const SCEV *visitAddRecExpr(const SCEVAddRecExpr *E) { - auto *Start = visit(E->getStart()); - auto *AddRec = SE.getAddRecExpr(SE.getConstant(E->getType(), 0), - visit(E->getStepRecurrence(SE)), - E->getLoop(), SCEV::FlagAnyWrap); + SCEVUse Start = visit(E->getStart()); + SCEVUse AddRec = SE.getAddRecExpr(SE.getConstant(E->getType(), 0), + visit(E->getStepRecurrence(SE)), + E->getLoop(), SCEV::FlagAnyWrap); return SE.getAddExpr(Start, AddRec); } diff --git a/polly/lib/Support/SCEVAffinator.cpp b/polly/lib/Support/SCEVAffinator.cpp index d8463b238822d..cebb9e9133bb7 100644 --- a/polly/lib/Support/SCEVAffinator.cpp +++ b/polly/lib/Support/SCEVAffinator.cpp @@ -281,7 +281,7 @@ PWACtx SCEVAffinator::visitTruncateExpr(const SCEVTruncateExpr *Expr) { // to fit in the new type size instead of introducing a modulo with a very // large constant. - auto *Op = Expr->getOperand(); + SCEVUse Op = Expr->getOperand(); auto OpPWAC = visit(Op); unsigned Width = TD.getTypeSizeInBits(Expr->getType()); @@ -354,7 +354,7 @@ PWACtx SCEVAffinator::visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { // bit-width is bigger than MaxZextSmallBitWidth we will employ overflow // assumptions and assume the "former negative" piece will not exist. - auto *Op = Expr->getOperand(); + SCEVUse Op = Expr->getOperand(); auto OpPWAC = visit(Op); // If the width is to big we assume the negative part does not occur. @@ -483,8 +483,8 @@ PWACtx SCEVAffinator::visitUDivExpr(const SCEVUDivExpr *Expr) { // For the dividend we could choose from the different representation // schemes introduced for zero-extend operations but for now we will // simply use an assumption. - auto *Dividend = Expr->getLHS(); - auto *Divisor = Expr->getRHS(); + SCEVUse Dividend = Expr->getLHS(); + SCEVUse Divisor = Expr->getRHS(); assert(isa(Divisor) && "UDiv is no parameter but has a non-constant RHS."); @@ -518,13 +518,13 @@ PWACtx SCEVAffinator::visitSDivInstruction(Instruction *SDiv) { auto *Scope = getScope(); auto *Divisor = SDiv->getOperand(1); - auto *DivisorSCEV = SE.getSCEVAtScope(Divisor, Scope); + SCEVUse DivisorSCEV = SE.getSCEVAtScope(Divisor, Scope); auto DivisorPWAC = visit(DivisorSCEV); assert(isa(DivisorSCEV) && "SDiv is no parameter but has a non-constant RHS."); auto *Dividend = SDiv->getOperand(0); - auto *DividendSCEV = SE.getSCEVAtScope(Dividend, Scope); + SCEVUse DividendSCEV = SE.getSCEVAtScope(Dividend, Scope); auto DividendPWAC = visit(DividendSCEV); DividendPWAC = combine(DividendPWAC, DivisorPWAC, isl_pw_aff_tdiv_q); return DividendPWAC; @@ -535,13 +535,13 @@ PWACtx SCEVAffinator::visitSRemInstruction(Instruction *SRem) { auto *Scope = getScope(); auto *Divisor = SRem->getOperand(1); - auto *DivisorSCEV = SE.getSCEVAtScope(Divisor, Scope); + SCEVUse DivisorSCEV = SE.getSCEVAtScope(Divisor, Scope); auto DivisorPWAC = visit(DivisorSCEV); assert(isa(Divisor) && "SRem is no parameter but has a non-constant RHS."); auto *Dividend = SRem->getOperand(0); - auto *DividendSCEV = SE.getSCEVAtScope(Dividend, Scope); + SCEVUse DividendSCEV = SE.getSCEVAtScope(Dividend, Scope); auto DividendPWAC = visit(DividendSCEV); DividendPWAC = combine(DividendPWAC, DivisorPWAC, isl_pw_aff_tdiv_r); return DividendPWAC; diff --git a/polly/lib/Support/SCEVValidator.cpp b/polly/lib/Support/SCEVValidator.cpp index 5bb82624ed784..a7cd1eb9917dc 100644 --- a/polly/lib/Support/SCEVValidator.cpp +++ b/polly/lib/Support/SCEVValidator.cpp @@ -403,8 +403,8 @@ class SCEVValidator : public SCEVVisitor { if (!PollyAllowUnsignedOperations) return ValidatorResult(SCEVType::INVALID); - auto *Dividend = Expr->getLHS(); - auto *Divisor = Expr->getRHS(); + SCEVUse Dividend = Expr->getLHS(); + SCEVUse Divisor = Expr->getRHS(); return visitDivision(Dividend, Divisor, Expr); } @@ -412,8 +412,8 @@ class SCEVValidator : public SCEVVisitor { assert(SDiv->getOpcode() == Instruction::SDiv && "Assumed SDiv instruction!"); - auto *Dividend = SE.getSCEV(SDiv->getOperand(0)); - auto *Divisor = SE.getSCEV(SDiv->getOperand(1)); + SCEVUse Dividend = SE.getSCEV(SDiv->getOperand(0)); + SCEVUse Divisor = SE.getSCEV(SDiv->getOperand(1)); return visitDivision(Dividend, Divisor, Expr, SDiv); } @@ -427,7 +427,7 @@ class SCEVValidator : public SCEVVisitor { return visitGenericInst(SRem, S); auto *Dividend = SRem->getOperand(0); - auto *DividendSCEV = SE.getSCEV(Dividend); + SCEVUse DividendSCEV = SE.getSCEV(Dividend); return visit(DividendSCEV); } @@ -566,11 +566,11 @@ class SCEVFindValues final { Inst->getOpcode() != Instruction::SDiv)) return false; - auto *Dividend = SE.getSCEV(Inst->getOperand(1)); + SCEVUse Dividend = SE.getSCEV(Inst->getOperand(1)); if (!isa(Dividend)) return false; - auto *Divisor = SE.getSCEV(Inst->getOperand(0)); + SCEVUse Divisor = SE.getSCEV(Inst->getOperand(0)); SCEVFindValues FindValues(SE, Values); SCEVTraversal ST(FindValues); ST.visitAll(Dividend); @@ -623,7 +623,7 @@ bool polly::isAffineExpr(const Region *R, llvm::Loop *Scope, const SCEV *Expr, static bool isAffineExpr(Value *V, const Region *R, Loop *Scope, ScalarEvolution &SE, ParameterSetTy &Params) { - auto *E = SE.getSCEV(V); + SCEVUse E = SE.getSCEV(V); if (isa(E)) return false; @@ -684,10 +684,10 @@ polly::extractConstantFactor(const SCEV *S, ScalarEvolution &SE) { auto *AddRec = dyn_cast(S); if (AddRec) { - auto *StartExpr = AddRec->getStart(); + SCEVUse StartExpr = AddRec->getStart(); if (StartExpr->isZero()) { auto StepPair = extractConstantFactor(AddRec->getStepRecurrence(SE), SE); - auto *LeftOverAddRec = + SCEVUse LeftOverAddRec = SE.getAddRecExpr(StartExpr, StepPair.second, AddRec->getLoop(), AddRec->getNoWrapFlags()); return std::make_pair(StepPair.first, LeftOverAddRec); @@ -717,7 +717,7 @@ polly::extractConstantFactor(const SCEV *S, ScalarEvolution &SE) { return std::make_pair(ConstPart, S); } - auto *NewAdd = SE.getAddExpr(LeftOvers, Add->getNoWrapFlags()); + SCEVUse NewAdd = SE.getAddExpr(LeftOvers, Add->getNoWrapFlags()); return std::make_pair(Factor, NewAdd); } @@ -726,7 +726,7 @@ polly::extractConstantFactor(const SCEV *S, ScalarEvolution &SE) { return std::make_pair(ConstPart, S); SmallVector LeftOvers; - for (auto *Op : Mul->operands()) + for (SCEVUse Op : Mul->operands()) if (isa(Op)) ConstPart = cast(SE.getMulExpr(ConstPart, Op)); else diff --git a/polly/lib/Support/ScopHelper.cpp b/polly/lib/Support/ScopHelper.cpp index 754bf50e2911f..330d9ed0aad51 100644 --- a/polly/lib/Support/ScopHelper.cpp +++ b/polly/lib/Support/ScopHelper.cpp @@ -250,8 +250,8 @@ void polly::recordAssumption(polly::RecordedAssumptionsTy *RecordedAssumptions, /// and we generate code outside/in front of that region. Hence, we generate the /// code for the SDiv/SRem operands in front of the analyzed region and then /// create a new SDiv/SRem operation there too. -struct ScopExpander final : SCEVVisitor { - friend struct SCEVVisitor; +struct ScopExpander final : SCEVVisitor { + friend struct SCEVVisitor; explicit ScopExpander(const Region &R, ScalarEvolution &SE, Function *GenFn, ScalarEvolution &GenSE, const DataLayout &DL, @@ -261,20 +261,20 @@ struct ScopExpander final : SCEVVisitor { VMap(VMap), LoopMap(LoopMap), RTCBB(RTCBB), GenSE(GenSE), GenFn(GenFn) { } - Value *expandCodeFor(const SCEV *E, Type *Ty, Instruction *IP) { + Value *expandCodeFor(SCEVUse E, Type *Ty, Instruction *IP) { assert(isInGenRegion(IP) && "ScopExpander assumes to be applied to generated code region"); - const SCEV *GenE = visit(E); + SCEVUse GenE = visit(E); return Expander.expandCodeFor(GenE, Ty, IP); } - const SCEV *visit(const SCEV *E) { + SCEVUse visit(SCEVUse E) { // Cache the expansion results for intermediate SCEV expressions. A SCEV // expression can refer to an operand multiple times (e.g. "x*x), so // a naive visitor takes exponential time. if (SCEVCache.count(E)) return SCEVCache[E]; - const SCEV *Result = SCEVVisitor::visit(E); + SCEVUse Result = SCEVVisitor::visit(E); SCEVCache[E] = Result; return Result; } @@ -286,7 +286,7 @@ struct ScopExpander final : SCEVVisitor { ValueMapT *VMap; LoopToScevMapT *LoopMap; BasicBlock *RTCBB; - DenseMap SCEVCache; + DenseMap SCEVCache; ScalarEvolution &GenSE; Function *GenFn; @@ -304,8 +304,8 @@ struct ScopExpander final : SCEVVisitor { bool isInGenRegion(Instruction *Inst) { return !isInOrigRegion(Inst); } - const SCEV *visitGenericInst(const SCEVUnknown *E, Instruction *Inst, - Instruction *IP) { + SCEVUse visitGenericInst(const SCEVUnknown *E, Instruction *Inst, + Instruction *IP) { if (!Inst || isInGenRegion(Inst)) return E; @@ -315,7 +315,7 @@ struct ScopExpander final : SCEVVisitor { auto *InstClone = Inst->clone(); for (auto &Op : Inst->operands()) { assert(GenSE.isSCEVable(Op->getType())); - auto *OpSCEV = GenSE.getSCEV(Op); + SCEVUse OpSCEV = GenSE.getSCEV(Op); auto *OpClone = expandCodeFor(OpSCEV, Op->getType(), IP); InstClone->replaceUsesOfWith(Op, OpClone); } @@ -325,12 +325,12 @@ struct ScopExpander final : SCEVVisitor { return GenSE.getSCEV(InstClone); } - const SCEV *visitUnknown(const SCEVUnknown *E) { + SCEVUse visitUnknown(const SCEVUnknown *E) { // If a value mapping was given try if the underlying value is remapped. Value *NewVal = VMap ? VMap->lookup(E->getValue()) : nullptr; if (NewVal) { - auto *NewE = GenSE.getSCEV(NewVal); + SCEVUse NewE = GenSE.getSCEV(NewVal); // While the mapped value might be different the SCEV representation might // not be. To this end we will check before we go into recursion here. @@ -359,8 +359,8 @@ struct ScopExpander final : SCEVVisitor { Inst->getOpcode() != Instruction::SDiv)) return visitGenericInst(E, Inst, IP); - const SCEV *LHSScev = GenSE.getSCEV(Inst->getOperand(0)); - const SCEV *RHSScev = GenSE.getSCEV(Inst->getOperand(1)); + SCEVUse LHSScev = GenSE.getSCEV(Inst->getOperand(0)); + SCEVUse RHSScev = GenSE.getSCEV(Inst->getOperand(1)); if (!GenSE.isKnownNonZero(RHSScev)) RHSScev = GenSE.getUMaxExpr(RHSScev, GenSE.getConstant(E->getType(), 1)); @@ -378,80 +378,80 @@ struct ScopExpander final : SCEVVisitor { /// GenSE and the new operands returned by the traversal. /// ///{ - const SCEV *visitConstant(const SCEVConstant *E) { return E; } - const SCEV *visitVScale(const SCEVVScale *E) { return E; } - const SCEV *visitPtrToIntExpr(const SCEVPtrToIntExpr *E) { + SCEVUse visitConstant(const SCEVConstant *E) { return E; } + SCEVUse visitVScale(const SCEVVScale *E) { return E; } + SCEVUse visitPtrToIntExpr(const SCEVPtrToIntExpr *E) { return GenSE.getPtrToIntExpr(visit(E->getOperand()), E->getType()); } - const SCEV *visitTruncateExpr(const SCEVTruncateExpr *E) { + SCEVUse visitTruncateExpr(const SCEVTruncateExpr *E) { return GenSE.getTruncateExpr(visit(E->getOperand()), E->getType()); } - const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *E) { + SCEVUse visitZeroExtendExpr(const SCEVZeroExtendExpr *E) { return GenSE.getZeroExtendExpr(visit(E->getOperand()), E->getType()); } - const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *E) { + SCEVUse visitSignExtendExpr(const SCEVSignExtendExpr *E) { return GenSE.getSignExtendExpr(visit(E->getOperand()), E->getType()); } - const SCEV *visitUDivExpr(const SCEVUDivExpr *E) { - auto *RHSScev = visit(E->getRHS()); + SCEVUse visitUDivExpr(const SCEVUDivExpr *E) { + SCEVUse RHSScev = visit(E->getRHS()); if (!GenSE.isKnownNonZero(RHSScev)) RHSScev = GenSE.getUMaxExpr(RHSScev, GenSE.getConstant(E->getType(), 1)); return GenSE.getUDivExpr(visit(E->getLHS()), RHSScev); } - const SCEV *visitAddExpr(const SCEVAddExpr *E) { - SmallVector NewOps; - for (const SCEV *Op : E->operands()) + SCEVUse visitAddExpr(const SCEVAddExpr *E) { + SmallVector NewOps; + for (SCEVUse Op : E->operands()) NewOps.push_back(visit(Op)); return GenSE.getAddExpr(NewOps); } - const SCEV *visitMulExpr(const SCEVMulExpr *E) { - SmallVector NewOps; - for (const SCEV *Op : E->operands()) + SCEVUse visitMulExpr(const SCEVMulExpr *E) { + SmallVector NewOps; + for (SCEVUse Op : E->operands()) NewOps.push_back(visit(Op)); return GenSE.getMulExpr(NewOps); } - const SCEV *visitUMaxExpr(const SCEVUMaxExpr *E) { - SmallVector NewOps; - for (const SCEV *Op : E->operands()) + SCEVUse visitUMaxExpr(const SCEVUMaxExpr *E) { + SmallVector NewOps; + for (SCEVUse Op : E->operands()) NewOps.push_back(visit(Op)); return GenSE.getUMaxExpr(NewOps); } - const SCEV *visitSMaxExpr(const SCEVSMaxExpr *E) { - SmallVector NewOps; - for (const SCEV *Op : E->operands()) + SCEVUse visitSMaxExpr(const SCEVSMaxExpr *E) { + SmallVector NewOps; + for (SCEVUse Op : E->operands()) NewOps.push_back(visit(Op)); return GenSE.getSMaxExpr(NewOps); } - const SCEV *visitUMinExpr(const SCEVUMinExpr *E) { - SmallVector NewOps; - for (const SCEV *Op : E->operands()) + SCEVUse visitUMinExpr(const SCEVUMinExpr *E) { + SmallVector NewOps; + for (SCEVUse Op : E->operands()) NewOps.push_back(visit(Op)); return GenSE.getUMinExpr(NewOps); } - const SCEV *visitSMinExpr(const SCEVSMinExpr *E) { - SmallVector NewOps; - for (const SCEV *Op : E->operands()) + SCEVUse visitSMinExpr(const SCEVSMinExpr *E) { + SmallVector NewOps; + for (SCEVUse Op : E->operands()) NewOps.push_back(visit(Op)); return GenSE.getSMinExpr(NewOps); } - const SCEV *visitSequentialUMinExpr(const SCEVSequentialUMinExpr *E) { - SmallVector NewOps; - for (const SCEV *Op : E->operands()) + SCEVUse visitSequentialUMinExpr(const SCEVSequentialUMinExpr *E) { + SmallVector NewOps; + for (SCEVUse Op : E->operands()) NewOps.push_back(visit(Op)); return GenSE.getUMinExpr(NewOps, /*Sequential=*/true); } - const SCEV *visitAddRecExpr(const SCEVAddRecExpr *E) { - SmallVector NewOps; - for (const SCEV *Op : E->operands()) + SCEVUse visitAddRecExpr(const SCEVAddRecExpr *E) { + SmallVector NewOps; + for (SCEVUse Op : E->operands()) NewOps.push_back(visit(Op)); const Loop *L = E->getLoop(); - const SCEV *GenLRepl = LoopMap ? LoopMap->lookup(L) : nullptr; + SCEVUse GenLRepl = LoopMap ? LoopMap->lookup(L) : nullptr; if (!GenLRepl) return GenSE.getAddRecExpr(NewOps, L, E->getNoWrapFlags()); // evaluateAtIteration replaces the SCEVAddrExpr with a direct calculation. - const SCEV *Evaluated = + SCEVUse Evaluated = SCEVAddRecExpr::evaluateAtIteration(NewOps, GenLRepl, GenSE); // FIXME: This emits a SCEV for GenSE (since GenLRepl will refer to the @@ -464,10 +464,9 @@ struct ScopExpander final : SCEVVisitor { Value *polly::expandCodeFor(Scop &S, llvm::ScalarEvolution &SE, llvm::Function *GenFn, ScalarEvolution &GenSE, - const DataLayout &DL, const char *Name, - const SCEV *E, Type *Ty, Instruction *IP, - ValueMapT *VMap, LoopToScevMapT *LoopMap, - BasicBlock *RTCBB) { + const DataLayout &DL, const char *Name, SCEVUse E, + Type *Ty, Instruction *IP, ValueMapT *VMap, + LoopToScevMapT *LoopMap, BasicBlock *RTCBB) { ScopExpander Expander(S.getRegion(), SE, GenFn, GenSE, DL, Name, VMap, LoopMap, RTCBB); return Expander.expandCodeFor(E, Ty, IP); @@ -564,7 +563,7 @@ Loop *polly::getRegionNodeLoop(RegionNode *RN, LoopInfo &LI) { static bool hasVariantIndex(GetElementPtrInst *Gep, Loop *L, Region &R, ScalarEvolution &SE) { for (const Use &Val : llvm::drop_begin(Gep->operands(), 1)) { - const SCEV *PtrSCEV = SE.getSCEVAtScope(Val, L); + SCEVUse PtrSCEV = SE.getSCEVAtScope(Val, L); Loop *OuterLoop = R.outermostLoopInRegion(L); if (!SE.isLoopInvariant(PtrSCEV, OuterLoop)) return true; @@ -595,7 +594,7 @@ bool polly::isHoistableLoad(LoadInst *LInst, Region &R, LoopInfo &LI, } } - const SCEV *PtrSCEV = SE.getSCEVAtScope(Ptr, L); + SCEVUse PtrSCEV = SE.getSCEVAtScope(Ptr, L); while (L && R.contains(L)) { if (!SE.isLoopInvariant(PtrSCEV, L)) return false; @@ -665,7 +664,7 @@ bool polly::canSynthesize(const Value *V, const Scop &S, ScalarEvolution *SE, return false; const InvariantLoadsSetTy &ILS = S.getRequiredInvariantLoads(); - if (const SCEV *Scev = SE->getSCEVAtScope(const_cast(V), Scope)) + if (SCEVUse Scev = SE->getSCEVAtScope(const_cast(V), Scope)) if (!isa(Scev)) if (!hasScalarDepsInsideRegion(Scev, &S.getRegion(), Scope, false, ILS)) return true; diff --git a/polly/lib/Support/VirtualInstruction.cpp b/polly/lib/Support/VirtualInstruction.cpp index e570d8d546494..4911a7437e212 100644 --- a/polly/lib/Support/VirtualInstruction.cpp +++ b/polly/lib/Support/VirtualInstruction.cpp @@ -65,7 +65,7 @@ VirtualUse VirtualUse::create(Scop *S, ScopStmt *UserStmt, Loop *UserScope, // We assume synthesizable which practically should have the same effect. auto *SE = S->getSE(); if (SE->isSCEVable(Val->getType())) { - auto *ScevExpr = SE->getSCEVAtScope(Val, UserScope); + SCEVUse ScevExpr = SE->getSCEVAtScope(Val, UserScope); if (!UserStmt || canSynthesize(Val, *UserStmt->getParent(), SE, UserScope)) return VirtualUse(UserStmt, Val, Synthesizable, ScevExpr, nullptr); }