@@ -128,62 +128,200 @@ def test_broadcast_raise(self, sh1, sh2):
128
128
func (dpnp , dp_a )
129
129
130
130
131
- @pytest .mark .usefixtures ("allow_fall_back_on_numpy" )
132
131
class TestConcatenate :
133
132
def test_returns_copy (self ):
134
- a = dpnp .array ( numpy . eye (3 ) )
133
+ a = dpnp .eye (3 )
135
134
b = dpnp .concatenate ([a ])
136
135
b [0 , 0 ] = 2
137
136
assert b [0 , 0 ] != a [0 , 0 ]
138
137
139
- def test_large_concatenate_axis_None (self ):
140
- x = dpnp .arange (1 , 100 )
141
- r = dpnp .concatenate (x , None )
142
- assert_array_equal (x , r )
143
- r = dpnp .concatenate (x , 100 )
144
- assert_array_equal (x , r )
138
+ @pytest .mark .parametrize ("ndim" , [1 , 2 , 3 ])
139
+ def test_axis_exceptions (self , ndim ):
140
+ dp_a = dpnp .ones ((1 ,) * ndim )
141
+ np_a = numpy .ones ((1 ,) * ndim )
142
+
143
+ dp_res = dpnp .concatenate ((dp_a , dp_a ), axis = 0 )
144
+ np_res = numpy .concatenate ((np_a , np_a ), axis = 0 )
145
+ assert_equal (dp_res .asnumpy (), np_res )
146
+
147
+ for axis in [ndim , - (ndim + 1 )]:
148
+ with pytest .raises (numpy .AxisError ):
149
+ dpnp .concatenate ((dp_a , dp_a ), axis = axis )
150
+ numpy .concatenate ((np_a , np_a ), axis = axis )
151
+
152
+ def test_scalar_exceptions (self ):
153
+ assert_raises (TypeError , dpnp .concatenate , (0 ,))
154
+ assert_raises (ValueError , numpy .concatenate , (0 ,))
155
+
156
+ for xp in [dpnp , numpy ]:
157
+ with pytest .raises (ValueError ):
158
+ xp .concatenate ((xp .array (0 ),))
159
+
160
+ def test_dims_exception (self ):
161
+ for xp in [dpnp , numpy ]:
162
+ with pytest .raises (ValueError ):
163
+ xp .concatenate ((xp .zeros (1 ), xp .zeros ((1 , 1 ))))
164
+
165
+ def test_shapes_match_exception (self ):
166
+ axis = list (range (3 ))
167
+ np_a = numpy .ones ((1 , 2 , 3 ))
168
+ np_b = numpy .ones ((2 , 2 , 3 ))
169
+
170
+ dp_a = dpnp .array (np_a )
171
+ dp_b = dpnp .array (np_b )
172
+
173
+ for _ in range (3 ):
174
+ # shapes must match except for concatenation axis
175
+ np_res = numpy .concatenate ((np_a , np_b ), axis = axis [0 ])
176
+ dp_res = dpnp .concatenate ((dp_a , dp_b ), axis = axis [0 ])
177
+ assert_equal (dp_res .asnumpy (), np_res )
178
+
179
+ for i in range (1 , 3 ):
180
+ with pytest .raises (ValueError ):
181
+ numpy .concatenate ((np_a , np_b ), axis = axis [i ])
182
+ dpnp .concatenate ((dp_a , dp_b ), axis = axis [i ])
183
+
184
+ np_a = numpy .moveaxis (np_a , - 1 , 0 )
185
+ dp_a = dpnp .moveaxis (dp_a , - 1 , 0 )
186
+
187
+ np_b = numpy .moveaxis (np_b , - 1 , 0 )
188
+ dp_b = dpnp .moveaxis (dp_b , - 1 , 0 )
189
+ axis .append (axis .pop (0 ))
190
+
191
+ def test_no_array_exception (self ):
192
+ with pytest .raises (ValueError ):
193
+ numpy .concatenate (())
194
+ dpnp .concatenate (())
195
+
196
+ @pytest .mark .parametrize ("dtype" , get_all_dtypes (no_none = True ))
197
+ def test_concatenate_axis_None (self , dtype ):
198
+ stop , sh = (4 , (2 , 2 )) if dtype is not dpnp .bool else (2 , (2 , 1 ))
199
+ np_a = numpy .arange (stop , dtype = dtype ).reshape (sh )
200
+ dp_a = dpnp .arange (stop , dtype = dtype ).reshape (sh )
201
+
202
+ np_res = numpy .concatenate ((np_a , np_a ), axis = None )
203
+ dp_res = dpnp .concatenate ((dp_a , dp_a ), axis = None )
204
+ assert_equal (dp_res .asnumpy (), np_res )
205
+
206
+ @pytest .mark .parametrize (
207
+ "dtype" , get_all_dtypes (no_bool = True , no_none = True )
208
+ )
209
+ def test_large_concatenate_axis_None (self , dtype ):
210
+ start , stop = (1 , 100 )
211
+ np_a = numpy .arange (start , stop , dtype = dtype )
212
+ dp_a = dpnp .arange (start , stop , dtype = dtype )
213
+
214
+ np_res = numpy .concatenate (np_a , axis = None )
215
+ dp_res = dpnp .concatenate (dp_a , axis = None )
216
+ assert_array_equal (dp_res .asnumpy (), np_res )
217
+
218
+ # numpy doesn't raise an exception here but probably should
219
+ with pytest .raises (numpy .AxisError ):
220
+ dpnp .concatenate (dp_a , axis = 100 )
145
221
146
- def test_concatenate (self ):
222
+ @pytest .mark .parametrize ("dtype" , get_all_dtypes (no_none = True ))
223
+ def test_concatenate (self , dtype ):
147
224
# Test concatenate function
148
225
# One sequence returns unmodified (but as array)
149
226
r4 = list (range (4 ))
150
- assert_array_equal (dpnp .concatenate ((r4 ,)), r4 )
151
- # Any sequence
152
- assert_array_equal (dpnp .concatenate ((tuple (r4 ),)), r4 )
153
- assert_array_equal (dpnp .concatenate ((dpnp .array (r4 ),)), r4 )
227
+ np_r4 = numpy .array (r4 , dtype = dtype )
228
+ dp_r4 = dpnp .array (r4 , dtype = dtype )
229
+
230
+ np_res = numpy .concatenate ((np_r4 ,))
231
+ dp_res = dpnp .concatenate ((dp_r4 ,))
232
+ assert_array_equal (dp_res .asnumpy (), np_res )
233
+
154
234
# 1D default concatenation
155
235
r3 = list (range (3 ))
156
- assert_array_equal (dpnp .concatenate ((r4 , r3 )), r4 + r3 )
157
- # Mixed sequence types
158
- assert_array_equal (dpnp .concatenate ((tuple (r4 ), r3 )), r4 + r3 )
159
- assert_array_equal (dpnp .concatenate ((dpnp .array (r4 ), r3 )), r4 + r3 )
236
+ np_r3 = numpy .array (r3 , dtype = dtype )
237
+ dp_r3 = dpnp .array (r3 , dtype = dtype )
238
+
239
+ np_res = numpy .concatenate ((np_r4 , np_r3 ))
240
+ dp_res = dpnp .concatenate ((dp_r4 , dp_r3 ))
241
+ assert_array_equal (dp_res .asnumpy (), np_res )
242
+
160
243
# Explicit axis specification
161
- assert_array_equal (dpnp .concatenate ((r4 , r3 ), 0 ), r4 + r3 )
244
+ np_res = numpy .concatenate ((np_r4 , np_r3 ), axis = 0 )
245
+ dp_res = dpnp .concatenate ((dp_r4 , dp_r3 ), axis = 0 )
246
+ assert_array_equal (dp_res .asnumpy (), np_res )
247
+
162
248
# Including negative
163
- assert_array_equal (dpnp .concatenate ((r4 , r3 ), - 1 ), r4 + r3 )
164
- # 2D
165
- a23 = dpnp .array ([[10 , 11 , 12 ], [13 , 14 , 15 ]])
166
- a13 = dpnp .array ([[0 , 1 , 2 ]])
167
- res = dpnp .array ([[10 , 11 , 12 ], [13 , 14 , 15 ], [0 , 1 , 2 ]])
168
- assert_array_equal (dpnp .concatenate ((a23 , a13 )), res )
169
- assert_array_equal (dpnp .concatenate ((a23 , a13 ), 0 ), res )
170
- assert_array_equal (dpnp .concatenate ((a23 .T , a13 .T ), 1 ), res .T )
171
- assert_array_equal (dpnp .concatenate ((a23 .T , a13 .T ), - 1 ), res .T )
249
+ np_res = numpy .concatenate ((np_r4 , np_r3 ), axis = - 1 )
250
+ dp_res = dpnp .concatenate ((dp_r4 , dp_r3 ), axis = - 1 )
251
+ assert_array_equal (dp_res .asnumpy (), np_res )
252
+
253
+ @pytest .mark .parametrize ("dtype" , get_all_dtypes (no_none = True ))
254
+ def test_concatenate_2d (self , dtype ):
255
+ np_a23 = numpy .array ([[10 , 11 , 12 ], [13 , 14 , 15 ]], dtype = dtype )
256
+ np_a13 = numpy .array ([[0 , 1 , 2 ]], dtype = dtype )
257
+
258
+ dp_a23 = dpnp .array ([[10 , 11 , 12 ], [13 , 14 , 15 ]], dtype = dtype )
259
+ dp_a13 = dpnp .array ([[0 , 1 , 2 ]], dtype = dtype )
260
+
261
+ np_res = numpy .concatenate ((np_a23 , np_a13 ))
262
+ dp_res = dpnp .concatenate ((dp_a23 , dp_a13 ))
263
+ assert_array_equal (dp_res .asnumpy (), np_res )
264
+
265
+ np_res = numpy .concatenate ((np_a23 , np_a13 ), axis = 0 )
266
+ dp_res = dpnp .concatenate ((dp_a23 , dp_a13 ), axis = 0 )
267
+ assert_array_equal (dp_res .asnumpy (), np_res )
268
+
269
+ for axis in [1 , - 1 ]:
270
+ np_res = numpy .concatenate ((np_a23 .T , np_a13 .T ), axis = axis )
271
+ dp_res = dpnp .concatenate ((dp_a23 .T , dp_a13 .T ), axis = axis )
272
+ assert_array_equal (dp_res .asnumpy (), np_res )
273
+
172
274
# Arrays much match shape
173
- assert_raises (ValueError , dpnp .concatenate , (a23 .T , a13 .T ), 0 )
174
- # 3D
175
- res = dpnp .reshape (dpnp .arange (2 * 3 * 7 ), (2 , 3 , 7 ))
176
- a0 = res [..., :4 ]
177
- a1 = res [..., 4 :6 ]
178
- a2 = res [..., 6 :]
179
- assert_array_equal (dpnp .concatenate ((a0 , a1 , a2 ), 2 ), res )
180
- assert_array_equal (dpnp .concatenate ((a0 , a1 , a2 ), - 1 ), res )
181
- assert_array_equal (dpnp .concatenate ((a0 .T , a1 .T , a2 .T ), 0 ), res .T )
182
-
183
- out = dpnp .copy (res )
184
- rout = dpnp .concatenate ((a0 , a1 , a2 ), 2 , out = out )
185
- assert_ (out is rout )
186
- assert_equal (res , rout )
275
+ with pytest .raises (ValueError ):
276
+ numpy .concatenate ((np_a23 .T , np_a13 .T ), axis = 0 )
277
+ dpnp .concatenate ((dp_a23 .T , dp_a13 .T ), axis = 0 )
278
+
279
+ @pytest .mark .parametrize (
280
+ "dtype" , get_all_dtypes (no_bool = True , no_none = True )
281
+ )
282
+ def test_concatenate_3d (self , dtype ):
283
+ np_a = numpy .arange (2 * 3 * 7 , dtype = dtype ).reshape ((2 , 3 , 7 ))
284
+ np_a0 = np_a [..., :4 ]
285
+ np_a1 = np_a [..., 4 :6 ]
286
+ np_a2 = np_a [..., 6 :]
287
+
288
+ dp_a = dpnp .arange (2 * 3 * 7 , dtype = dtype ).reshape ((2 , 3 , 7 ))
289
+ dp_a0 = dp_a [..., :4 ]
290
+ dp_a1 = dp_a [..., 4 :6 ]
291
+ dp_a2 = dp_a [..., 6 :]
292
+
293
+ for axis in [2 , - 1 ]:
294
+ np_res = numpy .concatenate ((np_a0 , np_a1 , np_a2 ), axis = axis )
295
+ dp_res = dpnp .concatenate ((dp_a0 , dp_a1 , dp_a2 ), axis = axis )
296
+ assert_array_equal (dp_res .asnumpy (), np_res )
297
+
298
+ np_res = numpy .concatenate ((np_a0 .T , np_a1 .T , np_a2 .T ), axis = 0 )
299
+ dp_res = dpnp .concatenate ((dp_a0 .T , dp_a1 .T , dp_a2 .T ), axis = 0 )
300
+ assert_array_equal (dp_res .asnumpy (), np_res )
301
+
302
+ @pytest .mark .skip ("out keyword is currently unsupported" )
303
+ @pytest .mark .parametrize (
304
+ "dtype" , get_all_dtypes (no_bool = True , no_none = True )
305
+ )
306
+ def test_concatenate_out (self , dtype ):
307
+ np_a = numpy .arange (2 * 3 * 7 , dtype = dtype ).reshape ((2 , 3 , 7 ))
308
+ np_a0 = np_a [..., :4 ]
309
+ np_a1 = np_a [..., 4 :6 ]
310
+ np_a2 = np_a [..., 6 :]
311
+ np_out = numpy .empty_like (np_a )
312
+
313
+ dp_a = dpnp .arange (2 * 3 * 7 , dtype = dtype ).reshape ((2 , 3 , 7 ))
314
+ dp_a0 = dp_a [..., :4 ]
315
+ dp_a1 = dp_a [..., 4 :6 ]
316
+ dp_a2 = dp_a [..., 6 :]
317
+ dp_out = dpnp .empty_like (dp_a )
318
+
319
+ np_res = numpy .concatenate ((np_a0 , np_a1 , np_a2 ), axis = 2 , out = np_out )
320
+ dp_res = dpnp .concatenate ((dp_a0 , dp_a1 , dp_a2 ), axis = 2 , out = dp_out )
321
+
322
+ assert dp_out is dp_res
323
+ assert_array_equal (dp_out .asnumpy (), np_out )
324
+ assert_array_equal (dp_res .asnumpy (), np_res )
187
325
188
326
189
327
class TestHstack :
0 commit comments