Skip to content

Commit 0577f01

Browse files
author
anikelal
committed
[Clang][OpenCL][AMDGPU] Allow a kernel to call another kernel
This feature is currently not supported in the compiler. To facilitate this we emit a stub version of each kernel function body with different name mangling scheme, and replaces the respective kernel call-sites appropriately. Fixes llvm#60313 D120566 was an earlier attempt made to upstream a solution for this.
1 parent ae03197 commit 0577f01

31 files changed

+2314
-496
lines changed

clang/include/clang/AST/Decl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3048,6 +3048,8 @@ class FunctionDecl : public DeclaratorDecl,
30483048
static FunctionDecl *castFromDeclContext(const DeclContext *DC) {
30493049
return static_cast<FunctionDecl *>(const_cast<DeclContext*>(DC));
30503050
}
3051+
3052+
bool isReferenceableKernel() const;
30513053
};
30523054

30533055
/// Represents a member of a struct/union/class.

clang/include/clang/AST/GlobalDecl.h

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,15 @@ class GlobalDecl {
7070
GlobalDecl(const VarDecl *D) { Init(D);}
7171
GlobalDecl(const FunctionDecl *D, unsigned MVIndex = 0)
7272
: MultiVersionIndex(MVIndex) {
73-
if (!D->hasAttr<CUDAGlobalAttr>()) {
74-
Init(D);
73+
if (D->isReferenceableKernel()) {
74+
Value.setPointerAndInt(D, unsigned(getDefaultKernelReference(D)));
7575
return;
7676
}
77-
Value.setPointerAndInt(D, unsigned(getDefaultKernelReference(D)));
77+
Init(D);
7878
}
7979
GlobalDecl(const FunctionDecl *D, KernelReferenceKind Kind)
8080
: Value(D, unsigned(Kind)) {
81-
assert(D->hasAttr<CUDAGlobalAttr>() && "Decl is not a GPU kernel!");
81+
assert(D->isReferenceableKernel() && "Decl is not a GPU kernel!");
8282
}
8383
GlobalDecl(const NamedDecl *D) { Init(D); }
8484
GlobalDecl(const BlockDecl *D) { Init(D); }
@@ -131,12 +131,13 @@ class GlobalDecl {
131131

132132
KernelReferenceKind getKernelReferenceKind() const {
133133
assert(((isa<FunctionDecl>(getDecl()) &&
134-
cast<FunctionDecl>(getDecl())->hasAttr<CUDAGlobalAttr>()) ||
134+
cast<FunctionDecl>(getDecl())->isReferenceableKernel()) ||
135135
(isa<FunctionTemplateDecl>(getDecl()) &&
136136
cast<FunctionTemplateDecl>(getDecl())
137137
->getTemplatedDecl()
138138
->hasAttr<CUDAGlobalAttr>())) &&
139139
"Decl is not a GPU kernel!");
140+
140141
return static_cast<KernelReferenceKind>(Value.getInt());
141142
}
142143

@@ -160,8 +161,9 @@ class GlobalDecl {
160161
}
161162

162163
static KernelReferenceKind getDefaultKernelReference(const FunctionDecl *D) {
163-
return D->getLangOpts().CUDAIsDevice ? KernelReferenceKind::Kernel
164-
: KernelReferenceKind::Stub;
164+
return (D->hasAttr<OpenCLKernelAttr>() || D->getLangOpts().CUDAIsDevice)
165+
? KernelReferenceKind::Kernel
166+
: KernelReferenceKind::Stub;
165167
}
166168

167169
GlobalDecl getWithDecl(const Decl *D) {
@@ -197,7 +199,7 @@ class GlobalDecl {
197199

198200
GlobalDecl getWithKernelReferenceKind(KernelReferenceKind Kind) {
199201
assert(isa<FunctionDecl>(getDecl()) &&
200-
cast<FunctionDecl>(getDecl())->hasAttr<CUDAGlobalAttr>() &&
202+
cast<FunctionDecl>(getDecl())->isReferenceableKernel() &&
201203
"Decl is not a GPU kernel!");
202204
GlobalDecl Result(*this);
203205
Result.Value.setInt(unsigned(Kind));

clang/lib/AST/Decl.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5468,6 +5468,10 @@ FunctionDecl *FunctionDecl::CreateDeserialized(ASTContext &C, GlobalDeclID ID) {
54685468
/*TrailingRequiresClause=*/{});
54695469
}
54705470

5471+
bool FunctionDecl::isReferenceableKernel() const {
5472+
return hasAttr<CUDAGlobalAttr>() || hasAttr<OpenCLKernelAttr>();
5473+
}
5474+
54715475
BlockDecl *BlockDecl::Create(ASTContext &C, DeclContext *DC, SourceLocation L) {
54725476
return new (C, DC) BlockDecl(DC, L);
54735477
}

clang/lib/AST/Expr.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -695,9 +695,9 @@ std::string PredefinedExpr::ComputeName(PredefinedIdentKind IK,
695695
GD = GlobalDecl(CD, Ctor_Base);
696696
else if (const CXXDestructorDecl *DD = dyn_cast<CXXDestructorDecl>(ND))
697697
GD = GlobalDecl(DD, Dtor_Base);
698-
else if (ND->hasAttr<CUDAGlobalAttr>())
699-
GD = GlobalDecl(cast<FunctionDecl>(ND));
700-
else
698+
else if (auto FD = dyn_cast<FunctionDecl>(ND)) {
699+
GD = FD->isReferenceableKernel() ? GlobalDecl(FD) : GlobalDecl(ND);
700+
} else
701701
GD = GlobalDecl(ND);
702702
MC->mangleName(GD, Out);
703703

clang/lib/AST/ItaniumMangle.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,7 @@ class CXXNameMangler {
526526
void mangleSourceName(const IdentifierInfo *II);
527527
void mangleRegCallName(const IdentifierInfo *II);
528528
void mangleDeviceStubName(const IdentifierInfo *II);
529+
void mangleOCLDeviceStubName(const IdentifierInfo *II);
529530
void mangleSourceNameWithAbiTags(
530531
const NamedDecl *ND, const AbiTagList *AdditionalAbiTags = nullptr);
531532
void mangleLocalName(GlobalDecl GD,
@@ -1561,8 +1562,13 @@ void CXXNameMangler::mangleUnqualifiedName(
15611562
bool IsDeviceStub =
15621563
FD && FD->hasAttr<CUDAGlobalAttr>() &&
15631564
GD.getKernelReferenceKind() == KernelReferenceKind::Stub;
1565+
bool IsOCLDeviceStub =
1566+
FD && FD->hasAttr<OpenCLKernelAttr>() &&
1567+
GD.getKernelReferenceKind() == KernelReferenceKind::Stub;
15641568
if (IsDeviceStub)
15651569
mangleDeviceStubName(II);
1570+
else if (IsOCLDeviceStub)
1571+
mangleOCLDeviceStubName(II);
15661572
else if (IsRegCall)
15671573
mangleRegCallName(II);
15681574
else
@@ -1780,6 +1786,15 @@ void CXXNameMangler::mangleDeviceStubName(const IdentifierInfo *II) {
17801786
<< II->getName();
17811787
}
17821788

1789+
void CXXNameMangler::mangleOCLDeviceStubName(const IdentifierInfo *II) {
1790+
// <source-name> ::= <positive length number> __clang_ocl_kern_imp_
1791+
// <identifier> <number> ::= [n] <non-negative decimal integer> <identifier>
1792+
// ::= <unqualified source code identifier>
1793+
StringRef OCLDeviceStubNamePrefix = "__clang_ocl_kern_imp_";
1794+
Out << II->getLength() + OCLDeviceStubNamePrefix.size()
1795+
<< OCLDeviceStubNamePrefix << II->getName();
1796+
}
1797+
17831798
void CXXNameMangler::mangleSourceName(const IdentifierInfo *II) {
17841799
// <source-name> ::= <positive length number> <identifier>
17851800
// <number> ::= [n] <non-negative decimal integer>

clang/lib/AST/Mangle.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -540,9 +540,9 @@ class ASTNameGenerator::Implementation {
540540
GD = GlobalDecl(CtorD, Ctor_Complete);
541541
else if (const auto *DtorD = dyn_cast<CXXDestructorDecl>(D))
542542
GD = GlobalDecl(DtorD, Dtor_Complete);
543-
else if (D->hasAttr<CUDAGlobalAttr>())
544-
GD = GlobalDecl(cast<FunctionDecl>(D));
545-
else
543+
else if (const FunctionDecl *FD = dyn_cast<FunctionDecl>(D)) {
544+
GD = FD->isReferenceableKernel() ? GlobalDecl(FD) : GlobalDecl(D);
545+
} else
546546
GD = GlobalDecl(D);
547547
MC->mangleName(GD, OS);
548548
return false;

clang/lib/AST/MicrosoftMangle.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,9 +1162,15 @@ void MicrosoftCXXNameMangler::mangleUnqualifiedName(GlobalDecl GD,
11621162
->getTemplatedDecl()
11631163
->hasAttr<CUDAGlobalAttr>())) &&
11641164
GD.getKernelReferenceKind() == KernelReferenceKind::Stub;
1165+
bool IsOCLDeviceStub =
1166+
ND && (isa<FunctionDecl>(ND) && ND->hasAttr<OpenCLKernelAttr>()) &&
1167+
GD.getKernelReferenceKind() == KernelReferenceKind::Stub;
11651168
if (IsDeviceStub)
11661169
mangleSourceName(
11671170
(llvm::Twine("__device_stub__") + II->getName()).str());
1171+
else if (IsOCLDeviceStub)
1172+
mangleSourceName(
1173+
(llvm::Twine("__clang_ocl_kern_imp_") + II->getName()).str());
11681174
else
11691175
mangleSourceName(II->getName());
11701176
break;

clang/lib/CodeGen/CGCall.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,9 @@ CodeGenTypes::arrangeCXXConstructorCall(const CallArgList &args,
499499
/// Arrange the argument and result information for the declaration or
500500
/// definition of the given function.
501501
const CGFunctionInfo &
502-
CodeGenTypes::arrangeFunctionDeclaration(const FunctionDecl *FD) {
502+
CodeGenTypes::arrangeFunctionDeclaration(const GlobalDecl GD) {
503+
const FunctionDecl *FD = dyn_cast<FunctionDecl>(GD.getDecl());
504+
assert(FD && "GD must contain FunctionDecl");
503505
if (const CXXMethodDecl *MD = dyn_cast<CXXMethodDecl>(FD))
504506
if (MD->isImplicitObjectMemberFunction())
505507
return arrangeCXXMethodDeclaration(MD);
@@ -509,6 +511,13 @@ CodeGenTypes::arrangeFunctionDeclaration(const FunctionDecl *FD) {
509511
assert(isa<FunctionType>(FTy));
510512
setCUDAKernelCallingConvention(FTy, CGM, FD);
511513

514+
if (FD->hasAttr<OpenCLKernelAttr>() &&
515+
GD.getKernelReferenceKind() == KernelReferenceKind::Stub) {
516+
const FunctionType *FT = FTy->getAs<FunctionType>();
517+
CGM.getTargetCodeGenInfo().setOCLKernelStubCallingConvention(FT);
518+
FTy = FT->getCanonicalTypeUnqualified();
519+
}
520+
512521
// When declaring a function without a prototype, always use a
513522
// non-variadic type.
514523
if (CanQual<FunctionNoProtoType> noProto = FTy.getAs<FunctionNoProtoType>()) {
@@ -581,13 +590,11 @@ CodeGenTypes::arrangeUnprototypedObjCMessageSend(QualType returnType,
581590
const CGFunctionInfo &
582591
CodeGenTypes::arrangeGlobalDeclaration(GlobalDecl GD) {
583592
// FIXME: Do we need to handle ObjCMethodDecl?
584-
const FunctionDecl *FD = cast<FunctionDecl>(GD.getDecl());
585-
586593
if (isa<CXXConstructorDecl>(GD.getDecl()) ||
587594
isa<CXXDestructorDecl>(GD.getDecl()))
588595
return arrangeCXXStructorDeclaration(GD);
589596

590-
return arrangeFunctionDeclaration(FD);
597+
return arrangeFunctionDeclaration(GD);
591598
}
592599

593600
/// Arrange a thunk that takes 'this' as the first parameter followed by
@@ -2392,7 +2399,6 @@ void CodeGenModule::ConstructAttributeList(StringRef Name,
23922399
// Collect function IR attributes from the callee prototype if we have one.
23932400
AddAttributesFromFunctionProtoType(getContext(), FuncAttrs,
23942401
CalleeInfo.getCalleeFunctionProtoType());
2395-
23962402
const Decl *TargetDecl = CalleeInfo.getCalleeDecl().getDecl();
23972403

23982404
// Attach assumption attributes to the declaration. If this is a call

clang/lib/CodeGen/CGExpr.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5756,6 +5756,12 @@ static CGCallee EmitDirectCallee(CodeGenFunction &CGF, GlobalDecl GD) {
57565756
return CGCallee::forDirect(CalleePtr, GD);
57575757
}
57585758

5759+
static GlobalDecl getGlobalDeclForDirectCall(const FunctionDecl *FD) {
5760+
if (FD->hasAttr<OpenCLKernelAttr>())
5761+
return GlobalDecl(FD, KernelReferenceKind::Stub);
5762+
return GlobalDecl(FD);
5763+
}
5764+
57595765
CGCallee CodeGenFunction::EmitCallee(const Expr *E) {
57605766
E = E->IgnoreParens();
57615767

@@ -5769,7 +5775,7 @@ CGCallee CodeGenFunction::EmitCallee(const Expr *E) {
57695775
// Resolve direct calls.
57705776
} else if (auto DRE = dyn_cast<DeclRefExpr>(E)) {
57715777
if (auto FD = dyn_cast<FunctionDecl>(DRE->getDecl())) {
5772-
return EmitDirectCallee(*this, FD);
5778+
return EmitDirectCallee(*this, getGlobalDeclForDirectCall(FD));
57735779
}
57745780
} else if (auto ME = dyn_cast<MemberExpr>(E)) {
57755781
if (auto FD = dyn_cast<FunctionDecl>(ME->getMemberDecl())) {
@@ -6138,6 +6144,12 @@ RValue CodeGenFunction::EmitCall(QualType CalleeType,
61386144

61396145
const auto *FnType = cast<FunctionType>(PointeeType);
61406146

6147+
if (auto FD = dyn_cast_or_null<FunctionDecl>(TargetDecl)) {
6148+
if (FD->hasAttr<OpenCLKernelAttr>()) {
6149+
CGM.getTargetCodeGenInfo().setOCLKernelStubCallingConvention(FnType);
6150+
}
6151+
}
6152+
61416153
// If we are checking indirect calls and this call is indirect, check that the
61426154
// function pointer is a member of the bit set for the function type.
61436155
if (SanOpts.has(SanitizerKind::CFIICall) &&

clang/lib/CodeGen/CodeGenFunction.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1595,6 +1595,26 @@ void CodeGenFunction::GenerateCode(GlobalDecl GD, llvm::Function *Fn,
15951595
// Implicit copy-assignment gets the same special treatment as implicit
15961596
// copy-constructors.
15971597
emitImplicitAssignmentOperatorBody(Args);
1598+
} else if (FD->hasAttr<OpenCLKernelAttr>() &&
1599+
GD.getKernelReferenceKind() == KernelReferenceKind::Kernel) {
1600+
CallArgList CallArgs;
1601+
for (unsigned i = 0; i < Args.size(); ++i) {
1602+
Address ArgAddr = GetAddrOfLocalVar(Args[i]);
1603+
QualType ArgQualType = Args[i]->getType();
1604+
RValue ArgRValue = convertTempToRValue(ArgAddr, ArgQualType, Loc);
1605+
CallArgs.add(ArgRValue, ArgQualType);
1606+
}
1607+
GlobalDecl GDStub = GlobalDecl(FD, KernelReferenceKind::Stub);
1608+
const FunctionType *FT = cast<FunctionType>(FD->getType());
1609+
CGM.getTargetCodeGenInfo().setOCLKernelStubCallingConvention(FT);
1610+
const CGFunctionInfo &FnInfo = CGM.getTypes().arrangeFreeFunctionCall(
1611+
CallArgs, FT, /*ChainCall=*/false);
1612+
llvm::FunctionType *FTy = CGM.getTypes().GetFunctionType(FnInfo);
1613+
llvm::Constant *GDStubFunctionPointer =
1614+
CGM.getRawFunctionPointer(GDStub, FTy);
1615+
CGCallee GDStubCallee = CGCallee::forDirect(GDStubFunctionPointer, GDStub);
1616+
EmitCall(FnInfo, GDStubCallee, ReturnValueSlot(), CallArgs, nullptr, false,
1617+
Loc); // set IsMustTail=true?
15981618
} else if (Body) {
15991619
EmitFunctionBody(Body);
16001620
} else

clang/lib/CodeGen/CodeGenModule.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1913,6 +1913,9 @@ static std::string getMangledNameImpl(CodeGenModule &CGM, GlobalDecl GD,
19131913
} else if (FD && FD->hasAttr<CUDAGlobalAttr>() &&
19141914
GD.getKernelReferenceKind() == KernelReferenceKind::Stub) {
19151915
Out << "__device_stub__" << II->getName();
1916+
} else if (FD && FD->hasAttr<OpenCLKernelAttr>() &&
1917+
GD.getKernelReferenceKind() == KernelReferenceKind::Stub) {
1918+
Out << "__clang_ocl_kern_imp_" << II->getName();
19161919
} else {
19171920
Out << II->getName();
19181921
}
@@ -3903,6 +3906,10 @@ void CodeGenModule::EmitGlobal(GlobalDecl GD) {
39033906

39043907
// Ignore declarations, they will be emitted on their first use.
39053908
if (const auto *FD = dyn_cast<FunctionDecl>(Global)) {
3909+
3910+
if (FD->hasAttr<OpenCLKernelAttr>() && FD->doesThisDeclarationHaveABody())
3911+
addDeferredDeclToEmit(GlobalDecl(FD, KernelReferenceKind::Stub));
3912+
39063913
// Update deferred annotations with the latest declaration if the function
39073914
// function was already used or defined.
39083915
if (FD->hasAttr<AnnotateAttr>()) {
@@ -4870,6 +4877,11 @@ CodeGenModule::GetAddrOfFunction(GlobalDecl GD, llvm::Type *Ty, bool ForVTable,
48704877
if (!Ty) {
48714878
const auto *FD = cast<FunctionDecl>(GD.getDecl());
48724879
Ty = getTypes().ConvertType(FD->getType());
4880+
if (FD->hasAttr<OpenCLKernelAttr>() &&
4881+
GD.getKernelReferenceKind() == KernelReferenceKind::Stub) {
4882+
const CGFunctionInfo &FI = getTypes().arrangeGlobalDeclaration(GD);
4883+
Ty = getTypes().GetFunctionType(FI);
4884+
}
48734885
}
48744886

48754887
// Devirtualized destructor calls may come through here instead of via
@@ -6149,6 +6161,17 @@ void CodeGenModule::EmitGlobalFunctionDefinition(GlobalDecl GD,
61496161
CodeGenFunction(*this).GenerateCode(GD, Fn, FI);
61506162

61516163
setNonAliasAttributes(GD, Fn);
6164+
6165+
if (D->hasAttr<OpenCLKernelAttr>()) {
6166+
if (GD.getKernelReferenceKind() == KernelReferenceKind::Stub) {
6167+
if (Fn->hasFnAttribute(llvm::Attribute::NoInline))
6168+
Fn->removeFnAttr(llvm::Attribute::NoInline);
6169+
if (Fn->hasFnAttribute(llvm::Attribute::InlineHint))
6170+
Fn->removeFnAttr(llvm::Attribute::InlineHint);
6171+
Fn->addFnAttr(llvm::Attribute::AlwaysInline);
6172+
}
6173+
}
6174+
61526175
SetLLVMFunctionAttributesForDefinition(D, Fn);
61536176

61546177
if (const ConstructorAttr *CA = D->getAttr<ConstructorAttr>())

clang/lib/CodeGen/CodeGenTypes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ class CodeGenTypes {
207207

208208
/// Free functions are functions that are compatible with an ordinary
209209
/// C function pointer type.
210-
const CGFunctionInfo &arrangeFunctionDeclaration(const FunctionDecl *FD);
210+
const CGFunctionInfo &arrangeFunctionDeclaration(const GlobalDecl GD);
211211
const CGFunctionInfo &arrangeFreeFunctionCall(const CallArgList &Args,
212212
const FunctionType *Ty,
213213
bool ChainCall);

clang/lib/CodeGen/TargetInfo.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,12 @@ unsigned TargetCodeGenInfo::getOpenCLKernelCallingConv() const {
117117
return llvm::CallingConv::SPIR_KERNEL;
118118
}
119119

120+
void TargetCodeGenInfo::setOCLKernelStubCallingConvention(
121+
const FunctionType *&FT) const {
122+
FT = getABIInfo().getContext().adjustFunctionType(
123+
FT, FT->getExtInfo().withCallingConv(CC_C));
124+
}
125+
120126
llvm::Constant *TargetCodeGenInfo::getNullPointer(const CodeGen::CodeGenModule &CGM,
121127
llvm::PointerType *T, QualType QT) const {
122128
return llvm::ConstantPointerNull::get(T);

clang/lib/CodeGen/TargetInfo.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ class TargetCodeGenInfo {
400400
virtual bool shouldEmitDWARFBitFieldSeparators() const { return false; }
401401

402402
virtual void setCUDAKernelCallingConvention(const FunctionType *&FT) const {}
403-
403+
virtual void setOCLKernelStubCallingConvention(const FunctionType *&FT) const;
404404
/// Return the device-side type for the CUDA device builtin surface type.
405405
virtual llvm::Type *getCUDADeviceBuiltinSurfaceDeviceType() const {
406406
// By default, no change from the original one.

clang/lib/CodeGen/Targets/SPIR.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ class CommonSPIRTargetCodeGenInfo : public TargetCodeGenInfo {
5858
llvm::Type *getSPIRVImageTypeFromHLSLResource(
5959
const HLSLAttributedResourceType::Attributes &attributes,
6060
llvm::Type *ElementType, llvm::LLVMContext &Ctx) const;
61+
void
62+
setOCLKernelStubCallingConvention(const FunctionType *&FT) const override;
6163
};
6264
class SPIRVTargetCodeGenInfo : public CommonSPIRTargetCodeGenInfo {
6365
public:
@@ -230,6 +232,12 @@ void SPIRVTargetCodeGenInfo::setCUDAKernelCallingConvention(
230232
}
231233
}
232234

235+
void CommonSPIRTargetCodeGenInfo::setOCLKernelStubCallingConvention(
236+
const FunctionType *&FT) const {
237+
FT = getABIInfo().getContext().adjustFunctionType(
238+
FT, FT->getExtInfo().withCallingConv(CC_SpirFunction));
239+
}
240+
233241
LangAS
234242
SPIRVTargetCodeGenInfo::getGlobalVarAddressSpace(CodeGenModule &CGM,
235243
const VarDecl *D) const {

0 commit comments

Comments
 (0)