@@ -54,13 +54,17 @@ from ._backend cimport ( # noqa: E211
5454 DPCTLSyclContextRef,
5555 DPCTLSyclDeviceSelectorRef,
5656 DPCTLSyclEventRef,
57+ DPCTLWorkGroupMemory_Available,
58+ DPCTLWorkGroupMemory_Create,
59+ DPCTLWorkGroupMemory_Delete,
5760 _arg_data_type,
5861 _backend_type,
5962 _queue_property_type,
6063)
6164from .memory._memory cimport _Memory
6265
6366import ctypes
67+ import numbers
6468
6569from .enum_types import backend_type
6670
@@ -250,6 +254,15 @@ cdef class _kernel_arg_type:
250254 _arg_data_type._LOCAL_ACCESSOR
251255 )
252256
257+ @property
258+ def dpctl_work_group_memory (self ):
259+ cdef str p_name = " dpctl_work_group_memory"
260+ return kernel_arg_type_attribute(
261+ self ._name,
262+ p_name,
263+ _arg_data_type._WORK_GROUP_MEMORY
264+ )
265+
253266
254267kernel_arg_type = _kernel_arg_type()
255268
@@ -849,6 +862,9 @@ cdef class SyclQueue(_SyclQueue):
849862 elif isinstance (arg, _Memory):
850863 kargs[idx]= < void * > (< size_t> arg._pointer)
851864 kargty[idx] = _arg_data_type._VOID_PTR
865+ elif isinstance (arg, WorkGroupMemory):
866+ kargs[idx] = < void * > (< size_t> arg._ref)
867+ kargty[idx] = _arg_data_type._WORK_GROUP_MEMORY
852868 else :
853869 ret = - 1
854870 return ret
@@ -1524,3 +1540,89 @@ cdef api SyclQueue SyclQueue_Make(DPCTLSyclQueueRef QRef):
15241540 """
15251541 cdef DPCTLSyclQueueRef copied_QRef = DPCTLQueue_Copy(QRef)
15261542 return SyclQueue._create(copied_QRef)
1543+
1544+ cdef class _WorkGroupMemory:
1545+ def __dealloc__ (self ):
1546+ if (self ._mem_ref):
1547+ DPCTLWorkGroupMemory_Delete(self ._mem_ref)
1548+
1549+ cdef class WorkGroupMemory:
1550+ """
1551+ WorkGroupMemory(nbytes)
1552+ Python class representing the ``work_group_memory`` class from the
1553+ Workgroup Memory oneAPI SYCL extension for low-overhead allocation of local
1554+ memory shared by the workitems in a workgroup.
1555+
1556+ This class is intended be used as kernel argument when launching kernels.
1557+
1558+ This is based on a DPC++ SYCL extension and only available in newer
1559+ versions. Use ``is_available()`` to check availability in your build.
1560+
1561+ There are multiple ways to create a `WorkGroupMemory`.
1562+
1563+ - If the constructor is invoked with just a single argument, this argument
1564+ is interpreted as the number of bytes to allocated in the shared local
1565+ memory.
1566+
1567+ - If the constructor is invoked with two arguments, the first argument is
1568+ interpreted as the datatype of the local memory, using the numpy type
1569+ naming scheme.
1570+ The second argument is interpreted as the number of elements to allocate.
1571+ The number of bytes to allocate is then computed from the byte size of
1572+ the data type and the element count.
1573+
1574+ Args:
1575+ args:
1576+ Variadic argument, see class documentation.
1577+
1578+ Raises:
1579+ TypeError: In case of incorrect arguments given to constructors,
1580+ unexpected types of input arguments.
1581+ """
1582+ def __cinit__ (self , *args ):
1583+ cdef size_t nbytes
1584+ if not DPCTLWorkGroupMemory_Available():
1585+ raise RuntimeError (" Workgroup memory extension not available" )
1586+
1587+ if not (0 < len (args) < 3 ):
1588+ raise TypeError (" WorkGroupMemory constructor takes 1 or 2 "
1589+ f" arguments, but {len(args)} were given" )
1590+
1591+ if len (args) == 1 :
1592+ if not isinstance (args[0 ], numbers.Integral):
1593+ raise TypeError (" WorkGroupMemory single argument constructor"
1594+ " expects first argument to be `int`" ,
1595+ f" but got {type(args[0])}" )
1596+ nbytes = < size_t> (args[0 ])
1597+ else :
1598+ if not isinstance (args[0 ], str ):
1599+ raise TypeError (" WorkGroupMemory constructor expects first"
1600+ f" argument to be `str`, but got {type(args[0])}" )
1601+ if not isinstance (args[1 ], numbers.Integral):
1602+ raise TypeError (" WorkGroupMemory constructor expects second"
1603+ f" argument to be `int`, but got {type(args[1])}" )
1604+ dtype = < str > (args[0 ])
1605+ count = < size_t> (args[1 ])
1606+ if not dtype[0 ] in [" i" , " u" , " f" ]:
1607+ raise TypeError (f" Unrecognized type value: '{dtype}'" )
1608+ try :
1609+ bit_width = int (dtype[1 :])
1610+ except ValueError :
1611+ raise TypeError (f" Unrecognized type value: '{dtype}'" )
1612+
1613+ byte_size = < size_t> bit_width
1614+ nbytes = count * byte_size
1615+
1616+ self ._mem_ref = DPCTLWorkGroupMemory_Create(nbytes)
1617+
1618+ """ Check whether the work_group_memory extension is available"""
1619+ @staticmethod
1620+ def is_available ():
1621+ return DPCTLWorkGroupMemory_Available()
1622+
1623+ property _ref :
1624+ """ Returns the address of the C API ``DPCTLWorkGroupMemoryRef``
1625+ pointer as a ``size_t``.
1626+ """
1627+ def __get__ (self ):
1628+ return < size_t> self ._mem_ref
0 commit comments