Skip to content

Commit 49fd770

Browse files
joppermsommerlukas
andauthored
[SYCL][RTC] Use program manager (#16316)
The idea is to assign a per-compilation prefix (e.g., `rtc_42$`) to all offload entries in the `sycl_device_binaries` datastructure before feeding them into `ProgramManager::addImages`, resulting in unique `kernel_id`s as far as the PM is concerned. When querying the PM for the device images to construct the kernel bundle, I look for kernel IDs starting with the current prefix, which should reliably return only the device images corresponding to current compilation request. Note that the actual kernel names don't change, i.e. `__sycl_kernel_foo` keeps that name even though the PM might know it as `rtc_42$__sycl_kernel_foo`. The prefix is stored inside the bundle, and prepended to the requested kernel name in the `ext_onapi_[has|get]_kernel(string)` methods. Kernel objects are also obtained via the program manager. Compared to creating the UR kernel from the selected device image's UR program directly, this approach ensures eliminated arguments are handled correctly. Hence, I was able to drop the previously mandatory `-fno-sycl-dead-args-optimization` from the pipeline. The compilation pipeline and extended kernel bundle now support multiple device images. To test this, I added support for the `-fsycl-device-code-split=` option, and apply it to one of the compilations in the E2E test. The persistent cache is circumvented for now for the `sycl_jit` language (lack of suitable on-disk format), but should be brought back in the future. --------- Signed-off-by: Julian Oppermann <[email protected]> Co-authored-by: Lukas Sommer <[email protected]>
1 parent df9fba6 commit 49fd770

File tree

12 files changed

+272
-169
lines changed

12 files changed

+272
-169
lines changed

sycl-jit/common/include/Kernel.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ struct InMemoryFile {
359359
const char *Contents;
360360
};
361361

362-
using RTCBundleBinaryInfo = SYCLKernelBinaryInfo;
362+
using RTCDevImgBinaryInfo = SYCLKernelBinaryInfo;
363363
using FrozenSymbolTable = DynArray<sycl::detail::string>;
364364

365365
// Note: `FrozenPropertyValue` and `FrozenPropertySet` constructors take
@@ -399,16 +399,18 @@ struct FrozenPropertySet {
399399

400400
using FrozenPropertyRegistry = DynArray<FrozenPropertySet>;
401401

402-
struct RTCBundleInfo {
403-
RTCBundleBinaryInfo BinaryInfo;
402+
struct RTCDevImgInfo {
403+
RTCDevImgBinaryInfo BinaryInfo;
404404
FrozenSymbolTable SymbolTable;
405405
FrozenPropertyRegistry Properties;
406406

407-
RTCBundleInfo() = default;
408-
RTCBundleInfo(RTCBundleInfo &&) = default;
409-
RTCBundleInfo &operator=(RTCBundleInfo &&) = default;
407+
RTCDevImgInfo() = default;
408+
RTCDevImgInfo(RTCDevImgInfo &&) = default;
409+
RTCDevImgInfo &operator=(RTCDevImgInfo &&) = default;
410410
};
411411

412+
using RTCBundleInfo = DynArray<RTCDevImgInfo>;
413+
412414
} // namespace jit_compiler
413415

414416
#endif // SYCL_FUSION_COMMON_KERNEL_H

sycl-jit/jit-compiler/lib/KernelFusion.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -266,17 +266,18 @@ compileSYCL(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
266266
return errorTo<RTCResult>(PostLinkResultOrError.takeError(),
267267
"Post-link phase failed");
268268
}
269-
RTCBundleInfo BundleInfo;
270-
std::tie(BundleInfo, Module) = std::move(*PostLinkResultOrError);
271-
272-
auto BinaryInfoOrError =
273-
translation::KernelTranslator::translateBundleToSPIRV(
274-
*Module, JITContext::getInstance());
275-
if (!BinaryInfoOrError) {
276-
return errorTo<RTCResult>(BinaryInfoOrError.takeError(),
277-
"SPIR-V translation failed");
269+
auto [BundleInfo, Modules] = std::move(*PostLinkResultOrError);
270+
271+
for (auto [DevImgInfo, Module] : llvm::zip_equal(BundleInfo, Modules)) {
272+
auto BinaryInfoOrError =
273+
translation::KernelTranslator::translateDevImgToSPIRV(
274+
*Module, JITContext::getInstance());
275+
if (!BinaryInfoOrError) {
276+
return errorTo<RTCResult>(BinaryInfoOrError.takeError(),
277+
"SPIR-V translation failed");
278+
}
279+
DevImgInfo.BinaryInfo = std::move(*BinaryInfoOrError);
278280
}
279-
BundleInfo.BinaryInfo = std::move(*BinaryInfoOrError);
280281

281282
return RTCResult{std::move(BundleInfo), BuildLog.c_str()};
282283
}

sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.cpp

Lines changed: 92 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -233,8 +233,6 @@ Expected<std::unique_ptr<llvm::Module>> jit_compiler::compileDeviceCode(
233233
DerivedArgList DAL{UserArgList};
234234
const auto &OptTable = getDriverOptTable();
235235
DAL.AddFlagArg(nullptr, OptTable.getOption(OPT_fsycl_device_only));
236-
DAL.AddFlagArg(nullptr,
237-
OptTable.getOption(OPT_fno_sycl_dead_args_optimization));
238236
DAL.AddJoinedArg(
239237
nullptr, OptTable.getOption(OPT_resource_dir_EQ),
240238
(DPCPPRoot + "/lib/clang/" + Twine(CLANG_VERSION_MAJOR)).str());
@@ -436,15 +434,35 @@ template <class PassClass> static bool runModulePass(llvm::Module &M) {
436434
return !Res.areAllPreserved();
437435
}
438436

439-
llvm::Expected<PostLinkResult> jit_compiler::performPostLink(
440-
std::unique_ptr<llvm::Module> Module,
441-
[[maybe_unused]] const llvm::opt::InputArgList &UserArgList) {
437+
static IRSplitMode getDeviceCodeSplitMode(const InputArgList &UserArgList) {
438+
// This is the (combined) logic from
439+
// `get[NonTriple|Triple]BasedSYCLPostLinkOpts` in
440+
// `clang/lib/Driver/ToolChains/Clang.cpp`: Default is auto mode, but the user
441+
// can override it by specifying the `-fsycl-device-code-split=` option. The
442+
// no-argument variant `-fsycl-device-code-split` is ignored.
443+
if (auto *Arg = UserArgList.getLastArg(OPT_fsycl_device_code_split_EQ)) {
444+
StringRef ArgVal{Arg->getValue()};
445+
if (ArgVal == "per_kernel") {
446+
return SPLIT_PER_KERNEL;
447+
}
448+
if (ArgVal == "per_source") {
449+
return SPLIT_PER_TU;
450+
}
451+
if (ArgVal == "off") {
452+
return SPLIT_NONE;
453+
}
454+
}
455+
return SPLIT_AUTO;
456+
}
457+
458+
Expected<PostLinkResult>
459+
jit_compiler::performPostLink(std::unique_ptr<llvm::Module> Module,
460+
const InputArgList &UserArgList) {
442461
// This is a simplified version of `processInputModule` in
443462
// `llvm/tools/sycl-post-link.cpp`. Assertions/TODOs point to functionality
444463
// left out of the algorithm for now.
445464

446-
// TODO: SplitMode can be controlled by the user.
447-
const auto SplitMode = SPLIT_NONE;
465+
const auto SplitMode = getDeviceCodeSplitMode(UserArgList);
448466

449467
// TODO: EmitOnlyKernelsAsEntryPoints is controlled by
450468
// `shouldEmitOnlyKernelsAsEntryPoints` in
@@ -480,77 +498,87 @@ llvm::Expected<PostLinkResult> jit_compiler::performPostLink(
480498
return createStringError("`invoke_simd` calls detected");
481499
}
482500

483-
// TODO: Implement actual device code splitting. We're just using the splitter
484-
// to obtain additional information about the module for now.
485-
486501
std::unique_ptr<ModuleSplitterBase> Splitter = getDeviceCodeSplitter(
487502
ModuleDesc{std::move(Module)}, SplitMode,
488503
/*IROutputOnly=*/false, EmitOnlyKernelsAsEntryPoints);
489504
assert(Splitter->hasMoreSplits());
490-
if (Splitter->remainingSplits() > 1) {
491-
return createStringError("Device code requires splitting");
492-
}
493505

494506
// TODO: Call `verifyNoCrossModuleDeviceGlobalUsage` if device globals shall
495507
// be processed.
496508

497-
ModuleDesc MDesc = Splitter->nextSplit();
509+
// TODO: This allocation assumes that there are no further splits required,
510+
// i.e. there are no mixed SYCL/ESIMD modules.
511+
RTCBundleInfo BundleInfo{Splitter->remainingSplits()};
512+
SmallVector<std::unique_ptr<llvm::Module>> Modules;
498513

499-
// TODO: Call `MDesc.fixupLinkageOfDirectInvokeSimdTargets()` when
500-
// `invoke_simd` is supported.
514+
auto *DevImgInfoIt = BundleInfo.begin();
515+
while (Splitter->hasMoreSplits()) {
516+
assert(DevImgInfoIt != BundleInfo.end());
501517

502-
SmallVector<ModuleDesc, 2> ESIMDSplits =
503-
splitByESIMD(std::move(MDesc), EmitOnlyKernelsAsEntryPoints);
504-
assert(!ESIMDSplits.empty());
505-
if (ESIMDSplits.size() > 1) {
506-
return createStringError("Mixing SYCL and ESIMD code is unsupported");
507-
}
508-
MDesc = std::move(ESIMDSplits.front());
518+
ModuleDesc MDesc = Splitter->nextSplit();
519+
RTCDevImgInfo &DevImgInfo = *DevImgInfoIt++;
509520

510-
if (MDesc.isESIMD()) {
511-
// `sycl-post-link` has a `-lower-esimd` option, but there's no clang driver
512-
// option to influence it. Rather, the driver sets it unconditionally in the
513-
// multi-file output mode, which we are mimicking here.
514-
lowerEsimdConstructs(MDesc, PerformOpts);
515-
}
521+
// TODO: Call `MDesc.fixupLinkageOfDirectInvokeSimdTargets()` when
522+
// `invoke_simd` is supported.
516523

517-
MDesc.saveSplitInformationAsMetadata();
518-
519-
RTCBundleInfo BundleInfo;
520-
BundleInfo.SymbolTable = FrozenSymbolTable{MDesc.entries().size()};
521-
transform(MDesc.entries(), BundleInfo.SymbolTable.begin(),
522-
[](Function *F) { return F->getName(); });
523-
524-
// TODO: Determine what is requested.
525-
GlobalBinImageProps PropReq{
526-
/*EmitKernelParamInfo=*/true, /*EmitProgramMetadata=*/true,
527-
/*EmitExportedSymbols=*/true, /*EmitImportedSymbols=*/true,
528-
/*DeviceGlobals=*/false};
529-
PropertySetRegistry Properties =
530-
computeModuleProperties(MDesc.getModule(), MDesc.entries(), PropReq);
531-
// TODO: Manually add `compile_target` property as in
532-
// `saveModuleProperties`?
533-
const auto &PropertySets = Properties.getPropSets();
534-
535-
BundleInfo.Properties = FrozenPropertyRegistry{PropertySets.size()};
536-
for (auto &&[KV, FrozenPropSet] : zip(PropertySets, BundleInfo.Properties)) {
537-
const auto &PropertySetName = KV.first;
538-
const auto &PropertySet = KV.second;
539-
FrozenPropSet =
540-
FrozenPropertySet{PropertySetName.str(), PropertySet.size()};
541-
for (auto &&[KV2, FrozenProp] : zip(PropertySet, FrozenPropSet.Values)) {
542-
const auto &PropertyName = KV2.first;
543-
const auto &PropertyValue = KV2.second;
544-
FrozenProp = PropertyValue.getType() == PropertyValue::Type::UINT32
545-
? FrozenPropertyValue{PropertyName.str(),
546-
PropertyValue.asUint32()}
547-
: FrozenPropertyValue{
548-
PropertyName.str(), PropertyValue.asRawByteArray(),
549-
PropertyValue.getRawByteArraySize()};
524+
SmallVector<ModuleDesc, 2> ESIMDSplits =
525+
splitByESIMD(std::move(MDesc), EmitOnlyKernelsAsEntryPoints);
526+
assert(!ESIMDSplits.empty());
527+
if (ESIMDSplits.size() > 1) {
528+
return createStringError("Mixing SYCL and ESIMD code is unsupported");
550529
}
551-
};
530+
MDesc = std::move(ESIMDSplits.front());
531+
532+
if (MDesc.isESIMD()) {
533+
// `sycl-post-link` has a `-lower-esimd` option, but there's no clang
534+
// driver option to influence it. Rather, the driver sets it
535+
// unconditionally in the multi-file output mode, which we are mimicking
536+
// here.
537+
lowerEsimdConstructs(MDesc, PerformOpts);
538+
}
539+
540+
MDesc.saveSplitInformationAsMetadata();
541+
542+
DevImgInfo.SymbolTable = FrozenSymbolTable{MDesc.entries().size()};
543+
transform(MDesc.entries(), DevImgInfo.SymbolTable.begin(),
544+
[](Function *F) { return F->getName(); });
545+
546+
// TODO: Determine what is requested.
547+
GlobalBinImageProps PropReq{
548+
/*EmitKernelParamInfo=*/true, /*EmitProgramMetadata=*/true,
549+
/*EmitExportedSymbols=*/true, /*EmitImportedSymbols=*/true,
550+
/*DeviceGlobals=*/false};
551+
PropertySetRegistry Properties =
552+
computeModuleProperties(MDesc.getModule(), MDesc.entries(), PropReq);
553+
// TODO: Manually add `compile_target` property as in
554+
// `saveModuleProperties`?
555+
const auto &PropertySets = Properties.getPropSets();
556+
557+
DevImgInfo.Properties = FrozenPropertyRegistry{PropertySets.size()};
558+
for (auto [KV, FrozenPropSet] :
559+
zip_equal(PropertySets, DevImgInfo.Properties)) {
560+
const auto &PropertySetName = KV.first;
561+
const auto &PropertySet = KV.second;
562+
FrozenPropSet =
563+
FrozenPropertySet{PropertySetName.str(), PropertySet.size()};
564+
for (auto [KV2, FrozenProp] :
565+
zip_equal(PropertySet, FrozenPropSet.Values)) {
566+
const auto &PropertyName = KV2.first;
567+
const auto &PropertyValue = KV2.second;
568+
FrozenProp =
569+
PropertyValue.getType() == PropertyValue::Type::UINT32
570+
? FrozenPropertyValue{PropertyName.str(),
571+
PropertyValue.asUint32()}
572+
: FrozenPropertyValue{PropertyName.str(),
573+
PropertyValue.asRawByteArray(),
574+
PropertyValue.getRawByteArraySize()};
575+
}
576+
};
577+
578+
Modules.push_back(MDesc.releaseModulePtr());
579+
}
552580

553-
return PostLinkResult{std::move(BundleInfo), MDesc.releaseModulePtr()};
581+
return PostLinkResult{std::move(BundleInfo), std::move(Modules)};
554582
}
555583

556584
Expected<InputArgList>
@@ -607,21 +635,10 @@ jit_compiler::parseUserArgs(View<const char *> UserArgs) {
607635
}
608636
}
609637

610-
if (auto DCSMode = AL.getLastArgValue(OPT_fsycl_device_code_split_EQ, "none");
611-
DCSMode != "none" && DCSMode != "auto") {
612-
return createStringError("Device code splitting is not yet supported");
613-
}
614-
615638
if (!AL.hasFlag(OPT_fsycl_device_code_split_esimd,
616639
OPT_fno_sycl_device_code_split_esimd, true)) {
617640
return createStringError("ESIMD device code split cannot be deactivated");
618641
}
619642

620-
if (AL.hasFlag(OPT_fsycl_dead_args_optimization,
621-
OPT_fno_sycl_dead_args_optimization, false)) {
622-
return createStringError(
623-
"Dead argument optimization must be disabled for runtime compilation");
624-
}
625-
626643
return std::move(AL);
627644
}

sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "Kernel.h"
1313
#include "View.h"
1414

15+
#include <llvm/ADT/SmallVector.h>
1516
#include <llvm/IR/Module.h>
1617
#include <llvm/Option/ArgList.h>
1718
#include <llvm/Support/Error.h>
@@ -30,7 +31,8 @@ llvm::Error linkDeviceLibraries(llvm::Module &Module,
3031
const llvm::opt::InputArgList &UserArgList,
3132
std::string &BuildLog);
3233

33-
using PostLinkResult = std::pair<RTCBundleInfo, std::unique_ptr<llvm::Module>>;
34+
using PostLinkResult =
35+
std::pair<RTCBundleInfo, llvm::SmallVector<std::unique_ptr<llvm::Module>>>;
3436
llvm::Expected<PostLinkResult>
3537
performPostLink(std::unique_ptr<llvm::Module> Module,
3638
const llvm::opt::InputArgList &UserArgList);

sycl-jit/jit-compiler/lib/translation/KernelTranslation.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -222,18 +222,18 @@ llvm::Error KernelTranslator::translateKernel(SYCLKernelInfo &Kernel,
222222
return Error::success();
223223
}
224224

225-
llvm::Expected<RTCBundleBinaryInfo>
226-
KernelTranslator::translateBundleToSPIRV(llvm::Module &Mod,
225+
llvm::Expected<RTCDevImgBinaryInfo>
226+
KernelTranslator::translateDevImgToSPIRV(llvm::Module &Mod,
227227
JITContext &JITCtx) {
228228
llvm::Expected<KernelBinary *> BinaryOrError = translateToSPIRV(Mod, JITCtx);
229229
if (auto Error = BinaryOrError.takeError()) {
230230
return Error;
231231
}
232232
KernelBinary *Binary = *BinaryOrError;
233-
RTCBundleBinaryInfo BBI{BinaryFormat::SPIRV,
234-
Mod.getDataLayout().getPointerSizeInBits(),
235-
Binary->address(), Binary->size()};
236-
return BBI;
233+
RTCDevImgBinaryInfo DIBI{BinaryFormat::SPIRV,
234+
Mod.getDataLayout().getPointerSizeInBits(),
235+
Binary->address(), Binary->size()};
236+
return DIBI;
237237
}
238238

239239
llvm::Expected<KernelBinary *>

sycl-jit/jit-compiler/lib/translation/KernelTranslation.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ class KernelTranslator {
2727
static llvm::Error translateKernel(SYCLKernelInfo &Kernel, llvm::Module &Mod,
2828
JITContext &JITCtx, BinaryFormat Format);
2929

30-
static llvm::Expected<RTCBundleBinaryInfo>
31-
translateBundleToSPIRV(llvm::Module &Mod, JITContext &JITCtx);
30+
static llvm::Expected<RTCDevImgBinaryInfo>
31+
translateDevImgToSPIRV(llvm::Module &Mod, JITContext &JITCtx);
3232

3333
private:
3434
///

0 commit comments

Comments
 (0)