Skip to content

Commit 96da90b

Browse files
authored
Merge pull request #1984 from sommerlukas/work_group_memory
Add support for work_group_memory extension
2 parents 7aa6fb7 + 5dca7ba commit 96da90b

19 files changed

+855
-4
lines changed

dpctl/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
SyclKernelSubmitError,
5353
SyclQueue,
5454
SyclQueueCreationError,
55+
WorkGroupMemory,
5556
)
5657
from ._sycl_queue_manager import get_device_cached_queue
5758
from ._sycl_timer import SyclTimer
@@ -100,6 +101,7 @@
100101
"SyclKernelInvalidRangeError",
101102
"SyclKernelSubmitError",
102103
"SyclQueueCreationError",
104+
"WorkGroupMemory",
103105
]
104106
__all__ += [
105107
"get_device_cached_queue",

dpctl/_backend.pxd

+16-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ cdef extern from "syclinterface/dpctl_sycl_enum_types.h":
6969
_FLOAT 'DPCTL_FLOAT32_T',
7070
_DOUBLE 'DPCTL_FLOAT64_T',
7171
_VOID_PTR 'DPCTL_VOID_PTR',
72-
_LOCAL_ACCESSOR 'DPCTL_LOCAL_ACCESSOR'
72+
_LOCAL_ACCESSOR 'DPCTL_LOCAL_ACCESSOR',
73+
_WORK_GROUP_MEMORY 'DPCTL_WORK_GROUP_MEMORY'
7374

7475
ctypedef enum _queue_property_type 'DPCTLQueuePropertyType':
7576
_DEFAULT_PROPERTY 'DPCTL_DEFAULT_PROPERTY'
@@ -470,3 +471,17 @@ cdef extern from "syclinterface/dpctl_sycl_usm_interface.h":
470471
cdef DPCTLSyclDeviceRef DPCTLUSM_GetPointerDevice(
471472
DPCTLSyclUSMRef MRef,
472473
DPCTLSyclContextRef CRef)
474+
475+
cdef extern from "syclinterface/dpctl_sycl_extension_interface.h":
476+
cdef struct RawWorkGroupMemoryTy
477+
ctypedef RawWorkGroupMemoryTy RawWorkGroupMemory
478+
479+
cdef struct DPCTLOpaqueWorkGroupMemory
480+
ctypedef DPCTLOpaqueWorkGroupMemory *DPCTLSyclWorkGroupMemoryRef;
481+
482+
cdef DPCTLSyclWorkGroupMemoryRef DPCTLWorkGroupMemory_Create(size_t nbytes);
483+
484+
cdef void DPCTLWorkGroupMemory_Delete(
485+
DPCTLSyclWorkGroupMemoryRef Ref);
486+
487+
cdef bint DPCTLWorkGroupMemory_Available();

dpctl/_sycl_queue.pxd

+16-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,12 @@
2222

2323
from libcpp cimport bool as cpp_bool
2424

25-
from ._backend cimport DPCTLSyclDeviceRef, DPCTLSyclQueueRef, _arg_data_type
25+
from ._backend cimport (
26+
DPCTLSyclDeviceRef,
27+
DPCTLSyclQueueRef,
28+
DPCTLSyclWorkGroupMemoryRef,
29+
_arg_data_type,
30+
)
2631
from ._sycl_context cimport SyclContext
2732
from ._sycl_device cimport SyclDevice
2833
from ._sycl_event cimport SyclEvent
@@ -98,3 +103,13 @@ cdef public api class SyclQueue (_SyclQueue) [
98103
cpdef prefetch(self, ptr, size_t count=*)
99104
cpdef mem_advise(self, ptr, size_t count, int mem)
100105
cpdef SyclEvent submit_barrier(self, dependent_events=*)
106+
107+
cdef public api class _WorkGroupMemory [
108+
object Py_WorkGroupMemoryObject, type Py_WorkGroupMemoryType
109+
]:
110+
cdef DPCTLSyclWorkGroupMemoryRef _mem_ref
111+
112+
cdef public api class WorkGroupMemory(_WorkGroupMemory) [
113+
object PyWorkGroupMemoryObject, type PyWorkGroupMemoryType
114+
]:
115+
pass

dpctl/_sycl_queue.pyx

+102
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,17 @@ from ._backend cimport ( # noqa: E211
5454
DPCTLSyclContextRef,
5555
DPCTLSyclDeviceSelectorRef,
5656
DPCTLSyclEventRef,
57+
DPCTLWorkGroupMemory_Available,
58+
DPCTLWorkGroupMemory_Create,
59+
DPCTLWorkGroupMemory_Delete,
5760
_arg_data_type,
5861
_backend_type,
5962
_queue_property_type,
6063
)
6164
from .memory._memory cimport _Memory
6265

6366
import ctypes
67+
import numbers
6468

6569
from .enum_types import backend_type
6670

@@ -250,6 +254,15 @@ cdef class _kernel_arg_type:
250254
_arg_data_type._LOCAL_ACCESSOR
251255
)
252256

257+
@property
258+
def dpctl_work_group_memory(self):
259+
cdef str p_name = "dpctl_work_group_memory"
260+
return kernel_arg_type_attribute(
261+
self._name,
262+
p_name,
263+
_arg_data_type._WORK_GROUP_MEMORY
264+
)
265+
253266

254267
kernel_arg_type = _kernel_arg_type()
255268

@@ -849,6 +862,9 @@ cdef class SyclQueue(_SyclQueue):
849862
elif isinstance(arg, _Memory):
850863
kargs[idx]= <void*>(<size_t>arg._pointer)
851864
kargty[idx] = _arg_data_type._VOID_PTR
865+
elif isinstance(arg, WorkGroupMemory):
866+
kargs[idx] = <void*>(<size_t>arg._ref)
867+
kargty[idx] = _arg_data_type._WORK_GROUP_MEMORY
852868
else:
853869
ret = -1
854870
return ret
@@ -1524,3 +1540,89 @@ cdef api SyclQueue SyclQueue_Make(DPCTLSyclQueueRef QRef):
15241540
"""
15251541
cdef DPCTLSyclQueueRef copied_QRef = DPCTLQueue_Copy(QRef)
15261542
return SyclQueue._create(copied_QRef)
1543+
1544+
cdef class _WorkGroupMemory:
1545+
def __dealloc__(self):
1546+
if(self._mem_ref):
1547+
DPCTLWorkGroupMemory_Delete(self._mem_ref)
1548+
1549+
cdef class WorkGroupMemory:
1550+
"""
1551+
WorkGroupMemory(nbytes)
1552+
Python class representing the ``work_group_memory`` class from the
1553+
Workgroup Memory oneAPI SYCL extension for low-overhead allocation of local
1554+
memory shared by the workitems in a workgroup.
1555+
1556+
This class is intended be used as kernel argument when launching kernels.
1557+
1558+
This is based on a DPC++ SYCL extension and only available in newer
1559+
versions. Use ``is_available()`` to check availability in your build.
1560+
1561+
There are multiple ways to create a `WorkGroupMemory`.
1562+
1563+
- If the constructor is invoked with just a single argument, this argument
1564+
is interpreted as the number of bytes to allocated in the shared local
1565+
memory.
1566+
1567+
- If the constructor is invoked with two arguments, the first argument is
1568+
interpreted as the datatype of the local memory, using the numpy type
1569+
naming scheme.
1570+
The second argument is interpreted as the number of elements to allocate.
1571+
The number of bytes to allocate is then computed from the byte size of
1572+
the data type and the element count.
1573+
1574+
Args:
1575+
args:
1576+
Variadic argument, see class documentation.
1577+
1578+
Raises:
1579+
TypeError: In case of incorrect arguments given to constructors,
1580+
unexpected types of input arguments.
1581+
"""
1582+
def __cinit__(self, *args):
1583+
cdef size_t nbytes
1584+
if not DPCTLWorkGroupMemory_Available():
1585+
raise RuntimeError("Workgroup memory extension not available")
1586+
1587+
if not (0 < len(args) < 3):
1588+
raise TypeError("WorkGroupMemory constructor takes 1 or 2 "
1589+
f"arguments, but {len(args)} were given")
1590+
1591+
if len(args) == 1:
1592+
if not isinstance(args[0], numbers.Integral):
1593+
raise TypeError("WorkGroupMemory single argument constructor"
1594+
"expects first argument to be `int`",
1595+
f"but got {type(args[0])}")
1596+
nbytes = <size_t>(args[0])
1597+
else:
1598+
if not isinstance(args[0], str):
1599+
raise TypeError("WorkGroupMemory constructor expects first"
1600+
f"argument to be `str`, but got {type(args[0])}")
1601+
if not isinstance(args[1], numbers.Integral):
1602+
raise TypeError("WorkGroupMemory constructor expects second"
1603+
f"argument to be `int`, but got {type(args[1])}")
1604+
dtype = <str>(args[0])
1605+
count = <size_t>(args[1])
1606+
if not dtype[0] in ["i", "u", "f"]:
1607+
raise TypeError(f"Unrecognized type value: '{dtype}'")
1608+
try:
1609+
bit_width = int(dtype[1:])
1610+
except ValueError:
1611+
raise TypeError(f"Unrecognized type value: '{dtype}'")
1612+
1613+
byte_size = <size_t>bit_width
1614+
nbytes = count * byte_size
1615+
1616+
self._mem_ref = DPCTLWorkGroupMemory_Create(nbytes)
1617+
1618+
"""Check whether the work_group_memory extension is available"""
1619+
@staticmethod
1620+
def is_available():
1621+
return DPCTLWorkGroupMemory_Available()
1622+
1623+
property _ref:
1624+
"""Returns the address of the C API ``DPCTLWorkGroupMemoryRef``
1625+
pointer as a ``size_t``.
1626+
"""
1627+
def __get__(self):
1628+
return <size_t>self._mem_ref

dpctl/apis/include/dpctl_capi.h

+4-2
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,11 @@
2525
#pragma once
2626

2727
// clang-format off
28-
// Ordering of includes is important here. dpctl_sycl_types defines types
29-
// used by dpctl's Python C-API headers.
28+
// Ordering of includes is important here. dpctl_sycl_types and
29+
// dpctl_sycl_extension_interface define types used by dpctl's Python
30+
// C-API headers.
3031
#include "syclinterface/dpctl_sycl_types.h"
32+
#include "syclinterface/dpctl_sycl_extension_interface.h"
3133
#ifdef __cplusplus
3234
#define CYTHON_EXTERN_C extern "C"
3335
#else

dpctl/sycl.pxd

+13
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ cdef extern from "sycl/sycl.hpp" namespace "sycl":
4242
"sycl::kernel_bundle<sycl::bundle_state::executable>":
4343
pass
4444

45+
cdef extern from "syclinterface/dpctl_sycl_extension_interface.h":
46+
cdef struct RawWorkGroupMemoryTy
47+
ctypedef RawWorkGroupMemoryTy RawWorkGroupMemory
48+
4549
cdef extern from "syclinterface/dpctl_sycl_type_casters.hpp" \
4650
namespace "dpctl::syclinterface":
4751
# queue
@@ -67,3 +71,12 @@ cdef extern from "syclinterface/dpctl_sycl_type_casters.hpp" \
6771
"dpctl::syclinterface::wrap<sycl::event>" (const event *)
6872
cdef event * unwrap_event "dpctl::syclinterface::unwrap<sycl::event>" (
6973
dpctl_backend.DPCTLSyclEventRef)
74+
75+
# work group memory extension
76+
cdef dpctl_backend.DPCTLSyclWorkGroupMemoryRef wrap_work_group_memory \
77+
"dpctl::syclinterface::wrap<RawWorkGroupMemory>" \
78+
(const RawWorkGroupMemory *)
79+
80+
cdef RawWorkGroupMemory * unwrap_work_group_memory \
81+
"dpctl::syclinterface::unwrap<RawWorkGroupMemory>" (
82+
dpctl_backend.DPCTLSyclWorkGroupMemoryRef)
Binary file not shown.

dpctl/tests/test_sycl_kernel_submit.py

+1
Original file line numberDiff line numberDiff line change
@@ -278,3 +278,4 @@ def test_kernel_arg_type():
278278
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_float64)
279279
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_void_ptr)
280280
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_local_accessor)
281+
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_work_group_memory)

dpctl/tests/test_work_group_memory.py

+90
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2025 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
"""Defines unit test cases for the work_group_memory in a SYCL kernel"""
18+
19+
import os
20+
21+
import pytest
22+
23+
import dpctl
24+
import dpctl.tensor
25+
26+
27+
def get_spirv_abspath(fn):
28+
curr_dir = os.path.dirname(os.path.abspath(__file__))
29+
spirv_file = os.path.join(curr_dir, "input_files", fn)
30+
return spirv_file
31+
32+
33+
# The kernel in the SPIR-V file used in this test was generated from the
34+
# following SYCL source code:
35+
# #include <sycl/sycl.hpp>
36+
# using namespace sycl;
37+
# namespace syclexp = sycl::ext::oneapi::experimental;
38+
# namespace syclext = sycl::ext::oneapi;
39+
# using data_t = int32_t;
40+
#
41+
# extern "C" SYCL_EXTERNAL
42+
# SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<1>))
43+
# void local_mem_kernel(data_t* in, data_t* out,
44+
# syclexp::work_group_memory<data_t> mem){
45+
# auto* local_mem = &mem;
46+
# auto item = syclext::this_work_item::get_nd_item<1>();
47+
# size_t global_id = item.get_global_linear_id();
48+
# size_t local_id = item.get_local_linear_id();
49+
# local_mem[local_id] = in[global_id];
50+
# out[global_id] = local_mem[local_id];
51+
# }
52+
53+
54+
def test_submit_work_group_memory():
55+
if not dpctl.WorkGroupMemory.is_available():
56+
pytest.skip("Work group memory extension not supported")
57+
58+
try:
59+
q = dpctl.SyclQueue("level_zero")
60+
except dpctl.SyclQueueCreationError:
61+
pytest.skip("LevelZero queue could not be created")
62+
spirv_file = get_spirv_abspath("work-group-memory-kernel.spv")
63+
with open(spirv_file, "br") as spv:
64+
spv_bytes = spv.read()
65+
prog = dpctl.program.create_program_from_spirv(q, spv_bytes)
66+
kernel = prog.get_sycl_kernel("__sycl_kernel_local_mem_kernel")
67+
local_size = 16
68+
global_size = local_size * 8
69+
70+
x = dpctl.tensor.ones(global_size, dtype="int32")
71+
y = dpctl.tensor.zeros(global_size, dtype="int32")
72+
x.sycl_queue.wait()
73+
y.sycl_queue.wait()
74+
75+
try:
76+
q.submit(
77+
kernel,
78+
[
79+
x.usm_data,
80+
y.usm_data,
81+
dpctl.WorkGroupMemory("i4", local_size),
82+
],
83+
[global_size],
84+
[local_size],
85+
)
86+
q.wait()
87+
except dpctl._sycl_queue.SyclKernelSubmitError:
88+
pytest.skip(f"Kernel submission to {q.sycl_device} failed")
89+
90+
assert dpctl.tensor.all(x == y)

0 commit comments

Comments
 (0)