Skip to content

Commit 13cbd01

Browse files
committed
Allow construction with data type
1 parent c75fd5c commit 13cbd01

File tree

3 files changed

+37
-6
lines changed

3 files changed

+37
-6
lines changed

.flake8

-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ per-file-ignores =
3131
dpctl/utils/_compute_follows_data.pyx: E999, E225, E227
3232
dpctl/utils/_onetrace_context.py: E501, W505
3333
dpctl/tensor/_array_api.py: E501, W505
34-
dpctl/experimental/_work_group_memory.pyx: E999
3534
examples/cython/sycl_buffer/syclbuffer/_syclbuffer.pyx: E999, E225, E402
3635
examples/cython/usm_memory/blackscholes/_blackscholes_usm.pyx: E999, E225, E226, E402
3736
examples/cython/use_dpctl_sycl/use_dpctl_sycl/_cython_api.pyx: E999, E225, E226, E402

dpctl/_sycl_queue.pyx

+36-4
Original file line numberDiff line numberDiff line change
@@ -1555,15 +1555,47 @@ cdef class WorkGroupMemory:
15551555
This is based on a DPC++ SYCL extension and only available in newer
15561556
versions. Use ``is_available()`` to check availability in your build.
15571557
1558+
There are multiple ways to create a `WorkGroupMemory`.
1559+
1560+
- If the constructor is invoked with just a single argument, this argument
1561+
is interpreted as the number of bytes to allocated in the shared local
1562+
memory.
1563+
1564+
- If the constructor is invoked with two arguments, the first argument is
1565+
interpreted as the datatype of the local memory, using the numpy type
1566+
naming scheme.
1567+
The second argument is interpreted as the number of elements to allocate.
1568+
The number of bytes to allocate is then computed from the byte size of
1569+
the data type and the element count.
1570+
15581571
Args:
1559-
nbytes (int)
1560-
number of bytes to allocate in local memory.
1561-
Expected to be positive.
1572+
args:
1573+
Variadic argument, see class documentation.
1574+
1575+
Raises:
1576+
TypeError: In case of incorrect arguments given to constructors,
1577+
unexpected types of input arguments.
15621578
"""
1563-
def __cinit__(self, Py_ssize_t nbytes):
1579+
def __cinit__(self, *args):
1580+
cdef size_t nbytes
15641581
if not DPCTLWorkGroupMemory_Available():
15651582
raise RuntimeError("Workgroup memory extension not available")
15661583

1584+
if not (0 < len(args) < 3):
1585+
raise TypeError("WorkGroupMemory constructor takes 1 or 2 "
1586+
f"arguments, but {len(args)} were given")
1587+
1588+
if len(args) == 1:
1589+
nbytes = <size_t>(args[0])
1590+
else:
1591+
dtype = <str>(args[0])
1592+
count = <size_t>(args[1])
1593+
ty = dtype[0]
1594+
if not ty in ["i", "u", "f"]:
1595+
raise TypeError(f"Unrecognized type value: '{dtype}'")
1596+
byte_size = <size_t>(int(dtype[1:]))
1597+
nbytes = count * byte_size
1598+
15671599
self._mem_ref = DPCTLWorkGroupMemory_Create(nbytes)
15681600

15691601
"""Check whether the work_group_memory extension is available"""

dpctl/tests/test_work_group_memory.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def test_submit_work_group_memory():
7878
[
7979
x.usm_data,
8080
y.usm_data,
81-
dpctl.WorkGroupMemory(local_size * x.itemsize),
81+
dpctl.WorkGroupMemory("i4", local_size),
8282
],
8383
[global_size],
8484
[local_size],

0 commit comments

Comments
 (0)