@@ -32,10 +32,34 @@ void SYCLFuncDescriptor::declareFunction(ModuleOp &module, OpBuilder &b) {
32
32
funcRef = builder.getOrInsertFuncDecl (name, outputTy, argTys, module );
33
33
}
34
34
35
- Value SYCLFuncDescriptor::call (FuncId id, ValueRange args,
35
+ bool SYCLFuncDescriptor::isIdCtor (FuncId funcId) {
36
+ switch (funcId) {
37
+ case FuncId::Id1CtorDefault:
38
+ case FuncId::Id2CtorDefault:
39
+ case FuncId::Id3CtorDefault:
40
+ case FuncId::Id1CtorSizeT:
41
+ case FuncId::Id2CtorSizeT:
42
+ case FuncId::Id3CtorSizeT:
43
+ case FuncId::Id1CtorRange:
44
+ case FuncId::Id2CtorRange:
45
+ case FuncId::Id3CtorRange:
46
+ case FuncId::Id1CtorItem:
47
+ case FuncId::Id2CtorItem:
48
+ case FuncId::Id3CtorItem:
49
+ case FuncId::Id1CopyCtor:
50
+ case FuncId::Id2CopyCtor:
51
+ case FuncId::Id3CopyCtor:
52
+ return true ;
53
+ default :;
54
+ }
55
+
56
+ return false ;
57
+ }
58
+
59
+ Value SYCLFuncDescriptor::call (FuncId funcId, ValueRange args,
36
60
const SYCLFuncRegistry ®istry, OpBuilder &b,
37
61
Location loc) {
38
- const SYCLFuncDescriptor &funcDesc = registry.getFuncDesc (id );
62
+ const SYCLFuncDescriptor &funcDesc = registry.getFuncDesc (funcId );
39
63
LLVM_DEBUG (
40
64
llvm::dbgs () << " Creating SYCLFuncDescriptor::call to funcDesc.funcRef: "
41
65
<< funcDesc.funcRef << " \n " );
@@ -59,14 +83,56 @@ Value SYCLFuncDescriptor::call(FuncId id, ValueRange args,
59
83
60
84
SYCLFuncRegistry *SYCLFuncRegistry::instance = nullptr ;
61
85
62
- const SYCLFuncRegistry SYCLFuncRegistry::create (
63
- ModuleOp & module , OpBuilder &builder) {
86
+ const SYCLFuncRegistry SYCLFuncRegistry::create (ModuleOp & module ,
87
+ OpBuilder &builder) {
64
88
if (!instance)
65
89
instance = new SYCLFuncRegistry (module , builder);
66
90
67
91
return *instance;
68
92
}
69
93
94
+ SYCLFuncDescriptor::FuncId
95
+ SYCLFuncRegistry::getFuncId (SYCLFuncDescriptor::FuncIdKind funcIdKind,
96
+ Type retType, TypeRange argTypes) const {
97
+ assert (funcIdKind != SYCLFuncDescriptor::FuncIdKind::Unknown &&
98
+ " Invalid funcIdKind" );
99
+
100
+ // Determines whether the given funcId has kind that matches the given
101
+ // funcIdKind.
102
+ auto kindMatches = [](SYCLFuncDescriptor::FuncId funcId,
103
+ SYCLFuncDescriptor::FuncIdKind funcIdKind) {
104
+ bool foundMatch = false ;
105
+ switch (funcIdKind) {
106
+ case SYCLFuncDescriptor::FuncIdKind::IdCtor:
107
+ foundMatch = SYCLFuncDescriptor::isIdCtor (funcId);
108
+ break ;
109
+ default :
110
+ foundMatch = false ;
111
+ }
112
+ return foundMatch;
113
+ };
114
+
115
+ for (const auto &entry : registry) {
116
+ // Skip through entries that do not match the requested funcIdKind.
117
+ if (!kindMatches (entry.second .id , funcIdKind))
118
+ continue ;
119
+
120
+ // Ensure that the entry has return and arguments type that match the one
121
+ // provided.
122
+ if (retType != entry.second .outputTy ||
123
+ argTypes.size () != entry.second .argTys .size ())
124
+ continue ;
125
+ if (!std::equal (argTypes.begin (), argTypes.end (),
126
+ entry.second .argTys .begin ()))
127
+ continue ;
128
+
129
+ return entry.second .id ;
130
+ }
131
+
132
+ llvm_unreachable (" Unimplemented descriptor" );
133
+ return SYCLFuncDescriptor::FuncId::Unknown;
134
+ }
135
+
70
136
SYCLFuncRegistry::SYCLFuncRegistry (ModuleOp &module , OpBuilder &builder)
71
137
: registry() {
72
138
MLIRContext *context = module .getContext ();
@@ -144,6 +210,22 @@ SYCLFuncRegistry::SYCLFuncRegistry(ModuleOp &module, OpBuilder &builder)
144
210
SYCLFuncDescriptor::FuncId::Id3CtorItem,
145
211
" _ZN2cl4sycl2idILi3EEC2ILi3EEENSt9enable_ifIXeqT_Li3EEmE4typeEmm" ,
146
212
voidTy, {id3PtrTy, i64Ty, i64Ty, i64Ty}),
213
+
214
+ // cl::sycl::id<1>::id(cl::sycl::id<1> const&)
215
+ SYCLFuncDescriptor (
216
+ SYCLFuncDescriptor::FuncId::Id1CopyCtor,
217
+ " _ZN2cl4sycl2idILi1EEC1ERKS2_" ,
218
+ voidTy, {id1PtrTy, id1PtrTy}),
219
+ // cl::sycl::id<2>::id(cl::sycl::id<2> const&)
220
+ SYCLFuncDescriptor (
221
+ SYCLFuncDescriptor::FuncId::Id2CopyCtor,
222
+ " _ZN2cl4sycl2idILi2EEC1ERKS2_" ,
223
+ voidTy, {id2PtrTy, id2PtrTy}),
224
+ // cl::sycl::id<3>::id(cl::sycl::id<3> const&)
225
+ SYCLFuncDescriptor (
226
+ SYCLFuncDescriptor::FuncId::Id3CopyCtor,
227
+ " _ZN2cl4sycl2idILi3EEC1ERKS2_" ,
228
+ voidTy, {id3PtrTy, id3PtrTy}),
147
229
};
148
230
// clang-format on
149
231
0 commit comments