Skip to content
3 changes: 3 additions & 0 deletions llvm/include/llvm/SYCLLowerIR/ESIMD/ESIMDUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ namespace llvm {
namespace esimd {

constexpr char ATTR_DOUBLE_GRF[] = "esimd-double-grf";
constexpr char ESIMD_MARKER_MD[] = "sycl_explicit_simd";

using CallGraphNodeAction = std::function<void(Function *)>;
void traverseCallgraphUp(llvm::Function *F, CallGraphNodeAction NodeF,
Expand All @@ -34,6 +35,8 @@ void traverseCallgraphUp(Function *F, CallGraphNodeActionF ActionF,

// Tells whether given function is a ESIMD kernel.
bool isESIMDKernel(const Function &F);
// Tells whether given function is a ESIMD function.
bool isESIMD(const Function &F);

/// Reports and error with the message \p Msg concatenated with the optional
/// \p OptMsg if \p Condition is false.
Expand Down
6 changes: 6 additions & 0 deletions llvm/include/llvm/SYCLLowerIR/ESIMD/LowerESIMD.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ class SYCLLowerESIMDKernelPropsPass
PreservedAnalyses run(Module &M, ModuleAnalysisManager &);
};

// Fixes ESIMD Kernel attributes for wrapper functions for ESIMD kernels
class SYCLFixupESIMDKernelWrapperMDPass
: public PassInfoMixin<SYCLFixupESIMDKernelWrapperMDPass> {
public:
PreservedAnalyses run(Module &M, ModuleAnalysisManager &);
};
} // namespace llvm

#endif // LLVM_SYCLLOWERIR_LOWERESIMD_H
1 change: 1 addition & 0 deletions llvm/lib/Passes/PassRegistry.def
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ MODULE_PASS("SYCLMutatePrintfAddrspace", SYCLMutatePrintfAddrspacePass())
MODULE_PASS("SPIRITTAnnotations", SPIRITTAnnotationsPass())
MODULE_PASS("deadargelim-sycl", DeadArgumentEliminationSYCLPass())
MODULE_PASS("sycllowerwglocalmemory", SYCLLowerWGLocalMemoryPass())
MODULE_PASS("lower-esimd-kernel-attrs", SYCLFixupESIMDKernelWrapperMDPass())
#undef MODULE_PASS

#ifndef MODULE_PASS_WITH_PARAMS
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/SYCLLowerIR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ add_llvm_component_library(LLVMSYCLLowerIR
ESIMD/LowerESIMDVecArg.cpp
ESIMD/ESIMDUtils.cpp
ESIMD/ESIMDVerifier.cpp
ESIMD/LowerESIMDKernelAttrs.cpp
LowerInvokeSimd.cpp
LowerWGScope.cpp
LowerWGLocalMemory.cpp
Expand Down
12 changes: 8 additions & 4 deletions llvm/lib/SYCLLowerIR/ESIMD/ESIMDUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,10 @@ void traverseCallgraphUp(llvm::Function *F, CallGraphNodeAction ActionF,
} else {
auto *CI = cast<CallInst>(FCall);

if ((CI->getCalledFunction() != CurF) && ErrorOnNonCallUse) {
if ((CI->getCalledFunction() != CurF)) {
// CurF is used in a call, but not as the callee.
llvm::report_fatal_error(ErrMsg);
if (ErrorOnNonCallUse)
llvm::report_fatal_error(ErrMsg);
} else {
auto FCaller = CI->getFunction();

Expand All @@ -69,9 +70,12 @@ void traverseCallgraphUp(llvm::Function *F, CallGraphNodeAction ActionF,
}
}

bool isESIMD(const Function &F) {
return F.getMetadata(ESIMD_MARKER_MD) != nullptr;
}

bool isESIMDKernel(const Function &F) {
return (F.getCallingConv() == CallingConv::SPIR_KERNEL) &&
(F.getMetadata("sycl_explicit_simd") != nullptr);
return (F.getCallingConv() == CallingConv::SPIR_KERNEL) && isESIMD(F);
}

} // namespace esimd
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/SYCLLowerIR/ESIMD/LowerESIMD.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1215,6 +1215,7 @@ translateSpirvGlobalUses(LoadInst *LI, StringRef SpirvGlobalName,
// TODO: Implement support for the following intrinsics:
// uint32_t __spirv_BuiltIn NumSubgroups;
// uint32_t __spirv_BuiltIn SubgroupId;
// uint32_t __spirv_BuiltIn GlobalLinearId

// Translate those loads from _scalar_ SPIRV globals that can be replaced with
// a const value here.
Expand All @@ -1227,6 +1228,8 @@ translateSpirvGlobalUses(LoadInst *LI, StringRef SpirvGlobalName,
SpirvGlobalName == "SubgroupMaxSize") {
NewInst = llvm::Constant::getIntegerValue(LI->getType(),
llvm::APInt(32, 1, true));
} else if (SpirvGlobalName == "GlobalLinearId") {
NewInst = llvm::Constant::getNullValue(LI->getType());
}
if (NewInst) {
LI->replaceAllUsesWith(NewInst);
Expand Down
42 changes: 42 additions & 0 deletions llvm/lib/SYCLLowerIR/ESIMD/LowerESIMDKernelAttrs.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
//===---- LowerESIMDKernelAttrs - lower __esimd_set_kernel_attributes ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// Finds and adds sycl_explicit_simd attributes to wrapper functions that wrap
// ESIMD kernel functions

#include "llvm/SYCLLowerIR/ESIMD/ESIMDUtils.h"
#include "llvm/SYCLLowerIR/ESIMD/LowerESIMD.h"

#include "llvm/IR/Module.h"
#include "llvm/Pass.h"

#define DEBUG_TYPE "LowerESIMDKernelAttrs"

using namespace llvm;

namespace llvm {
PreservedAnalyses
SYCLFixupESIMDKernelWrapperMDPass::run(Module &M, ModuleAnalysisManager &MAM) {
bool Modified = false;
for (Function &F : M) {
if (llvm::esimd::isESIMD(F)) {
llvm::esimd::traverseCallgraphUp(
&F,
[&](Function *GraphNode) {
if (!llvm::esimd::isESIMD(*GraphNode)) {
GraphNode->setMetadata(
llvm::esimd::ESIMD_MARKER_MD,
llvm::MDNode::get(GraphNode->getContext(), {}));
Modified = true;
}
},
false);
}
}
return Modified ? PreservedAnalyses::none() : PreservedAnalyses::all();
}
} // namespace llvm
4 changes: 4 additions & 0 deletions llvm/tools/sycl-post-link/sycl-post-link.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,10 @@ processInputModule(std::unique_ptr<Module> M) {
// if none were made.
bool Modified = false;

// Propagate ESIMD attribute to wrapper functions to prevent
// spurious splits and kernel link errors.
Modified |= runModulePass<SYCLFixupESIMDKernelWrapperMDPass>(*M);

// After linking device bitcode "llvm.used" holds references to the kernels
// that are defined in the device image. But after splitting device image into
// separate kernels we may end up with having references to kernel declaration
Expand Down