Skip to content

Commit 9869a76

Browse files
committed
Support compilation from SYCL source code
1 parent ab98697 commit 9869a76

12 files changed

+785
-8
lines changed

dpctl/_backend.pxd

+40
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,9 @@ cdef extern from "syclinterface/dpctl_sycl_device_interface.h":
221221
cdef uint64_t DPCTLDevice_GetMaxMemAllocSize(const DPCTLSyclDeviceRef DRef)
222222
cdef DPCTLSyclDeviceRef DPCTLDevice_GetCompositeDevice(const DPCTLSyclDeviceRef DRef)
223223
cdef DPCTLDeviceVectorRef DPCTLDevice_GetComponentDevices(const DPCTLSyclDeviceRef DRef)
224+
cdef bool DPCTLDevice_CanCompileSPIRV(const DPCTLSyclDeviceRef DRef)
225+
cdef bool DPCTLDevice_CanCompileOpenCL(const DPCTLSyclDeviceRef DRef)
226+
cdef bool DPCTLDevice_CanCompileSYCL(const DPCTLSyclDeviceRef DRef)
224227

225228

226229
cdef extern from "syclinterface/dpctl_sycl_device_manager.h":
@@ -367,6 +370,43 @@ cdef extern from "syclinterface/dpctl_sycl_kernel_bundle_interface.h":
367370
cdef DPCTLSyclKernelBundleRef DPCTLKernelBundle_Copy(
368371
const DPCTLSyclKernelBundleRef KBRef)
369372

373+
cdef struct DPCTLBuildOptionList
374+
cdef struct DPCTLKernelNameList
375+
cdef struct DPCTLVirtualHeaderList
376+
ctypedef DPCTLBuildOptionList* DPCTLBuildOptionListRef
377+
ctypedef DPCTLKernelNameList* DPCTLKernelNameListRef
378+
ctypedef DPCTLVirtualHeaderList* DPCTLVirtualHeaderListRef
379+
380+
cdef DPCTLBuildOptionListRef DPCTLBuildOptionList_Create()
381+
cdef void DPCTLBuildOptionList_Delete(DPCTLBuildOptionListRef Ref)
382+
cdef void DPCTLBuildOptionList_Append(DPCTLBuildOptionListRef Ref,
383+
const char *Option)
384+
385+
cdef DPCTLKernelNameListRef DPCTLKernelNameList_Create()
386+
cdef void DPCTLKernelNameList_Delete(DPCTLKernelNameListRef Ref)
387+
cdef void DPCTLKernelNameList_Append(DPCTLKernelNameListRef Ref,
388+
const char *Option)
389+
390+
cdef DPCTLVirtualHeaderListRef DPCTLVirtualHeaderList_Create()
391+
cdef void DPCTLVirtualHeaderList_Delete(DPCTLVirtualHeaderListRef Ref)
392+
cdef void DPCTLVirtualHeaderList_Append(DPCTLVirtualHeaderListRef Ref,
393+
const char *Name,
394+
const char *Content)
395+
396+
cdef DPCTLSyclKernelBundleRef DPCTLKernelBundle_CreateFromSYCLSource(
397+
const DPCTLSyclContextRef Ctx,
398+
const DPCTLSyclDeviceRef Dev,
399+
const char *Source,
400+
DPCTLVirtualHeaderListRef Headers,
401+
DPCTLKernelNameListRef Names,
402+
DPCTLBuildOptionListRef BuildOptions)
403+
404+
cdef DPCTLSyclKernelRef DPCTLKernelBundle_GetSyclKernel(DPCTLSyclKernelBundleRef KBRef,
405+
const char *KernelName)
406+
407+
cdef bool DPCTLKernelBundle_HasSyclKernel(DPCTLSyclKernelBundleRef KBRef,
408+
const char *KernelName);
409+
370410

371411
cdef extern from "syclinterface/dpctl_sycl_queue_interface.h":
372412
ctypedef struct _md_local_accessor 'MDLocalAccessor':

dpctl/_sycl_device.pxd

+1
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,4 @@ cdef public api class SyclDevice(_SyclDevice) [
5959
cdef int get_overall_ordinal(self)
6060
cdef int get_backend_ordinal(self)
6161
cdef int get_backend_and_device_type_ordinal(self)
62+
cpdef bint can_compile(self, str language)

dpctl/_sycl_device.pyx

+32
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ from ._backend cimport ( # noqa: E211
2525
DPCTLCString_Delete,
2626
DPCTLDefaultSelector_Create,
2727
DPCTLDevice_AreEq,
28+
DPCTLDevice_CanCompileOpenCL,
29+
DPCTLDevice_CanCompileSPIRV,
30+
DPCTLDevice_CanCompileSYCL,
2831
DPCTLDevice_Copy,
2932
DPCTLDevice_CreateFromSelector,
3033
DPCTLDevice_CreateSubDevicesByAffinity,
@@ -2164,6 +2167,35 @@ cdef class SyclDevice(_SyclDevice):
21642167
raise ValueError("device could not be found")
21652168
return dev_id
21662169

2170+
cpdef bint can_compile(self, str language):
2171+
"""
2172+
Check whether it is possible to create an executable kernel_bundle
2173+
for this device from the given source language.
2174+
2175+
Parameters:
2176+
language
2177+
Input language. Possible values are "spirv" for SPIR-V binary
2178+
files, "opencl" for OpenCL C device code and "sycl" for SYCL
2179+
device code.
2180+
2181+
Returns:
2182+
bool:
2183+
True if compilation is supported, False otherwise.
2184+
2185+
Raises:
2186+
ValueError:
2187+
If an unknown source language is used.
2188+
"""
2189+
if language == "spirv" or language == "spv":
2190+
return DPCTLDevice_CanCompileSYCL(self._device_ref)
2191+
if language == "opencl" or language == "ocl":
2192+
return DPCTLDevice_CanCompileOpenCL(self._device_ref)
2193+
if language == "sycl":
2194+
return DPCTLDevice_CanCompileSYCL(self._device_ref)
2195+
2196+
raise ValueError(f"Unknown source language {language}")
2197+
2198+
21672199

21682200
cdef api DPCTLSyclDeviceRef SyclDevice_GetDeviceRef(SyclDevice dev):
21692201
"""

dpctl/program/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,13 @@
2626
SyclProgramCompilationError,
2727
create_program_from_source,
2828
create_program_from_spirv,
29+
create_program_from_sycl_source,
2930
)
3031

3132
__all__ = [
3233
"create_program_from_source",
3334
"create_program_from_spirv",
35+
"create_program_from_sycl_source",
3436
"SyclKernel",
3537
"SyclProgram",
3638
"SyclProgramCompilationError",

dpctl/program/_program.pxd

+5-1
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,17 @@ cdef api class SyclProgram [object PySyclProgramObject, type PySyclProgramType]:
4949
binary file.
5050
'''
5151
cdef DPCTLSyclKernelBundleRef _program_ref
52+
cdef bint _is_sycl_source
5253

5354
@staticmethod
54-
cdef SyclProgram _create (DPCTLSyclKernelBundleRef pref)
55+
cdef SyclProgram _create (DPCTLSyclKernelBundleRef pref, bint _is_sycl_source)
5556
cdef DPCTLSyclKernelBundleRef get_program_ref (self)
5657
cpdef SyclKernel get_sycl_kernel(self, str kernel_name)
5758

5859

5960
cpdef create_program_from_source (SyclQueue q, unicode source, unicode copts=*)
6061
cpdef create_program_from_spirv (SyclQueue q, const unsigned char[:] IL,
6162
unicode copts=*)
63+
cpdef create_program_from_sycl_source(SyclQueue q, unicode source,
64+
list headers=*, list registered_names=*,
65+
list copts=*)

dpctl/program/_program.pyx

+126-4
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ a OpenCL source string or a SPIR-V binary file.
2828
from libc.stdint cimport uint32_t
2929

3030
from dpctl._backend cimport ( # noqa: E211, E402;
31+
DPCTLBuildOptionList_Append,
32+
DPCTLBuildOptionList_Create,
33+
DPCTLBuildOptionList_Delete,
34+
DPCTLBuildOptionListRef,
3135
DPCTLCString_Delete,
3236
DPCTLKernel_Copy,
3337
DPCTLKernel_Delete,
@@ -42,13 +46,24 @@ from dpctl._backend cimport ( # noqa: E211, E402;
4246
DPCTLKernelBundle_Copy,
4347
DPCTLKernelBundle_CreateFromOCLSource,
4448
DPCTLKernelBundle_CreateFromSpirv,
49+
DPCTLKernelBundle_CreateFromSYCLSource,
4550
DPCTLKernelBundle_Delete,
4651
DPCTLKernelBundle_GetKernel,
52+
DPCTLKernelBundle_GetSyclKernel,
4753
DPCTLKernelBundle_HasKernel,
54+
DPCTLKernelBundle_HasSyclKernel,
55+
DPCTLKernelNameList_Append,
56+
DPCTLKernelNameList_Create,
57+
DPCTLKernelNameList_Delete,
58+
DPCTLKernelNameListRef,
4859
DPCTLSyclContextRef,
4960
DPCTLSyclDeviceRef,
5061
DPCTLSyclKernelBundleRef,
5162
DPCTLSyclKernelRef,
63+
DPCTLVirtualHeaderList_Append,
64+
DPCTLVirtualHeaderList_Create,
65+
DPCTLVirtualHeaderList_Delete,
66+
DPCTLVirtualHeaderListRef,
5267
)
5368

5469
__all__ = [
@@ -197,9 +212,10 @@ cdef class SyclProgram:
197212
"""
198213

199214
@staticmethod
200-
cdef SyclProgram _create(DPCTLSyclKernelBundleRef KBRef):
215+
cdef SyclProgram _create(DPCTLSyclKernelBundleRef KBRef, bint is_sycl_source):
201216
cdef SyclProgram ret = SyclProgram.__new__(SyclProgram)
202217
ret._program_ref = KBRef
218+
ret._is_sycl_source = is_sycl_source
203219
return ret
204220

205221
def __dealloc__(self):
@@ -210,13 +226,19 @@ cdef class SyclProgram:
210226

211227
cpdef SyclKernel get_sycl_kernel(self, str kernel_name):
212228
name = kernel_name.encode('utf8')
229+
if self._is_sycl_source:
230+
return SyclKernel._create(
231+
DPCTLKernelBundle_GetSyclKernel(self._program_ref, name),
232+
kernel_name)
213233
return SyclKernel._create(
214234
DPCTLKernelBundle_GetKernel(self._program_ref, name),
215235
kernel_name
216236
)
217237

218238
def has_sycl_kernel(self, str kernel_name):
219239
name = kernel_name.encode('utf8')
240+
if self._is_sycl_source:
241+
return DPCTLKernelBundle_HasSyclKernel(self._program_ref, name)
220242
return DPCTLKernelBundle_HasKernel(self._program_ref, name)
221243

222244
def addressof_ref(self):
@@ -272,7 +294,7 @@ cpdef create_program_from_source(SyclQueue q, str src, str copts=""):
272294
if KBref is NULL:
273295
raise SyclProgramCompilationError()
274296

275-
return SyclProgram._create(KBref)
297+
return SyclProgram._create(KBref, False)
276298

277299

278300
cpdef create_program_from_spirv(SyclQueue q, const unsigned char[:] IL,
@@ -318,7 +340,107 @@ cpdef create_program_from_spirv(SyclQueue q, const unsigned char[:] IL,
318340
if KBref is NULL:
319341
raise SyclProgramCompilationError()
320342

321-
return SyclProgram._create(KBref)
343+
return SyclProgram._create(KBref, False)
344+
345+
346+
cpdef create_program_from_sycl_source(SyclQueue q, unicode source, list headers=[], list registered_names=[], list copts=[]):
347+
"""
348+
Creates an executable SYCL kernel_bundle from SYCL source code.
349+
350+
This uses the DPC++ ``kernel_compiler`` extension to create a
351+
``sycl::kernel_bundle<sycl::bundle_state::executable>`` object from
352+
SYCL source code.
353+
354+
Parameters:
355+
q (:class:`dpctl.SyclQueue`)
356+
The :class:`dpctl.SyclQueue` for which the
357+
:class:`.SyclProgram` is going to be built.
358+
source (unicode)
359+
SYCL source code string.
360+
headers (list)
361+
Optional list of virtual headers, where each entry in the list
362+
needs to be a tuple of header name and header content. See the
363+
documentation of the ``include_files`` property in the DPC++
364+
``kernel_compiler`` extension for more information.
365+
Default: []
366+
registered_names (list, optional)
367+
Optional list of kernel names to register. See the
368+
documentation of the ``registered_names`` property in the DPC++
369+
``kernel_compiler`` extension for more information.
370+
Default: []
371+
copts (list)
372+
Optional list of compilation flags that will be used
373+
when compiling the program. Default: ``""``.
374+
375+
Returns:
376+
program (:class:`.SyclProgram`)
377+
A :class:`.SyclProgram` object wrapping the
378+
``sycl::kernel_bundle<sycl::bundle_state::executable>``
379+
returned by the C API.
380+
381+
Raises:
382+
SyclProgramCompilationError
383+
If a SYCL kernel bundle could not be created.
384+
"""
385+
cdef DPCTLSyclKernelBundleRef KBref
386+
cdef DPCTLSyclContextRef CRef = q.get_sycl_context().get_context_ref()
387+
cdef DPCTLSyclDeviceRef DRef = q.get_sycl_device().get_device_ref()
388+
cdef bytes bSrc = source.encode('utf8')
389+
cdef const char *Src = <const char*>bSrc
390+
cdef DPCTLBuildOptionListRef BuildOpts = DPCTLBuildOptionList_Create()
391+
cdef bytes bOpt
392+
cdef const char* sOpt
393+
cdef bytes bName
394+
cdef const char* sName
395+
cdef bytes bContent
396+
cdef const char* sContent
397+
for opt in copts:
398+
if not isinstance(opt, unicode):
399+
DPCTLBuildOptionList_Delete(BuildOpts)
400+
raise SyclProgramCompilationError()
401+
bOpt = opt.encode('utf8')
402+
sOpt = <const char*>bOpt
403+
DPCTLBuildOptionList_Append(BuildOpts, sOpt)
404+
405+
cdef DPCTLKernelNameListRef KernelNames = DPCTLKernelNameList_Create()
406+
for name in registered_names:
407+
if not isinstance(name, unicode):
408+
DPCTLBuildOptionList_Delete(BuildOpts)
409+
DPCTLKernelNameList_Delete(KernelNames)
410+
raise SyclProgramCompilationError()
411+
bName = name.encode('utf8')
412+
sName = <const char*>bName
413+
DPCTLKernelNameList_Append(KernelNames, sName)
414+
415+
416+
cdef DPCTLVirtualHeaderListRef VirtualHeaders = DPCTLVirtualHeaderList_Create()
417+
for name, content in headers:
418+
if not isinstance(name, unicode) or not isinstance(content, unicode):
419+
DPCTLBuildOptionList_Delete(BuildOpts)
420+
DPCTLKernelNameList_Delete(KernelNames)
421+
DPCTLVirtualHeaderList_Delete(VirtualHeaders)
422+
raise SyclProgramCompilationError()
423+
bName = name.encode('utf8')
424+
sName = <const char*>bName
425+
bContent = content.encode('utf8')
426+
sContent = <const char*>bContent
427+
DPCTLVirtualHeaderList_Append(VirtualHeaders, sName, sContent)
428+
429+
KBref = DPCTLKernelBundle_CreateFromSYCLSource(CRef, DRef, Src,
430+
VirtualHeaders, KernelNames,
431+
BuildOpts)
432+
433+
if KBref is NULL:
434+
DPCTLBuildOptionList_Delete(BuildOpts)
435+
DPCTLKernelNameList_Delete(KernelNames)
436+
DPCTLVirtualHeaderList_Delete(VirtualHeaders)
437+
raise SyclProgramCompilationError()
438+
439+
DPCTLBuildOptionList_Delete(BuildOpts)
440+
DPCTLKernelNameList_Delete(KernelNames)
441+
DPCTLVirtualHeaderList_Delete(VirtualHeaders)
442+
443+
return SyclProgram._create(KBref, True)
322444

323445

324446
cdef api DPCTLSyclKernelBundleRef SyclProgram_GetKernelBundleRef(SyclProgram pro):
@@ -335,4 +457,4 @@ cdef api SyclProgram SyclProgram_Make(DPCTLSyclKernelBundleRef KBRef):
335457
reference.
336458
"""
337459
cdef DPCTLSyclKernelBundleRef copied_KBRef = DPCTLKernelBundle_Copy(KBRef)
338-
return SyclProgram._create(copied_KBRef)
460+
return SyclProgram._create(copied_KBRef, False)

0 commit comments

Comments
 (0)