Skip to content

Commit 5cf2a37

Browse files
committed
[HIP] Emit kernel symbol
Currently clang uses stub function to launch kernel. This is inconvenient to interop with C++ programs since the stub function has different name as kernel, which is required by ROCm debugger. This patch emits a variable symbol which has the same name as the kernel and uses it to register and launch the kernel. This allows C++ program to launch a kernel by using the original kernel name. Reviewed by: Artem Belevich Differential Revision: https://reviews.llvm.org/D86376
1 parent 154c47d commit 5cf2a37

File tree

9 files changed

+208
-23
lines changed

9 files changed

+208
-23
lines changed

clang/lib/CodeGen/CGCUDANV.cpp

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,18 @@ class CGNVCUDARuntime : public CGCUDARuntime {
4242
llvm::LLVMContext &Context;
4343
/// Convenience reference to the current module
4444
llvm::Module &TheModule;
45-
/// Keeps track of kernel launch stubs emitted in this module
45+
/// Keeps track of kernel launch stubs and handles emitted in this module
4646
struct KernelInfo {
47-
llvm::Function *Kernel;
47+
llvm::Function *Kernel; // stub function to help launch kernel
4848
const Decl *D;
4949
};
5050
llvm::SmallVector<KernelInfo, 16> EmittedKernels;
51+
// Map a device stub function to a symbol for identifying kernel in host code.
52+
// For CUDA, the symbol for identifying the kernel is the same as the device
53+
// stub function. For HIP, they are different.
54+
llvm::DenseMap<llvm::Function *, llvm::GlobalValue *> KernelHandles;
55+
// Map a kernel handle to the kernel stub.
56+
llvm::DenseMap<llvm::GlobalValue *, llvm::Function *> KernelStubs;
5157
struct VarInfo {
5258
llvm::GlobalVariable *Var;
5359
const VarDecl *D;
@@ -154,6 +160,12 @@ class CGNVCUDARuntime : public CGCUDARuntime {
154160
public:
155161
CGNVCUDARuntime(CodeGenModule &CGM);
156162

163+
llvm::GlobalValue *getKernelHandle(llvm::Function *F, GlobalDecl GD) override;
164+
llvm::Function *getKernelStub(llvm::GlobalValue *Handle) override {
165+
auto Loc = KernelStubs.find(Handle);
166+
assert(Loc != KernelStubs.end());
167+
return Loc->second;
168+
}
157169
void emitDeviceStub(CodeGenFunction &CGF, FunctionArgList &Args) override;
158170
void handleVarRegistration(const VarDecl *VD,
159171
llvm::GlobalVariable &Var) override;
@@ -272,6 +284,10 @@ std::string CGNVCUDARuntime::getDeviceSideName(const NamedDecl *ND) {
272284
void CGNVCUDARuntime::emitDeviceStub(CodeGenFunction &CGF,
273285
FunctionArgList &Args) {
274286
EmittedKernels.push_back({CGF.CurFn, CGF.CurFuncDecl});
287+
if (auto *GV = dyn_cast<llvm::GlobalVariable>(KernelHandles[CGF.CurFn])) {
288+
GV->setLinkage(CGF.CurFn->getLinkage());
289+
GV->setInitializer(CGF.CurFn);
290+
}
275291
if (CudaFeatureEnabled(CGM.getTarget().getSDKVersion(),
276292
CudaFeature::CUDA_USES_NEW_LAUNCH) ||
277293
(CGF.getLangOpts().HIP && CGF.getLangOpts().HIPUseNewLaunchAPI))
@@ -350,7 +366,8 @@ void CGNVCUDARuntime::emitDeviceStubBodyNew(CodeGenFunction &CGF,
350366
ShmemSize.getPointer(), Stream.getPointer()});
351367

352368
// Emit the call to cudaLaunch
353-
llvm::Value *Kernel = CGF.Builder.CreatePointerCast(CGF.CurFn, VoidPtrTy);
369+
llvm::Value *Kernel =
370+
CGF.Builder.CreatePointerCast(KernelHandles[CGF.CurFn], VoidPtrTy);
354371
CallArgList LaunchKernelArgs;
355372
LaunchKernelArgs.add(RValue::get(Kernel),
356373
cudaLaunchKernelFD->getParamDecl(0)->getType());
@@ -405,7 +422,8 @@ void CGNVCUDARuntime::emitDeviceStubBodyLegacy(CodeGenFunction &CGF,
405422

406423
// Emit the call to cudaLaunch
407424
llvm::FunctionCallee cudaLaunchFn = getLaunchFn();
408-
llvm::Value *Arg = CGF.Builder.CreatePointerCast(CGF.CurFn, CharPtrTy);
425+
llvm::Value *Arg =
426+
CGF.Builder.CreatePointerCast(KernelHandles[CGF.CurFn], CharPtrTy);
409427
CGF.EmitRuntimeCallOrInvoke(cudaLaunchFn, Arg);
410428
CGF.EmitBranch(EndBlock);
411429

@@ -499,7 +517,7 @@ llvm::Function *CGNVCUDARuntime::makeRegisterGlobalsFn() {
499517
llvm::Constant *NullPtr = llvm::ConstantPointerNull::get(VoidPtrTy);
500518
llvm::Value *Args[] = {
501519
&GpuBinaryHandlePtr,
502-
Builder.CreateBitCast(I.Kernel, VoidPtrTy),
520+
Builder.CreateBitCast(KernelHandles[I.Kernel], VoidPtrTy),
503521
KernelName,
504522
KernelName,
505523
llvm::ConstantInt::get(IntTy, -1),
@@ -1070,3 +1088,28 @@ llvm::Function *CGNVCUDARuntime::finalizeModule() {
10701088
}
10711089
return makeModuleCtorFunction();
10721090
}
1091+
1092+
llvm::GlobalValue *CGNVCUDARuntime::getKernelHandle(llvm::Function *F,
1093+
GlobalDecl GD) {
1094+
auto Loc = KernelHandles.find(F);
1095+
if (Loc != KernelHandles.end())
1096+
return Loc->second;
1097+
1098+
if (!CGM.getLangOpts().HIP) {
1099+
KernelHandles[F] = F;
1100+
KernelStubs[F] = F;
1101+
return F;
1102+
}
1103+
1104+
auto *Var = new llvm::GlobalVariable(
1105+
TheModule, F->getType(), /*isConstant=*/true, F->getLinkage(),
1106+
/*Initializer=*/nullptr,
1107+
CGM.getMangledName(
1108+
GD.getWithKernelReferenceKind(KernelReferenceKind::Kernel)));
1109+
Var->setAlignment(CGM.getPointerAlign().getAsAlign());
1110+
Var->setDSOLocal(F->isDSOLocal());
1111+
Var->setVisibility(F->getVisibility());
1112+
KernelHandles[F] = Var;
1113+
KernelStubs[Var] = F;
1114+
return Var;
1115+
}

clang/lib/CodeGen/CGCUDARuntime.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#ifndef LLVM_CLANG_LIB_CODEGEN_CGCUDARUNTIME_H
1616
#define LLVM_CLANG_LIB_CODEGEN_CGCUDARUNTIME_H
1717

18+
#include "clang/AST/GlobalDecl.h"
1819
#include "llvm/ADT/StringRef.h"
1920
#include "llvm/IR/GlobalValue.h"
2021

@@ -94,6 +95,13 @@ class CGCUDARuntime {
9495
/// compilation is for host.
9596
virtual std::string getDeviceSideName(const NamedDecl *ND) = 0;
9697

98+
/// Get kernel handle by stub function.
99+
virtual llvm::GlobalValue *getKernelHandle(llvm::Function *Stub,
100+
GlobalDecl GD) = 0;
101+
102+
/// Get kernel stub by kernel handle.
103+
virtual llvm::Function *getKernelStub(llvm::GlobalValue *Handle) = 0;
104+
97105
/// Adjust linkage of shadow variables in host compilation.
98106
virtual void
99107
internalizeDeviceSideVar(const VarDecl *D,

clang/lib/CodeGen/CGExpr.cpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
//
1111
//===----------------------------------------------------------------------===//
1212

13+
#include "CGCUDARuntime.h"
1314
#include "CGCXXABI.h"
1415
#include "CGCall.h"
1516
#include "CGCleanup.h"
@@ -4871,8 +4872,12 @@ static CGCallee EmitDirectCallee(CodeGenFunction &CGF, GlobalDecl GD) {
48714872
return CGCallee::forBuiltin(builtinID, FD);
48724873
}
48734874

4874-
llvm::Constant *calleePtr = EmitFunctionDeclPointer(CGF.CGM, GD);
4875-
return CGCallee::forDirect(calleePtr, GD);
4875+
llvm::Constant *CalleePtr = EmitFunctionDeclPointer(CGF.CGM, GD);
4876+
if (CGF.CGM.getLangOpts().CUDA && !CGF.CGM.getLangOpts().CUDAIsDevice &&
4877+
FD->hasAttr<CUDAGlobalAttr>())
4878+
CalleePtr = CGF.CGM.getCUDARuntime().getKernelStub(
4879+
cast<llvm::GlobalValue>(CalleePtr->stripPointerCasts()));
4880+
return CGCallee::forDirect(CalleePtr, GD);
48764881
}
48774882

48784883
CGCallee CodeGenFunction::EmitCallee(const Expr *E) {
@@ -5266,6 +5271,19 @@ RValue CodeGenFunction::EmitCall(QualType CalleeType, const CGCallee &OrigCallee
52665271
Callee.setFunctionPointer(CalleePtr);
52675272
}
52685273

5274+
// HIP function pointer contains kernel handle when it is used in triple
5275+
// chevron. The kernel stub needs to be loaded from kernel handle and used
5276+
// as callee.
5277+
if (CGM.getLangOpts().HIP && !CGM.getLangOpts().CUDAIsDevice &&
5278+
isa<CUDAKernelCallExpr>(E) &&
5279+
(!TargetDecl || !isa<FunctionDecl>(TargetDecl))) {
5280+
llvm::Value *Handle = Callee.getFunctionPointer();
5281+
Handle->dump();
5282+
auto *Cast =
5283+
Builder.CreateBitCast(Handle, Handle->getType()->getPointerTo());
5284+
auto *Stub = Builder.CreateLoad(Address(Cast, CGM.getPointerAlign()));
5285+
Callee.setFunctionPointer(Stub);
5286+
}
52695287
llvm::CallBase *CallOrInvoke = nullptr;
52705288
RValue Call = EmitCall(FnInfo, Callee, ReturnValue, Args, &CallOrInvoke,
52715289
E->getExprLoc());

clang/lib/CodeGen/CodeGenModule.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3571,9 +3571,19 @@ llvm::Constant *CodeGenModule::GetAddrOfFunction(GlobalDecl GD,
35713571
}
35723572

35733573
StringRef MangledName = getMangledName(GD);
3574-
return GetOrCreateLLVMFunction(MangledName, Ty, GD, ForVTable, DontDefer,
3575-
/*IsThunk=*/false, llvm::AttributeList(),
3576-
IsForDefinition);
3574+
auto *F = GetOrCreateLLVMFunction(MangledName, Ty, GD, ForVTable, DontDefer,
3575+
/*IsThunk=*/false, llvm::AttributeList(),
3576+
IsForDefinition);
3577+
// Returns kernel handle for HIP kernel stub function.
3578+
if (LangOpts.CUDA && !LangOpts.CUDAIsDevice &&
3579+
cast<FunctionDecl>(GD.getDecl())->hasAttr<CUDAGlobalAttr>()) {
3580+
auto *Handle = getCUDARuntime().getKernelHandle(
3581+
cast<llvm::Function>(F->stripPointerCasts()), GD);
3582+
if (IsForDefinition)
3583+
return F;
3584+
return llvm::ConstantExpr::getBitCast(Handle, Ty->getPointerTo());
3585+
}
3586+
return F;
35773587
}
35783588

35793589
static const FunctionDecl *

clang/test/CodeGenCUDA/Inputs/cuda.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <stddef.h>
44

5+
#if __HIP__ || __CUDA__
56
#define __constant__ __attribute__((constant))
67
#define __device__ __attribute__((device))
78
#define __global__ __attribute__((global))
@@ -11,13 +12,22 @@
1112
#define __managed__ __attribute__((managed))
1213
#endif
1314
#define __launch_bounds__(...) __attribute__((launch_bounds(__VA_ARGS__)))
15+
#else
16+
#define __constant__
17+
#define __device__
18+
#define __global__
19+
#define __host__
20+
#define __shared__
21+
#define __managed__
22+
#define __launch_bounds__(...)
23+
#endif
1424

1525
struct dim3 {
1626
unsigned x, y, z;
1727
__host__ __device__ dim3(unsigned x, unsigned y = 1, unsigned z = 1) : x(x), y(y), z(z) {}
1828
};
1929

20-
#ifdef __HIP__
30+
#if __HIP__ || HIP_PLATFORM
2131
typedef struct hipStream *hipStream_t;
2232
typedef enum hipError {} hipError_t;
2333
int hipConfigureCall(dim3 gridSize, dim3 blockSize, size_t sharedSize = 0,
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// RUN: %clang_cc1 -x hip -emit-llvm-bc %s -o %t.hip.bc
2+
// RUN: %clang_cc1 -mlink-bitcode-file %t.hip.bc -DHIP_PLATFORM -emit-llvm \
3+
// RUN: %s -o - | FileCheck %s
4+
5+
#include "Inputs/cuda.h"
6+
7+
// CHECK: @_Z2g1i = constant void (i32)* @_Z17__device_stub__g1i, align 8
8+
#if __HIP__
9+
__global__ void g1(int x) {}
10+
#else
11+
extern void g1(int x);
12+
13+
// CHECK: call i32 @hipLaunchKernel{{.*}}@_Z2g1i
14+
void test() {
15+
hipLaunchKernel((void*)g1, 1, 1, nullptr, 0, 0);
16+
}
17+
18+
// CHECK: __hipRegisterFunction{{.*}}@_Z2g1i
19+
#endif

clang/test/CodeGenCUDA/kernel-dbg-info.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ extern "C" __global__ void ckernel(int *a) {
3030
*a = 1;
3131
}
3232

33+
// Kernel symbol for launching kernel.
34+
// CHECK: @[[SYM:ckernel]] = constant void (i32*)* @__device_stub__ckernel, align 8
35+
3336
// Device side kernel names
3437
// CHECK: @[[CKERN:[0-9]*]] = {{.*}} c"ckernel\00"
3538

@@ -40,7 +43,7 @@ extern "C" __global__ void ckernel(int *a) {
4043
// Make sure there is no !dbg between function attributes and '{'
4144
// CHECK: define{{.*}} void @[[CSTUB:__device_stub__ckernel]]{{.*}} #{{[0-9]+}} {
4245
// CHECK-NOT: call {{.*}}@hipLaunchByPtr{{.*}}!dbg
43-
// CHECK: call {{.*}}@hipLaunchByPtr{{.*}}@[[CSTUB]]
46+
// CHECK: call {{.*}}@hipLaunchByPtr{{.*}}@[[SYM]]
4447
// CHECK-NOT: ret {{.*}}!dbg
4548

4649
// CHECK-LABEL: define {{.*}}@_Z8hostfuncPi{{.*}}!dbg

clang/test/CodeGenCUDA/kernel-stub-name.cu

Lines changed: 83 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,17 @@
22

33
// RUN: %clang_cc1 -triple x86_64-linux-gnu -emit-llvm %s \
44
// RUN: -fcuda-include-gpubinary %t -o - -x hip\
5-
// RUN: | FileCheck -allow-deprecated-dag-overlap %s --check-prefixes=CHECK
5+
// RUN: | FileCheck %s
66

77
#include "Inputs/cuda.h"
88

9+
// Kernel handles
10+
11+
// CHECK: @[[HCKERN:ckernel]] = constant void ()* @__device_stub__ckernel, align 8
12+
// CHECK: @[[HNSKERN:_ZN2ns8nskernelEv]] = constant void ()* @_ZN2ns23__device_stub__nskernelEv, align 8
13+
// CHECK: @[[HTKERN:_Z10kernelfuncIiEvv]] = linkonce_odr constant void ()* @_Z25__device_stub__kernelfuncIiEvv, align 8
14+
// CHECK: @[[HDKERN:_Z11kernel_declv]] = external constant void ()*, align 8
15+
916
extern "C" __global__ void ckernel() {}
1017

1118
namespace ns {
@@ -17,6 +24,11 @@ __global__ void kernelfunc() {}
1724

1825
__global__ void kernel_decl();
1926

27+
void (*kernel_ptr)();
28+
void *void_ptr;
29+
30+
void launch(void *kern);
31+
2032
// Device side kernel names
2133

2234
// CHECK: @[[CKERN:[0-9]*]] = {{.*}} c"ckernel\00"
@@ -26,16 +38,20 @@ __global__ void kernel_decl();
2638
// Non-template kernel stub functions
2739

2840
// CHECK: define{{.*}}@[[CSTUB:__device_stub__ckernel]]
29-
// CHECK: call{{.*}}@hipLaunchByPtr{{.*}}@[[CSTUB]]
41+
// CHECK: call{{.*}}@hipLaunchByPtr{{.*}}@[[HCKERN]]
3042
// CHECK: define{{.*}}@[[NSSTUB:_ZN2ns23__device_stub__nskernelEv]]
31-
// CHECK: call{{.*}}@hipLaunchByPtr{{.*}}@[[NSSTUB]]
43+
// CHECK: call{{.*}}@hipLaunchByPtr{{.*}}@[[HNSKERN]]
44+
3245

33-
// CHECK-LABEL: define{{.*}}@_Z8hostfuncv()
46+
// Check kernel stub is used for triple chevron
47+
48+
// CHECK-LABEL: define{{.*}}@_Z4fun1v()
3449
// CHECK: call void @[[CSTUB]]()
3550
// CHECK: call void @[[NSSTUB]]()
3651
// CHECK: call void @[[TSTUB:_Z25__device_stub__kernelfuncIiEvv]]()
3752
// CHECK: call void @[[DSTUB:_Z26__device_stub__kernel_declv]]()
38-
void hostfunc(void) {
53+
54+
void fun1(void) {
3955
ckernel<<<1, 1>>>();
4056
ns::nskernel<<<1, 1>>>();
4157
kernelfunc<int><<<1, 1>>>();
@@ -45,11 +61,69 @@ void hostfunc(void) {
4561
// Template kernel stub functions
4662

4763
// CHECK: define{{.*}}@[[TSTUB]]
48-
// CHECK: call{{.*}}@hipLaunchByPtr{{.*}}@[[TSTUB]]
64+
// CHECK: call{{.*}}@hipLaunchByPtr{{.*}}@[[HTKERN]]
65+
66+
// Check declaration of stub function for external kernel.
4967

5068
// CHECK: declare{{.*}}@[[DSTUB]]
5169

70+
// Check kernel handle is used for passing the kernel as a function pointer
71+
72+
// CHECK-LABEL: define{{.*}}@_Z4fun2v()
73+
// CHECK: call void @_Z6launchPv({{.*}}[[HCKERN]]
74+
// CHECK: call void @_Z6launchPv({{.*}}[[HNSKERN]]
75+
// CHECK: call void @_Z6launchPv({{.*}}[[HTKERN]]
76+
// CHECK: call void @_Z6launchPv({{.*}}[[HDKERN]]
77+
void fun2() {
78+
launch((void *)ckernel);
79+
launch((void *)ns::nskernel);
80+
launch((void *)kernelfunc<int>);
81+
launch((void *)kernel_decl);
82+
}
83+
84+
// Check kernel handle is used for assigning a kernel to a function pointer
85+
86+
// CHECK-LABEL: define{{.*}}@_Z4fun3v()
87+
// CHECK: store void ()* bitcast (void ()** @[[HCKERN]] to void ()*), void ()** @kernel_ptr, align 8
88+
// CHECK: store void ()* bitcast (void ()** @[[HCKERN]] to void ()*), void ()** @kernel_ptr, align 8
89+
// CHECK: store i8* bitcast (void ()** @[[HCKERN]] to i8*), i8** @void_ptr, align 8
90+
// CHECK: store i8* bitcast (void ()** @[[HCKERN]] to i8*), i8** @void_ptr, align 8
91+
void fun3() {
92+
kernel_ptr = ckernel;
93+
kernel_ptr = &ckernel;
94+
void_ptr = (void *)ckernel;
95+
void_ptr = (void *)&ckernel;
96+
}
97+
98+
// Check kernel stub is loaded from kernel handle when function pointer is
99+
// used with triple chevron
100+
101+
// CHECK-LABEL: define{{.*}}@_Z4fun4v()
102+
// CHECK: store void ()* bitcast (void ()** @[[HCKERN]] to void ()*), void ()** @kernel_ptr
103+
// CHECK: call i32 @_Z16hipConfigureCall4dim3S_mP9hipStream
104+
// CHECK: %[[HANDLE:.*]] = load void ()*, void ()** @kernel_ptr, align 8
105+
// CHECK: %[[CAST:.*]] = bitcast void ()* %[[HANDLE]] to void ()**
106+
// CHECK: %[[STUB:.*]] = load void ()*, void ()** %[[CAST]], align 8
107+
// CHECK: call void %[[STUB]]()
108+
void fun4() {
109+
kernel_ptr = ckernel;
110+
kernel_ptr<<<1,1>>>();
111+
}
112+
113+
// Check kernel handle is passed to a function
114+
115+
// CHECK-LABEL: define{{.*}}@_Z4fun5v()
116+
// CHECK: store void ()* bitcast (void ()** @[[HCKERN]] to void ()*), void ()** @kernel_ptr
117+
// CHECK: %[[HANDLE:.*]] = load void ()*, void ()** @kernel_ptr, align 8
118+
// CHECK: %[[CAST:.*]] = bitcast void ()* %[[HANDLE]] to i8*
119+
// CHECK: call void @_Z6launchPv(i8* %[[CAST]])
120+
void fun5() {
121+
kernel_ptr = ckernel;
122+
launch((void *)kernel_ptr);
123+
}
124+
52125
// CHECK-LABEL: define{{.*}}@__hip_register_globals
53-
// CHECK: call{{.*}}@__hipRegisterFunction{{.*}}@[[CSTUB]]{{.*}}@[[CKERN]]
54-
// CHECK: call{{.*}}@__hipRegisterFunction{{.*}}@[[NSSTUB]]{{.*}}@[[NSKERN]]
55-
// CHECK: call{{.*}}@__hipRegisterFunction{{.*}}@[[TSTUB]]{{.*}}@[[TKERN]]
126+
// CHECK: call{{.*}}@__hipRegisterFunction{{.*}}@[[HCKERN]]{{.*}}@[[CKERN]]
127+
// CHECK: call{{.*}}@__hipRegisterFunction{{.*}}@[[HNSKERN]]{{.*}}@[[NSKERN]]
128+
// CHECK: call{{.*}}@__hipRegisterFunction{{.*}}@[[HTKERN]]{{.*}}@[[TKERN]]
129+
// CHECK-NOT: call{{.*}}@__hipRegisterFunction{{.*}}@[[HDKERN]]{{.*}}@[[DKERN]]

0 commit comments

Comments
 (0)