Skip to content

Commit d2233dd

Browse files
committed
[SYCL][NVPTX] Emit 'grid_constant' annotations for by-val kernel params
Also fix up the DeadArgumentElimination passes to correctly preserve the annotations; when removing arguments from functions, dead parameters need pruned and alive ones may need their values shifted down by the number of dead arguments that came before them.
1 parent 922c2d5 commit d2233dd

File tree

5 files changed

+353
-5
lines changed

5 files changed

+353
-5
lines changed

clang/lib/CodeGen/Targets/NVPTX.cpp

+133
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "ABIInfoImpl.h"
1010
#include "TargetInfo.h"
11+
#include "clang/Basic/Cuda.h"
1112
#include "llvm/IR/IntrinsicsNVPTX.h"
1213

1314
using namespace clang;
@@ -80,6 +81,9 @@ class NVPTXTargetCodeGenInfo : public TargetCodeGenInfo {
8081
static void addNVVMMetadata(llvm::GlobalValue *GV, StringRef Name,
8182
int Operand);
8283

84+
static void addNVVMMetadata(llvm::GlobalValue *GV, StringRef Name,
85+
const std::vector<int> &Operands);
86+
8387
private:
8488
static void emitBuiltinSurfTexDeviceCopy(CodeGenFunction &CGF, LValue Dst,
8589
LValue Src) {
@@ -218,6 +222,98 @@ Address NVPTXABIInfo::EmitVAArg(CodeGenFunction &CGF, Address VAListAddr,
218222
llvm_unreachable("NVPTX does not support varargs");
219223
}
220224

225+
// Get current CudaArch and ignore any unknown values
226+
// Copied from CGOpenMPRuntimeGPU
227+
static CudaArch getCudaArch(CodeGenModule &CGM) {
228+
if (!CGM.getTarget().hasFeature("ptx"))
229+
return CudaArch::UNKNOWN;
230+
for (const auto &Feature : CGM.getTarget().getTargetOpts().FeatureMap) {
231+
if (Feature.getValue()) {
232+
CudaArch Arch = StringToCudaArch(Feature.getKey());
233+
if (Arch != CudaArch::UNKNOWN)
234+
return Arch;
235+
}
236+
}
237+
return CudaArch::UNKNOWN;
238+
}
239+
240+
static bool supportsGridConstant(CudaArch Arch) {
241+
switch (Arch) {
242+
case CudaArch::SM_70:
243+
case CudaArch::SM_72:
244+
case CudaArch::SM_75:
245+
case CudaArch::SM_80:
246+
case CudaArch::SM_86:
247+
case CudaArch::SM_87:
248+
case CudaArch::SM_89:
249+
case CudaArch::SM_90:
250+
case CudaArch::SM_90a:
251+
return true;
252+
case CudaArch::UNKNOWN:
253+
case CudaArch::UNUSED:
254+
case CudaArch::SM_20:
255+
case CudaArch::SM_21:
256+
case CudaArch::SM_30:
257+
case CudaArch::SM_32_:
258+
case CudaArch::SM_35:
259+
case CudaArch::SM_37:
260+
case CudaArch::SM_50:
261+
case CudaArch::SM_52:
262+
case CudaArch::SM_53:
263+
case CudaArch::SM_60:
264+
case CudaArch::SM_61:
265+
case CudaArch::SM_62:
266+
return false;
267+
case CudaArch::GFX600:
268+
case CudaArch::GFX601:
269+
case CudaArch::GFX602:
270+
case CudaArch::GFX700:
271+
case CudaArch::GFX701:
272+
case CudaArch::GFX702:
273+
case CudaArch::GFX703:
274+
case CudaArch::GFX704:
275+
case CudaArch::GFX705:
276+
case CudaArch::GFX801:
277+
case CudaArch::GFX802:
278+
case CudaArch::GFX803:
279+
case CudaArch::GFX805:
280+
case CudaArch::GFX810:
281+
case CudaArch::GFX900:
282+
case CudaArch::GFX902:
283+
case CudaArch::GFX904:
284+
case CudaArch::GFX906:
285+
case CudaArch::GFX908:
286+
case CudaArch::GFX909:
287+
case CudaArch::GFX90a:
288+
case CudaArch::GFX90c:
289+
case CudaArch::GFX940:
290+
case CudaArch::GFX941:
291+
case CudaArch::GFX942:
292+
case CudaArch::GFX1010:
293+
case CudaArch::GFX1011:
294+
case CudaArch::GFX1012:
295+
case CudaArch::GFX1013:
296+
case CudaArch::GFX1030:
297+
case CudaArch::GFX1031:
298+
case CudaArch::GFX1032:
299+
case CudaArch::GFX1033:
300+
case CudaArch::GFX1034:
301+
case CudaArch::GFX1035:
302+
case CudaArch::GFX1036:
303+
case CudaArch::GFX1100:
304+
case CudaArch::GFX1101:
305+
case CudaArch::GFX1102:
306+
case CudaArch::GFX1103:
307+
case CudaArch::GFX1150:
308+
case CudaArch::GFX1151:
309+
case CudaArch::GFX1200:
310+
case CudaArch::GFX1201:
311+
case CudaArch::Generic:
312+
case CudaArch::LAST:
313+
llvm_unreachable("unhandled CudaArch");
314+
}
315+
}
316+
221317
void NVPTXTargetCodeGenInfo::setTargetAttributes(
222318
const Decl *D, llvm::GlobalValue *GV, CodeGen::CodeGenModule &M) const {
223319
if (GV->isDeclaration())
@@ -248,6 +344,21 @@ void NVPTXTargetCodeGenInfo::setTargetAttributes(
248344
addNVVMMetadata(F, "kernel", 1);
249345
// And kernel functions are not subject to inlining
250346
F->addFnAttr(llvm::Attribute::NoInline);
347+
348+
if (supportsGridConstant(getCudaArch(M))) {
349+
// Add grid_constant annotations to all relevant kernel-function
350+
// parameters. We can guarantee that in SYCL, all by-val kernel
351+
// parameters are "grid_constant".
352+
std::vector<int> GridConstantParamIdxs;
353+
for (auto [Idx, Arg] : llvm::enumerate(F->args())) {
354+
if (Arg.getType()->isPointerTy() && Arg.hasByValAttr()) {
355+
// Note - the parameter indices are numbered from 1.
356+
GridConstantParamIdxs.push_back(Idx + 1);
357+
}
358+
}
359+
if (!GridConstantParamIdxs.empty())
360+
addNVVMMetadata(F, "grid_constant", GridConstantParamIdxs);
361+
}
251362
}
252363
bool HasMaxWorkGroupSize = false;
253364
bool HasMinWorkGroupPerCU = false;
@@ -329,6 +440,28 @@ void NVPTXTargetCodeGenInfo::addNVVMMetadata(llvm::GlobalValue *GV,
329440
MD->addOperand(llvm::MDNode::get(Ctx, MDVals));
330441
}
331442

443+
void NVPTXTargetCodeGenInfo::addNVVMMetadata(llvm::GlobalValue *GV,
444+
StringRef Name,
445+
const std::vector<int> &Operands) {
446+
llvm::Module *M = GV->getParent();
447+
llvm::LLVMContext &Ctx = M->getContext();
448+
449+
// Get "nvvm.annotations" metadata node
450+
llvm::NamedMDNode *MD = M->getOrInsertNamedMetadata("nvvm.annotations");
451+
452+
llvm::SmallVector<llvm::Metadata *, 8> MDOps;
453+
for (int Op : Operands) {
454+
MDOps.push_back(llvm::ConstantAsMetadata::get(
455+
llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), Op)));
456+
}
457+
auto *OpList = llvm::MDNode::get(Ctx, MDOps);
458+
459+
llvm::Metadata *MDVals[] = {llvm::ConstantAsMetadata::get(GV),
460+
llvm::MDString::get(Ctx, Name), OpList};
461+
// Append metadata to nvvm.annotations
462+
MD->addOperand(llvm::MDNode::get(Ctx, MDVals));
463+
}
464+
332465
bool NVPTXTargetCodeGenInfo::shouldEmitStaticExternCAliases() const {
333466
return false;
334467
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// RUN: %clang_cc1 -fno-sycl-force-inline-kernel-lambda -fsycl-is-device -internal-isystem %S/Inputs -triple nvptx-nvidia-cuda -target-cpu sm_70 -disable-llvm-passes -sycl-std=2020 -emit-llvm -o - %s | FileCheck %s --check-prefixes=CHECK,GRIDCONST
2+
// RUN: %clang_cc1 -fno-sycl-force-inline-kernel-lambda -fsycl-is-device -internal-isystem %S/Inputs -triple nvptx64-nvidia-cuda -target-cpu sm_70 -disable-llvm-passes -sycl-std=2020 -emit-llvm -o - %s | FileCheck %s --check-prefixes=CHECK,GRIDCONST
3+
4+
// RUN: %clang_cc1 -fno-sycl-force-inline-kernel-lambda -fsycl-is-device -internal-isystem %S/Inputs -triple nvptx-nvidia-cuda -target-cpu sm_60 -disable-llvm-passes -sycl-std=2020 -emit-llvm -o - %s | FileCheck %s --check-prefixes=CHECK,NOGRIDCONST
5+
// RUN: %clang_cc1 -fno-sycl-force-inline-kernel-lambda -fsycl-is-device -internal-isystem %S/Inputs -triple nvptx64-nvidia-cuda -target-cpu sm_60 -disable-llvm-passes -sycl-std=2020 -emit-llvm -o - %s | FileCheck %s --check-prefixes=CHECK,NOGRIDCONST
6+
7+
// Tests that work_group_size_hint and reqd_work_group_size generate the same
8+
// metadata nodes for the same arguments.
9+
10+
#include "sycl.hpp"
11+
12+
using namespace sycl;
13+
14+
int main() {
15+
queue q;
16+
17+
struct S {
18+
int a;
19+
} s;
20+
q.submit([&](handler &h) {
21+
// CHECK: define{{.*}} void @[[FUNC1:.*kernel_grid_const_params]](ptr noundef byval(%struct.S) align 4 %_arg_s)
22+
h.single_task<class kernel_grid_const_params>([=]() { (void) s;});
23+
});
24+
25+
return 0;
26+
}
27+
// Don't emit grid_constant annotations for older architectures.
28+
29+
// NOGRIDCONST-NOT: "grid_constant"
30+
31+
// This isn't stable in general, as it depends on the order of the captured
32+
// parameters, but in this case there's only one parameter so we know it's 1.
33+
// GRIDCONST-DAG: = !{ptr @[[FUNC1]], !"grid_constant", [[MD:\![0-9]+]]}
34+
// GRIDCONST-DAG: [[MD]] = !{i32 1}
35+

llvm/include/llvm/Transforms/IPO/DeadArgumentElimination.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,8 @@ class DeadArgumentEliminationPass
145145
bool removeDeadArgumentsFromCallers(Function &F);
146146
void propagateVirtMustcallLiveness(const Module &M);
147147

148-
void UpdateNVPTXMetadata(Module &M, Function *F, Function *NF);
148+
void UpdateNVPTXMetadata(Module &M, Function *F, Function *NF,
149+
const SmallVectorImpl<bool> &ArgAlive);
149150
llvm::DenseSet<Function *> NVPTXKernelSet;
150151

151152
bool IsNVPTXKernel(const Function *F) { return NVPTXKernelSet.contains(F); };

llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp

+66-4
Original file line numberDiff line numberDiff line change
@@ -1171,7 +1171,7 @@ bool DeadArgumentEliminationPass::removeDeadStuffFromFunction(Function *F) {
11711171
NF->addMetadata(KindID, *Node);
11721172

11731173
if (IsNVPTXKernel(F))
1174-
UpdateNVPTXMetadata(*(F->getParent()), F, NF);
1174+
UpdateNVPTXMetadata(*(F->getParent()), F, NF, ArgAlive);
11751175

11761176
// If either the return value(s) or argument(s) are removed, then probably the
11771177
// function does not follow standard calling conventions anymore. Hence, add
@@ -1249,8 +1249,9 @@ PreservedAnalyses DeadArgumentEliminationPass::run(Module &M,
12491249
return PreservedAnalyses::none();
12501250
}
12511251

1252-
void DeadArgumentEliminationPass::UpdateNVPTXMetadata(Module &M, Function *F,
1253-
Function *NF) {
1252+
void DeadArgumentEliminationPass::UpdateNVPTXMetadata(
1253+
Module &M, Function *F, Function *NF,
1254+
const SmallVectorImpl<bool> &ArgAlive) {
12541255

12551256
auto *NvvmMetadata = M.getNamedMetadata("nvvm.annotations");
12561257
if (!NvvmMetadata)
@@ -1260,13 +1261,74 @@ void DeadArgumentEliminationPass::UpdateNVPTXMetadata(Module &M, Function *F,
12601261
const auto &FuncOperand = MetadataNode->getOperand(0);
12611262
if (!FuncOperand)
12621263
continue;
1263-
auto FuncConstant = dyn_cast<ConstantAsMetadata>(FuncOperand);
1264+
auto *FuncConstant = dyn_cast<ConstantAsMetadata>(FuncOperand);
12641265
if (!FuncConstant)
12651266
continue;
12661267
auto *Func = dyn_cast<Function>(FuncConstant->getValue());
12671268
if (Func != F)
12681269
continue;
12691270
// Update the metadata with the new function
12701271
MetadataNode->replaceOperandWith(0, llvm::ConstantAsMetadata::get(NF));
1272+
1273+
// Carefully update grid_constant annotations, since those are denoted
1274+
// parameter indices, which may have changed
1275+
if (auto *Type = dyn_cast<MDString>(MetadataNode->getOperand(1));
1276+
Type && Type->getString() == "grid_constant") {
1277+
LLVMContext &Ctx = NF->getContext();
1278+
LLVM_DEBUG(dbgs() << "DeadArgumentEliminationPass - updating nvvm "
1279+
"grid_constant annotations for fn: "
1280+
<< NF->getName() << "\n");
1281+
// The 'value' operand is a list of integers denoting parameter indices
1282+
auto *OldGridConstParamIdxs = dyn_cast<MDNode>(MetadataNode->getOperand(2));
1283+
if (!OldGridConstParamIdxs)
1284+
continue;
1285+
// For each parameter that's identified as a grid_constant, count how
1286+
// many arguments before that position are dead, and shift the number
1287+
// down by that amount.
1288+
// Note that there's no guaranteed order to the parameter indices, so
1289+
// there's fewer 'smart' things like counting up incrementally as we go.
1290+
SmallVector<Metadata *, 8> NewGridConstParamOps;
1291+
for (const auto &Op : OldGridConstParamIdxs->operands()) {
1292+
auto *ParamIdx = mdconst::dyn_extract<ConstantInt>(Op);
1293+
// If the operand's not a constant, or its constant value is not
1294+
// within the range of the old function's parameter list (note - it
1295+
// counts from 1), it's not well-defined. Just strip it out for
1296+
// safety.
1297+
if (!ParamIdx || ParamIdx->isZero() ||
1298+
ParamIdx->getZExtValue() > F->arg_size())
1299+
continue;
1300+
1301+
size_t OldParamIdx = ParamIdx->getZExtValue() - 1;
1302+
// If the parameter is no longer alive, it's definitely not a
1303+
// grid_constant. Strip it out.
1304+
if (!ArgAlive[OldParamIdx])
1305+
continue;
1306+
1307+
unsigned ShiftDownAmt = 0;
1308+
for (unsigned i = 0; i < std::min(F->arg_size(), OldParamIdx); i++) {
1309+
if (!ArgAlive[i])
1310+
ShiftDownAmt++;
1311+
}
1312+
NewGridConstParamOps.push_back(ConstantAsMetadata::get(ConstantInt::get(
1313+
Type::getInt32Ty(Ctx), OldParamIdx - ShiftDownAmt + 1)));
1314+
}
1315+
1316+
// Update the metadata with the new grid_constant information
1317+
MDNode *NewGridConstParamIdxs = MDNode::get(Ctx, NewGridConstParamOps);
1318+
1319+
LLVM_DEBUG(dbgs() << " * updating old annotation {";
1320+
auto PrintList =
1321+
[](const MDNode *MD) {
1322+
for (const auto &O : MD->operands())
1323+
if (const auto *ParamNo =
1324+
mdconst::dyn_extract<ConstantInt>(O))
1325+
dbgs() << ParamNo->getZExtValue() << ",";
1326+
};
1327+
PrintList(OldGridConstParamIdxs);
1328+
dbgs() << "} to new annotation {";
1329+
PrintList(NewGridConstParamIdxs); dbgs() << "}\n";);
1330+
1331+
MetadataNode->replaceOperandWith(2, NewGridConstParamIdxs);
1332+
}
12711333
}
12721334
}

0 commit comments

Comments
 (0)