diff --git a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp index d47b9c7a25b8f..aa7769899ff27 100644 --- a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp +++ b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp @@ -15,6 +15,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Analysis/DXILMetadataAnalysis.h" #include "llvm/BinaryFormat/DXContainer.h" #include "llvm/CodeGen/Passes.h" #include "llvm/IR/Constants.h" @@ -57,6 +58,7 @@ class DXContainerGlobals : public llvm::ModulePass { void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesAll(); AU.addRequired(); + AU.addRequired(); } }; @@ -143,23 +145,35 @@ void DXContainerGlobals::addPipelineStateValidationInfo( SmallString<256> Data; raw_svector_ostream OS(Data); PSVRuntimeInfo PSV; - Triple TT(M.getTargetTriple()); PSV.BaseData.MinimumWaveLaneCount = 0; PSV.BaseData.MaximumWaveLaneCount = std::numeric_limits::max(); + + dxil::ModuleMetadataInfo &MMI = + getAnalysis().getModuleMetadata(); + assert(MMI.EntryPropertyVec.size() == 1 || + MMI.ShaderStage == Triple::Library); PSV.BaseData.ShaderStage = - static_cast(TT.getEnvironment() - Triple::Pixel); + static_cast(MMI.ShaderStage - Triple::Pixel); // Hardcoded values here to unblock loading the shader into D3D. // // TODO: Lots more stuff to do here! // // See issue https://github.com/llvm/llvm-project/issues/96674. - PSV.BaseData.NumThreadsX = 1; - PSV.BaseData.NumThreadsY = 1; - PSV.BaseData.NumThreadsZ = 1; - PSV.EntryName = "main"; + switch (MMI.ShaderStage) { + case Triple::Compute: + PSV.BaseData.NumThreadsX = MMI.EntryPropertyVec[0].NumThreadsX; + PSV.BaseData.NumThreadsY = MMI.EntryPropertyVec[0].NumThreadsY; + PSV.BaseData.NumThreadsZ = MMI.EntryPropertyVec[0].NumThreadsZ; + break; + default: + break; + } + + if (MMI.ShaderStage != Triple::Library) + PSV.EntryName = MMI.EntryPropertyVec[0].Entry->getName(); - PSV.finalize(TT.getEnvironment()); + PSV.finalize(MMI.ShaderStage); PSV.write(OS); Constant *Constant = ConstantDataArray::getString(M.getContext(), Data, /*AddNull*/ false); @@ -170,6 +184,7 @@ char DXContainerGlobals::ID = 0; INITIALIZE_PASS_BEGIN(DXContainerGlobals, "dxil-globals", "DXContainer Global Emitter", false, true) INITIALIZE_PASS_DEPENDENCY(ShaderFlagsAnalysisWrapper) +INITIALIZE_PASS_DEPENDENCY(DXILMetadataAnalysisWrapperPass) INITIALIZE_PASS_END(DXContainerGlobals, "dxil-globals", "DXContainer Global Emitter", false, true) diff --git a/llvm/lib/Target/DirectX/DXILPrepare.cpp b/llvm/lib/Target/DirectX/DXILPrepare.cpp index 56098864e987f..f6b7355b93625 100644 --- a/llvm/lib/Target/DirectX/DXILPrepare.cpp +++ b/llvm/lib/Target/DirectX/DXILPrepare.cpp @@ -19,6 +19,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringSet.h" +#include "llvm/Analysis/DXILMetadataAnalysis.h" #include "llvm/CodeGen/Passes.h" #include "llvm/IR/AttributeMask.h" #include "llvm/IR/IRBuilder.h" @@ -247,6 +248,7 @@ class DXILPrepareModule : public ModulePass { void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addPreserved(); AU.addPreserved(); + AU.addPreserved(); } static char ID; // Pass identification. }; diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp index 2c6d20112060d..11cd9df1d1dc4 100644 --- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp +++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp @@ -13,6 +13,7 @@ #include "DXILShaderFlags.h" #include "DirectX.h" #include "llvm/ADT/StringSet.h" +#include "llvm/Analysis/DXILMetadataAnalysis.h" #include "llvm/Analysis/DXILResource.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Metadata.h" @@ -103,6 +104,7 @@ class DXILTranslateMetadataLegacy : public ModulePass { AU.addRequired(); AU.addRequired(); AU.addRequired(); + AU.addRequired(); } bool runOnModule(Module &M) override { diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RuntimeInfoCS.ll b/llvm/test/CodeGen/DirectX/ContainerData/RuntimeInfoCS.ll new file mode 100644 index 0000000000000..595e70092bb08 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/ContainerData/RuntimeInfoCS.ll @@ -0,0 +1,41 @@ +; RUN: opt %s -dxil-embed -dxil-globals -S -o - | FileCheck %s +; RUN: llc %s --filetype=obj -o - | obj2yaml | FileCheck %s --check-prefix=DXC +target triple = "dxil-unknown-shadermodel6.0-compute" + +; CHECK: @dx.psv0 = private constant [80 x i8] c"{{.*}}", section "PSV0", align 4 + +define void @cs_main() #0 { +entry: + ret void +} + +attributes #0 = { "hlsl.numthreads"="8,8,1" "hlsl.shader"="compute" } + +!dx.valver = !{!0} + +!0 = !{i32 1, i32 7} + +; DXC: - Name: PSV0 +; DXC-NEXT: Size: 80 +; DXC-NEXT: PSVInfo: +; DXC-NEXT: Version: 3 +; DXC-NEXT: ShaderStage: 5 +; DXC-NEXT: MinimumWaveLaneCount: 0 +; DXC-NEXT: MaximumWaveLaneCount: 4294967295 +; DXC-NEXT: UsesViewID: 0 +; DXC-NEXT: SigInputVectors: 0 +; DXC-NEXT: SigOutputVectors: [ 0, 0, 0, 0 ] +; DXC-NEXT: NumThreadsX: 8 +; DXC-NEXT: NumThreadsY: 8 +; DXC-NEXT: NumThreadsZ: 1 +; DXC-NEXT: EntryName: cs_main +; DXC-NEXT: ResourceStride: 24 +; DXC-NEXT: Resources: [] +; DXC-NEXT: SigInputElements: [] +; DXC-NEXT: SigOutputElements: [] +; DXC-NEXT: SigPatchOrPrimElements: [] +; DXC-NEXT: InputOutputMap: +; DXC-NEXT: - [ ] +; DXC-NEXT: - [ ] +; DXC-NEXT: - [ ] +; DXC-NEXT: - [ ]