@@ -54,13 +54,17 @@ from ._backend cimport ( # noqa: E211
54
54
DPCTLSyclContextRef,
55
55
DPCTLSyclDeviceSelectorRef,
56
56
DPCTLSyclEventRef,
57
+ DPCTLWorkGroupMemory_Available,
58
+ DPCTLWorkGroupMemory_Create,
59
+ DPCTLWorkGroupMemory_Delete,
57
60
_arg_data_type,
58
61
_backend_type,
59
62
_queue_property_type,
60
63
)
61
64
from .memory._memory cimport _Memory
62
65
63
66
import ctypes
67
+ import numbers
64
68
65
69
from .enum_types import backend_type
66
70
@@ -250,6 +254,15 @@ cdef class _kernel_arg_type:
250
254
_arg_data_type._LOCAL_ACCESSOR
251
255
)
252
256
257
+ @property
258
+ def dpctl_work_group_memory (self ):
259
+ cdef str p_name = " dpctl_work_group_memory"
260
+ return kernel_arg_type_attribute(
261
+ self ._name,
262
+ p_name,
263
+ _arg_data_type._WORK_GROUP_MEMORY
264
+ )
265
+
253
266
254
267
kernel_arg_type = _kernel_arg_type()
255
268
@@ -849,6 +862,9 @@ cdef class SyclQueue(_SyclQueue):
849
862
elif isinstance (arg, _Memory):
850
863
kargs[idx]= < void * > (< size_t> arg._pointer)
851
864
kargty[idx] = _arg_data_type._VOID_PTR
865
+ elif isinstance (arg, WorkGroupMemory):
866
+ kargs[idx] = < void * > (< size_t> arg._ref)
867
+ kargty[idx] = _arg_data_type._WORK_GROUP_MEMORY
852
868
else :
853
869
ret = - 1
854
870
return ret
@@ -1524,3 +1540,89 @@ cdef api SyclQueue SyclQueue_Make(DPCTLSyclQueueRef QRef):
1524
1540
"""
1525
1541
cdef DPCTLSyclQueueRef copied_QRef = DPCTLQueue_Copy(QRef)
1526
1542
return SyclQueue._create(copied_QRef)
1543
+
1544
+ cdef class _WorkGroupMemory:
1545
+ def __dealloc__ (self ):
1546
+ if (self ._mem_ref):
1547
+ DPCTLWorkGroupMemory_Delete(self ._mem_ref)
1548
+
1549
+ cdef class WorkGroupMemory:
1550
+ """
1551
+ WorkGroupMemory(nbytes)
1552
+ Python class representing the ``work_group_memory`` class from the
1553
+ Workgroup Memory oneAPI SYCL extension for low-overhead allocation of local
1554
+ memory shared by the workitems in a workgroup.
1555
+
1556
+ This class is intended be used as kernel argument when launching kernels.
1557
+
1558
+ This is based on a DPC++ SYCL extension and only available in newer
1559
+ versions. Use ``is_available()`` to check availability in your build.
1560
+
1561
+ There are multiple ways to create a `WorkGroupMemory`.
1562
+
1563
+ - If the constructor is invoked with just a single argument, this argument
1564
+ is interpreted as the number of bytes to allocated in the shared local
1565
+ memory.
1566
+
1567
+ - If the constructor is invoked with two arguments, the first argument is
1568
+ interpreted as the datatype of the local memory, using the numpy type
1569
+ naming scheme.
1570
+ The second argument is interpreted as the number of elements to allocate.
1571
+ The number of bytes to allocate is then computed from the byte size of
1572
+ the data type and the element count.
1573
+
1574
+ Args:
1575
+ args:
1576
+ Variadic argument, see class documentation.
1577
+
1578
+ Raises:
1579
+ TypeError: In case of incorrect arguments given to constructors,
1580
+ unexpected types of input arguments.
1581
+ """
1582
+ def __cinit__ (self , *args ):
1583
+ cdef size_t nbytes
1584
+ if not DPCTLWorkGroupMemory_Available():
1585
+ raise RuntimeError (" Workgroup memory extension not available" )
1586
+
1587
+ if not (0 < len (args) < 3 ):
1588
+ raise TypeError (" WorkGroupMemory constructor takes 1 or 2 "
1589
+ f" arguments, but {len(args)} were given" )
1590
+
1591
+ if len (args) == 1 :
1592
+ if not isinstance (args[0 ], numbers.Integral):
1593
+ raise TypeError (" WorkGroupMemory single argument constructor"
1594
+ " expects first argument to be `int`" ,
1595
+ f" but got {type(args[0])}" )
1596
+ nbytes = < size_t> (args[0 ])
1597
+ else :
1598
+ if not isinstance (args[0 ], str ):
1599
+ raise TypeError (" WorkGroupMemory constructor expects first"
1600
+ f" argument to be `str`, but got {type(args[0])}" )
1601
+ if not isinstance (args[1 ], numbers.Integral):
1602
+ raise TypeError (" WorkGroupMemory constructor expects second"
1603
+ f" argument to be `int`, but got {type(args[1])}" )
1604
+ dtype = < str > (args[0 ])
1605
+ count = < size_t> (args[1 ])
1606
+ if not dtype[0 ] in [" i" , " u" , " f" ]:
1607
+ raise TypeError (f" Unrecognized type value: '{dtype}'" )
1608
+ try :
1609
+ bit_width = int (dtype[1 :])
1610
+ except ValueError :
1611
+ raise TypeError (f" Unrecognized type value: '{dtype}'" )
1612
+
1613
+ byte_size = < size_t> bit_width
1614
+ nbytes = count * byte_size
1615
+
1616
+ self ._mem_ref = DPCTLWorkGroupMemory_Create(nbytes)
1617
+
1618
+ """ Check whether the work_group_memory extension is available"""
1619
+ @staticmethod
1620
+ def is_available ():
1621
+ return DPCTLWorkGroupMemory_Available()
1622
+
1623
+ property _ref :
1624
+ """ Returns the address of the C API ``DPCTLWorkGroupMemoryRef``
1625
+ pointer as a ``size_t``.
1626
+ """
1627
+ def __get__ (self ):
1628
+ return < size_t> self ._mem_ref
0 commit comments