Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 84 additions & 57 deletions tools/clang/lib/Sema/SemaDXR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

#include "dxc/DXIL/DxilConstants.h"
#include "dxc/DXIL/DxilShaderModel.h"
#include "dxc/HlslIntrinsicOp.h"

using namespace clang;
using namespace sema;
Expand All @@ -49,9 +50,9 @@ struct PayloadUse {
const MemberExpr *Member = nullptr;
};

struct TraceRayCall {
TraceRayCall() = default;
TraceRayCall(const CallExpr *Call, const CFGBlock *Parent)
struct PayloadBuiltinCall {
PayloadBuiltinCall() = default;
PayloadBuiltinCall(const CallExpr *Call, const CFGBlock *Parent)
: Call(Call), Parent(Parent) {}
const CallExpr *Call = nullptr;
const CFGBlock *Parent = nullptr;
Expand All @@ -71,7 +72,7 @@ struct DxrShaderDiagnoseInfo {
const FunctionDecl *funcDecl;
const VarDecl *Payload;
DXIL::PayloadAccessShaderStage Stage;
std::vector<TraceRayCall> TraceCalls;
std::vector<PayloadBuiltinCall> PayloadBuiltinCalls;
std::map<const FieldDecl *, std::vector<PayloadUse>> WritesPerField;
std::map<const FieldDecl *, std::vector<PayloadUse>> ReadsPerField;
std::vector<PayloadUse> PayloadAsCallArg;
Expand Down Expand Up @@ -121,24 +122,42 @@ GetPayloadQualifierForStage(FieldDecl *Field,
return DXIL::PayloadAccessQualifier::NoAccess;
}

// Returns the declaration of the payload used in a TraceRay call
const VarDecl *GetPayloadParameterForTraceCall(const CallExpr *Trace) {
const Decl *callee = Trace->getCalleeDecl();
if (!callee)
static int GetPayloadParamIdxForIntrinsic(const FunctionDecl *FD) {
HLSLIntrinsicAttr *IntrinAttr = FD->getAttr<HLSLIntrinsicAttr>();
if (!IntrinAttr)
return -1;
switch ((IntrinsicOp)IntrinAttr->getOpcode()) {
default:
return -1;
case IntrinsicOp::IOP_TraceRay:
case IntrinsicOp::MOP_DxHitObject_TraceRay:
case IntrinsicOp::MOP_DxHitObject_Invoke:
return FD->getNumParams() - 1;
}
}

static bool IsBuiltinWithPayload(const FunctionDecl *FD) {
return GetPayloadParamIdxForIntrinsic(FD) >= 0;
}

// Returns the declaration of the payload used in a call to TraceRay,
// HitObject::TraceRay or HitObject::Invoke.
const VarDecl *GetPayloadParameterForBuiltinCall(const CallExpr *Call) {
const Decl *Callee = Call->getCalleeDecl();
if (!Callee)
return nullptr;

if (!isa<FunctionDecl>(callee))
if (!isa<FunctionDecl>(Callee))
return nullptr;

const FunctionDecl *FD = cast<FunctionDecl>(callee);
int PldParamIdx = GetPayloadParamIdxForIntrinsic(cast<FunctionDecl>(Callee));
if (PldParamIdx < 0)
return nullptr;

if (FD->isImplicit() && FD->getName() == "TraceRay") {
const Stmt *Param = IgnoreParensAndDecay(Trace->getArg(7));
if (const DeclRefExpr *ParamRef = dyn_cast<DeclRefExpr>(Param)) {
if (const VarDecl *Decl = dyn_cast<VarDecl>(ParamRef->getDecl()))
return Decl;
}
}
const Stmt *Param = IgnoreParensAndDecay(Call->getArg(PldParamIdx));
if (const DeclRefExpr *ParamRef = dyn_cast<DeclRefExpr>(Param))
if (const VarDecl *Decl = dyn_cast<VarDecl>(ParamRef->getDecl()))
return Decl;
return nullptr;
}

Expand Down Expand Up @@ -190,12 +209,9 @@ void CollectReadsWritesAndCallsForPayload(const Stmt *S,
}
}

// Collects all TraceRay calls.
void CollectTraceRayCalls(const Stmt *S, DxrShaderDiagnoseInfo &Info,
const CFGBlock *Block) {
// TraceRay has void as return type so it should never be something else
// than a plain CallExpr.

// Collects all calls to TraceRay, HitObject::TraceRay and HitObject::Invoke.
void CollectBuiltinCallsWithPayload(const Stmt *S, DxrShaderDiagnoseInfo &Info,
const CFGBlock *Block) {
if (const CallExpr *Call = dyn_cast<CallExpr>(S)) {

const Decl *Callee = Call->getCalleeDecl();
Expand All @@ -204,11 +220,8 @@ void CollectTraceRayCalls(const Stmt *S, DxrShaderDiagnoseInfo &Info,

const FunctionDecl *CalledFunction = cast<FunctionDecl>(Callee);

// Ignore trace calls here.
if (CalledFunction->isImplicit() &&
CalledFunction->getName() == "TraceRay") {
Info.TraceCalls.push_back({Call, Block});
}
if (IsBuiltinWithPayload(CalledFunction))
Info.PayloadBuiltinCalls.push_back({Call, Block});
}
}

Expand Down Expand Up @@ -528,13 +541,14 @@ void TraverseCFG(const CFGBlock &Block, Action PerElementAction,
}
}

// Forward traverse the CFG and collect calls to TraceRay.
void ForwardTraverseCFGAndCollectTraceCalls(
// Forward traverse the CFG and collect calls to TraceRay, HitObject::TraceRay
// and HitObject::Invoke.
void ForwardTraverseCFGAndCollectBuiltinCallsWithPayload(
const CFGBlock &Block, DxrShaderDiagnoseInfo &Info,
std::set<const CFGBlock *> &Visited) {
auto Action = [&Info](const CFGBlock &Block, const CFGElement &Element) {
if (Optional<CFGStmt> S = Element.getAs<CFGStmt>()) {
CollectTraceRayCalls(S->getStmt(), Info, &Block);
CollectBuiltinCallsWithPayload(S->getStmt(), Info, &Block);
}
};

Expand Down Expand Up @@ -664,9 +678,9 @@ DiagnosePayloadAsFunctionArg(
const FunctionDecl *CalledFunction = cast<FunctionDecl>(Callee);

// Ignore trace calls here.
if (CalledFunction->isImplicit() &&
CalledFunction->getName() == "TraceRay") {
Info.TraceCalls.push_back(TraceRayCall{Call, Use.Parent});
if (IsBuiltinWithPayload(CalledFunction)) {
Info.PayloadBuiltinCalls.push_back(
PayloadBuiltinCall{Call, Use.Parent});
continue;
}

Expand Down Expand Up @@ -789,10 +803,12 @@ void HandlePayloadInitializer(DxrShaderDiagnoseInfo &Info) {
}
}

// Emit diagnostics for a TraceRay call.
void DiagnoseTraceCall(Sema &S, const VarDecl *Payload,
const TraceRayCall &Trace, DominatorTree &DT) {
// For each TraceRay call check if write(caller) fields are written.
// Emit diagnostics for this call to either TraceRay, HitObject::TraceRay or
// HitObject::Invoke.
void DiagnoseBuiltinCallWithPayload(Sema &S, const VarDecl *Payload,
const PayloadBuiltinCall &PldCall,
DominatorTree &DT) {
// For each call check if write(caller) fields are written.
const DXIL::PayloadAccessShaderStage CallerStage =
DXIL::PayloadAccessShaderStage::Caller;

Expand All @@ -810,6 +826,13 @@ void DiagnoseTraceCall(Sema &S, const VarDecl *Payload,
return;
}

// Verify that the payload type is legal
if (!hlsl::IsHLSLCopyableAnnotatableRecord(Payload->getType())) {
S.Diag(Payload->getLocation(), diag::err_payload_attrs_must_be_udt)
<< /*payload|attributes|callable*/ 0 << Payload;
return;
}

if (ContainsLongVector(Payload->getType())) {
const unsigned PayloadParametersIdx = 10;
S.Diag(Payload->getLocation(), diag::err_hlsl_unsupported_long_vector)
Expand All @@ -832,12 +855,12 @@ void DiagnoseTraceCall(Sema &S, const VarDecl *Payload,

std::set<const CFGBlock *> Visited;

const CFGBlock *Parent = Trace.Parent;
const CFGBlock *Parent = PldCall.Parent;
Visited.insert(Parent);
// Collect payload accesses in the same block until we reach the TraceRay call
// Collect payload accesses in the same block until we reach the call
for (auto Element : *Parent) {
if (Optional<CFGStmt> S = Element.getAs<CFGStmt>()) {
if (S->getStmt() == Trace.Call)
if (S->getStmt() == PldCall.Call)
break;
CollectReadsWritesAndCallsForPayload(S->getStmt(), TraceInfo, Parent);
}
Expand All @@ -850,10 +873,12 @@ void DiagnoseTraceCall(Sema &S, const VarDecl *Payload,
BackwardTraverseCFGAndCollectReadsWrites(*Pred, TraceInfo, Visited);
}

int PldArgIdx = PldCall.Call->getNumArgs() - 1;

// Warn if a writeable field has not been written.
for (const FieldDecl *Field : WriteableFields) {
if (!TraceInfo.WritesPerField.count(Field)) {
S.Diag(Trace.Call->getArg(7)->getExprLoc(),
S.Diag(PldCall.Call->getArg(PldArgIdx)->getExprLoc(),
diag::warn_hlsl_payload_access_no_write_for_trace_payload)
<< Field->getName();
}
Expand All @@ -862,7 +887,7 @@ void DiagnoseTraceCall(Sema &S, const VarDecl *Payload,
for (const FieldDecl *Field : NonWriteableFields) {
if (TraceInfo.WritesPerField.count(Field)) {
S.Diag(
Trace.Call->getArg(7)->getExprLoc(),
PldCall.Call->getArg(PldArgIdx)->getExprLoc(),
diag::warn_hlsl_payload_access_write_but_no_write_for_trace_payload)
<< Field->getName();
}
Expand All @@ -878,7 +903,7 @@ void DiagnoseTraceCall(Sema &S, const VarDecl *Payload,
bool CallFound = false;
for (auto Element : *Parent) { // TODO: reverse iterate?
if (Optional<CFGStmt> S = Element.getAs<CFGStmt>()) {
if (S->getStmt() == Trace.Call) {
if (S->getStmt() == PldCall.Call) {
CallFound = true;
continue;
}
Expand All @@ -895,7 +920,7 @@ void DiagnoseTraceCall(Sema &S, const VarDecl *Payload,

for (const FieldDecl *Field : ReadableFields) {
if (!TraceInfo.ReadsPerField.count(Field)) {
S.Diag(Trace.Call->getArg(7)->getExprLoc(),
S.Diag(PldCall.Call->getArg(PldArgIdx)->getExprLoc(),
diag::warn_hlsl_payload_access_read_but_no_read_after_trace)
<< Field->getName();
}
Expand Down Expand Up @@ -928,27 +953,29 @@ void DiagnoseTraceCall(Sema &S, const VarDecl *Payload,
}
}

// Emit diagnostics for all TraceRay calls.
void DiagnoseTraceCalls(Sema &S, CFG &ShaderCFG, DominatorTree &DT,
DxrShaderDiagnoseInfo &Info) {
// Collect TraceRay calls in the shader.
// Emit diagnostics for all calls to TraceRay, HitObject::TraceRay or
// HitObject::Invoke.
void DiagnoseBuiltinCallsWithPayload(Sema &S, CFG &ShaderCFG, DominatorTree &DT,
DxrShaderDiagnoseInfo &Info) {
// Collect calls with payload in the shader.
std::set<const CFGBlock *> Visited;
ForwardTraverseCFGAndCollectTraceCalls(ShaderCFG.getEntry(), Info, Visited);
ForwardTraverseCFGAndCollectBuiltinCallsWithPayload(ShaderCFG.getEntry(),
Info, Visited);

std::set<const CallExpr *> Diagnosed;

for (const TraceRayCall &TraceCall : Info.TraceCalls) {
if (Diagnosed.count(TraceCall.Call))
for (const PayloadBuiltinCall &PldCall : Info.PayloadBuiltinCalls) {
if (Diagnosed.count(PldCall.Call))
continue;
Diagnosed.insert(TraceCall.Call);
Diagnosed.insert(PldCall.Call);

const VarDecl *Payload = GetPayloadParameterForTraceCall(TraceCall.Call);
DiagnoseTraceCall(S, Payload, TraceCall, DT);
const VarDecl *Payload = GetPayloadParameterForBuiltinCall(PldCall.Call);
DiagnoseBuiltinCallWithPayload(S, Payload, PldCall, DT);
}
}

// Emit diagnostics for all access to the payload of a shader,
// and the input to TraceRay calls.
// and the input to TraceRay, HitObject::TraceRay or HitObject::Invoke calls.
std::vector<const FieldDecl *>
DiagnosePayloadAccess(Sema &S, DxrShaderDiagnoseInfo &Info,
const std::set<const FieldDecl *> &FieldsToIgnoreRead,
Expand Down Expand Up @@ -1012,7 +1039,7 @@ DiagnosePayloadAccess(Sema &S, DxrShaderDiagnoseInfo &Info,
DiagnosePayloadReads(S, TheCFG, DT, Info, NonReadableFields);
}

DiagnoseTraceCalls(S, TheCFG, DT, Info);
DiagnoseBuiltinCallsWithPayload(S, TheCFG, DT, Info);

return WrittenFields;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// RUN: %dxc -T lib_6_9 %s -D TEST_NUM=0 %s -verify
// RUN: %dxc -T lib_6_9 %s -D TEST_NUM=1 %s -verify

RaytracingAccelerationStructure scene : register(t0);

struct Payload
{
int a : read (caller, closesthit, miss) : write(caller, closesthit, miss);
};

struct Attribs
{
float2 barys;
};

[shader("raygeneration")]
void RayGen()
{
// expected-error@+1{{type 'Payload' used as payload requires that it is annotated with the [raypayload] attribute}}
Payload payload_in_rg;
RayDesc ray;
#if TEST_NUM == 0
dx::HitObject::TraceRay( scene, RAY_FLAG_NONE, 0xff, 0, 1, 0, ray, payload_in_rg );
#else
dx::HitObject::Invoke( dx::HitObject(), payload_in_rg );
#endif
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// RUN: %dxc -T lib_6_9 %s -verify

struct
[raypayload]
Payload
{
int a : read(caller, closesthit, miss) : write(caller, closesthit, miss);
dx::HitObject hit;
};

struct Attribs
{
float2 barys;
};

[shader("raygeneration")]
void RayGen()
{
// expected-error@+1{{payload parameter 'payload_in_rg' must be a user-defined type composed of only numeric types}}
Payload payload_in_rg;
dx::HitObject::Invoke( dx::HitObject(), payload_in_rg );
}
Loading