@@ -128,15 +128,12 @@ cdef class kernel_arg_type_attribute:
128
128
129
129
cdef class LocalAccessor:
130
130
"""
131
- LocalAccessor(ndim, dtype, dim0, dim1, dim2 )
131
+ LocalAccessor(dtype, shape )
132
132
133
133
Python class for specifying the dimensionality and type of a
134
134
``sycl::local_accessor``, to be used as a kernel argument type.
135
135
136
136
Args:
137
- ndim (size_t):
138
- number of dimensions.
139
- Can be between one and three.
140
137
dtype (str):
141
138
the data type of the local memory.
142
139
The permitted values are
@@ -149,29 +146,41 @@ cdef class LocalAccessor:
149
146
`'f4'`, `'f8'`,
150
147
single- and double-precision floating-point types float and
151
148
double
152
- dim0 (size_t):
153
- Size of the first dimension.
154
- dim1 (size_t):
155
- Size of the second dimension.
156
- dim2 (size_t):
157
- Size of the third dimension.
149
+ shape (tuple, list):
150
+ Size of LocalAccessor dimensions. Dimension of the LocalAccessor is
151
+ determined by the length of the tuple. Must be of length 1, 2, or 3,
152
+ and contain only non-negative integers.
158
153
159
154
Raises:
155
+ TypeError:
156
+ If the given shape is not a tuple or list.
160
157
ValueError:
161
- If the given dimension is not between one and three.
158
+ If the given shape sequence is not between one and three elements long.
159
+ TypeError:
160
+ If the shape is not a sequence of integers.
161
+ ValueError:
162
+ If the shape contains a negative integer.
162
163
ValueError:
163
164
If the dtype string is unrecognized.
164
165
"""
165
166
cdef _md_local_accessor lacc
166
167
167
- def __cinit__ (self , size_t ndim , str dtype , size_t dim0 , size_t dim1 , size_t dim2 ):
168
+ def __cinit__ (self , str dtype , shape ):
169
+ if not isinstance (shape, (list , tuple )):
170
+ raise TypeError (f" `shape` must be a list or tuple, got {type(shape)}" )
171
+ ndim = len (shape)
172
+ if ndim < 1 or ndim > 3 :
173
+ raise ValueError (" LocalAccessor must have dimension between one and three" )
174
+ for s in shape:
175
+ if not isinstance (s, numbers.Integral):
176
+ raise TypeError (" LocalAccessor shape must be a sequence of integers" )
177
+ if s < 0 :
178
+ raise ValueError (" LocalAccessor dimensions must be non-negative" )
168
179
self .lacc.ndim = ndim
169
- self .lacc.dim0 = dim0
170
- self .lacc.dim1 = dim1
171
- self .lacc.dim2 = dim2
180
+ self .lacc.dim0 = < size_t > shape[ 0 ]
181
+ self .lacc.dim1 = < size_t > shape[ 1 ] if ndim > 1 else 1
182
+ self .lacc.dim2 = < size_t > shape[ 2 ] if ndim > 2 else 1
172
183
173
- if ndim < 1 or ndim > 3 :
174
- raise ValueError (" LocalAccessor must have dimension between one and three" )
175
184
if dtype == ' i1' :
176
185
self .lacc.dpctl_type_id = _arg_data_type._INT8_T
177
186
elif dtype == ' u1' :
0 commit comments