Skip to content

Commit d77a348

Browse files
authored
[SYCL][NVPTX] Emit 'grid_constant' annotations for by-val kernel params (#14332)
Also fix up the `DeadArgumentElimination` passes to correctly preserve the annotations; when removing arguments from functions, dead parameters need pruned and alive ones may need their values shifted down by the number of dead arguments that came before them.
1 parent b3e7606 commit d77a348

File tree

5 files changed

+283
-5
lines changed

5 files changed

+283
-5
lines changed

clang/lib/CodeGen/Targets/NVPTX.cpp

+64
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "ABIInfoImpl.h"
1010
#include "TargetInfo.h"
11+
#include "clang/Basic/Cuda.h"
1112
#include "llvm/IR/IntrinsicsNVPTX.h"
1213

1314
using namespace clang;
@@ -80,6 +81,9 @@ class NVPTXTargetCodeGenInfo : public TargetCodeGenInfo {
8081
static void addNVVMMetadata(llvm::GlobalValue *GV, StringRef Name,
8182
int Operand);
8283

84+
static void addNVVMMetadata(llvm::GlobalValue *GV, StringRef Name,
85+
const std::vector<int> &Operands);
86+
8387
private:
8488
static void emitBuiltinSurfTexDeviceCopy(CodeGenFunction &CGF, LValue Dst,
8589
LValue Src) {
@@ -218,6 +222,28 @@ RValue NVPTXABIInfo::EmitVAArg(CodeGenFunction &CGF, Address VAListAddr,
218222
llvm_unreachable("NVPTX does not support varargs");
219223
}
220224

225+
// Get current CudaArch and ignore any unknown values
226+
// Copied from CGOpenMPRuntimeGPU
227+
static CudaArch getCudaArch(CodeGenModule &CGM) {
228+
if (!CGM.getTarget().hasFeature("ptx"))
229+
return CudaArch::UNKNOWN;
230+
for (const auto &Feature : CGM.getTarget().getTargetOpts().FeatureMap) {
231+
if (Feature.getValue()) {
232+
CudaArch Arch = StringToCudaArch(Feature.getKey());
233+
if (Arch != CudaArch::UNKNOWN)
234+
return Arch;
235+
}
236+
}
237+
return CudaArch::UNKNOWN;
238+
}
239+
240+
static bool supportsGridConstant(CudaArch Arch) {
241+
assert((Arch == CudaArch::UNKNOWN || IsNVIDIAGpuArch(Arch)) &&
242+
"Unexpected architecture");
243+
static_assert(CudaArch::UNKNOWN < CudaArch::SM_70);
244+
return Arch >= CudaArch::SM_70;
245+
}
246+
221247
void NVPTXTargetCodeGenInfo::setTargetAttributes(
222248
const Decl *D, llvm::GlobalValue *GV, CodeGen::CodeGenModule &M) const {
223249
if (GV->isDeclaration())
@@ -248,6 +274,22 @@ void NVPTXTargetCodeGenInfo::setTargetAttributes(
248274
addNVVMMetadata(F, "kernel", 1);
249275
// And kernel functions are not subject to inlining
250276
F->addFnAttr(llvm::Attribute::NoInline);
277+
278+
if (M.getLangOpts().SYCLIsDevice &&
279+
supportsGridConstant(getCudaArch(M))) {
280+
// Add grid_constant annotations to all relevant kernel-function
281+
// parameters. We can guarantee that in SYCL, all by-val kernel
282+
// parameters are "grid_constant".
283+
std::vector<int> GridConstantParamIdxs;
284+
for (auto [Idx, Arg] : llvm::enumerate(F->args())) {
285+
if (Arg.getType()->isPointerTy() && Arg.hasByValAttr()) {
286+
// Note - the parameter indices are numbered from 1.
287+
GridConstantParamIdxs.push_back(Idx + 1);
288+
}
289+
}
290+
if (!GridConstantParamIdxs.empty())
291+
addNVVMMetadata(F, "grid_constant", GridConstantParamIdxs);
292+
}
251293
}
252294
bool HasMaxWorkGroupSize = false;
253295
bool HasMinWorkGroupPerCU = false;
@@ -329,6 +371,28 @@ void NVPTXTargetCodeGenInfo::addNVVMMetadata(llvm::GlobalValue *GV,
329371
MD->addOperand(llvm::MDNode::get(Ctx, MDVals));
330372
}
331373

374+
void NVPTXTargetCodeGenInfo::addNVVMMetadata(llvm::GlobalValue *GV,
375+
StringRef Name,
376+
const std::vector<int> &Operands) {
377+
llvm::Module *M = GV->getParent();
378+
llvm::LLVMContext &Ctx = M->getContext();
379+
380+
// Get "nvvm.annotations" metadata node
381+
llvm::NamedMDNode *MD = M->getOrInsertNamedMetadata("nvvm.annotations");
382+
383+
llvm::SmallVector<llvm::Metadata *, 8> MDOps;
384+
for (int Op : Operands) {
385+
MDOps.push_back(llvm::ConstantAsMetadata::get(
386+
llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), Op)));
387+
}
388+
auto *OpList = llvm::MDNode::get(Ctx, MDOps);
389+
390+
llvm::Metadata *MDVals[] = {llvm::ConstantAsMetadata::get(GV),
391+
llvm::MDString::get(Ctx, Name), OpList};
392+
// Append metadata to nvvm.annotations
393+
MD->addOperand(llvm::MDNode::get(Ctx, MDVals));
394+
}
395+
332396
bool NVPTXTargetCodeGenInfo::shouldEmitStaticExternCAliases() const {
333397
return false;
334398
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// RUN: %clang_cc1 -fno-sycl-force-inline-kernel-lambda -fsycl-is-device -internal-isystem %S/Inputs -triple nvptx-nvidia-cuda -target-cpu sm_70 -disable-llvm-passes -sycl-std=2020 -emit-llvm -o - %s | FileCheck %s --check-prefixes=CHECK,GRIDCONST
2+
// RUN: %clang_cc1 -fno-sycl-force-inline-kernel-lambda -fsycl-is-device -internal-isystem %S/Inputs -triple nvptx64-nvidia-cuda -target-cpu sm_70 -disable-llvm-passes -sycl-std=2020 -emit-llvm -o - %s | FileCheck %s --check-prefixes=CHECK,GRIDCONST
3+
4+
// RUN: %clang_cc1 -fno-sycl-force-inline-kernel-lambda -fsycl-is-device -internal-isystem %S/Inputs -triple nvptx-nvidia-cuda -target-cpu sm_60 -disable-llvm-passes -sycl-std=2020 -emit-llvm -o - %s | FileCheck %s --check-prefixes=CHECK,NOGRIDCONST
5+
// RUN: %clang_cc1 -fno-sycl-force-inline-kernel-lambda -fsycl-is-device -internal-isystem %S/Inputs -triple nvptx64-nvidia-cuda -target-cpu sm_60 -disable-llvm-passes -sycl-std=2020 -emit-llvm -o - %s | FileCheck %s --check-prefixes=CHECK,NOGRIDCONST
6+
7+
// Tests that certain SYCL kernel parameters are annotated with "grid_constant" for supported microarchitectures.
8+
9+
#include "sycl.hpp"
10+
11+
using namespace sycl;
12+
13+
int main() {
14+
queue q;
15+
16+
struct S {
17+
int a;
18+
} s;
19+
20+
q.submit([&](handler &h) {
21+
// CHECK: define{{.*}} void @[[FUNC1:.*kernel_grid_const_params]](ptr noundef byval(%struct.S) align 4 %_arg_s)
22+
h.single_task<class kernel_grid_const_params>([=]() { (void) s;});
23+
});
24+
25+
return 0;
26+
}
27+
28+
// Don't emit grid_constant annotations for older architectures.
29+
// NOGRIDCONST-NOT: "grid_constant"
30+
31+
// This isn't stable in general, as it depends on the order of the captured
32+
// parameters, but in this case there's only one parameter so we know it's 1.
33+
// GRIDCONST-DAG: = !{ptr @[[FUNC1]], !"grid_constant", [[MD:\![0-9]+]]}
34+
// GRIDCONST-DAG: [[MD]] = !{i32 1}

llvm/include/llvm/Transforms/IPO/DeadArgumentElimination.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,8 @@ class DeadArgumentEliminationPass
145145
bool removeDeadArgumentsFromCallers(Function &F);
146146
void propagateVirtMustcallLiveness(const Module &M);
147147

148-
void UpdateNVPTXMetadata(Module &M, Function *F, Function *NF);
148+
void UpdateNVPTXMetadata(Module &M, Function *F, Function *NF,
149+
const SmallVectorImpl<bool> &ArgAlive);
149150
llvm::DenseSet<Function *> NVPTXKernelSet;
150151

151152
bool IsNVPTXKernel(const Function *F) { return NVPTXKernelSet.contains(F); };

llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp

+70-4
Original file line numberDiff line numberDiff line change
@@ -1171,7 +1171,7 @@ bool DeadArgumentEliminationPass::removeDeadStuffFromFunction(Function *F) {
11711171
NF->addMetadata(KindID, *Node);
11721172

11731173
if (IsNVPTXKernel(F))
1174-
UpdateNVPTXMetadata(*(F->getParent()), F, NF);
1174+
UpdateNVPTXMetadata(*(F->getParent()), F, NF, ArgAlive);
11751175

11761176
// If either the return value(s) or argument(s) are removed, then probably the
11771177
// function does not follow standard calling conventions anymore. Hence, add
@@ -1249,8 +1249,9 @@ PreservedAnalyses DeadArgumentEliminationPass::run(Module &M,
12491249
return PreservedAnalyses::none();
12501250
}
12511251

1252-
void DeadArgumentEliminationPass::UpdateNVPTXMetadata(Module &M, Function *F,
1253-
Function *NF) {
1252+
void DeadArgumentEliminationPass::UpdateNVPTXMetadata(
1253+
Module &M, Function *F, Function *NF,
1254+
const SmallVectorImpl<bool> &ArgAlive) {
12541255

12551256
auto *NvvmMetadata = M.getNamedMetadata("nvvm.annotations");
12561257
if (!NvvmMetadata)
@@ -1260,13 +1261,78 @@ void DeadArgumentEliminationPass::UpdateNVPTXMetadata(Module &M, Function *F,
12601261
const auto &FuncOperand = MetadataNode->getOperand(0);
12611262
if (!FuncOperand)
12621263
continue;
1263-
auto FuncConstant = dyn_cast<ConstantAsMetadata>(FuncOperand);
1264+
auto *FuncConstant = dyn_cast<ConstantAsMetadata>(FuncOperand);
12641265
if (!FuncConstant)
12651266
continue;
12661267
auto *Func = dyn_cast<Function>(FuncConstant->getValue());
12671268
if (Func != F)
12681269
continue;
12691270
// Update the metadata with the new function
12701271
MetadataNode->replaceOperandWith(0, llvm::ConstantAsMetadata::get(NF));
1272+
1273+
// Carefully update any and all grid_constant annotations, since those are
1274+
// denoted parameter indices, which may have changed
1275+
for (unsigned I = 1; I < MetadataNode->getNumOperands() - 1; I += 2) {
1276+
if (auto *Type = dyn_cast<MDString>(MetadataNode->getOperand(I));
1277+
Type && Type->getString() == "grid_constant") {
1278+
LLVMContext &Ctx = NF->getContext();
1279+
LLVM_DEBUG(dbgs() << "DeadArgumentEliminationPass - updating nvvm "
1280+
"grid_constant annotations for fn: "
1281+
<< NF->getName() << "\n");
1282+
// The 'value' operand is a list of integers denoting parameter indices
1283+
const auto *OldGridConstParamIdxs =
1284+
dyn_cast<MDNode>(MetadataNode->getOperand(I + 1));
1285+
assert(OldGridConstParamIdxs &&
1286+
"Unexpected NVVM annotation format: expected MDNode operand");
1287+
// For each parameter that's identified as a grid_constant, count how
1288+
// many arguments before that position are dead, and shift the number
1289+
// down by that amount.
1290+
// Note that there's no guaranteed order to the parameter indices, so
1291+
// there's fewer 'smart' things like counting up incrementally as we go.
1292+
SmallVector<Metadata *, 8> NewGridConstParamOps;
1293+
for (const auto &Op : OldGridConstParamIdxs->operands()) {
1294+
auto *ParamIdx = mdconst::dyn_extract<ConstantInt>(Op);
1295+
// If the operand's not a constant, or its constant value is not
1296+
// within the range of the old function's parameter list (note - it
1297+
// counts from 1), it's not well-defined. Just strip it out for
1298+
// safety.
1299+
if (!ParamIdx || ParamIdx->isZero() ||
1300+
ParamIdx->getZExtValue() > F->arg_size())
1301+
continue;
1302+
1303+
size_t OldParamIdx = ParamIdx->getZExtValue() - 1;
1304+
// If the parameter is no longer alive, it's definitely not a
1305+
// grid_constant. Strip it out.
1306+
if (!ArgAlive[OldParamIdx])
1307+
continue;
1308+
1309+
unsigned ShiftDownAmt = 0;
1310+
for (unsigned i = 0; i < std::min(F->arg_size(), OldParamIdx); i++) {
1311+
if (!ArgAlive[i])
1312+
ShiftDownAmt++;
1313+
}
1314+
NewGridConstParamOps.push_back(
1315+
ConstantAsMetadata::get(ConstantInt::get(
1316+
Type::getInt32Ty(Ctx), OldParamIdx - ShiftDownAmt + 1)));
1317+
}
1318+
1319+
// Update the metadata with the new grid_constant information
1320+
MDNode *NewGridConstParamIdxs = MDNode::get(Ctx, NewGridConstParamOps);
1321+
1322+
LLVM_DEBUG(dbgs() << " * updating old annotation {";
1323+
auto PrintList =
1324+
[](const MDNode *MD) {
1325+
for (const auto &O : MD->operands())
1326+
if (const auto *ParamNo =
1327+
mdconst::dyn_extract<ConstantInt>(O))
1328+
dbgs() << ParamNo->getZExtValue() << ",";
1329+
};
1330+
PrintList(OldGridConstParamIdxs);
1331+
dbgs() << "} to new annotation {";
1332+
PrintList(NewGridConstParamIdxs); dbgs() << "}\n";);
1333+
1334+
MetadataNode->replaceOperandWith(I + 1, NewGridConstParamIdxs);
1335+
}
1336+
}
12711337
}
12721338
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals all --version 5
2+
; RUN: opt < %s -passes=deadargelim -S | FileCheck %s
3+
4+
define internal void @test1(i32 %v, ptr byval(i32) %DEADARG1, ptr %p) {
5+
; CHECK-LABEL: define internal void @test1(
6+
; CHECK-SAME: i32 [[V:%.*]], ptr [[P:%.*]]) {
7+
; CHECK-NEXT: store i32 [[V]], ptr [[P]], align 4
8+
; CHECK-NEXT: ret void
9+
;
10+
store i32 %v, ptr %p
11+
ret void
12+
}
13+
14+
define internal void @test2(ptr byval(i32) %DEADARG1, ptr byval(i32) %p) {
15+
; CHECK-LABEL: define internal void @test2(
16+
; CHECK-SAME: ptr byval(i32) [[P:%.*]]) {
17+
; CHECK-NEXT: store i32 0, ptr [[P]], align 4
18+
; CHECK-NEXT: ret void
19+
;
20+
store i32 0, ptr %p
21+
ret void
22+
}
23+
24+
define internal void @test3(ptr byval(i32) %DEADARG1, i32 %v, ptr byval(i32) %p) {
25+
; CHECK-LABEL: define internal void @test3(
26+
; CHECK-SAME: i32 [[V:%.*]], ptr byval(i32) [[P:%.*]]) {
27+
; CHECK-NEXT: store i32 [[V]], ptr [[P]], align 4
28+
; CHECK-NEXT: ret void
29+
;
30+
store i32 %v, ptr %p
31+
ret void
32+
}
33+
34+
define internal void @test4(ptr byval(i32) %p, i32 %v, ptr byval(i32) %DEADARG) {
35+
; CHECK-LABEL: define internal void @test4(
36+
; CHECK-SAME: ptr byval(i32) [[P:%.*]], i32 [[V:%.*]]) {
37+
; CHECK-NEXT: store i32 [[V]], ptr [[P]], align 4
38+
; CHECK-NEXT: ret void
39+
;
40+
store i32 %v, ptr %p
41+
ret void
42+
}
43+
44+
define internal void @test5(ptr byval(i32) %p, i32 %x, ptr byval(i32) %DEADARG1, ptr byval(i32) %DEADARG2, i32 %y, ptr byval(i32) %q) {
45+
; CHECK-LABEL: define internal void @test5(
46+
; CHECK-SAME: ptr byval(i32) [[P:%.*]], i32 [[X:%.*]], i32 [[Y:%.*]], ptr byval(i32) [[Q:%.*]]) {
47+
; CHECK-NEXT: [[T:%.*]] = add i32 [[X]], [[Y]]
48+
; CHECK-NEXT: store i32 [[T]], ptr [[P]], align 4
49+
; CHECK-NEXT: store i32 [[T]], ptr [[Q]], align 4
50+
; CHECK-NEXT: ret void
51+
;
52+
%t = add i32 %x, %y
53+
store i32 %t, ptr %p
54+
store i32 %t, ptr %q
55+
ret void
56+
}
57+
58+
!nvvm.annotations = !{
59+
!0, !1,
60+
!3, !4, !6,
61+
!8, !9, !11,
62+
!13, !14, !16,
63+
!18, !19
64+
}
65+
66+
; Note - also test various permutations of the parameter lists, as they are not
67+
; specified to be in any particular order (e.g., consecutive).
68+
!0 = !{ptr @test1, !"kernel", i32 1}
69+
!1 = !{ptr @test1, !"grid_constant", !2}
70+
!2 = !{i32 2}
71+
72+
!3 = !{ptr @test2, !"kernel", i32 1}
73+
!4 = !{ptr @test2, !"grid_constant", !5}
74+
!5 = !{i32 1, i32 2}
75+
!6 = !{ptr @test2, !"grid_constant", !7}
76+
!7 = !{i32 2, i32 1}
77+
78+
!8 = !{ptr @test3, !"kernel", i32 1}
79+
!9 = !{ptr @test3, !"grid_constant", !10}
80+
!10 = !{i32 1, i32 3}
81+
!11 = !{ptr @test3, !"grid_constant", !12}
82+
!12 = !{i32 3, i32 1}
83+
84+
!13 = !{ptr @test4, !"kernel", i32 1}
85+
!14 = !{ptr @test4, !"grid_constant", !15}
86+
!15 = !{i32 1, i32 3}
87+
!16 = !{ptr @test4, !"grid_constant", !17}
88+
!17 = !{i32 3, i32 1}
89+
90+
!18 = !{ptr @test5, !"kernel", i32 1}
91+
!19 = !{ptr @test5, !"grid_constant", !20, !"grid_constant", !21, !"grid_constant", !22}
92+
!20 = !{i32 1, i32 3, i32 4, i32 6}
93+
!21 = !{i32 3, i32 1, i32 4, i32 6}
94+
!22 = !{i32 3, i32 1, i32 6, i32 4}
95+
;.
96+
; CHECK: [[META0:![0-9]+]] = !{ptr @test1, !"kernel", i32 1}
97+
; CHECK: [[META1:![0-9]+]] = !{ptr @test1, !"grid_constant", [[META2:![0-9]+]]}
98+
; CHECK: [[META2]] = !{}
99+
; CHECK: [[META3:![0-9]+]] = !{ptr @test2, !"kernel", i32 1}
100+
; CHECK: [[META4:![0-9]+]] = !{ptr @test2, !"grid_constant", [[META5:![0-9]+]]}
101+
; CHECK: [[META5]] = !{i32 1}
102+
; CHECK: [[META6:![0-9]+]] = distinct !{ptr @test2, !"grid_constant", [[META5]]}
103+
; CHECK: [[META7:![0-9]+]] = !{ptr @test3, !"kernel", i32 1}
104+
; CHECK: [[META8:![0-9]+]] = !{ptr @test3, !"grid_constant", [[META9:![0-9]+]]}
105+
; CHECK: [[META9]] = !{i32 2}
106+
; CHECK: [[META10:![0-9]+]] = distinct !{ptr @test3, !"grid_constant", [[META9]]}
107+
; CHECK: [[META11:![0-9]+]] = !{ptr @test4, !"kernel", i32 1}
108+
; CHECK: [[META12:![0-9]+]] = !{ptr @test4, !"grid_constant", [[META5]]}
109+
; CHECK: [[META13:![0-9]+]] = distinct !{ptr @test4, !"grid_constant", [[META5]]}
110+
; CHECK: [[META14:![0-9]+]] = !{ptr @test5, !"kernel", i32 1}
111+
; CHECK: [[META15:![0-9]+]] = !{ptr @test5, !"grid_constant", [[META16:![0-9]+]], !"grid_constant", [[META16]], !"grid_constant", [[META16]]}
112+
; CHECK: [[META16]] = !{i32 1, i32 4}
113+
;.

0 commit comments

Comments
 (0)