Skip to content

Commit 11b05e8

Browse files
committed
[CIR][CUDA] Generate registration function (Part 1)
1 parent 8746bd4 commit 11b05e8

File tree

3 files changed

+214
-2
lines changed

3 files changed

+214
-2
lines changed

clang/include/clang/CIR/MissingFeatures.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,8 @@ struct MissingFeatures {
251251
static bool emitEmptyRecordCheck() { return false; }
252252
static bool isPPC_FP128Ty() { return false; }
253253
static bool createLaunderInvariantGroup() { return false; }
254+
static bool hipModuleCtor() { return false; }
255+
static bool checkMacOSXTriple() { return false; }
254256

255257
// Inline assembly
256258
static bool asmGoto() { return false; }

clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "clang/AST/ASTContext.h"
1414
#include "clang/AST/CharUnits.h"
1515
#include "clang/AST/Mangle.h"
16+
#include "clang/Basic/Cuda.h"
1617
#include "clang/Basic/Module.h"
1718
#include "clang/Basic/TargetInfo.h"
1819
#include "clang/CIR/Dialect/Builder/CIRBaseBuilder.h"
@@ -27,6 +28,7 @@
2728
#include "llvm/ADT/Twine.h"
2829
#include "llvm/Support/ErrorHandling.h"
2930
#include "llvm/Support/Path.h"
31+
#include "llvm/Support/VirtualFileSystem.h"
3032

3133
#include <memory>
3234

@@ -117,6 +119,17 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
117119
/// has an empty name, and prevent collisions.
118120
uint64_t annonGlobalConstArrayCount = 0;
119121

122+
///
123+
/// CUDA related
124+
/// ------------
125+
126+
// Maps CUDA device stub name to kernel name.
127+
llvm::DenseMap<llvm::StringRef, std::string> cudaKernelMap;
128+
129+
void buildCUDAModuleCtor();
130+
void buildCUDAModuleDtor();
131+
std::optional<FuncOp> buildCUDARegisterGlobals();
132+
120133
///
121134
/// AST related
122135
/// -----------
@@ -953,6 +966,143 @@ void LoweringPreparePass::buildCXXGlobalInitFunc() {
953966
builder.create<ReturnOp>(f.getLoc());
954967
}
955968

969+
void LoweringPreparePass::buildCUDAModuleCtor() {
970+
if (astCtx->getLangOpts().HIP)
971+
assert(!cir::MissingFeatures::hipModuleCtor());
972+
if (astCtx->getLangOpts().GPURelocatableDeviceCode)
973+
llvm_unreachable("NYI");
974+
975+
// There's no device-side binary, so no need to proceed for CUDA.
976+
// HIP has to create an external symbol in this case, which is NYI.
977+
auto cudaBinaryHandleAttr =
978+
theModule->getAttr(CIRDialect::getCUDABinaryHandleAttrName());
979+
if (!cudaBinaryHandleAttr)
980+
return;
981+
std::string cudaGPUBinaryName =
982+
cast<CUDABinaryHandleAttr>(cudaBinaryHandleAttr).getName();
983+
984+
llvm::StringRef prefix = "cuda";
985+
986+
constexpr unsigned cudaFatMagic = 0x466243b1;
987+
constexpr unsigned hipFatMagic = 0x48495046; // "HIPF"
988+
989+
const unsigned fatMagic =
990+
astCtx->getLangOpts().HIP ? hipFatMagic : cudaFatMagic;
991+
992+
auto addUnderscoredPrefix = [&](llvm::StringRef name) -> std::string {
993+
return ("__" + prefix + name).str();
994+
};
995+
996+
// MAC OS X needs special care, but we haven't supported that in CIR yet.
997+
assert(!cir::MissingFeatures::checkMacOSXTriple());
998+
999+
CIRBaseBuilderTy builder(getContext());
1000+
builder.setInsertionPointToStart(theModule.getBody());
1001+
1002+
mlir::Location loc = theModule.getLoc();
1003+
1004+
// Extract types from the module.
1005+
auto typeSizesAttr = cast<TypeSizeInfoAttr>(
1006+
theModule->getAttr(CIRDialect::getTypeSizeInfoAttrName()));
1007+
1008+
auto voidTy = VoidType::get(&getContext());
1009+
auto voidPtrTy = PointerType::get(voidTy);
1010+
auto voidPtrPtrTy = PointerType::get(voidPtrTy);
1011+
auto intTy = typeSizesAttr.getIntType(&getContext());
1012+
auto charTy = typeSizesAttr.getCharType(&getContext());
1013+
1014+
// Read the GPU binary and create a constant array for it.
1015+
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> cudaGPUBinaryOrErr =
1016+
llvm::MemoryBuffer::getFile(cudaGPUBinaryName);
1017+
if (std::error_code ec = cudaGPUBinaryOrErr.getError()) {
1018+
theModule->emitError("cannot open file: " + cudaGPUBinaryName +
1019+
ec.message());
1020+
return;
1021+
}
1022+
std::unique_ptr<llvm::MemoryBuffer> cudaGPUBinary =
1023+
std::move(cudaGPUBinaryOrErr.get());
1024+
1025+
// The section names are different for MAC OS X.
1026+
llvm::StringRef fatbinConstName = ".nv_fatbin";
1027+
llvm::StringRef fatbinSectionName = ".nvFatBinSegment";
1028+
1029+
// Create a global variable with the contents of GPU binary.
1030+
auto fatbinType =
1031+
ArrayType::get(&getContext(), charTy, cudaGPUBinary->getBuffer().size());
1032+
1033+
// OG gives an empty name to this global constant,
1034+
// which is not allowed in CIR.
1035+
std::string fatbinStrName = addUnderscoredPrefix("_fatbin_str");
1036+
GlobalOp fatbinStr = builder.create<GlobalOp>(
1037+
loc, fatbinStrName, fatbinType, /*isConstant=*/true,
1038+
/*linkage=*/cir::GlobalLinkageKind::PrivateLinkage);
1039+
fatbinStr.setAlignment(8);
1040+
fatbinStr.setInitialValueAttr(cir::ConstArrayAttr::get(
1041+
fatbinType, builder.getStringAttr(cudaGPUBinary->getBuffer())));
1042+
fatbinStr.setSection(fatbinConstName);
1043+
fatbinStr.setPrivate();
1044+
1045+
// Create a struct FatbinWrapper, pointing to the GPU binary.
1046+
// Struct layout:
1047+
// struct { int magicNum; int version; void *fatbin; void *unused; };
1048+
// This will be initialized in the module ctor below.
1049+
auto fatbinWrapperType = StructType::get(
1050+
&getContext(), {intTy, intTy, voidPtrTy, voidPtrTy}, /*packed=*/false,
1051+
/*padded=*/false, StructType::RecordKind::Struct);
1052+
1053+
std::string fatbinWrapperName = addUnderscoredPrefix("_fatbin_wrapper");
1054+
GlobalOp fatbinWrapper = builder.create<GlobalOp>(
1055+
loc, fatbinWrapperName, fatbinWrapperType, /*isConstant=*/true,
1056+
/*linkage=*/cir::GlobalLinkageKind::InternalLinkage);
1057+
fatbinWrapper.setPrivate();
1058+
fatbinWrapper.setSection(fatbinSectionName);
1059+
1060+
auto magicInit = IntAttr::get(intTy, fatMagic);
1061+
auto versionInit = IntAttr::get(intTy, 1);
1062+
// `fatbinInit` is only a placeholder. The value will be initialized at the
1063+
// beginning of module ctor.
1064+
auto fatbinInit = builder.getConstNullPtrAttr(voidPtrTy);
1065+
auto unusedInit = builder.getConstNullPtrAttr(voidPtrTy);
1066+
fatbinWrapper.setInitialValueAttr(cir::ConstStructAttr::get(
1067+
fatbinWrapperType,
1068+
ArrayAttr::get(&getContext(),
1069+
{magicInit, versionInit, fatbinInit, unusedInit})));
1070+
1071+
// Declare this function:
1072+
// void **__{cuda|hip}RegisterFatBinary(void *);
1073+
1074+
std::string regFuncName = addUnderscoredPrefix("RegisterFatBinary");
1075+
auto regFuncType = FuncType::get({voidPtrTy}, voidPtrPtrTy);
1076+
auto regFunc = buildRuntimeFunction(builder, regFuncName, loc, regFuncType);
1077+
1078+
// Create the module constructor.
1079+
1080+
std::string moduleCtorName = addUnderscoredPrefix("_module_ctor");
1081+
auto moduleCtor = buildRuntimeFunction(builder, moduleCtorName, loc,
1082+
FuncType::get({}, voidTy),
1083+
GlobalLinkageKind::InternalLinkage);
1084+
globalCtorList.push_back(GlobalCtorAttr::get(&getContext(), moduleCtorName));
1085+
builder.setInsertionPointToStart(moduleCtor.addEntryBlock());
1086+
1087+
auto wrapper = builder.createGetGlobal(fatbinWrapper);
1088+
// Put fatbinStr inside fatbinWrapper.
1089+
mlir::Value fatbinStrValue = builder.createGetGlobal(fatbinStr);
1090+
mlir::Value fatbinField = builder.createGetMemberOp(loc, wrapper, "", 2);
1091+
builder.createStore(loc, fatbinStrValue, fatbinField);
1092+
1093+
// Register binary with CUDA runtime. This is substantially different in
1094+
// default mode vs. separate compilation.
1095+
// Corresponding code:
1096+
// gpuBinaryHandle = __cudaRegisterFatBinary(&fatbinWrapper);
1097+
auto fatbinVoidPtr = builder.createBitcast(wrapper, voidPtrTy);
1098+
auto gpuBinaryHandle = builder.createCallOp(loc, regFunc, fatbinVoidPtr);
1099+
1100+
// This is currently incomplete.
1101+
// TODO(cir): create __cuda_register_globals(), and call it here.
1102+
1103+
builder.create<cir::ReturnOp>(loc);
1104+
}
1105+
9561106
void LoweringPreparePass::lowerDynamicCastOp(DynamicCastOp op) {
9571107
CIRBaseBuilderTy builder(getContext());
9581108
builder.setInsertionPointAfter(op);
@@ -1213,6 +1363,13 @@ void LoweringPreparePass::runOnOp(Operation *op) {
12131363
} else if (auto globalDtor = fnOp.getGlobalDtorAttr()) {
12141364
globalDtorList.push_back(globalDtor);
12151365
}
1366+
if (auto attr = fnOp.getExtraAttrs().getElements().get(
1367+
CIRDialect::getCUDABinaryHandleAttrName())) {
1368+
auto cudaBinaryAttr = dyn_cast<CUDABinaryHandleAttr>(attr);
1369+
std::string kernelName = cudaBinaryAttr.getName();
1370+
llvm::StringRef stubName = fnOp.getSymName();
1371+
cudaKernelMap[stubName] = kernelName;
1372+
}
12161373
if (std::optional<mlir::ArrayAttr> annotations = fnOp.getAnnotations())
12171374
addGlobalAnnotations(fnOp, annotations.value());
12181375
} else if (auto throwOp = dyn_cast<cir::ThrowOp>(op)) {
@@ -1240,6 +1397,10 @@ void LoweringPreparePass::runOnOperation() {
12401397
for (auto *o : opsToTransform)
12411398
runOnOp(o);
12421399

1400+
if (astCtx->getLangOpts().CUDA && !astCtx->getLangOpts().CUDAIsDevice) {
1401+
buildCUDAModuleCtor();
1402+
}
1403+
12431404
buildCXXGlobalInitFunc();
12441405
buildGlobalCtorDtorList();
12451406
buildGlobalAnnotationValues();
Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,58 @@
11
#include "../Inputs/cuda.h"
22

3+
// RUN: echo "sample fatbin" > %t.fatbin
34
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir \
45
// RUN: -x cuda -emit-cir -target-sdk-version=12.3 \
5-
// RUN: -fcuda-include-gpubinary fatbin.o\
6+
// RUN: -fcuda-include-gpubinary %t.fatbin \
67
// RUN: %s -o %t.cir
78
// RUN: FileCheck --check-prefix=CIR-HOST --input-file=%t.cir %s
89

9-
// CIR-HOST: module @"{{.*}}" attributes{{.*}}cir.cu.binary_handle = #cir.cu.binary_handle<fatbin.o>{{.*}}
10+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir \
11+
// RUN: -x cuda -emit-llvm -target-sdk-version=12.3 \
12+
// RUN: -fcuda-include-gpubinary %t.fatbin \
13+
// RUN: %s -o %t.ll
14+
// RUN: FileCheck --check-prefix=LLVM-HOST --input-file=%t.ll %s
15+
16+
// CIR-HOST: module @"{{.*}}" attributes {
17+
// CIR-HOST: cir.cu.binary_handle = #cir.cu.binary_handle<{{.*}}.fatbin>,
18+
// CIR-HOST: cir.global_ctors = [#cir.global_ctor<"__cuda_module_ctor", {{[0-9]+}}>]
19+
// CIR-HOST: }
20+
21+
// The content in const array should be the same as echoed above,
22+
// with a trailing line break ('\n', 0x0A).
23+
// CIR-HOST: cir.global "private" constant cir_private @__cuda_fatbin_str =
24+
// CIR-HOST-SAME: #cir.const_array<"sample fatbin\0A">
25+
// CIR-HOST-SAME: {{.*}}section = ".nv_fatbin"
26+
27+
// LLVM-HOST: @__cuda_fatbin_str = private constant [14 x i8] c"sample fatbin\0A", section ".nv_fatbin"
28+
29+
// The first value is CUDA file head magic number.
30+
// CIR-HOST: cir.global "private" constant internal @__cuda_fatbin_wrapper
31+
// CIR-HOST: = #cir.const_struct<{
32+
// CIR-HOST: #cir.int<1180844977> : !s32i,
33+
// CIR-HOST: #cir.int<1> : !s32i,
34+
// CIR-HOST: #cir.ptr<null> : !cir.ptr<!void>,
35+
// CIR-HOST: #cir.ptr<null> : !cir.ptr<!void>
36+
// CIR-HOST: }>
37+
// CIR-HOST-SAME: {{.*}}section = ".nvFatBinSegment"
38+
39+
// LLVM-HOST: @__cuda_fatbin_wrapper = internal constant {
40+
// LLVM-HOST: i32 1180844977, i32 1, ptr null, ptr null
41+
// LLVM-HOST: }
42+
43+
// LLVM-HOST: @llvm.global_ctors = {{.*}}ptr @__cuda_module_ctor
44+
45+
// CIR-HOST: cir.func private @__cudaRegisterFatBinary
46+
// CIR-HOST: cir.func {{.*}} @__cuda_module_ctor() {
47+
// CIR-HOST: %[[#F0:]] = cir.get_global @__cuda_fatbin_wrapper
48+
// CIR-HOST: %[[#F1:]] = cir.get_global @__cuda_fatbin_str
49+
// CIR-HOST: %[[#F2:]] = cir.get_member %[[#F0]][2]
50+
// CIR-HOST: %[[#F3:]] = cir.cast(bitcast, %[[#F2]]
51+
// CIR-HOST: cir.store %[[#F1]], %[[#F3]]
52+
// CIR-HOST: cir.call @__cudaRegisterFatBinary
53+
// CIR-HOST: }
54+
55+
// LLVM-HOST: define internal void @__cuda_module_ctor() {
56+
// LLVM-HOST: store ptr @__cuda_fatbin_str, ptr getelementptr {{.*}}, ptr @__cuda_fatbin_wrapper
57+
// LLVM-HOST: call ptr @__cudaRegisterFatBinary(ptr @__cuda_fatbin_wrapper)
58+
// LLVM-HOST: }

0 commit comments

Comments
 (0)