diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h index 58577a6b6eb5c..4eab357f1b33b 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -777,9 +777,9 @@ class TargetTransformInfo { bool forceScalarizeMaskedScatter(VectorType *Type, Align Alignment) const; /// Return true if the target supports masked compress store. - bool isLegalMaskedCompressStore(Type *DataType) const; + bool isLegalMaskedCompressStore(Type *DataType, Align Alignment) const; /// Return true if the target supports masked expand load. - bool isLegalMaskedExpandLoad(Type *DataType) const; + bool isLegalMaskedExpandLoad(Type *DataType, Align Alignment) const; /// Return true if the target supports strided load. bool isLegalStridedLoadStore(Type *DataType, Align Alignment) const; @@ -1863,8 +1863,8 @@ class TargetTransformInfo::Concept { Align Alignment) = 0; virtual bool forceScalarizeMaskedScatter(VectorType *DataType, Align Alignment) = 0; - virtual bool isLegalMaskedCompressStore(Type *DataType) = 0; - virtual bool isLegalMaskedExpandLoad(Type *DataType) = 0; + virtual bool isLegalMaskedCompressStore(Type *DataType, Align Alignment) = 0; + virtual bool isLegalMaskedExpandLoad(Type *DataType, Align Alignment) = 0; virtual bool isLegalStridedLoadStore(Type *DataType, Align Alignment) = 0; virtual bool isLegalAltInstr(VectorType *VecTy, unsigned Opcode0, unsigned Opcode1, @@ -2358,11 +2358,11 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept { Align Alignment) override { return Impl.forceScalarizeMaskedScatter(DataType, Alignment); } - bool isLegalMaskedCompressStore(Type *DataType) override { - return Impl.isLegalMaskedCompressStore(DataType); + bool isLegalMaskedCompressStore(Type *DataType, Align Alignment) override { + return Impl.isLegalMaskedCompressStore(DataType, Alignment); } - bool isLegalMaskedExpandLoad(Type *DataType) override { - return Impl.isLegalMaskedExpandLoad(DataType); + bool isLegalMaskedExpandLoad(Type *DataType, Align Alignment) override { + return Impl.isLegalMaskedExpandLoad(DataType, Alignment); } bool isLegalStridedLoadStore(Type *DataType, Align Alignment) override { return Impl.isLegalStridedLoadStore(DataType, Alignment); diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h index 13379cc126a40..95fb13d1c9715 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -295,14 +295,18 @@ class TargetTransformInfoImplBase { return false; } - bool isLegalMaskedCompressStore(Type *DataType) const { return false; } + bool isLegalMaskedCompressStore(Type *DataType, Align Alignment) const { + return false; + } bool isLegalAltInstr(VectorType *VecTy, unsigned Opcode0, unsigned Opcode1, const SmallBitVector &OpcodeMask) const { return false; } - bool isLegalMaskedExpandLoad(Type *DataType) const { return false; } + bool isLegalMaskedExpandLoad(Type *DataType, Align Alignment) const { + return false; + } bool isLegalStridedLoadStore(Type *DataType, Align Alignment) const { return false; diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp index 1f11f0d7dd620..15311be4dba27 100644 --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -492,12 +492,14 @@ bool TargetTransformInfo::forceScalarizeMaskedScatter(VectorType *DataType, return TTIImpl->forceScalarizeMaskedScatter(DataType, Alignment); } -bool TargetTransformInfo::isLegalMaskedCompressStore(Type *DataType) const { - return TTIImpl->isLegalMaskedCompressStore(DataType); +bool TargetTransformInfo::isLegalMaskedCompressStore(Type *DataType, + Align Alignment) const { + return TTIImpl->isLegalMaskedCompressStore(DataType, Alignment); } -bool TargetTransformInfo::isLegalMaskedExpandLoad(Type *DataType) const { - return TTIImpl->isLegalMaskedExpandLoad(DataType); +bool TargetTransformInfo::isLegalMaskedExpandLoad(Type *DataType, + Align Alignment) const { + return TTIImpl->isLegalMaskedExpandLoad(DataType, Alignment); } bool TargetTransformInfo::isLegalStridedLoadStore(Type *DataType, diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp index 4cca291a24562..d336ab9d309c4 100644 --- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp +++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp @@ -5938,7 +5938,7 @@ bool X86TTIImpl::isLegalBroadcastLoad(Type *ElementTy, ElementTy == Type::getDoubleTy(ElementTy->getContext()); } -bool X86TTIImpl::isLegalMaskedExpandLoad(Type *DataTy) { +bool X86TTIImpl::isLegalMaskedExpandLoad(Type *DataTy, Align Alignment) { if (!isa(DataTy)) return false; @@ -5962,8 +5962,8 @@ bool X86TTIImpl::isLegalMaskedExpandLoad(Type *DataTy) { ((IntWidth == 8 || IntWidth == 16) && ST->hasVBMI2()); } -bool X86TTIImpl::isLegalMaskedCompressStore(Type *DataTy) { - return isLegalMaskedExpandLoad(DataTy); +bool X86TTIImpl::isLegalMaskedCompressStore(Type *DataTy, Align Alignment) { + return isLegalMaskedExpandLoad(DataTy, Alignment); } bool X86TTIImpl::supportsGather() const { diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h index 07a3fff4f84b3..1a5e6bc886aa6 100644 --- a/llvm/lib/Target/X86/X86TargetTransformInfo.h +++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h @@ -269,8 +269,8 @@ class X86TTIImpl : public BasicTTIImplBase { bool isLegalMaskedGatherScatter(Type *DataType, Align Alignment); bool isLegalMaskedGather(Type *DataType, Align Alignment); bool isLegalMaskedScatter(Type *DataType, Align Alignment); - bool isLegalMaskedExpandLoad(Type *DataType); - bool isLegalMaskedCompressStore(Type *DataType); + bool isLegalMaskedExpandLoad(Type *DataType, Align Alignment); + bool isLegalMaskedCompressStore(Type *DataType, Align Alignment); bool isLegalAltInstr(VectorType *VecTy, unsigned Opcode0, unsigned Opcode1, const SmallBitVector &OpcodeMask) const; bool hasDivRemOp(Type *DataType, bool IsSigned); diff --git a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp index f362dc5708b79..a4111fad5d9f2 100644 --- a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp +++ b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp @@ -979,12 +979,16 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, return true; } case Intrinsic::masked_expandload: - if (TTI.isLegalMaskedExpandLoad(CI->getType())) + if (TTI.isLegalMaskedExpandLoad( + CI->getType(), + CI->getAttributes().getParamAttrs(0).getAlignment().valueOrOne())) return false; scalarizeMaskedExpandLoad(DL, CI, DTU, ModifiedDT); return true; case Intrinsic::masked_compressstore: - if (TTI.isLegalMaskedCompressStore(CI->getArgOperand(0)->getType())) + if (TTI.isLegalMaskedCompressStore( + CI->getArgOperand(0)->getType(), + CI->getAttributes().getParamAttrs(1).getAlignment().valueOrOne())) return false; scalarizeMaskedCompressStore(DL, CI, DTU, ModifiedDT); return true;