@@ -28,6 +28,10 @@ a OpenCL source string or a SPIR-V binary file.
28
28
from libc.stdint cimport uint32_t
29
29
30
30
from dpctl._backend cimport ( # noqa: E211, E402;
31
+ DPCTLBuildOptionList_Append,
32
+ DPCTLBuildOptionList_Create,
33
+ DPCTLBuildOptionList_Delete,
34
+ DPCTLBuildOptionListRef,
31
35
DPCTLCString_Delete,
32
36
DPCTLKernel_Copy,
33
37
DPCTLKernel_Delete,
@@ -42,13 +46,24 @@ from dpctl._backend cimport ( # noqa: E211, E402;
42
46
DPCTLKernelBundle_Copy,
43
47
DPCTLKernelBundle_CreateFromOCLSource,
44
48
DPCTLKernelBundle_CreateFromSpirv,
49
+ DPCTLKernelBundle_CreateFromSYCLSource,
45
50
DPCTLKernelBundle_Delete,
46
51
DPCTLKernelBundle_GetKernel,
52
+ DPCTLKernelBundle_GetSyclKernel,
47
53
DPCTLKernelBundle_HasKernel,
54
+ DPCTLKernelBundle_HasSyclKernel,
55
+ DPCTLKernelNameList_Append,
56
+ DPCTLKernelNameList_Create,
57
+ DPCTLKernelNameList_Delete,
58
+ DPCTLKernelNameListRef,
48
59
DPCTLSyclContextRef,
49
60
DPCTLSyclDeviceRef,
50
61
DPCTLSyclKernelBundleRef,
51
62
DPCTLSyclKernelRef,
63
+ DPCTLVirtualHeaderList_Append,
64
+ DPCTLVirtualHeaderList_Create,
65
+ DPCTLVirtualHeaderList_Delete,
66
+ DPCTLVirtualHeaderListRef,
52
67
)
53
68
54
69
__all__ = [
@@ -197,9 +212,10 @@ cdef class SyclProgram:
197
212
"""
198
213
199
214
@staticmethod
200
- cdef SyclProgram _create(DPCTLSyclKernelBundleRef KBRef):
215
+ cdef SyclProgram _create(DPCTLSyclKernelBundleRef KBRef, bint is_sycl_source ):
201
216
cdef SyclProgram ret = SyclProgram.__new__ (SyclProgram)
202
217
ret._program_ref = KBRef
218
+ ret._is_sycl_source = is_sycl_source
203
219
return ret
204
220
205
221
def __dealloc__ (self ):
@@ -210,13 +226,19 @@ cdef class SyclProgram:
210
226
211
227
cpdef SyclKernel get_sycl_kernel(self , str kernel_name):
212
228
name = kernel_name.encode(' utf8' )
229
+ if self ._is_sycl_source:
230
+ return SyclKernel._create(
231
+ DPCTLKernelBundle_GetSyclKernel(self ._program_ref, name),
232
+ kernel_name)
213
233
return SyclKernel._create(
214
234
DPCTLKernelBundle_GetKernel(self ._program_ref, name),
215
235
kernel_name
216
236
)
217
237
218
238
def has_sycl_kernel (self , str kernel_name ):
219
239
name = kernel_name.encode(' utf8' )
240
+ if self ._is_sycl_source:
241
+ return DPCTLKernelBundle_HasSyclKernel(self ._program_ref, name)
220
242
return DPCTLKernelBundle_HasKernel(self ._program_ref, name)
221
243
222
244
def addressof_ref (self ):
@@ -272,7 +294,7 @@ cpdef create_program_from_source(SyclQueue q, str src, str copts=""):
272
294
if KBref is NULL :
273
295
raise SyclProgramCompilationError()
274
296
275
- return SyclProgram._create(KBref)
297
+ return SyclProgram._create(KBref, False )
276
298
277
299
278
300
cpdef create_program_from_spirv(SyclQueue q, const unsigned char [:] IL,
@@ -318,7 +340,107 @@ cpdef create_program_from_spirv(SyclQueue q, const unsigned char[:] IL,
318
340
if KBref is NULL :
319
341
raise SyclProgramCompilationError()
320
342
321
- return SyclProgram._create(KBref)
343
+ return SyclProgram._create(KBref, False )
344
+
345
+
346
+ cpdef create_program_from_sycl_source(SyclQueue q, unicode source, list headers = [], list registered_names = [], list copts = []):
347
+ """
348
+ Creates an executable SYCL kernel_bundle from SYCL source code.
349
+
350
+ This uses the DPC++ ``kernel_compiler`` extension to create a
351
+ ``sycl::kernel_bundle<sycl::bundle_state::executable>`` object from
352
+ SYCL source code.
353
+
354
+ Parameters:
355
+ q (:class:`dpctl.SyclQueue`)
356
+ The :class:`dpctl.SyclQueue` for which the
357
+ :class:`.SyclProgram` is going to be built.
358
+ source (unicode)
359
+ SYCL source code string.
360
+ headers (list)
361
+ Optional list of virtual headers, where each entry in the list
362
+ needs to be a tuple of header name and header content. See the
363
+ documentation of the ``include_files`` property in the DPC++
364
+ ``kernel_compiler`` extension for more information.
365
+ Default: []
366
+ registered_names (list, optional)
367
+ Optional list of kernel names to register. See the
368
+ documentation of the ``registered_names`` property in the DPC++
369
+ ``kernel_compiler`` extension for more information.
370
+ Default: []
371
+ copts (list)
372
+ Optional list of compilation flags that will be used
373
+ when compiling the program. Default: ``""``.
374
+
375
+ Returns:
376
+ program (:class:`.SyclProgram`)
377
+ A :class:`.SyclProgram` object wrapping the
378
+ ``sycl::kernel_bundle<sycl::bundle_state::executable>``
379
+ returned by the C API.
380
+
381
+ Raises:
382
+ SyclProgramCompilationError
383
+ If a SYCL kernel bundle could not be created.
384
+ """
385
+ cdef DPCTLSyclKernelBundleRef KBref
386
+ cdef DPCTLSyclContextRef CRef = q.get_sycl_context().get_context_ref()
387
+ cdef DPCTLSyclDeviceRef DRef = q.get_sycl_device().get_device_ref()
388
+ cdef bytes bSrc = source.encode(' utf8' )
389
+ cdef const char * Src = < const char * > bSrc
390
+ cdef DPCTLBuildOptionListRef BuildOpts = DPCTLBuildOptionList_Create()
391
+ cdef bytes bOpt
392
+ cdef const char * sOpt
393
+ cdef bytes bName
394
+ cdef const char * sName
395
+ cdef bytes bContent
396
+ cdef const char * sContent
397
+ for opt in copts:
398
+ if not isinstance (opt, unicode ):
399
+ DPCTLBuildOptionList_Delete(BuildOpts)
400
+ raise SyclProgramCompilationError()
401
+ bOpt = opt.encode(' utf8' )
402
+ sOpt = < const char * > bOpt
403
+ DPCTLBuildOptionList_Append(BuildOpts, sOpt)
404
+
405
+ cdef DPCTLKernelNameListRef KernelNames = DPCTLKernelNameList_Create()
406
+ for name in registered_names:
407
+ if not isinstance (name, unicode ):
408
+ DPCTLBuildOptionList_Delete(BuildOpts)
409
+ DPCTLKernelNameList_Delete(KernelNames)
410
+ raise SyclProgramCompilationError()
411
+ bName = name.encode(' utf8' )
412
+ sName = < const char * > bName
413
+ DPCTLKernelNameList_Append(KernelNames, sName)
414
+
415
+
416
+ cdef DPCTLVirtualHeaderListRef VirtualHeaders = DPCTLVirtualHeaderList_Create()
417
+ for name, content in headers:
418
+ if not isinstance (name, unicode ) or not isinstance (content, unicode ):
419
+ DPCTLBuildOptionList_Delete(BuildOpts)
420
+ DPCTLKernelNameList_Delete(KernelNames)
421
+ DPCTLVirtualHeaderList_Delete(VirtualHeaders)
422
+ raise SyclProgramCompilationError()
423
+ bName = name.encode(' utf8' )
424
+ sName = < const char * > bName
425
+ bContent = content.encode(' utf8' )
426
+ sContent = < const char * > bContent
427
+ DPCTLVirtualHeaderList_Append(VirtualHeaders, sName, sContent)
428
+
429
+ KBref = DPCTLKernelBundle_CreateFromSYCLSource(CRef, DRef, Src,
430
+ VirtualHeaders, KernelNames,
431
+ BuildOpts)
432
+
433
+ if KBref is NULL :
434
+ DPCTLBuildOptionList_Delete(BuildOpts)
435
+ DPCTLKernelNameList_Delete(KernelNames)
436
+ DPCTLVirtualHeaderList_Delete(VirtualHeaders)
437
+ raise SyclProgramCompilationError()
438
+
439
+ DPCTLBuildOptionList_Delete(BuildOpts)
440
+ DPCTLKernelNameList_Delete(KernelNames)
441
+ DPCTLVirtualHeaderList_Delete(VirtualHeaders)
442
+
443
+ return SyclProgram._create(KBref, True )
322
444
323
445
324
446
cdef api DPCTLSyclKernelBundleRef SyclProgram_GetKernelBundleRef(SyclProgram pro):
@@ -335,4 +457,4 @@ cdef api SyclProgram SyclProgram_Make(DPCTLSyclKernelBundleRef KBRef):
335
457
reference.
336
458
"""
337
459
cdef DPCTLSyclKernelBundleRef copied_KBRef = DPCTLKernelBundle_Copy(KBRef)
338
- return SyclProgram._create(copied_KBRef)
460
+ return SyclProgram._create(copied_KBRef, False )
0 commit comments