1010
1111#include " clang-mlir.h"
1212#include " TypeUtils.h"
13+ #include " mlir/Conversion/SYCLToLLVM/SYCLFuncRegistry.h"
1314#include " mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1415#include " mlir/Dialect/DLTI/DLTI.h"
1516#include " mlir/Dialect/SCF/IR/SCF.h"
4546#include " mlir/Dialect/SYCL/IR/SYCLOpsDialect.h.inc"
4647#include " mlir/Dialect/SYCL/IR/SYCLOpsTypes.h"
4748
48- static bool DEBUG_FUNCTION = false ;
4949static bool BREAKPOINT_FUNCTION = false ;
5050
5151using namespace std ;
@@ -56,6 +56,7 @@ using namespace llvm::opt;
5656using namespace mlir ;
5757using namespace mlir ::arith;
5858using namespace mlir ::func;
59+ using namespace mlir ::sycl;
5960using namespace mlirclang ;
6061
6162static cl::opt<bool >
@@ -68,6 +69,10 @@ static cl::opt<bool> memRefABI("memref-abi", cl::init(true),
6869cl::opt<std::string> PrefixABI (" prefix-abi" , cl::init(" " ),
6970 cl::desc(" Prefix for emitted symbols" ));
7071
72+ static cl::opt<bool > DebugFunction (
73+ " debug-function" , cl::init(false ),
74+ cl::desc(" Print informations about functions being processed." ));
75+
7176static cl::opt<bool >
7277 CombinedStructABI (" struct-abi" , cl::init(true ),
7378 cl::desc(" Use literal LLVM ABI for structs" ));
@@ -111,6 +116,34 @@ MLIRScanner::MLIRScanner(MLIRASTConsumer &Glob,
111116 : Glob(Glob), module(module ), builder(module ->getContext ()),
112117 loc(builder.getUnknownLoc()), ThisCapture(nullptr ), LTInfo(LTInfo) {}
113118
119+ void MLIRScanner::initSupportedConstructors () {
120+ // List from SYCLFuncRegistry.cpp Please modify as new constructors are
121+ // added to that file.
122+ supportedCons.insert (" _ZN2cl4sycl2idILi1EEC1Ev" );
123+ supportedCons.insert (" _ZN2cl4sycl2idILi2EEC1Ev" );
124+ supportedCons.insert (" _ZN2cl4sycl2idILi3EEC1Ev" );
125+ supportedCons.insert (
126+ " _ZN2cl4sycl2idILi1EEC1ILi1EEENSt9enable_ifIXeqT_Li1EEmE4typeE" );
127+ supportedCons.insert (
128+ " _ZN2cl4sycl2idILi2EEC1ILi2EEENSt9enable_ifIXeqT_Li2EEmE4typeE" );
129+ supportedCons.insert (
130+ " _ZN2cl4sycl2idILi3EEC1ILi3EEENSt9enable_ifIXeqT_Li3EEmE4typeE" );
131+ supportedCons.insert (
132+ " _ZN2cl4sycl2idILi1EEC1ILi1EEENSt9enable_ifIXeqT_Li1EEmE4typeEm" );
133+ supportedCons.insert (
134+ " _ZN2cl4sycl2idILi2EEC1ILi2EEENSt9enable_ifIXeqT_Li2EEmE4typeEm" );
135+ supportedCons.insert (
136+ " _ZN2cl4sycl2idILi3EEC1ILi3EEENSt9enable_ifIXeqT_Li3EEmE4typeEm" );
137+ supportedCons.insert (
138+ " _ZN2cl4sycl2idILi1EEC1ILi1EEENSt9enable_ifIXeqT_Li1EEmE4typeEmm" );
139+ supportedCons.insert (
140+ " _ZN2cl4sycl2idILi2EEC1ILi2EEENSt9enable_ifIXeqT_Li2EEmE4typeEmm" );
141+ supportedCons.insert (
142+ " _ZN2cl4sycl2idILi3EEC1ILi3EEENSt9enable_ifIXeqT_Li3EEmE4typeEmm" );
143+ supportedCons.insert (" _ZN2cl4sycl6detail5arrayILi1EEC1ILi1EEENSt9enable_"
144+ " ifIXeqT_Li1EEmE4typeE" );
145+ }
146+
114147void MLIRScanner::init (mlir::func::FuncOp function, const FunctionDecl *fd) {
115148 this ->function = function;
116149 this ->EmittingFunctionDecl = fd;
@@ -120,6 +153,7 @@ void MLIRScanner::init(mlir::func::FuncOp function, const FunctionDecl *fd) {
120153 llvm::errs () << *fd << " \n " ;
121154 }
122155
156+ initSupportedConstructors ();
123157 setEntryAndAllocBlock (function.addEntryBlock ());
124158
125159 unsigned i = 0 ;
@@ -1363,6 +1397,16 @@ MLIRScanner::VisitCXXConstructExpr(clang::CXXConstructExpr *cons) {
13631397 return VisitConstructCommon (cons, /* name*/ nullptr , /* space*/ 0 );
13641398}
13651399
1400+ static void getMangledFuncName (std::string &name, const FunctionDecl *FD,
1401+ CodeGen::CodeGenModule &CGM) {
1402+ if (auto CC = dyn_cast<CXXConstructorDecl>(FD))
1403+ name = CGM.getMangledName (GlobalDecl (CC, CXXCtorType::Ctor_Complete)).str ();
1404+ else if (auto CC = dyn_cast<CXXDestructorDecl>(FD))
1405+ name = CGM.getMangledName (GlobalDecl (CC, CXXDtorType::Dtor_Complete)).str ();
1406+ else
1407+ name = CGM.getMangledName (FD).str ();
1408+ }
1409+
13661410ValueCategory MLIRScanner::VisitConstructCommon (clang::CXXConstructExpr *cons,
13671411 VarDecl *name, unsigned memtype,
13681412 mlir::Value op,
@@ -1439,11 +1483,33 @@ ValueCategory MLIRScanner::VisitConstructCommon(clang::CXXConstructExpr *cons,
14391483 assert (obj.isReference );
14401484 }
14411485
1442- // / If the constructor is part of the SYCL namespace, we do not want the
1486+ // / If the constructor is part of the SYCL namespace, we may not want the
14431487 // / GetOrCreateMLIRFunction to add this FuncOp to the functionsToEmit dequeu,
1444- // / since we will create it's equivalent with SYCL operations.
1445- const auto ShouldEmit = !mlirclang::isNamespaceSYCL (
1488+ // / since we will create it's equivalent with SYCL operations. Please note
1489+ // / that we still generate some constructors that we need for lowering some
1490+ // / sycl op. Therefore, in those case, we set ShouldEmit back to "true" by
1491+ // / looking them up in our "registry" of supported constructors.
1492+
1493+ bool ShouldEmit = !mlirclang::isNamespaceSYCL (
14461494 cons->getConstructor ()->getEnclosingNamespaceContext ());
1495+
1496+ if (const FunctionDecl *FuncDecl =
1497+ dyn_cast<FunctionDecl>(cons->getConstructor ())) {
1498+ std::string name;
1499+ getMangledFuncName (name, FuncDecl, Glob.CGM );
1500+ name = (PrefixABI + name);
1501+
1502+ if (DebugFunction) {
1503+ llvm::dbgs () << " Starting codegen of " << name << " \n " ;
1504+ }
1505+ if (isSupportedConstructor (name)) {
1506+ if (DebugFunction) {
1507+ llvm::dbgs () << " Function found in registry, continue codegen-ing...\n " ;
1508+ }
1509+ ShouldEmit = true ;
1510+ }
1511+ }
1512+
14471513 auto tocall =
14481514 Glob.GetOrCreateMLIRFunction (cons->getConstructor (), ShouldEmit);
14491515
@@ -4262,12 +4328,7 @@ mlir::LLVM::LLVMFuncOp MLIRASTConsumer::GetOrCreateFreeFunction() {
42624328mlir::LLVM::LLVMFuncOp
42634329MLIRASTConsumer::GetOrCreateLLVMFunction (const FunctionDecl *FD) {
42644330 std::string name;
4265- if (auto CC = dyn_cast<CXXConstructorDecl>(FD))
4266- name = CGM.getMangledName (GlobalDecl (CC, CXXCtorType::Ctor_Complete)).str ();
4267- else if (auto CC = dyn_cast<CXXDestructorDecl>(FD))
4268- name = CGM.getMangledName (GlobalDecl (CC, CXXDtorType::Dtor_Complete)).str ();
4269- else
4270- name = CGM.getMangledName (FD).str ();
4331+ getMangledFuncName (name, FD, CGM);
42714332
42724333 if (name != " malloc" && name != " free" )
42734334 name = (PrefixABI + name);
@@ -4630,25 +4691,20 @@ mlir::Value MLIRASTConsumer::GetOrCreateGlobalLLVMString(
46304691 return globalPtr;
46314692}
46324693
4633- mlir::func::FuncOp
4634- MLIRASTConsumer::GetOrCreateMLIRFunction (const FunctionDecl *FD,
4635- const bool ShouldEmit,
4636- bool getDeviceStub) {
4694+ mlir::func::FuncOp MLIRASTConsumer::GetOrCreateMLIRFunction (
4695+ const FunctionDecl *FD, const bool ShouldEmit, bool getDeviceStub) {
46374696 assert (FD->getTemplatedKind () !=
46384697 FunctionDecl::TemplatedKind::TK_FunctionTemplate);
46394698 assert (
46404699 FD->getTemplatedKind () !=
46414700 FunctionDecl::TemplatedKind::TK_DependentFunctionTemplateSpecialization);
4701+
46424702 std::string name;
46434703 if (getDeviceStub)
46444704 name =
46454705 CGM.getMangledName (GlobalDecl (FD, KernelReferenceKind::Kernel)).str ();
4646- else if (auto CC = dyn_cast<CXXConstructorDecl>(FD))
4647- name = CGM.getMangledName (GlobalDecl (CC, CXXCtorType::Ctor_Complete)).str ();
4648- else if (auto CC = dyn_cast<CXXDestructorDecl>(FD))
4649- name = CGM.getMangledName (GlobalDecl (CC, CXXDtorType::Dtor_Complete)).str ();
46504706 else
4651- name = CGM. getMangledName (FD). str ( );
4707+ getMangledFuncName ( name, FD, CGM);
46524708
46534709 name = (PrefixABI + name);
46544710
@@ -4855,7 +4911,7 @@ void MLIRASTConsumer::run() {
48554911 while (functionsToEmit.size ()) {
48564912 const FunctionDecl *FD = functionsToEmit.front ();
48574913
4858- if (BREAKPOINT_FUNCTION && DEBUG_FUNCTION ) {
4914+ if (BREAKPOINT_FUNCTION && DebugFunction ) {
48594915 printf (" \n " );
48604916 printf (" -- FUNCTION BEING EMITTED : \033 [0;32m %s \033 [0m -- \n " ,
48614917 FD->getNameAsString ().c_str ());
@@ -4870,14 +4926,7 @@ void MLIRASTConsumer::run() {
48704926 TK_DependentFunctionTemplateSpecialization);
48714927 std::string name;
48724928
4873- if (auto CC = dyn_cast<CXXConstructorDecl>(FD))
4874- name =
4875- CGM.getMangledName (GlobalDecl (CC, CXXCtorType::Ctor_Complete)).str ();
4876- else if (auto CC = dyn_cast<CXXDestructorDecl>(FD))
4877- name =
4878- CGM.getMangledName (GlobalDecl (CC, CXXDtorType::Dtor_Complete)).str ();
4879- else
4880- name = CGM.getMangledName (FD).str ();
4929+ getMangledFuncName (name, FD, CGM);
48814930
48824931 if (done.count (name))
48834932 continue ;
@@ -4886,7 +4935,7 @@ void MLIRASTConsumer::run() {
48864935 auto Function = GetOrCreateMLIRFunction (FD, true );
48874936 ms.init (Function, FD);
48884937
4889- if (BREAKPOINT_FUNCTION && DEBUG_FUNCTION ) {
4938+ if (BREAKPOINT_FUNCTION && DebugFunction ) {
48904939 printf (" \n " );
48914940 Function.dump ();
48924941 printf (" \n " );
@@ -4926,7 +4975,7 @@ void MLIRASTConsumer::HandleDeclContext(DeclContext *DC) {
49264975 HandleDeclContext (NS);
49274976 continue ;
49284977 }
4929- FunctionDecl *fd = dyn_cast<clang::FunctionDecl>(D);
4978+ const FunctionDecl *fd = dyn_cast<clang::FunctionDecl>(D);
49304979 if (!fd) {
49314980 continue ;
49324981 }
@@ -4953,14 +5002,7 @@ void MLIRASTConsumer::HandleDeclContext(DeclContext *DC) {
49535002 externLinkage = false ;
49545003
49555004 std::string name;
4956- if (auto CC = dyn_cast<CXXConstructorDecl>(fd))
4957- name =
4958- CGM.getMangledName (GlobalDecl (CC, CXXCtorType::Ctor_Complete)).str ();
4959- else if (auto CC = dyn_cast<CXXDestructorDecl>(fd))
4960- name =
4961- CGM.getMangledName (GlobalDecl (CC, CXXDtorType::Dtor_Complete)).str ();
4962- else
4963- name = CGM.getMangledName (fd).str ();
5005+ getMangledFuncName (name, fd, CGM);
49645006
49655007 // Don't create std functions unless necessary
49665008 if (StringRef (name).startswith (" _ZNKSt" ))
@@ -5002,7 +5044,7 @@ bool MLIRASTConsumer::HandleTopLevelDecl(DeclGroupRef dg) {
50025044 HandleDeclContext (NS);
50035045 continue ;
50045046 }
5005- FunctionDecl *fd = dyn_cast<clang::FunctionDecl>(*it);
5047+ const FunctionDecl *fd = dyn_cast<clang::FunctionDecl>(*it);
50065048 if (!fd) {
50075049 continue ;
50085050 }
@@ -5034,14 +5076,7 @@ bool MLIRASTConsumer::HandleTopLevelDecl(DeclGroupRef dg) {
50345076 externLinkage = false ;
50355077
50365078 std::string name;
5037- if (auto CC = dyn_cast<CXXConstructorDecl>(fd))
5038- name =
5039- CGM.getMangledName (GlobalDecl (CC, CXXCtorType::Ctor_Complete)).str ();
5040- else if (auto CC = dyn_cast<CXXDestructorDecl>(fd))
5041- name =
5042- CGM.getMangledName (GlobalDecl (CC, CXXDtorType::Dtor_Complete)).str ();
5043- else
5044- name = CGM.getMangledName (fd).str ();
5079+ getMangledFuncName (name, fd, CGM);
50455080
50465081 // Don't create std functions unless necessary
50475082 if (StringRef (name).startswith (" _ZNKSt" ))
0 commit comments