From c463571d3f417bad336db534e49a3214c75555b3 Mon Sep 17 00:00:00 2001 From: Kolya Panchenko Date: Thu, 29 Feb 2024 13:22:01 -0800 Subject: [PATCH 1/2] [TTI] Add alignment argument to TTI for compress/expand support Since `llvm.compressstore` and `llvm.expandload` do require memory access, it's essential for some target to check if alignment is good to be able to lower them to target-specific instructions --- .../llvm/Analysis/TargetTransformInfo.h | 16 ++++++------ .../llvm/Analysis/TargetTransformInfoImpl.h | 8 ++++-- llvm/lib/Analysis/TargetTransformInfo.cpp | 10 +++++--- .../Target/RISCV/RISCVTargetTransformInfo.cpp | 25 +++++++++++++++++++ .../Target/RISCV/RISCVTargetTransformInfo.h | 2 ++ .../lib/Target/X86/X86TargetTransformInfo.cpp | 6 ++--- llvm/lib/Target/X86/X86TargetTransformInfo.h | 4 +-- .../Scalar/ScalarizeMaskedMemIntrin.cpp | 8 ++++-- 8 files changed, 58 insertions(+), 21 deletions(-) diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h index 58577a6b6eb5c..4eab357f1b33b 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -777,9 +777,9 @@ class TargetTransformInfo { bool forceScalarizeMaskedScatter(VectorType *Type, Align Alignment) const; /// Return true if the target supports masked compress store. - bool isLegalMaskedCompressStore(Type *DataType) const; + bool isLegalMaskedCompressStore(Type *DataType, Align Alignment) const; /// Return true if the target supports masked expand load. - bool isLegalMaskedExpandLoad(Type *DataType) const; + bool isLegalMaskedExpandLoad(Type *DataType, Align Alignment) const; /// Return true if the target supports strided load. bool isLegalStridedLoadStore(Type *DataType, Align Alignment) const; @@ -1863,8 +1863,8 @@ class TargetTransformInfo::Concept { Align Alignment) = 0; virtual bool forceScalarizeMaskedScatter(VectorType *DataType, Align Alignment) = 0; - virtual bool isLegalMaskedCompressStore(Type *DataType) = 0; - virtual bool isLegalMaskedExpandLoad(Type *DataType) = 0; + virtual bool isLegalMaskedCompressStore(Type *DataType, Align Alignment) = 0; + virtual bool isLegalMaskedExpandLoad(Type *DataType, Align Alignment) = 0; virtual bool isLegalStridedLoadStore(Type *DataType, Align Alignment) = 0; virtual bool isLegalAltInstr(VectorType *VecTy, unsigned Opcode0, unsigned Opcode1, @@ -2358,11 +2358,11 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept { Align Alignment) override { return Impl.forceScalarizeMaskedScatter(DataType, Alignment); } - bool isLegalMaskedCompressStore(Type *DataType) override { - return Impl.isLegalMaskedCompressStore(DataType); + bool isLegalMaskedCompressStore(Type *DataType, Align Alignment) override { + return Impl.isLegalMaskedCompressStore(DataType, Alignment); } - bool isLegalMaskedExpandLoad(Type *DataType) override { - return Impl.isLegalMaskedExpandLoad(DataType); + bool isLegalMaskedExpandLoad(Type *DataType, Align Alignment) override { + return Impl.isLegalMaskedExpandLoad(DataType, Alignment); } bool isLegalStridedLoadStore(Type *DataType, Align Alignment) override { return Impl.isLegalStridedLoadStore(DataType, Alignment); diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h index 13379cc126a40..95fb13d1c9715 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -295,14 +295,18 @@ class TargetTransformInfoImplBase { return false; } - bool isLegalMaskedCompressStore(Type *DataType) const { return false; } + bool isLegalMaskedCompressStore(Type *DataType, Align Alignment) const { + return false; + } bool isLegalAltInstr(VectorType *VecTy, unsigned Opcode0, unsigned Opcode1, const SmallBitVector &OpcodeMask) const { return false; } - bool isLegalMaskedExpandLoad(Type *DataType) const { return false; } + bool isLegalMaskedExpandLoad(Type *DataType, Align Alignment) const { + return false; + } bool isLegalStridedLoadStore(Type *DataType, Align Alignment) const { return false; diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp index 1f11f0d7dd620..15311be4dba27 100644 --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -492,12 +492,14 @@ bool TargetTransformInfo::forceScalarizeMaskedScatter(VectorType *DataType, return TTIImpl->forceScalarizeMaskedScatter(DataType, Alignment); } -bool TargetTransformInfo::isLegalMaskedCompressStore(Type *DataType) const { - return TTIImpl->isLegalMaskedCompressStore(DataType); +bool TargetTransformInfo::isLegalMaskedCompressStore(Type *DataType, + Align Alignment) const { + return TTIImpl->isLegalMaskedCompressStore(DataType, Alignment); } -bool TargetTransformInfo::isLegalMaskedExpandLoad(Type *DataType) const { - return TTIImpl->isLegalMaskedExpandLoad(DataType); +bool TargetTransformInfo::isLegalMaskedExpandLoad(Type *DataType, + Align Alignment) const { + return TTIImpl->isLegalMaskedExpandLoad(DataType, Alignment); } bool TargetTransformInfo::isLegalStridedLoadStore(Type *DataType, diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp index 2e4e69fb4f920..0bd623e1196e1 100644 --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp @@ -1609,3 +1609,28 @@ bool RISCVTTIImpl::isLSRCostLess(const TargetTransformInfo::LSRCost &C1, C2.NumIVMuls, C2.NumBaseAdds, C2.ScaleCost, C2.ImmCost, C2.SetupCost); } + +bool RISCVTTIImpl::isLegalMaskedCompressStore(Type *DataTy, Align Alignment) { + auto *VTy = dyn_cast(DataTy); + if (!VTy || VTy->isScalableTy() || !ST->hasVInstructions()) + return false; + + Type *ScalarTy = VTy->getScalarType(); + if (ScalarTy->isFloatTy() || ScalarTy->isDoubleTy()) + return true; + + if (!ScalarTy->isIntegerTy()) + return false; + + switch (ScalarTy->getIntegerBitWidth()) { + case 8: + case 16: + case 32: + case 64: + break; + default: + return false; + } + + return getRegUsageForType(VTy) <= 8; +} diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h index af36e9d5d5e88..8daf6845dc8bc 100644 --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h @@ -261,6 +261,8 @@ class RISCVTTIImpl : public BasicTTIImplBase { return TLI->isLegalStridedLoadStore(DataTypeVT, Alignment); } + bool isLegalMaskedCompressStore(Type *DataTy, Align Alignment); + bool isVScaleKnownToBeAPowerOfTwo() const { return TLI->isVScaleKnownToBeAPowerOfTwo(); } diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp index 18bf32fe1acaa..9c1e4b2f83ab7 100644 --- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp +++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp @@ -5938,7 +5938,7 @@ bool X86TTIImpl::isLegalBroadcastLoad(Type *ElementTy, ElementTy == Type::getDoubleTy(ElementTy->getContext()); } -bool X86TTIImpl::isLegalMaskedExpandLoad(Type *DataTy) { +bool X86TTIImpl::isLegalMaskedExpandLoad(Type *DataTy, Align Alignment) { if (!isa(DataTy)) return false; @@ -5962,8 +5962,8 @@ bool X86TTIImpl::isLegalMaskedExpandLoad(Type *DataTy) { ((IntWidth == 8 || IntWidth == 16) && ST->hasVBMI2()); } -bool X86TTIImpl::isLegalMaskedCompressStore(Type *DataTy) { - return isLegalMaskedExpandLoad(DataTy); +bool X86TTIImpl::isLegalMaskedCompressStore(Type *DataTy, Align Alignment) { + return isLegalMaskedExpandLoad(DataTy, Alignment); } bool X86TTIImpl::supportsGather() const { diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h index 07a3fff4f84b3..1a5e6bc886aa6 100644 --- a/llvm/lib/Target/X86/X86TargetTransformInfo.h +++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h @@ -269,8 +269,8 @@ class X86TTIImpl : public BasicTTIImplBase { bool isLegalMaskedGatherScatter(Type *DataType, Align Alignment); bool isLegalMaskedGather(Type *DataType, Align Alignment); bool isLegalMaskedScatter(Type *DataType, Align Alignment); - bool isLegalMaskedExpandLoad(Type *DataType); - bool isLegalMaskedCompressStore(Type *DataType); + bool isLegalMaskedExpandLoad(Type *DataType, Align Alignment); + bool isLegalMaskedCompressStore(Type *DataType, Align Alignment); bool isLegalAltInstr(VectorType *VecTy, unsigned Opcode0, unsigned Opcode1, const SmallBitVector &OpcodeMask) const; bool hasDivRemOp(Type *DataType, bool IsSigned); diff --git a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp index c01d03f644724..d545c0ae49f5a 100644 --- a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp +++ b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp @@ -969,12 +969,16 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, return true; } case Intrinsic::masked_expandload: - if (TTI.isLegalMaskedExpandLoad(CI->getType())) + if (TTI.isLegalMaskedExpandLoad( + CI->getType(), + CI->getAttributes().getParamAttrs(0).getAlignment().valueOrOne())) return false; scalarizeMaskedExpandLoad(DL, CI, DTU, ModifiedDT); return true; case Intrinsic::masked_compressstore: - if (TTI.isLegalMaskedCompressStore(CI->getArgOperand(0)->getType())) + if (TTI.isLegalMaskedCompressStore( + CI->getArgOperand(0)->getType(), + CI->getAttributes().getParamAttrs(1).getAlignment().valueOrOne())) return false; scalarizeMaskedCompressStore(DL, CI, DTU, ModifiedDT); return true; From 98ca9a958046b50cac4affa40bdd5a7fdf30828d Mon Sep 17 00:00:00 2001 From: Kolya Panchenko Date: Fri, 1 Mar 2024 05:48:14 -0800 Subject: [PATCH 2/2] removed isLegalMaskedCompressStore from RISCVTTIImpl --- .../Target/RISCV/RISCVTargetTransformInfo.cpp | 25 ------------------- .../Target/RISCV/RISCVTargetTransformInfo.h | 2 -- 2 files changed, 27 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp index 0bd623e1196e1..2e4e69fb4f920 100644 --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp @@ -1609,28 +1609,3 @@ bool RISCVTTIImpl::isLSRCostLess(const TargetTransformInfo::LSRCost &C1, C2.NumIVMuls, C2.NumBaseAdds, C2.ScaleCost, C2.ImmCost, C2.SetupCost); } - -bool RISCVTTIImpl::isLegalMaskedCompressStore(Type *DataTy, Align Alignment) { - auto *VTy = dyn_cast(DataTy); - if (!VTy || VTy->isScalableTy() || !ST->hasVInstructions()) - return false; - - Type *ScalarTy = VTy->getScalarType(); - if (ScalarTy->isFloatTy() || ScalarTy->isDoubleTy()) - return true; - - if (!ScalarTy->isIntegerTy()) - return false; - - switch (ScalarTy->getIntegerBitWidth()) { - case 8: - case 16: - case 32: - case 64: - break; - default: - return false; - } - - return getRegUsageForType(VTy) <= 8; -} diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h index 8daf6845dc8bc..af36e9d5d5e88 100644 --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h @@ -261,8 +261,6 @@ class RISCVTTIImpl : public BasicTTIImplBase { return TLI->isLegalStridedLoadStore(DataTypeVT, Alignment); } - bool isLegalMaskedCompressStore(Type *DataTy, Align Alignment); - bool isVScaleKnownToBeAPowerOfTwo() const { return TLI->isVScaleKnownToBeAPowerOfTwo(); }