Skip to content

Expose LocalAccessor as kernel argument type #1991

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Mar 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dpctl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from ._sycl_event import SyclEvent
from ._sycl_platform import SyclPlatform, get_platforms, lsplatform
from ._sycl_queue import (
LocalAccessor,
SyclKernelInvalidRangeError,
SyclKernelSubmitError,
SyclQueue,
Expand Down Expand Up @@ -102,6 +103,7 @@
"SyclKernelSubmitError",
"SyclQueueCreationError",
"WorkGroupMemory",
"LocalAccessor",
]
__all__ += [
"get_device_cached_queue",
Expand Down
6 changes: 6 additions & 0 deletions dpctl/_backend.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,12 @@ cdef extern from "syclinterface/dpctl_sycl_kernel_bundle_interface.h":


cdef extern from "syclinterface/dpctl_sycl_queue_interface.h":
ctypedef struct _md_local_accessor 'MDLocalAccessor':
size_t ndim
_arg_data_type dpctl_type_id
size_t dim0
size_t dim1
size_t dim2
cdef bool DPCTLQueue_AreEq(const DPCTLSyclQueueRef QRef1,
const DPCTLSyclQueueRef QRef2)
cdef DPCTLSyclQueueRef DPCTLQueue_Create(
Expand Down
93 changes: 93 additions & 0 deletions dpctl/_sycl_queue.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ from ._backend cimport ( # noqa: E211
DPCTLWorkGroupMemory_Delete,
_arg_data_type,
_backend_type,
_md_local_accessor,
_queue_property_type,
)
from .memory._memory cimport _Memory
Expand Down Expand Up @@ -125,6 +126,95 @@ cdef class kernel_arg_type_attribute:
return self.attr_value


cdef class LocalAccessor:
"""
LocalAccessor(dtype, shape)

Python class for specifying the dimensionality and type of a
``sycl::local_accessor``, to be used as a kernel argument type.

Args:
dtype (str):
the data type of the local memory.
The permitted values are

`'i1'`, `'i2'`, `'i4'`, `'i8'`:
signed integral types int8_t, int16_t, int32_t, int64_t
`'u1'`, `'u2'`, `'u4'`, `'u8'`
unsigned integral types uint8_t, uint16_t, uint32_t,
uint64_t
`'f4'`, `'f8'`,
single- and double-precision floating-point types float and
double
shape (tuple, list):
Size of LocalAccessor dimensions. Dimension of the LocalAccessor is
determined by the length of the tuple. Must be of length 1, 2, or 3,
and contain only non-negative integers.

Raises:
TypeError:
If the given shape is not a tuple or list.
ValueError:
If the given shape sequence is not between one and three elements long.
TypeError:
If the shape is not a sequence of integers.
ValueError:
If the shape contains a negative integer.
ValueError:
If the dtype string is unrecognized.
"""
cdef _md_local_accessor lacc

def __cinit__(self, str dtype, shape):
if not isinstance(shape, (list, tuple)):
raise TypeError(f"`shape` must be a list or tuple, got {type(shape)}")
ndim = len(shape)
if ndim < 1 or ndim > 3:
raise ValueError("LocalAccessor must have dimension between one and three")
for s in shape:
if not isinstance(s, numbers.Integral):
raise TypeError("LocalAccessor shape must be a sequence of integers")
if s < 0:
raise ValueError("LocalAccessor dimensions must be non-negative")
self.lacc.ndim = ndim
self.lacc.dim0 = <size_t> shape[0]
self.lacc.dim1 = <size_t> shape[1] if ndim > 1 else 1
self.lacc.dim2 = <size_t> shape[2] if ndim > 2 else 1

if dtype == 'i1':
self.lacc.dpctl_type_id = _arg_data_type._INT8_T
elif dtype == 'u1':
self.lacc.dpctl_type_id = _arg_data_type._UINT8_T
elif dtype == 'i2':
self.lacc.dpctl_type_id = _arg_data_type._INT16_T
elif dtype == 'u2':
self.lacc.dpctl_type_id = _arg_data_type._UINT16_T
elif dtype == 'i4':
self.lacc.dpctl_type_id = _arg_data_type._INT32_T
elif dtype == 'u4':
self.lacc.dpctl_type_id = _arg_data_type._UINT32_T
elif dtype == 'i8':
self.lacc.dpctl_type_id = _arg_data_type._INT64_T
elif dtype == 'u8':
self.lacc.dpctl_type_id = _arg_data_type._UINT64_T
elif dtype == 'f4':
self.lacc.dpctl_type_id = _arg_data_type._FLOAT
elif dtype == 'f8':
self.lacc.dpctl_type_id = _arg_data_type._DOUBLE
else:
raise ValueError(f"Unrecognized type value: '{dtype}'")

def __repr__(self):
return f"LocalAccessor({self.lacc.ndim})"

cdef size_t addressof(self):
"""
Returns the address of the _md_local_accessor for this LocalAccessor
cast to ``size_t``.
"""
return <size_t>&self.lacc


cdef class _kernel_arg_type:
"""
An enumeration of supported kernel argument types in
Expand Down Expand Up @@ -865,6 +955,9 @@ cdef class SyclQueue(_SyclQueue):
elif isinstance(arg, WorkGroupMemory):
kargs[idx] = <void*>(<size_t>arg._ref)
kargty[idx] = _arg_data_type._WORK_GROUP_MEMORY
elif isinstance(arg, LocalAccessor):
kargs[idx] = <void*>((<LocalAccessor>arg).addressof())
kargty[idx] = _arg_data_type._LOCAL_ACCESSOR
else:
ret = -1
return ret
Expand Down
Binary file not shown.
Binary file not shown.
40 changes: 40 additions & 0 deletions dpctl/tests/test_sycl_kernel_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""

import ctypes
import os

import numpy as np
import pytest
Expand Down Expand Up @@ -279,3 +280,42 @@ def test_kernel_arg_type():
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_void_ptr)
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_local_accessor)
_check_kernel_arg_type_instance(kernel_arg_type.dpctl_work_group_memory)


def get_spirv_abspath(fn):
curr_dir = os.path.dirname(os.path.abspath(__file__))
spirv_file = os.path.join(curr_dir, "input_files", fn)
return spirv_file


# the process for generating the .spv files in this test is documented in
# libsyclinterface/tests/test_sycl_queue_submit_local_accessor_arg.cpp
# in a comment starting on line 123
def test_submit_local_accessor_arg():
try:
q = dpctl.SyclQueue("level_zero")
except dpctl.SyclQueueCreationError:
pytest.skip("OpenCL queue could not be created")
fn = get_spirv_abspath("local_accessor_kernel_inttys_fp32.spv")
with open(fn, "br") as f:
spirv_bytes = f.read()
prog = dpctl_prog.create_program_from_spirv(q, spirv_bytes)
krn = prog.get_sycl_kernel("_ZTS14SyclKernel_SLMIlE")
lws = 32
gws = lws * 10
x = dpt.ones(gws, dtype="i8")
x.sycl_queue.wait()
try:
e = q.submit(
krn,
[x.usm_data, dpctl.LocalAccessor("i8", (lws,))],
[gws],
[lws],
)
e.wait()
except dpctl._sycl_queue.SyclKernelSubmitError:
pytest.skip(f"Kernel submission failed for device {q.sycl_device}")
expected = dpt.arange(1, x.size + 1, dtype=x.dtype, device=x.device) * (
2 * lws
)
assert dpt.all(x == expected)
Loading