Skip to content

Commit b5887df

Browse files
committed
LocalAccessor takes a sequence of non-negative integers instead of separate integer arguments
Update test for LocalAccessor kernel submission
1 parent 5fc86c6 commit b5887df

File tree

2 files changed

+27
-18
lines changed

2 files changed

+27
-18
lines changed

dpctl/_sycl_queue.pyx

+26-17
Original file line numberDiff line numberDiff line change
@@ -128,15 +128,12 @@ cdef class kernel_arg_type_attribute:
128128

129129
cdef class LocalAccessor:
130130
"""
131-
LocalAccessor(ndim, dtype, dim0, dim1, dim2)
131+
LocalAccessor(dtype, shape)
132132
133133
Python class for specifying the dimensionality and type of a
134134
``sycl::local_accessor``, to be used as a kernel argument type.
135135
136136
Args:
137-
ndim (size_t):
138-
number of dimensions.
139-
Can be between one and three.
140137
dtype (str):
141138
the data type of the local memory.
142139
The permitted values are
@@ -149,29 +146,41 @@ cdef class LocalAccessor:
149146
`'f4'`, `'f8'`,
150147
single- and double-precision floating-point types float and
151148
double
152-
dim0 (size_t):
153-
Size of the first dimension.
154-
dim1 (size_t):
155-
Size of the second dimension.
156-
dim2 (size_t):
157-
Size of the third dimension.
149+
shape (tuple, list):
150+
Size of LocalAccessor dimensions. Dimension of the LocalAccessor is
151+
determined by the length of the tuple. Must be of length 1, 2, or 3,
152+
and contain only non-negative integers.
158153
159154
Raises:
155+
TypeError:
156+
If the given shape is not a tuple or list.
160157
ValueError:
161-
If the given dimension is not between one and three.
158+
If the given shape sequence is not between one and three elements long.
159+
TypeError:
160+
If the shape is not a sequence of integers.
161+
ValueError:
162+
If the shape contains a negative integer.
162163
ValueError:
163164
If the dtype string is unrecognized.
164165
"""
165166
cdef _md_local_accessor lacc
166167

167-
def __cinit__(self, size_t ndim, str dtype, size_t dim0, size_t dim1, size_t dim2):
168+
def __cinit__(self, str dtype, shape):
169+
if not isinstance(shape, (list, tuple)):
170+
raise TypeError(f"`shape` must be a list or tuple, got {type(shape)}")
171+
ndim = len(shape)
172+
if ndim < 1 or ndim > 3:
173+
raise ValueError("LocalAccessor must have dimension between one and three")
174+
for s in shape:
175+
if not isinstance(s, numbers.Integral):
176+
raise TypeError("LocalAccessor shape must be a sequence of integers")
177+
if s < 0:
178+
raise ValueError("LocalAccessor dimensions must be non-negative")
168179
self.lacc.ndim = ndim
169-
self.lacc.dim0 = dim0
170-
self.lacc.dim1 = dim1
171-
self.lacc.dim2 = dim2
180+
self.lacc.dim0 = <size_t> shape[0]
181+
self.lacc.dim1 = <size_t> shape[1] if ndim > 1 else 1
182+
self.lacc.dim2 = <size_t> shape[2] if ndim > 2 else 1
172183

173-
if ndim < 1 or ndim > 3:
174-
raise ValueError("LocalAccessor must have dimension between one and three")
175184
if dtype == 'i1':
176185
self.lacc.dpctl_type_id = _arg_data_type._INT8_T
177186
elif dtype == 'u1':

dpctl/tests/test_sycl_kernel_submit.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def test_submit_local_accessor_arg():
308308
try:
309309
e = q.submit(
310310
krn,
311-
[x.usm_data, dpctl._sycl_queue.LocalAccessor(1, "i8", lws, 1, 1)],
311+
[x.usm_data, dpctl.LocalAccessor("i8", (lws, 1, 1))],
312312
[
313313
gws,
314314
],

0 commit comments

Comments
 (0)