Skip to content

Commit 01b1b7e

Browse files
committed
TEST: pytest conversion #864 #865
1 parent 4ee9e34 commit 01b1b7e

File tree

1 file changed

+93
-57
lines changed

1 file changed

+93
-57
lines changed

nibabel/streamlines/tests/test_array_sequence.py

Lines changed: 93 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import itertools
66
import numpy as np
77

8-
from nose.tools import assert_equal, assert_raises, assert_true
9-
from nibabel.testing import assert_arrays_equal
8+
import pytest
9+
from nibabel.testing_pytest import assert_arrays_equal
1010
from numpy.testing import assert_array_equal
1111

1212
import pytest; pytestmark = pytest.mark.skip()
@@ -17,7 +17,7 @@
1717
SEQ_DATA = {}
1818

1919

20-
def setup():
20+
def setup_module():
2121
global SEQ_DATA
2222
rng = np.random.RandomState(42)
2323
SEQ_DATA['rng'] = rng
@@ -32,22 +32,25 @@ def generate_data(nb_arrays, common_shape, rng):
3232

3333

3434
def check_empty_arr_seq(seq):
35-
assert_equal(len(seq), 0)
36-
assert_equal(len(seq._offsets), 0)
37-
assert_equal(len(seq._lengths), 0)
35+
assert len(seq) == 0
36+
assert len(seq._offsets) == 0
37+
assert len(seq._lengths) == 0
3838
# assert_equal(seq._data.ndim, 0)
39-
assert_equal(seq._data.ndim, 1)
40-
assert_true(seq.common_shape == ())
39+
assert seq._data.ndim == 1
40+
41+
# TODO: Check assert_true
42+
# assert_true(seq.common_shape == ())
4143

4244

4345
def check_arr_seq(seq, arrays):
4446
lengths = list(map(len, arrays))
45-
assert_true(is_array_sequence(seq))
46-
assert_equal(len(seq), len(arrays))
47-
assert_equal(len(seq._offsets), len(arrays))
48-
assert_equal(len(seq._lengths), len(arrays))
49-
assert_equal(seq._data.shape[1:], arrays[0].shape[1:])
50-
assert_equal(seq.common_shape, arrays[0].shape[1:])
47+
assert is_array_sequence(seq) == True
48+
assert len(seq) == len(arrays)
49+
assert len(seq._offsets) == len(arrays)
50+
assert len(seq._lengths) == len(arrays)
51+
assert seq._data.shape[1:] == arrays[0].shape[1:]
52+
assert seq.common_shape == arrays[0].shape[1:]
53+
5154
assert_arrays_equal(seq, arrays)
5255

5356
# If seq is a view, then order of internal data is not guaranteed.
@@ -56,18 +59,20 @@ def check_arr_seq(seq, arrays):
5659
assert_array_equal(sorted(seq._lengths), sorted(lengths))
5760
else:
5861
seq.shrink_data()
59-
assert_equal(seq._data.shape[0], sum(lengths))
62+
63+
assert seq._data.shape[0] == sum(lengths)
64+
6065
assert_array_equal(seq._data, np.concatenate(arrays, axis=0))
6166
assert_array_equal(seq._offsets, np.r_[0, np.cumsum(lengths)[:-1]])
6267
assert_array_equal(seq._lengths, lengths)
6368

6469

6570
def check_arr_seq_view(seq_view, seq):
66-
assert_true(seq_view._is_view)
67-
assert_true(seq_view is not seq)
68-
assert_true(np.may_share_memory(seq_view._data, seq._data))
69-
assert_true(seq_view._offsets is not seq._offsets)
70-
assert_true(seq_view._lengths is not seq._lengths)
71+
assert seq_view._is_view is True
72+
assert (seq_view is not seq) is True
73+
assert (np.may_share_memory(seq_view._data, seq._data)) is True
74+
assert seq_view._offsets is not seq._offsets
75+
assert seq_view._lengths is not seq._lengths
7176

7277

7378
class TestArraySequence(unittest.TestCase):
@@ -99,8 +104,8 @@ def test_creating_arraysequence_from_generator(self):
99104
seq_with_buffer = ArraySequence(gen_2, buffer_size=256)
100105

101106
# Check buffer size effect
102-
assert_equal(seq_with_buffer.data.shape, seq.data.shape)
103-
assert_true(seq_with_buffer._buffer_size > seq._buffer_size)
107+
assert seq_with_buffer.data.shape == seq.data.shape
108+
assert seq_with_buffer._buffer_size > seq._buffer_size
104109

105110
# Check generator result
106111
check_arr_seq(seq, SEQ_DATA['data'])
@@ -123,26 +128,27 @@ def test_arraysequence_iter(self):
123128
# Try iterating through a corrupted ArraySequence object.
124129
seq = SEQ_DATA['seq'].copy()
125130
seq._lengths = seq._lengths[::2]
126-
assert_raises(ValueError, list, seq)
131+
with pytest.raises(ValueError):
132+
list(seq)
127133

128134
def test_arraysequence_copy(self):
129135
orig = SEQ_DATA['seq']
130136
seq = orig.copy()
131137
n_rows = seq.total_nb_rows
132-
assert_equal(n_rows, orig.total_nb_rows)
138+
assert n_rows == orig.total_nb_rows
133139
assert_array_equal(seq._data, orig._data[:n_rows])
134-
assert_true(seq._data is not orig._data)
140+
assert seq._data is not orig._data
135141
assert_array_equal(seq._offsets, orig._offsets)
136-
assert_true(seq._offsets is not orig._offsets)
142+
assert seq._offsets is not orig._offsets
137143
assert_array_equal(seq._lengths, orig._lengths)
138-
assert_true(seq._lengths is not orig._lengths)
139-
assert_equal(seq.common_shape, orig.common_shape)
144+
assert seq._lengths is not orig._lengths
145+
assert seq.common_shape == orig.common_shape
140146

141147
# Taking a copy of an `ArraySequence` generated by slicing.
142148
# Only keep needed data.
143149
seq = orig[::2].copy()
144150
check_arr_seq(seq, SEQ_DATA['data'][::2])
145-
assert_true(seq._data is not orig._data)
151+
assert seq._data is not orig._data
146152

147153
def test_arraysequence_append(self):
148154
element = generate_data(nb_arrays=1,
@@ -173,7 +179,9 @@ def test_arraysequence_append(self):
173179
element = generate_data(nb_arrays=1,
174180
common_shape=SEQ_DATA['seq'].common_shape*2,
175181
rng=SEQ_DATA['rng'])[0]
176-
assert_raises(ValueError, seq.append, element)
182+
183+
with pytest.raises(ValueError):
184+
seq.append(element)
177185

178186
def test_arraysequence_extend(self):
179187
new_data = generate_data(nb_arrays=10,
@@ -219,7 +227,8 @@ def test_arraysequence_extend(self):
219227
common_shape=SEQ_DATA['seq'].common_shape*2,
220228
rng=SEQ_DATA['rng'])
221229
seq = SEQ_DATA['seq'].copy() # Copy because of in-place modification.
222-
assert_raises(ValueError, seq.extend, data)
230+
with pytest.raises(ValueError):
231+
seq.extend(data)
223232

224233
# Extend after extracting some slice
225234
working_slice = seq[:2]
@@ -264,7 +273,9 @@ def test_arraysequence_getitem(self):
264273
for i, keep in enumerate(selection) if keep])
265274

266275
# Test invalid indexing
267-
assert_raises(TypeError, SEQ_DATA['seq'].__getitem__, 'abc')
276+
with pytest.raises(TypeError):
277+
SEQ_DATA['seq'].__getitem__('abc')
278+
#SEQ_DATA['seq'].abc
268279

269280
# Get specific columns.
270281
seq_view = SEQ_DATA['seq'][:, 2]
@@ -287,7 +298,7 @@ def test_arraysequence_setitem(self):
287298
# Setitem with a scalar.
288299
seq = SEQ_DATA['seq'].copy()
289300
seq[:] = 0
290-
assert_true(seq._data.sum() == 0)
301+
assert seq._data.sum() == 0
291302

292303
# Setitem with a list of ndarray.
293304
seq = SEQ_DATA['seq'] * 0
@@ -297,12 +308,12 @@ def test_arraysequence_setitem(self):
297308
# Setitem using tuple indexing.
298309
seq = ArraySequence(np.arange(900).reshape((50,6,3)))
299310
seq[:, 0] = 0
300-
assert_true(seq._data[:, 0].sum() == 0)
311+
assert seq._data[:, 0].sum() == 0
301312

302313
# Setitem using tuple indexing.
303314
seq = ArraySequence(np.arange(900).reshape((50,6,3)))
304315
seq[range(len(seq))] = 0
305-
assert_true(seq._data.sum() == 0)
316+
assert seq._data.sum() == 0
306317

307318
# Setitem of a slice using another slice.
308319
seq = ArraySequence(np.arange(900).reshape((50,6,3)))
@@ -311,20 +322,26 @@ def test_arraysequence_setitem(self):
311322

312323
# Setitem between array sequences with different number of sequences.
313324
seq = ArraySequence(np.arange(900).reshape((50,6,3)))
314-
assert_raises(ValueError, seq.__setitem__, slice(0, 4), seq[5:10])
325+
with pytest.raises(ValueError):
326+
seq.__setitem__(slice(0, 4), seq[5:10])
327+
315328

316329
# Setitem between array sequences with different amount of points.
317330
seq1 = ArraySequence(np.arange(10).reshape(5, 2))
318331
seq2 = ArraySequence(np.arange(15).reshape(5, 3))
319-
assert_raises(ValueError, seq1.__setitem__, slice(0, 5), seq2)
332+
with pytest.raises(ValueError):
333+
seq1.__setitem__(slice(0, 5), seq2)
320334

321335
# Setitem between array sequences with different common shape.
322336
seq1 = ArraySequence(np.arange(12).reshape(2, 2, 3))
323337
seq2 = ArraySequence(np.arange(8).reshape(2, 2, 2))
324-
assert_raises(ValueError, seq1.__setitem__, slice(0, 2), seq2)
338+
339+
with pytest.raises(ValueError):
340+
seq1.__setitem__(slice(0, 2), seq2)
325341

326342
# Invalid index.
327-
assert_raises(TypeError, seq.__setitem__, object(), None)
343+
with pytest.raises(TypeError):
344+
seq.__setitem__(object(), None)
328345

329346
def test_arraysequence_operators(self):
330347
# Disable division per zero warnings.
@@ -343,36 +360,45 @@ def test_arraysequence_operators(self):
343360
def _test_unary(op, arrseq):
344361
orig = arrseq.copy()
345362
seq = getattr(orig, op)()
346-
assert_true(seq is not orig)
363+
assert seq is not orig
347364
check_arr_seq(seq, [getattr(d, op)() for d in orig])
348365

349366
def _test_binary(op, arrseq, scalars, seqs, inplace=False):
350367
for scalar in scalars:
351368
orig = arrseq.copy()
352369
seq = getattr(orig, op)(scalar)
353-
assert_true((seq is orig) if inplace else (seq is not orig))
370+
371+
if inplace:
372+
assert seq is orig
373+
else:
374+
assert seq is not orig
375+
354376
check_arr_seq(seq, [getattr(e, op)(scalar) for e in arrseq])
355377

356378
# Test math operators with another ArraySequence.
357379
for other in seqs:
358380
orig = arrseq.copy()
359381
seq = getattr(orig, op)(other)
360-
assert_true(seq is not SEQ_DATA['seq'])
382+
assert seq is not SEQ_DATA['seq']
361383
check_arr_seq(seq, [getattr(e1, op)(e2) for e1, e2 in zip(arrseq, other)])
362384

363385
# Operations between array sequences of different lengths.
364386
orig = arrseq.copy()
365-
assert_raises(ValueError, getattr(orig, op), orig[::2])
387+
with pytest.raises(ValueError):
388+
getattr(orig, op)(orig[::2])
366389

367390
# Operations between array sequences with different amount of data.
368391
seq1 = ArraySequence(np.arange(10).reshape(5, 2))
369392
seq2 = ArraySequence(np.arange(15).reshape(5, 3))
370-
assert_raises(ValueError, getattr(seq1, op), seq2)
393+
with pytest.raises(ValueError):
394+
getattr(seq1, op)(seq2)
371395

372396
# Operations between array sequences with different common shape.
373397
seq1 = ArraySequence(np.arange(12).reshape(2, 2, 3))
374398
seq2 = ArraySequence(np.arange(8).reshape(2, 2, 2))
375-
assert_raises(ValueError, getattr(seq1, op), seq2)
399+
with pytest.raises(ValueError):
400+
getattr(seq1, op)(seq2)
401+
376402

377403

378404
for op in ["__add__", "__sub__", "__mul__", "__mod__",
@@ -394,24 +420,33 @@ def _test_binary(op, arrseq, scalars, seqs, inplace=False):
394420
continue # Going to deal with it separately.
395421

396422
_test_binary(op, seq_int, [42, -3, True, 0], [seq_int, seq_bool, -seq_int], inplace=True) # int <-- int
397-
assert_raises(TypeError, _test_binary, op, seq_int, [0.5], [], inplace=True) # int <-- float
398-
assert_raises(TypeError, _test_binary, op, seq_int, [], [seq], inplace=True) # int <-- float
423+
424+
with pytest.raises(TypeError):
425+
_test_binary(op, seq_int, [0.5], [], inplace=True) # int <-- float
426+
_test_binary(op, seq_int, [], [seq], inplace=True) # int <-- float
427+
399428

400429
# __pow__ : Integers to negative integer powers are not allowed.
401430
_test_binary("__pow__", seq, [42, -3, True, 0], [seq_int, seq_bool, -seq_int])
402431
_test_binary("__ipow__", seq, [42, -3, True, 0], [seq_int, seq_bool, -seq_int], inplace=True)
403-
assert_raises(ValueError, _test_binary, "__pow__", seq_int, [-3], [])
404-
assert_raises(ValueError, _test_binary, "__ipow__", seq_int, [-3], [], inplace=True)
405-
432+
433+
with pytest.raises(ValueError):
434+
_test_binary("__pow__", seq_int, [-3], [])
435+
_test_binary("__ipow__", seq_int, [-3], [], inplace=True)
436+
406437
# __itruediv__ is only valid with float arrseq.
407438
for scalar in SCALARS + ARRSEQS:
408-
assert_raises(TypeError, getattr(seq_int.copy(), "__itruediv__"), scalar)
439+
with pytest.raises(TypeError):
440+
seq_int_cp = seq_int.copy()
441+
seq_int_cp.__itruediv__(scalar)
409442

410443
# Bitwise operators
411444
for op in ("__lshift__", "__rshift__", "__or__", "__and__", "__xor__"):
412445
_test_binary(op, seq_bool, [42, -3, True, 0], [seq_int, seq_bool, -seq_int])
413-
assert_raises(TypeError, _test_binary, op, seq_bool, [0.5], [])
414-
assert_raises(TypeError, _test_binary, op, seq, [], [seq])
446+
447+
with pytest.raises(TypeError):
448+
_test_binary(op, seq_bool, [0.5], [])
449+
_test_binary(op, seq, [], [seq])
415450

416451
# Unary operators
417452
for op in ["__neg__", "__abs__"]:
@@ -422,7 +457,8 @@ def _test_binary(op, arrseq, scalars, seqs, inplace=False):
422457

423458
_test_unary("__abs__", seq_bool)
424459
_test_unary("__invert__", seq_bool)
425-
assert_raises(TypeError, _test_unary, "__invert__", seq)
460+
with pytest.raises(TypeError):
461+
_test_unary("__invert__", seq)
426462

427463
# Restore flags.
428464
np.seterr(**flags)
@@ -442,7 +478,7 @@ def test_arraysequence_repr(self):
442478
txt1 = repr(seq)
443479
np.set_printoptions(threshold=nb_arrays//2)
444480
txt2 = repr(seq)
445-
assert_true(len(txt2) < len(txt1))
481+
assert len(txt2) < len(txt1)
446482
np.set_printoptions(threshold=bkp_threshold)
447483

448484
def test_save_and_load_arraysequence(self):
@@ -485,10 +521,10 @@ def test_concatenate():
485521
new_seq = concatenate(seqs, axis=1)
486522
seq._data += 100 # Modifying the 'seq' shouldn't change 'new_seq'.
487523
check_arr_seq(new_seq, SEQ_DATA['data'])
488-
assert_true(not new_seq._is_view)
524+
assert new_seq._is_view is not True
489525

490526
seq = SEQ_DATA['seq']
491527
seqs = [seq[:, [i]] for i in range(seq.common_shape[0])]
492528
new_seq = concatenate(seqs, axis=0)
493-
assert_true(len(new_seq), seq.common_shape[0] * len(seq))
529+
assert len(new_seq) == seq.common_shape[0] * len(seq)
494530
assert_array_equal(new_seq._data, seq._data.T.reshape((-1, 1)))

0 commit comments

Comments
 (0)