@@ -1555,15 +1555,47 @@ cdef class WorkGroupMemory:
1555
1555
This is based on a DPC++ SYCL extension and only available in newer
1556
1556
versions. Use ``is_available()`` to check availability in your build.
1557
1557
1558
+ There are multiple ways to create a `WorkGroupMemory`.
1559
+
1560
+ - If the constructor is invoked with just a single argument, this argument
1561
+ is interpreted as the number of bytes to allocated in the shared local
1562
+ memory.
1563
+
1564
+ - If the constructor is invoked with two arguments, the first argument is
1565
+ interpreted as the datatype of the local memory, using the numpy type
1566
+ naming scheme.
1567
+ The second argument is interpreted as the number of elements to allocate.
1568
+ The number of bytes to allocate is then computed from the byte size of
1569
+ the data type and the element count.
1570
+
1558
1571
Args:
1559
- nbytes (int)
1560
- number of bytes to allocate in local memory.
1561
- Expected to be positive.
1572
+ args:
1573
+ Variadic argument, see class documentation.
1574
+
1575
+ Raises:
1576
+ TypeError: In case of incorrect arguments given to constructors,
1577
+ unexpected types of input arguments.
1562
1578
"""
1563
- def __cinit__ (self , Py_ssize_t nbytes ):
1579
+ def __cinit__ (self , *args ):
1580
+ cdef size_t nbytes
1564
1581
if not DPCTLWorkGroupMemory_Available():
1565
1582
raise RuntimeError (" Workgroup memory extension not available" )
1566
1583
1584
+ if not (0 < len (args) < 3 ):
1585
+ raise TypeError (" WorkGroupMemory constructor takes 1 or 2 "
1586
+ f" arguments, but {len(args)} were given" )
1587
+
1588
+ if len (args) == 1 :
1589
+ nbytes = < size_t> (args[0 ])
1590
+ else :
1591
+ dtype = < str > (args[0 ])
1592
+ count = < size_t> (args[1 ])
1593
+ ty = dtype[0 ]
1594
+ if not ty in [" i" , " u" , " f" ]:
1595
+ raise TypeError (f" Unrecognized type value: '{dtype}'" )
1596
+ byte_size = < size_t> (int (dtype[1 :]))
1597
+ nbytes = count * byte_size
1598
+
1567
1599
self ._mem_ref = DPCTLWorkGroupMemory_Create(nbytes)
1568
1600
1569
1601
""" Check whether the work_group_memory extension is available"""
0 commit comments