@@ -128,15 +128,12 @@ cdef class kernel_arg_type_attribute:
128128
129129cdef class LocalAccessor:
130130 """
131- LocalAccessor(ndim, dtype, dim0, dim1, dim2 )
131+ LocalAccessor(dtype, shape )
132132
133133 Python class for specifying the dimensionality and type of a
134134 ``sycl::local_accessor``, to be used as a kernel argument type.
135135
136136 Args:
137- ndim (size_t):
138- number of dimensions.
139- Can be between one and three.
140137 dtype (str):
141138 the data type of the local memory.
142139 The permitted values are
@@ -149,29 +146,41 @@ cdef class LocalAccessor:
149146 `'f4'`, `'f8'`,
150147 single- and double-precision floating-point types float and
151148 double
152- dim0 (size_t):
153- Size of the first dimension.
154- dim1 (size_t):
155- Size of the second dimension.
156- dim2 (size_t):
157- Size of the third dimension.
149+ shape (tuple, list):
150+ Size of LocalAccessor dimensions. Dimension of the LocalAccessor is
151+ determined by the length of the tuple. Must be of length 1, 2, or 3,
152+ and contain only non-negative integers.
158153
159154 Raises:
155+ TypeError:
156+ If the given shape is not a tuple or list.
160157 ValueError:
161- If the given dimension is not between one and three.
158+ If the given shape sequence is not between one and three elements long.
159+ TypeError:
160+ If the shape is not a sequence of integers.
161+ ValueError:
162+ If the shape contains a negative integer.
162163 ValueError:
163164 If the dtype string is unrecognized.
164165 """
165166 cdef _md_local_accessor lacc
166167
167- def __cinit__ (self , size_t ndim , str dtype , size_t dim0 , size_t dim1 , size_t dim2 ):
168+ def __cinit__ (self , str dtype , shape ):
169+ if not isinstance (shape, (list , tuple )):
170+ raise TypeError (f" `shape` must be a list or tuple, got {type(shape)}" )
171+ ndim = len (shape)
172+ if ndim < 1 or ndim > 3 :
173+ raise ValueError (" LocalAccessor must have dimension between one and three" )
174+ for s in shape:
175+ if not isinstance (s, numbers.Integral):
176+ raise TypeError (" LocalAccessor shape must be a sequence of integers" )
177+ if s < 0 :
178+ raise ValueError (" LocalAccessor dimensions must be non-negative" )
168179 self .lacc.ndim = ndim
169- self .lacc.dim0 = dim0
170- self .lacc.dim1 = dim1
171- self .lacc.dim2 = dim2
180+ self .lacc.dim0 = < size_t > shape[ 0 ]
181+ self .lacc.dim1 = < size_t > shape[ 1 ] if ndim > 1 else 1
182+ self .lacc.dim2 = < size_t > shape[ 2 ] if ndim > 2 else 1
172183
173- if ndim < 1 or ndim > 3 :
174- raise ValueError (" LocalAccessor must have dimension between one and three" )
175184 if dtype == ' i1' :
176185 self .lacc.dpctl_type_id = _arg_data_type._INT8_T
177186 elif dtype == ' u1' :
0 commit comments