Skip to content

[DXIL][Analysis] Implement enough of DXILResourceAnalysis for buffers #100699

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
123 changes: 104 additions & 19 deletions llvm/include/llvm/Analysis/DXILResource.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,16 @@
#ifndef LLVM_ANALYSIS_DXILRESOURCE_H
#define LLVM_ANALYSIS_DXILRESOURCE_H

#include "llvm/ADT/MapVector.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/Value.h"
#include "llvm/Pass.h"
#include "llvm/Support/DXILABI.h"

namespace llvm {
class CallInst;
class MDTuple;
class TargetExtType;

namespace dxil {

Expand Down Expand Up @@ -47,7 +52,7 @@ class ResourceInfo {

struct StructInfo {
uint32_t Stride;
Align Alignment;
MaybeAlign Alignment;

bool operator==(const StructInfo &RHS) const {
return std::tie(Stride, Alignment) == std::tie(RHS.Stride, RHS.Alignment);
Expand Down Expand Up @@ -106,6 +111,11 @@ class ResourceInfo {

MSInfo MultiSample;

public:
ResourceInfo(dxil::ResourceClass RC, dxil::ResourceKind Kind, Value *Symbol,
StringRef Name)
: Symbol(Symbol), Name(Name), RC(RC), Kind(Kind) {}

// Conditions to check before accessing union members.
bool isUAV() const;
bool isCBuffer() const;
Expand All @@ -115,17 +125,53 @@ class ResourceInfo {
bool isFeedback() const;
bool isMultiSample() const;

ResourceInfo(dxil::ResourceClass RC, dxil::ResourceKind Kind, Value *Symbol,
StringRef Name)
: Symbol(Symbol), Name(Name), RC(RC), Kind(Kind) {}
void bind(uint32_t UniqueID, uint32_t Space, uint32_t LowerBound,
uint32_t Size) {
Binding.UniqueID = UniqueID;
Binding.Space = Space;
Binding.LowerBound = LowerBound;
Binding.Size = Size;
}
void setUAV(bool GloballyCoherent, bool HasCounter, bool IsROV) {
assert(isUAV() && "Not a UAV");
UAVFlags.GloballyCoherent = GloballyCoherent;
UAVFlags.HasCounter = HasCounter;
UAVFlags.IsROV = IsROV;
}
void setCBuffer(uint32_t Size) {
assert(isCBuffer() && "Not a CBuffer");
CBufferSize = Size;
}
void setSampler(dxil::SamplerType Ty) {
SamplerTy = Ty;
}
void setStruct(uint32_t Stride, MaybeAlign Alignment) {
assert(isStruct() && "Not a Struct");
Struct.Stride = Stride;
Struct.Alignment = Alignment;
}
void setTyped(dxil::ElementType ElementTy, uint32_t ElementCount) {
assert(isTyped() && "Not Typed");
Typed.ElementTy = ElementTy;
Typed.ElementCount = ElementCount;
}
void setFeedback(dxil::SamplerFeedbackType Type) {
assert(isFeedback() && "Not Feedback");
Feedback.Type = Type;
}
void setMultiSample(uint32_t Count) {
assert(isMultiSample() && "Not MultiSampled");
MultiSample.Count = Count;
}

bool operator==(const ResourceInfo &RHS) const;

public:
static ResourceInfo SRV(Value *Symbol, StringRef Name,
dxil::ElementType ElementTy, uint32_t ElementCount,
dxil::ResourceKind Kind);
static ResourceInfo RawBuffer(Value *Symbol, StringRef Name);
static ResourceInfo StructuredBuffer(Value *Symbol, StringRef Name,
uint32_t Stride, Align Alignment);
uint32_t Stride, MaybeAlign Alignment);
static ResourceInfo Texture2DMS(Value *Symbol, StringRef Name,
dxil::ElementType ElementTy,
uint32_t ElementCount, uint32_t SampleCount);
Expand All @@ -141,9 +187,9 @@ class ResourceInfo {
static ResourceInfo RWRawBuffer(Value *Symbol, StringRef Name,
bool GloballyCoherent, bool IsROV);
static ResourceInfo RWStructuredBuffer(Value *Symbol, StringRef Name,
uint32_t Stride,
Align Alignment, bool GloballyCoherent,
bool IsROV, bool HasCounter);
uint32_t Stride, MaybeAlign Alignment,
bool GloballyCoherent, bool IsROV,
bool HasCounter);
static ResourceInfo RWTexture2DMS(Value *Symbol, StringRef Name,
dxil::ElementType ElementTy,
uint32_t ElementCount, uint32_t SampleCount,
Expand All @@ -164,23 +210,62 @@ class ResourceInfo {
static ResourceInfo Sampler(Value *Symbol, StringRef Name,
dxil::SamplerType SamplerTy);

void bind(uint32_t UniqueID, uint32_t Space, uint32_t LowerBound,
uint32_t Size) {
Binding.UniqueID = UniqueID;
Binding.Space = Space;
Binding.LowerBound = LowerBound;
Binding.Size = Size;
}

bool operator==(const ResourceInfo &RHS) const;

MDTuple *getAsMetadata(LLVMContext &Ctx) const;

ResourceBinding getBinding() const { return Binding; }
std::pair<uint32_t, uint32_t> getAnnotateProps() const;

void print(raw_ostream &OS) const;
};

} // namespace dxil

using DXILResourceMap = MapVector<CallInst *, dxil::ResourceInfo>;

class DXILResourceAnalysis : public AnalysisInfoMixin<DXILResourceAnalysis> {
friend AnalysisInfoMixin<DXILResourceAnalysis>;

static AnalysisKey Key;

public:
using Result = DXILResourceMap;

/// Gather resource info for the module \c M.
DXILResourceMap run(Module &M, ModuleAnalysisManager &AM);
};

/// Printer pass for the \c DXILResourceAnalysis results.
class DXILResourcePrinterPass : public PassInfoMixin<DXILResourcePrinterPass> {
raw_ostream &OS;

public:
explicit DXILResourcePrinterPass(raw_ostream &OS) : OS(OS) {}

PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);

static bool isRequired() { return true; }
};

class DXILResourceWrapperPass : public ModulePass {
std::unique_ptr<DXILResourceMap> ResourceMap;

public:
static char ID; // Class identification, replacement for typeinfo

DXILResourceWrapperPass();
~DXILResourceWrapperPass() override;

const DXILResourceMap &getResourceMap() const { return *ResourceMap; }
DXILResourceMap &getResourceMap() { return *ResourceMap; }

void getAnalysisUsage(AnalysisUsage &AU) const override;
bool runOnModule(Module &M) override;
void releaseMemory() override;

void print(raw_ostream &OS, const Module *M) const override;
void dump() const;
};

} // namespace llvm

#endif // LLVM_ANALYSIS_DXILRESOURCE_H
1 change: 1 addition & 0 deletions llvm/include/llvm/InitializePasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ void initializeCycleInfoWrapperPassPass(PassRegistry &);
void initializeDAEPass(PassRegistry&);
void initializeDAHPass(PassRegistry&);
void initializeDCELegacyPassPass(PassRegistry&);
void initializeDXILResourceWrapperPassPass(PassRegistry &);
void initializeDeadMachineInstructionElimPass(PassRegistry&);
void initializeDebugifyMachineModulePass(PassRegistry &);
void initializeDependenceAnalysisWrapperPassPass(PassRegistry&);
Expand Down
Loading
Loading