Skip to content

Commit 3e317dc

Browse files
lalaniket8anikelal
authored andcommitted
[Clang][OpenCL][AMDGPU] Allow a kernel to call another kernel (#115821)
This feature is currently not supported in the compiler. To facilitate this we emit a stub version of each kernel function body with different name mangling scheme, and replaces the respective kernel call-sites appropriately. Fixes llvm/llvm-project#60313 D120566 was an earlier attempt made to upstream a solution for this issue. --------- Co-authored-by: anikelal <[email protected]>
1 parent 8e42124 commit 3e317dc

33 files changed

+3375
-1375
lines changed

clang/include/clang/AST/Decl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3049,6 +3049,8 @@ class FunctionDecl : public DeclaratorDecl,
30493049
static FunctionDecl *castFromDeclContext(const DeclContext *DC) {
30503050
return static_cast<FunctionDecl *>(const_cast<DeclContext*>(DC));
30513051
}
3052+
3053+
bool isReferenceableKernel() const;
30523054
};
30533055

30543056
/// Represents a member of a struct/union/class.

clang/include/clang/AST/GlobalDecl.h

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,15 @@ class GlobalDecl {
7070
GlobalDecl(const VarDecl *D) { Init(D);}
7171
GlobalDecl(const FunctionDecl *D, unsigned MVIndex = 0)
7272
: MultiVersionIndex(MVIndex) {
73-
if (!D->hasAttr<CUDAGlobalAttr>()) {
74-
Init(D);
73+
if (D->isReferenceableKernel()) {
74+
Value.setPointerAndInt(D, unsigned(getDefaultKernelReference(D)));
7575
return;
7676
}
77-
Value.setPointerAndInt(D, unsigned(getDefaultKernelReference(D)));
77+
Init(D);
7878
}
7979
GlobalDecl(const FunctionDecl *D, KernelReferenceKind Kind)
8080
: Value(D, unsigned(Kind)) {
81-
assert(D->hasAttr<CUDAGlobalAttr>() && "Decl is not a GPU kernel!");
81+
assert(D->isReferenceableKernel() && "Decl is not a GPU kernel!");
8282
}
8383
GlobalDecl(const NamedDecl *D) { Init(D); }
8484
GlobalDecl(const BlockDecl *D) { Init(D); }
@@ -131,12 +131,13 @@ class GlobalDecl {
131131

132132
KernelReferenceKind getKernelReferenceKind() const {
133133
assert(((isa<FunctionDecl>(getDecl()) &&
134-
cast<FunctionDecl>(getDecl())->hasAttr<CUDAGlobalAttr>()) ||
134+
cast<FunctionDecl>(getDecl())->isReferenceableKernel()) ||
135135
(isa<FunctionTemplateDecl>(getDecl()) &&
136136
cast<FunctionTemplateDecl>(getDecl())
137137
->getTemplatedDecl()
138138
->hasAttr<CUDAGlobalAttr>())) &&
139139
"Decl is not a GPU kernel!");
140+
140141
return static_cast<KernelReferenceKind>(Value.getInt());
141142
}
142143

@@ -160,8 +161,9 @@ class GlobalDecl {
160161
}
161162

162163
static KernelReferenceKind getDefaultKernelReference(const FunctionDecl *D) {
163-
return D->getLangOpts().CUDAIsDevice ? KernelReferenceKind::Kernel
164-
: KernelReferenceKind::Stub;
164+
return (D->hasAttr<OpenCLKernelAttr>() || D->getLangOpts().CUDAIsDevice)
165+
? KernelReferenceKind::Kernel
166+
: KernelReferenceKind::Stub;
165167
}
166168

167169
GlobalDecl getWithDecl(const Decl *D) {
@@ -197,7 +199,7 @@ class GlobalDecl {
197199

198200
GlobalDecl getWithKernelReferenceKind(KernelReferenceKind Kind) {
199201
assert(isa<FunctionDecl>(getDecl()) &&
200-
cast<FunctionDecl>(getDecl())->hasAttr<CUDAGlobalAttr>() &&
202+
cast<FunctionDecl>(getDecl())->isReferenceableKernel() &&
201203
"Decl is not a GPU kernel!");
202204
GlobalDecl Result(*this);
203205
Result.Value.setInt(unsigned(Kind));

clang/lib/AST/Decl.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5472,6 +5472,10 @@ FunctionDecl *FunctionDecl::CreateDeserialized(ASTContext &C, GlobalDeclID ID) {
54725472
/*TrailingRequiresClause=*/{});
54735473
}
54745474

5475+
bool FunctionDecl::isReferenceableKernel() const {
5476+
return hasAttr<CUDAGlobalAttr>() || hasAttr<OpenCLKernelAttr>();
5477+
}
5478+
54755479
BlockDecl *BlockDecl::Create(ASTContext &C, DeclContext *DC, SourceLocation L) {
54765480
return new (C, DC) BlockDecl(DC, L);
54775481
}

clang/lib/AST/Expr.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -750,9 +750,9 @@ std::string PredefinedExpr::ComputeName(PredefinedIdentKind IK,
750750
GD = GlobalDecl(CD, Ctor_Base);
751751
else if (const CXXDestructorDecl *DD = dyn_cast<CXXDestructorDecl>(ND))
752752
GD = GlobalDecl(DD, Dtor_Base);
753-
else if (ND->hasAttr<CUDAGlobalAttr>())
754-
GD = GlobalDecl(cast<FunctionDecl>(ND));
755-
else
753+
else if (auto FD = dyn_cast<FunctionDecl>(ND)) {
754+
GD = FD->isReferenceableKernel() ? GlobalDecl(FD) : GlobalDecl(ND);
755+
} else
756756
GD = GlobalDecl(ND);
757757
MC->mangleName(GD, Out);
758758

clang/lib/AST/ItaniumMangle.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,7 @@ class CXXNameMangler {
526526
void mangleSourceName(const IdentifierInfo *II);
527527
void mangleRegCallName(const IdentifierInfo *II);
528528
void mangleDeviceStubName(const IdentifierInfo *II);
529+
void mangleOCLDeviceStubName(const IdentifierInfo *II);
529530
void mangleSourceNameWithAbiTags(
530531
const NamedDecl *ND, const AbiTagList *AdditionalAbiTags = nullptr);
531532
void mangleLocalName(GlobalDecl GD,
@@ -1561,8 +1562,13 @@ void CXXNameMangler::mangleUnqualifiedName(
15611562
bool IsDeviceStub =
15621563
FD && FD->hasAttr<CUDAGlobalAttr>() &&
15631564
GD.getKernelReferenceKind() == KernelReferenceKind::Stub;
1565+
bool IsOCLDeviceStub =
1566+
FD && FD->hasAttr<OpenCLKernelAttr>() &&
1567+
GD.getKernelReferenceKind() == KernelReferenceKind::Stub;
15641568
if (IsDeviceStub)
15651569
mangleDeviceStubName(II);
1570+
else if (IsOCLDeviceStub)
1571+
mangleOCLDeviceStubName(II);
15661572
else if (IsRegCall)
15671573
mangleRegCallName(II);
15681574
else
@@ -1780,6 +1786,15 @@ void CXXNameMangler::mangleDeviceStubName(const IdentifierInfo *II) {
17801786
<< II->getName();
17811787
}
17821788

1789+
void CXXNameMangler::mangleOCLDeviceStubName(const IdentifierInfo *II) {
1790+
// <source-name> ::= <positive length number> __clang_ocl_kern_imp_
1791+
// <identifier> <number> ::= [n] <non-negative decimal integer> <identifier>
1792+
// ::= <unqualified source code identifier>
1793+
StringRef OCLDeviceStubNamePrefix = "__clang_ocl_kern_imp_";
1794+
Out << II->getLength() + OCLDeviceStubNamePrefix.size()
1795+
<< OCLDeviceStubNamePrefix << II->getName();
1796+
}
1797+
17831798
void CXXNameMangler::mangleSourceName(const IdentifierInfo *II) {
17841799
// <source-name> ::= <positive length number> <identifier>
17851800
// <number> ::= [n] <non-negative decimal integer>

clang/lib/AST/Mangle.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -540,9 +540,9 @@ class ASTNameGenerator::Implementation {
540540
GD = GlobalDecl(CtorD, Ctor_Complete);
541541
else if (const auto *DtorD = dyn_cast<CXXDestructorDecl>(D))
542542
GD = GlobalDecl(DtorD, Dtor_Complete);
543-
else if (D->hasAttr<CUDAGlobalAttr>())
544-
GD = GlobalDecl(cast<FunctionDecl>(D));
545-
else
543+
else if (const FunctionDecl *FD = dyn_cast<FunctionDecl>(D)) {
544+
GD = FD->isReferenceableKernel() ? GlobalDecl(FD) : GlobalDecl(D);
545+
} else
546546
GD = GlobalDecl(D);
547547
MC->mangleName(GD, OS);
548548
return false;

clang/lib/AST/MicrosoftMangle.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,9 +1162,15 @@ void MicrosoftCXXNameMangler::mangleUnqualifiedName(GlobalDecl GD,
11621162
->getTemplatedDecl()
11631163
->hasAttr<CUDAGlobalAttr>())) &&
11641164
GD.getKernelReferenceKind() == KernelReferenceKind::Stub;
1165+
bool IsOCLDeviceStub =
1166+
ND && isa<FunctionDecl>(ND) && ND->hasAttr<OpenCLKernelAttr>() &&
1167+
GD.getKernelReferenceKind() == KernelReferenceKind::Stub;
11651168
if (IsDeviceStub)
11661169
mangleSourceName(
11671170
(llvm::Twine("__device_stub__") + II->getName()).str());
1171+
else if (IsOCLDeviceStub)
1172+
mangleSourceName(
1173+
(llvm::Twine("__clang_ocl_kern_imp_") + II->getName()).str());
11681174
else
11691175
mangleSourceName(II->getName());
11701176
break;

clang/lib/CodeGen/CGCall.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,8 @@ CodeGenTypes::arrangeCXXConstructorCall(const CallArgList &args,
501501
/// Arrange the argument and result information for the declaration or
502502
/// definition of the given function.
503503
const CGFunctionInfo &
504-
CodeGenTypes::arrangeFunctionDeclaration(const FunctionDecl *FD) {
504+
CodeGenTypes::arrangeFunctionDeclaration(const GlobalDecl GD) {
505+
const FunctionDecl *FD = cast<FunctionDecl>(GD.getDecl());
505506
if (const CXXMethodDecl *MD = dyn_cast<CXXMethodDecl>(FD))
506507
if (MD->isImplicitObjectMemberFunction())
507508
return arrangeCXXMethodDeclaration(MD);
@@ -511,6 +512,13 @@ CodeGenTypes::arrangeFunctionDeclaration(const FunctionDecl *FD) {
511512
assert(isa<FunctionType>(FTy));
512513
setCUDAKernelCallingConvention(FTy, CGM, FD);
513514

515+
if (FD->hasAttr<OpenCLKernelAttr>() &&
516+
GD.getKernelReferenceKind() == KernelReferenceKind::Stub) {
517+
const FunctionType *FT = FTy->getAs<FunctionType>();
518+
CGM.getTargetCodeGenInfo().setOCLKernelStubCallingConvention(FT);
519+
FTy = FT->getCanonicalTypeUnqualified();
520+
}
521+
514522
// When declaring a function without a prototype, always use a
515523
// non-variadic type.
516524
if (CanQual<FunctionNoProtoType> noProto = FTy.getAs<FunctionNoProtoType>()) {
@@ -583,13 +591,11 @@ CodeGenTypes::arrangeUnprototypedObjCMessageSend(QualType returnType,
583591
const CGFunctionInfo &
584592
CodeGenTypes::arrangeGlobalDeclaration(GlobalDecl GD) {
585593
// FIXME: Do we need to handle ObjCMethodDecl?
586-
const FunctionDecl *FD = cast<FunctionDecl>(GD.getDecl());
587-
588594
if (isa<CXXConstructorDecl>(GD.getDecl()) ||
589595
isa<CXXDestructorDecl>(GD.getDecl()))
590596
return arrangeCXXStructorDeclaration(GD);
591597

592-
return arrangeFunctionDeclaration(FD);
598+
return arrangeFunctionDeclaration(GD);
593599
}
594600

595601
/// Arrange a thunk that takes 'this' as the first parameter followed by
@@ -2473,7 +2479,6 @@ void CodeGenModule::ConstructAttributeList(StringRef Name,
24732479
// Collect function IR attributes from the callee prototype if we have one.
24742480
AddAttributesFromFunctionProtoType(getContext(), FuncAttrs,
24752481
CalleeInfo.getCalleeFunctionProtoType());
2476-
24772482
const Decl *TargetDecl = CalleeInfo.getCalleeDecl().getDecl();
24782483

24792484
// Attach assumption attributes to the declaration. If this is a call
@@ -2580,7 +2585,11 @@ void CodeGenModule::ConstructAttributeList(StringRef Name,
25802585
NumElemsParam);
25812586
}
25822587

2583-
if (TargetDecl->hasAttr<OpenCLKernelAttr>()) {
2588+
if (TargetDecl->hasAttr<OpenCLKernelAttr>() &&
2589+
CallingConv != CallingConv::CC_C &&
2590+
CallingConv != CallingConv::CC_SpirFunction) {
2591+
// Check CallingConv to avoid adding uniform-work-group-size attribute to
2592+
// OpenCL Kernel Stub
25842593
if (getLangOpts().OpenCLVersion <= 120) {
25852594
// OpenCL v1.2 Work groups are always uniform
25862595
FuncAttrs.addAttribute("uniform-work-group-size", "true");

clang/lib/CodeGen/CGExpr.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5800,6 +5800,12 @@ static CGCallee EmitDirectCallee(CodeGenFunction &CGF, GlobalDecl GD) {
58005800
return CGCallee::forDirect(CalleePtr, GD);
58015801
}
58025802

5803+
static GlobalDecl getGlobalDeclForDirectCall(const FunctionDecl *FD) {
5804+
if (FD->hasAttr<OpenCLKernelAttr>())
5805+
return GlobalDecl(FD, KernelReferenceKind::Stub);
5806+
return GlobalDecl(FD);
5807+
}
5808+
58035809
CGCallee CodeGenFunction::EmitCallee(const Expr *E) {
58045810
E = E->IgnoreParens();
58055811

@@ -5813,7 +5819,7 @@ CGCallee CodeGenFunction::EmitCallee(const Expr *E) {
58135819
// Resolve direct calls.
58145820
} else if (auto DRE = dyn_cast<DeclRefExpr>(E)) {
58155821
if (auto FD = dyn_cast<FunctionDecl>(DRE->getDecl())) {
5816-
return EmitDirectCallee(*this, FD);
5822+
return EmitDirectCallee(*this, getGlobalDeclForDirectCall(FD));
58175823
}
58185824
} else if (auto ME = dyn_cast<MemberExpr>(E)) {
58195825
if (auto FD = dyn_cast<FunctionDecl>(ME->getMemberDecl())) {
@@ -6182,6 +6188,10 @@ RValue CodeGenFunction::EmitCall(QualType CalleeType,
61826188

61836189
const auto *FnType = cast<FunctionType>(PointeeType);
61846190

6191+
if (const auto *FD = dyn_cast_or_null<FunctionDecl>(TargetDecl);
6192+
FD && FD->hasAttr<OpenCLKernelAttr>())
6193+
CGM.getTargetCodeGenInfo().setOCLKernelStubCallingConvention(FnType);
6194+
61856195
// If we are checking indirect calls and this call is indirect, check that the
61866196
// function pointer is a member of the bit set for the function type.
61876197
if (SanOpts.has(SanitizerKind::CFIICall) &&

clang/lib/CodeGen/CodeGenFunction.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1931,6 +1931,26 @@ void CodeGenFunction::GenerateCode(GlobalDecl GD, llvm::Function *Fn,
19311931
// Implicit copy-assignment gets the same special treatment as implicit
19321932
// copy-constructors.
19331933
emitImplicitAssignmentOperatorBody(Args);
1934+
} else if (FD->hasAttr<OpenCLKernelAttr>() &&
1935+
GD.getKernelReferenceKind() == KernelReferenceKind::Kernel) {
1936+
CallArgList CallArgs;
1937+
for (unsigned i = 0; i < Args.size(); ++i) {
1938+
Address ArgAddr = GetAddrOfLocalVar(Args[i]);
1939+
QualType ArgQualType = Args[i]->getType();
1940+
RValue ArgRValue = convertTempToRValue(ArgAddr, ArgQualType, Loc);
1941+
CallArgs.add(ArgRValue, ArgQualType);
1942+
}
1943+
GlobalDecl GDStub = GlobalDecl(FD, KernelReferenceKind::Stub);
1944+
const FunctionType *FT = cast<FunctionType>(FD->getType());
1945+
CGM.getTargetCodeGenInfo().setOCLKernelStubCallingConvention(FT);
1946+
const CGFunctionInfo &FnInfo = CGM.getTypes().arrangeFreeFunctionCall(
1947+
CallArgs, FT, /*ChainCall=*/false);
1948+
llvm::FunctionType *FTy = CGM.getTypes().GetFunctionType(FnInfo);
1949+
llvm::Constant *GDStubFunctionPointer =
1950+
CGM.getRawFunctionPointer(GDStub, FTy);
1951+
CGCallee GDStubCallee = CGCallee::forDirect(GDStubFunctionPointer, GDStub);
1952+
EmitCall(FnInfo, GDStubCallee, ReturnValueSlot(), CallArgs, nullptr, false,
1953+
Loc);
19341954
} else if (Body) {
19351955
EmitFunctionBody(Body);
19361956
} else

clang/lib/CodeGen/CodeGenModule.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2030,6 +2030,9 @@ static std::string getMangledNameImpl(CodeGenModule &CGM, GlobalDecl GD,
20302030
} else if (FD && FD->hasAttr<CUDAGlobalAttr>() &&
20312031
GD.getKernelReferenceKind() == KernelReferenceKind::Stub) {
20322032
Out << "__device_stub__" << II->getName();
2033+
} else if (FD && FD->hasAttr<OpenCLKernelAttr>() &&
2034+
GD.getKernelReferenceKind() == KernelReferenceKind::Stub) {
2035+
Out << "__clang_ocl_kern_imp_" << II->getName();
20332036
} else {
20342037
Out << II->getName();
20352038
}
@@ -4301,6 +4304,9 @@ void CodeGenModule::EmitGlobal(GlobalDecl GD) {
43014304

43024305
// Ignore declarations, they will be emitted on their first use.
43034306
if (const auto *FD = dyn_cast<FunctionDecl>(Global)) {
4307+
if (FD->hasAttr<OpenCLKernelAttr>() && FD->doesThisDeclarationHaveABody())
4308+
addDeferredDeclToEmit(GlobalDecl(FD, KernelReferenceKind::Stub));
4309+
43044310
// Update deferred annotations with the latest declaration if the function
43054311
// function was already used or defined.
43064312
if (FD->hasAttr<AnnotateAttr>()) {
@@ -5327,6 +5333,11 @@ CodeGenModule::GetAddrOfFunction(GlobalDecl GD, llvm::Type *Ty, bool ForVTable,
53275333
if (!Ty) {
53285334
const auto *FD = cast<FunctionDecl>(GD.getDecl());
53295335
Ty = getTypes().ConvertType(FD->getType());
5336+
if (FD->hasAttr<OpenCLKernelAttr>() &&
5337+
GD.getKernelReferenceKind() == KernelReferenceKind::Stub) {
5338+
const CGFunctionInfo &FI = getTypes().arrangeGlobalDeclaration(GD);
5339+
Ty = getTypes().GetFunctionType(FI);
5340+
}
53305341
}
53315342

53325343
// Devirtualized destructor calls may come through here instead of via

clang/lib/CodeGen/CodeGenTypes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ class CodeGenTypes {
216216

217217
/// Free functions are functions that are compatible with an ordinary
218218
/// C function pointer type.
219-
const CGFunctionInfo &arrangeFunctionDeclaration(const FunctionDecl *FD);
219+
const CGFunctionInfo &arrangeFunctionDeclaration(const GlobalDecl GD);
220220
const CGFunctionInfo &arrangeFreeFunctionCall(const CallArgList &Args,
221221
const FunctionType *Ty,
222222
bool ChainCall);

clang/lib/CodeGen/TargetInfo.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,12 @@ unsigned TargetCodeGenInfo::getOpenCLKernelCallingConv() const {
117117
return llvm::CallingConv::SPIR_KERNEL;
118118
}
119119

120+
void TargetCodeGenInfo::setOCLKernelStubCallingConvention(
121+
const FunctionType *&FT) const {
122+
FT = getABIInfo().getContext().adjustFunctionType(
123+
FT, FT->getExtInfo().withCallingConv(CC_C));
124+
}
125+
120126
llvm::Constant *TargetCodeGenInfo::getNullPointer(const CodeGen::CodeGenModule &CGM,
121127
llvm::PointerType *T, QualType QT) const {
122128
return llvm::ConstantPointerNull::get(T);

clang/lib/CodeGen/TargetInfo.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ class TargetCodeGenInfo {
400400
virtual bool shouldEmitDWARFBitFieldSeparators() const { return false; }
401401

402402
virtual void setCUDAKernelCallingConvention(const FunctionType *&FT) const {}
403-
403+
virtual void setOCLKernelStubCallingConvention(const FunctionType *&FT) const;
404404
/// Return the device-side type for the CUDA device builtin surface type.
405405
virtual llvm::Type *getCUDADeviceBuiltinSurfaceDeviceType() const {
406406
// By default, no change from the original one.

clang/lib/CodeGen/Targets/SPIR.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,8 @@ class CommonSPIRTargetCodeGenInfo : public TargetCodeGenInfo {
203203
llvm::Type *getSPIRVImageTypeFromHLSLResource(
204204
const HLSLAttributedResourceType::Attributes &attributes,
205205
llvm::Type *ElementType, llvm::LLVMContext &Ctx) const;
206+
void
207+
setOCLKernelStubCallingConvention(const FunctionType *&FT) const override;
206208
};
207209
class SPIRVTargetCodeGenInfo : public CommonSPIRTargetCodeGenInfo {
208210
public:
@@ -379,6 +381,12 @@ void SPIRVTargetCodeGenInfo::setCUDAKernelCallingConvention(
379381
}
380382
}
381383

384+
void CommonSPIRTargetCodeGenInfo::setOCLKernelStubCallingConvention(
385+
const FunctionType *&FT) const {
386+
FT = getABIInfo().getContext().adjustFunctionType(
387+
FT, FT->getExtInfo().withCallingConv(CC_SpirFunction));
388+
}
389+
382390
LangAS
383391
SPIRVTargetCodeGenInfo::getGlobalVarAddressSpace(CodeGenModule &CGM,
384392
const VarDecl *D) const {

0 commit comments

Comments
 (0)