5
5
import itertools
6
6
import numpy as np
7
7
8
- from nose . tools import assert_equal , assert_raises , assert_true
9
- from nibabel .testing import assert_arrays_equal
8
+ import pytest
9
+ from nibabel .testing_pytest import assert_arrays_equal
10
10
from numpy .testing import assert_array_equal
11
11
12
12
import pytest ; pytestmark = pytest .mark .skip ()
17
17
SEQ_DATA = {}
18
18
19
19
20
- def setup ():
20
+ def setup_module ():
21
21
global SEQ_DATA
22
22
rng = np .random .RandomState (42 )
23
23
SEQ_DATA ['rng' ] = rng
@@ -32,22 +32,25 @@ def generate_data(nb_arrays, common_shape, rng):
32
32
33
33
34
34
def check_empty_arr_seq (seq ):
35
- assert_equal ( len (seq ), 0 )
36
- assert_equal ( len (seq ._offsets ), 0 )
37
- assert_equal ( len (seq ._lengths ), 0 )
35
+ assert len (seq ) == 0
36
+ assert len (seq ._offsets ) == 0
37
+ assert len (seq ._lengths ) == 0
38
38
# assert_equal(seq._data.ndim, 0)
39
- assert_equal (seq ._data .ndim , 1 )
40
- assert_true (seq .common_shape == ())
39
+ assert seq ._data .ndim == 1
40
+
41
+ # TODO: Check assert_true
42
+ # assert_true(seq.common_shape == ())
41
43
42
44
43
45
def check_arr_seq (seq , arrays ):
44
46
lengths = list (map (len , arrays ))
45
- assert_true (is_array_sequence (seq ))
46
- assert_equal (len (seq ), len (arrays ))
47
- assert_equal (len (seq ._offsets ), len (arrays ))
48
- assert_equal (len (seq ._lengths ), len (arrays ))
49
- assert_equal (seq ._data .shape [1 :], arrays [0 ].shape [1 :])
50
- assert_equal (seq .common_shape , arrays [0 ].shape [1 :])
47
+ assert is_array_sequence (seq ) == True
48
+ assert len (seq ) == len (arrays )
49
+ assert len (seq ._offsets ) == len (arrays )
50
+ assert len (seq ._lengths ) == len (arrays )
51
+ assert seq ._data .shape [1 :] == arrays [0 ].shape [1 :]
52
+ assert seq .common_shape == arrays [0 ].shape [1 :]
53
+
51
54
assert_arrays_equal (seq , arrays )
52
55
53
56
# If seq is a view, then order of internal data is not guaranteed.
@@ -56,18 +59,20 @@ def check_arr_seq(seq, arrays):
56
59
assert_array_equal (sorted (seq ._lengths ), sorted (lengths ))
57
60
else :
58
61
seq .shrink_data ()
59
- assert_equal (seq ._data .shape [0 ], sum (lengths ))
62
+
63
+ assert seq ._data .shape [0 ] == sum (lengths )
64
+
60
65
assert_array_equal (seq ._data , np .concatenate (arrays , axis = 0 ))
61
66
assert_array_equal (seq ._offsets , np .r_ [0 , np .cumsum (lengths )[:- 1 ]])
62
67
assert_array_equal (seq ._lengths , lengths )
63
68
64
69
65
70
def check_arr_seq_view (seq_view , seq ):
66
- assert_true ( seq_view ._is_view )
67
- assert_true (seq_view is not seq )
68
- assert_true (np .may_share_memory (seq_view ._data , seq ._data ))
69
- assert_true ( seq_view ._offsets is not seq ._offsets )
70
- assert_true ( seq_view ._lengths is not seq ._lengths )
71
+ assert seq_view ._is_view is True
72
+ assert (seq_view is not seq ) is True
73
+ assert (np .may_share_memory (seq_view ._data , seq ._data )) is True
74
+ assert seq_view ._offsets is not seq ._offsets
75
+ assert seq_view ._lengths is not seq ._lengths
71
76
72
77
73
78
class TestArraySequence (unittest .TestCase ):
@@ -99,8 +104,8 @@ def test_creating_arraysequence_from_generator(self):
99
104
seq_with_buffer = ArraySequence (gen_2 , buffer_size = 256 )
100
105
101
106
# Check buffer size effect
102
- assert_equal ( seq_with_buffer .data .shape , seq .data .shape )
103
- assert_true ( seq_with_buffer ._buffer_size > seq ._buffer_size )
107
+ assert seq_with_buffer .data .shape == seq .data .shape
108
+ assert seq_with_buffer ._buffer_size > seq ._buffer_size
104
109
105
110
# Check generator result
106
111
check_arr_seq (seq , SEQ_DATA ['data' ])
@@ -123,26 +128,27 @@ def test_arraysequence_iter(self):
123
128
# Try iterating through a corrupted ArraySequence object.
124
129
seq = SEQ_DATA ['seq' ].copy ()
125
130
seq ._lengths = seq ._lengths [::2 ]
126
- assert_raises (ValueError , list , seq )
131
+ with pytest .raises (ValueError ):
132
+ list (seq )
127
133
128
134
def test_arraysequence_copy (self ):
129
135
orig = SEQ_DATA ['seq' ]
130
136
seq = orig .copy ()
131
137
n_rows = seq .total_nb_rows
132
- assert_equal ( n_rows , orig .total_nb_rows )
138
+ assert n_rows == orig .total_nb_rows
133
139
assert_array_equal (seq ._data , orig ._data [:n_rows ])
134
- assert_true ( seq ._data is not orig ._data )
140
+ assert seq ._data is not orig ._data
135
141
assert_array_equal (seq ._offsets , orig ._offsets )
136
- assert_true ( seq ._offsets is not orig ._offsets )
142
+ assert seq ._offsets is not orig ._offsets
137
143
assert_array_equal (seq ._lengths , orig ._lengths )
138
- assert_true ( seq ._lengths is not orig ._lengths )
139
- assert_equal ( seq .common_shape , orig .common_shape )
144
+ assert seq ._lengths is not orig ._lengths
145
+ assert seq .common_shape == orig .common_shape
140
146
141
147
# Taking a copy of an `ArraySequence` generated by slicing.
142
148
# Only keep needed data.
143
149
seq = orig [::2 ].copy ()
144
150
check_arr_seq (seq , SEQ_DATA ['data' ][::2 ])
145
- assert_true ( seq ._data is not orig ._data )
151
+ assert seq ._data is not orig ._data
146
152
147
153
def test_arraysequence_append (self ):
148
154
element = generate_data (nb_arrays = 1 ,
@@ -173,7 +179,9 @@ def test_arraysequence_append(self):
173
179
element = generate_data (nb_arrays = 1 ,
174
180
common_shape = SEQ_DATA ['seq' ].common_shape * 2 ,
175
181
rng = SEQ_DATA ['rng' ])[0 ]
176
- assert_raises (ValueError , seq .append , element )
182
+
183
+ with pytest .raises (ValueError ):
184
+ seq .append (element )
177
185
178
186
def test_arraysequence_extend (self ):
179
187
new_data = generate_data (nb_arrays = 10 ,
@@ -219,7 +227,8 @@ def test_arraysequence_extend(self):
219
227
common_shape = SEQ_DATA ['seq' ].common_shape * 2 ,
220
228
rng = SEQ_DATA ['rng' ])
221
229
seq = SEQ_DATA ['seq' ].copy () # Copy because of in-place modification.
222
- assert_raises (ValueError , seq .extend , data )
230
+ with pytest .raises (ValueError ):
231
+ seq .extend (data )
223
232
224
233
# Extend after extracting some slice
225
234
working_slice = seq [:2 ]
@@ -264,7 +273,9 @@ def test_arraysequence_getitem(self):
264
273
for i , keep in enumerate (selection ) if keep ])
265
274
266
275
# Test invalid indexing
267
- assert_raises (TypeError , SEQ_DATA ['seq' ].__getitem__ , 'abc' )
276
+ with pytest .raises (TypeError ):
277
+ SEQ_DATA ['seq' ].__getitem__ ('abc' )
278
+ #SEQ_DATA['seq'].abc
268
279
269
280
# Get specific columns.
270
281
seq_view = SEQ_DATA ['seq' ][:, 2 ]
@@ -287,7 +298,7 @@ def test_arraysequence_setitem(self):
287
298
# Setitem with a scalar.
288
299
seq = SEQ_DATA ['seq' ].copy ()
289
300
seq [:] = 0
290
- assert_true ( seq ._data .sum () == 0 )
301
+ assert seq ._data .sum () == 0
291
302
292
303
# Setitem with a list of ndarray.
293
304
seq = SEQ_DATA ['seq' ] * 0
@@ -297,12 +308,12 @@ def test_arraysequence_setitem(self):
297
308
# Setitem using tuple indexing.
298
309
seq = ArraySequence (np .arange (900 ).reshape ((50 ,6 ,3 )))
299
310
seq [:, 0 ] = 0
300
- assert_true ( seq ._data [:, 0 ].sum () == 0 )
311
+ assert seq ._data [:, 0 ].sum () == 0
301
312
302
313
# Setitem using tuple indexing.
303
314
seq = ArraySequence (np .arange (900 ).reshape ((50 ,6 ,3 )))
304
315
seq [range (len (seq ))] = 0
305
- assert_true ( seq ._data .sum () == 0 )
316
+ assert seq ._data .sum () == 0
306
317
307
318
# Setitem of a slice using another slice.
308
319
seq = ArraySequence (np .arange (900 ).reshape ((50 ,6 ,3 )))
@@ -311,20 +322,26 @@ def test_arraysequence_setitem(self):
311
322
312
323
# Setitem between array sequences with different number of sequences.
313
324
seq = ArraySequence (np .arange (900 ).reshape ((50 ,6 ,3 )))
314
- assert_raises (ValueError , seq .__setitem__ , slice (0 , 4 ), seq [5 :10 ])
325
+ with pytest .raises (ValueError ):
326
+ seq .__setitem__ (slice (0 , 4 ), seq [5 :10 ])
327
+
315
328
316
329
# Setitem between array sequences with different amount of points.
317
330
seq1 = ArraySequence (np .arange (10 ).reshape (5 , 2 ))
318
331
seq2 = ArraySequence (np .arange (15 ).reshape (5 , 3 ))
319
- assert_raises (ValueError , seq1 .__setitem__ , slice (0 , 5 ), seq2 )
332
+ with pytest .raises (ValueError ):
333
+ seq1 .__setitem__ (slice (0 , 5 ), seq2 )
320
334
321
335
# Setitem between array sequences with different common shape.
322
336
seq1 = ArraySequence (np .arange (12 ).reshape (2 , 2 , 3 ))
323
337
seq2 = ArraySequence (np .arange (8 ).reshape (2 , 2 , 2 ))
324
- assert_raises (ValueError , seq1 .__setitem__ , slice (0 , 2 ), seq2 )
338
+
339
+ with pytest .raises (ValueError ):
340
+ seq1 .__setitem__ (slice (0 , 2 ), seq2 )
325
341
326
342
# Invalid index.
327
- assert_raises (TypeError , seq .__setitem__ , object (), None )
343
+ with pytest .raises (TypeError ):
344
+ seq .__setitem__ (object (), None )
328
345
329
346
def test_arraysequence_operators (self ):
330
347
# Disable division per zero warnings.
@@ -343,36 +360,45 @@ def test_arraysequence_operators(self):
343
360
def _test_unary (op , arrseq ):
344
361
orig = arrseq .copy ()
345
362
seq = getattr (orig , op )()
346
- assert_true ( seq is not orig )
363
+ assert seq is not orig
347
364
check_arr_seq (seq , [getattr (d , op )() for d in orig ])
348
365
349
366
def _test_binary (op , arrseq , scalars , seqs , inplace = False ):
350
367
for scalar in scalars :
351
368
orig = arrseq .copy ()
352
369
seq = getattr (orig , op )(scalar )
353
- assert_true ((seq is orig ) if inplace else (seq is not orig ))
370
+
371
+ if inplace :
372
+ assert seq is orig
373
+ else :
374
+ assert seq is not orig
375
+
354
376
check_arr_seq (seq , [getattr (e , op )(scalar ) for e in arrseq ])
355
377
356
378
# Test math operators with another ArraySequence.
357
379
for other in seqs :
358
380
orig = arrseq .copy ()
359
381
seq = getattr (orig , op )(other )
360
- assert_true ( seq is not SEQ_DATA ['seq' ])
382
+ assert seq is not SEQ_DATA ['seq' ]
361
383
check_arr_seq (seq , [getattr (e1 , op )(e2 ) for e1 , e2 in zip (arrseq , other )])
362
384
363
385
# Operations between array sequences of different lengths.
364
386
orig = arrseq .copy ()
365
- assert_raises (ValueError , getattr (orig , op ), orig [::2 ])
387
+ with pytest .raises (ValueError ):
388
+ getattr (orig , op )(orig [::2 ])
366
389
367
390
# Operations between array sequences with different amount of data.
368
391
seq1 = ArraySequence (np .arange (10 ).reshape (5 , 2 ))
369
392
seq2 = ArraySequence (np .arange (15 ).reshape (5 , 3 ))
370
- assert_raises (ValueError , getattr (seq1 , op ), seq2 )
393
+ with pytest .raises (ValueError ):
394
+ getattr (seq1 , op )(seq2 )
371
395
372
396
# Operations between array sequences with different common shape.
373
397
seq1 = ArraySequence (np .arange (12 ).reshape (2 , 2 , 3 ))
374
398
seq2 = ArraySequence (np .arange (8 ).reshape (2 , 2 , 2 ))
375
- assert_raises (ValueError , getattr (seq1 , op ), seq2 )
399
+ with pytest .raises (ValueError ):
400
+ getattr (seq1 , op )(seq2 )
401
+
376
402
377
403
378
404
for op in ["__add__" , "__sub__" , "__mul__" , "__mod__" ,
@@ -394,24 +420,33 @@ def _test_binary(op, arrseq, scalars, seqs, inplace=False):
394
420
continue # Going to deal with it separately.
395
421
396
422
_test_binary (op , seq_int , [42 , - 3 , True , 0 ], [seq_int , seq_bool , - seq_int ], inplace = True ) # int <-- int
397
- assert_raises (TypeError , _test_binary , op , seq_int , [0.5 ], [], inplace = True ) # int <-- float
398
- assert_raises (TypeError , _test_binary , op , seq_int , [], [seq ], inplace = True ) # int <-- float
423
+
424
+ with pytest .raises (TypeError ):
425
+ _test_binary (op , seq_int , [0.5 ], [], inplace = True ) # int <-- float
426
+ _test_binary (op , seq_int , [], [seq ], inplace = True ) # int <-- float
427
+
399
428
400
429
# __pow__ : Integers to negative integer powers are not allowed.
401
430
_test_binary ("__pow__" , seq , [42 , - 3 , True , 0 ], [seq_int , seq_bool , - seq_int ])
402
431
_test_binary ("__ipow__" , seq , [42 , - 3 , True , 0 ], [seq_int , seq_bool , - seq_int ], inplace = True )
403
- assert_raises (ValueError , _test_binary , "__pow__" , seq_int , [- 3 ], [])
404
- assert_raises (ValueError , _test_binary , "__ipow__" , seq_int , [- 3 ], [], inplace = True )
405
-
432
+
433
+ with pytest .raises (ValueError ):
434
+ _test_binary ("__pow__" , seq_int , [- 3 ], [])
435
+ _test_binary ("__ipow__" , seq_int , [- 3 ], [], inplace = True )
436
+
406
437
# __itruediv__ is only valid with float arrseq.
407
438
for scalar in SCALARS + ARRSEQS :
408
- assert_raises (TypeError , getattr (seq_int .copy (), "__itruediv__" ), scalar )
439
+ with pytest .raises (TypeError ):
440
+ seq_int_cp = seq_int .copy ()
441
+ seq_int_cp .__itruediv__ (scalar )
409
442
410
443
# Bitwise operators
411
444
for op in ("__lshift__" , "__rshift__" , "__or__" , "__and__" , "__xor__" ):
412
445
_test_binary (op , seq_bool , [42 , - 3 , True , 0 ], [seq_int , seq_bool , - seq_int ])
413
- assert_raises (TypeError , _test_binary , op , seq_bool , [0.5 ], [])
414
- assert_raises (TypeError , _test_binary , op , seq , [], [seq ])
446
+
447
+ with pytest .raises (TypeError ):
448
+ _test_binary (op , seq_bool , [0.5 ], [])
449
+ _test_binary (op , seq , [], [seq ])
415
450
416
451
# Unary operators
417
452
for op in ["__neg__" , "__abs__" ]:
@@ -422,7 +457,8 @@ def _test_binary(op, arrseq, scalars, seqs, inplace=False):
422
457
423
458
_test_unary ("__abs__" , seq_bool )
424
459
_test_unary ("__invert__" , seq_bool )
425
- assert_raises (TypeError , _test_unary , "__invert__" , seq )
460
+ with pytest .raises (TypeError ):
461
+ _test_unary ("__invert__" , seq )
426
462
427
463
# Restore flags.
428
464
np .seterr (** flags )
@@ -442,7 +478,7 @@ def test_arraysequence_repr(self):
442
478
txt1 = repr (seq )
443
479
np .set_printoptions (threshold = nb_arrays // 2 )
444
480
txt2 = repr (seq )
445
- assert_true ( len (txt2 ) < len (txt1 ) )
481
+ assert len (txt2 ) < len (txt1 )
446
482
np .set_printoptions (threshold = bkp_threshold )
447
483
448
484
def test_save_and_load_arraysequence (self ):
@@ -485,10 +521,10 @@ def test_concatenate():
485
521
new_seq = concatenate (seqs , axis = 1 )
486
522
seq ._data += 100 # Modifying the 'seq' shouldn't change 'new_seq'.
487
523
check_arr_seq (new_seq , SEQ_DATA ['data' ])
488
- assert_true ( not new_seq ._is_view )
524
+ assert new_seq ._is_view is not True
489
525
490
526
seq = SEQ_DATA ['seq' ]
491
527
seqs = [seq [:, [i ]] for i in range (seq .common_shape [0 ])]
492
528
new_seq = concatenate (seqs , axis = 0 )
493
- assert_true ( len (new_seq ), seq .common_shape [0 ] * len (seq ) )
529
+ assert len (new_seq ) == seq .common_shape [0 ] * len (seq )
494
530
assert_array_equal (new_seq ._data , seq ._data .T .reshape ((- 1 , 1 )))
0 commit comments