Skip to content

Commit 5ac336e

Browse files
committed
Add support for work_group_memory extension
Extend kernel argument handling to add support for the work_group_memory extension, allowing users to dynamically allocate local memory for a kernel. Signed-off-by: Lukas Sommer <[email protected]>
1 parent 0d0ff97 commit 5ac336e

22 files changed

+813
-1
lines changed

.flake8

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ per-file-ignores =
3131
dpctl/utils/_compute_follows_data.pyx: E999, E225, E227
3232
dpctl/utils/_onetrace_context.py: E501, W505
3333
dpctl/tensor/_array_api.py: E501, W505
34+
dpctl/experimental/_work_group_memory.pyx: E999
3435
examples/cython/sycl_buffer/syclbuffer/_syclbuffer.pyx: E999, E225, E402
3536
examples/cython/usm_memory/blackscholes/_blackscholes_usm.pyx: E999, E225, E226, E402
3637
examples/cython/use_dpctl_sycl/use_dpctl_sycl/_cython_api.pyx: E999, E225, E226, E402

dpctl/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,3 +207,4 @@ add_subdirectory(program)
207207
add_subdirectory(memory)
208208
add_subdirectory(tensor)
209209
add_subdirectory(utils)
210+
add_subdirectory(experimental)

dpctl/_backend.pxd

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ cdef extern from "syclinterface/dpctl_sycl_enum_types.h":
6969
_FLOAT 'DPCTL_FLOAT32_T',
7070
_DOUBLE 'DPCTL_FLOAT64_T',
7171
_VOID_PTR 'DPCTL_VOID_PTR',
72-
_LOCAL_ACCESSOR 'DPCTL_LOCAL_ACCESSOR'
72+
_LOCAL_ACCESSOR 'DPCTL_LOCAL_ACCESSOR',
73+
_WORK_GROUP_MEMORY 'DPCTL_WORK_GROUP_MEMORY'
7374

7475
ctypedef enum _queue_property_type 'DPCTLQueuePropertyType':
7576
_DEFAULT_PROPERTY 'DPCTL_DEFAULT_PROPERTY'
@@ -468,3 +469,18 @@ cdef extern from "syclinterface/dpctl_sycl_usm_interface.h":
468469
cdef DPCTLSyclDeviceRef DPCTLUSM_GetPointerDevice(
469470
DPCTLSyclUSMRef MRef,
470471
DPCTLSyclContextRef CRef)
472+
473+
cdef extern from "syclinterface/dpctl_sycl_extension_interface.h":
474+
cdef struct RawWorkGroupMemoryTy
475+
ctypedef RawWorkGroupMemoryTy RawWorkGroupMemory
476+
477+
478+
cdef struct DPCTLOpaqueWorkGroupMemory
479+
ctypedef DPCTLOpaqueWorkGroupMemory *DPCTLSyclWorkGroupMemoryRef;
480+
481+
cdef DPCTLSyclWorkGroupMemoryRef DPCTLWorkGroupMemory_Create(size_t nbytes);
482+
483+
cdef void DPCTLWorkGroupMemory_Delete(
484+
DPCTLSyclWorkGroupMemoryRef Ref);
485+
486+
cdef bint DPCTLWorkGroupMemory_Available();

dpctl/_sycl_queue.pyx

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ from ._backend cimport ( # noqa: E211
5858
_backend_type,
5959
_queue_property_type,
6060
)
61+
from .experimental._work_group_memory cimport WorkGroupMemory
6162
from .memory._memory cimport _Memory
6263

6364
import ctypes
@@ -250,6 +251,15 @@ cdef class _kernel_arg_type:
250251
_arg_data_type._LOCAL_ACCESSOR
251252
)
252253

254+
@property
255+
def dpctl_work_group_memory(self):
256+
cdef str p_name = "dpctl_work_group_memory"
257+
return kernel_arg_type_attribute(
258+
self._name,
259+
p_name,
260+
_arg_data_type._WORK_GROUP_MEMORY
261+
)
262+
253263

254264
kernel_arg_type = _kernel_arg_type()
255265

@@ -849,6 +859,9 @@ cdef class SyclQueue(_SyclQueue):
849859
elif isinstance(arg, _Memory):
850860
kargs[idx]= <void*>(<size_t>arg._pointer)
851861
kargty[idx] = _arg_data_type._VOID_PTR
862+
elif isinstance(arg, WorkGroupMemory):
863+
kargs[idx] = <void*>(<size_t>arg._ref)
864+
kargty[idx] = _arg_data_type._WORK_GROUP_MEMORY
852865
else:
853866
ret = -1
854867
return ret

dpctl/experimental/CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
file(GLOB _cython_sources *.pyx)
2+
foreach(_cy_file ${_cython_sources})
3+
get_filename_component(_trgt ${_cy_file} NAME_WLE)
4+
build_dpctl_ext(${_trgt} ${_cy_file} "dpctl/experimental" RELATIVE_PATH "..")
5+
target_include_directories(${_trgt} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include)
6+
target_link_libraries(DpctlCAPI INTERFACE ${_trgt}_headers)
7+
endforeach()

dpctl/experimental/__init__.pxd

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2025 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
# distutils: language = c++
18+
# cython: language_level=3
19+
20+
"""This file declares the extension types and functions for the Cython API
21+
implemented in dpctl.experimental.*.pyx.
22+
"""
23+
24+
25+
from dpctl.experimental._work_group_memory cimport *

dpctl/experimental/__init__.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2025 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
"""
18+
**Data Parallel Control Experimental" provides Python objects to interface
19+
with different experimental SYCL language extensions defined by the DPC++
20+
SYCL implementation.
21+
"""
22+
23+
from ._work_group_memory import WorkGroupMemory
24+
25+
__all__ = [
26+
"WorkGroupMemory",
27+
]
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2025 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
# distutils: language = c++
18+
# cython: language_level=3
19+
20+
from .._backend cimport DPCTLSyclWorkGroupMemoryRef
21+
22+
23+
cdef public api class _WorkGroupMemory [
24+
object Py_WorkGroupMemoryObject, type Py_WorkGroupMemoryType
25+
]:
26+
cdef DPCTLSyclWorkGroupMemoryRef _mem_ref
27+
28+
cdef public api class WorkGroupMemory(_WorkGroupMemory) [
29+
object PyWorkGroupMemoryObject, type PyWorkGroupMemoryType
30+
]:
31+
pass
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2025 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
# distutils: language = c++
18+
# cython: language_level=3
19+
# cython: linetrace=True
20+
21+
from .._backend cimport (
22+
DPCTLWorkGroupMemory_Available,
23+
DPCTLWorkGroupMemory_Create,
24+
DPCTLWorkGroupMemory_Delete,
25+
)
26+
27+
28+
cdef class _WorkGroupMemory:
29+
def __dealloc__(self):
30+
if(self._mem_ref):
31+
DPCTLWorkGroupMemory_Delete(self._mem_ref)
32+
33+
cdef class WorkGroupMemory:
34+
"""
35+
WorkGroupMemory(nbytes)
36+
Python class representing the ``work_group_memory`` class from the
37+
Workgroup Memory oneAPI SYCL extension for low-overhead allocation of local
38+
memory shared by the workitems in a workgroup.
39+
40+
Args:
41+
nbytes (int)
42+
number of bytes to allocate in local memory.
43+
Expected to be positive.
44+
"""
45+
def __cinit__(self, Py_ssize_t nbytes):
46+
if not DPCTLWorkGroupMemory_Available():
47+
raise RuntimeError("Workgroup memory extension not available")
48+
49+
self._mem_ref = DPCTLWorkGroupMemory_Create(nbytes)
50+
51+
@staticmethod
52+
def is_available():
53+
return DPCTLWorkGroupMemory_Available()
54+
55+
property _ref:
56+
"""Returns the address of the C API ``DPCTLWorkGroupMemoryRef``
57+
pointer as a ``size_t``.
58+
"""
59+
def __get__(self):
60+
return <size_t>self._mem_ref

dpctl/sycl.pxd

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,12 @@ cdef extern from "syclinterface/dpctl_sycl_type_casters.hpp" \
6767
"dpctl::syclinterface::wrap<sycl::event>" (const event *)
6868
cdef event * unwrap_event "dpctl::syclinterface::unwrap<sycl::event>" (
6969
dpctl_backend.DPCTLSyclEventRef)
70+
71+
# work group memory extension[
72+
cdef dpctl_backend.DPCTLSyclWorkGroupMemoryRef wrap_work_group_memory \
73+
"dpctl::syclinterface::wrap<RawWorkGroupMemory>" \
74+
(const RawWorkGroupMemory *)
75+
76+
cdef RawWorkGroupMemory * unwrap_work_group_memory \
77+
"dpctl::syclinterface::unwrap<RawWorkGroupMemory>" (
78+
dpctl_backend.DPCTLSyclWorkGroupMemoryRef)

0 commit comments

Comments
 (0)