Skip to content

TEST: pytest conversion #864 #865 #870

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Feb 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
481fb10
TEST: pytest conversion #864 #865
robbisg Feb 4, 2020
1bc0402
TEST: test_tractogram to pytest #865 #864
robbisg Feb 4, 2020
af1b3af
Update nibabel/streamlines/tests/test_array_sequence.py
robbisg Feb 4, 2020
9dc037e
Update nibabel/streamlines/tests/test_array_sequence.py
robbisg Feb 4, 2020
336425e
Update nibabel/streamlines/tests/test_array_sequence.py
robbisg Feb 4, 2020
3150e28
Update nibabel/streamlines/tests/test_array_sequence.py
robbisg Feb 4, 2020
8d8fb8e
Update nibabel/streamlines/tests/test_array_sequence.py
robbisg Feb 4, 2020
9628747
Update nibabel/streamlines/tests/test_array_sequence.py
robbisg Feb 4, 2020
8e314bd
Update nibabel/streamlines/tests/test_array_sequence.py
robbisg Feb 4, 2020
df9f0cd
Update nibabel/streamlines/tests/test_array_sequence.py
robbisg Feb 4, 2020
41fa3b6
Update nibabel/streamlines/tests/test_array_sequence.py
robbisg Feb 4, 2020
552ca69
Update nibabel/streamlines/tests/test_tractogram.py
robbisg Feb 4, 2020
a9c306f
Update nibabel/streamlines/tests/test_tractogram.py
robbisg Feb 4, 2020
ca6bbc0
Update nibabel/streamlines/tests/test_tractogram.py
robbisg Feb 4, 2020
a1c640f
Update nibabel/streamlines/tests/test_tractogram.py
robbisg Feb 4, 2020
3e7f6ab
Update nibabel/streamlines/tests/test_tractogram.py
robbisg Feb 4, 2020
e2d5d50
Update nibabel/streamlines/tests/test_tractogram.py
robbisg Feb 4, 2020
ac87b0c
Update nibabel/streamlines/tests/test_tractogram.py
robbisg Feb 4, 2020
cc650b3
Update nibabel/streamlines/tests/test_tractogram.py
robbisg Feb 4, 2020
bf40fa4
Update nibabel/streamlines/tests/test_tractogram.py
robbisg Feb 4, 2020
50f9ea5
Update nibabel/streamlines/tests/test_tractogram.py
robbisg Feb 4, 2020
5a8e9b7
Update nibabel/streamlines/tests/test_tractogram.py
robbisg Feb 4, 2020
6824ec5
Update nibabel/streamlines/tests/test_tractogram.py
robbisg Feb 4, 2020
6a77483
Update test_array_sequence.py
robbisg Feb 4, 2020
6d3155f
Update test_tractogram.py
robbisg Feb 4, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 93 additions & 57 deletions nibabel/streamlines/tests/test_array_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import itertools
import numpy as np

from nose.tools import assert_equal, assert_raises, assert_true
from nibabel.testing import assert_arrays_equal
import pytest
from ...testing_pytest import assert_arrays_equal
from numpy.testing import assert_array_equal

from ..array_sequence import ArraySequence, is_array_sequence, concatenate
Expand All @@ -15,7 +15,7 @@
SEQ_DATA = {}


def setup():
def setup_module():
global SEQ_DATA
rng = np.random.RandomState(42)
SEQ_DATA['rng'] = rng
Expand All @@ -30,22 +30,25 @@ def generate_data(nb_arrays, common_shape, rng):


def check_empty_arr_seq(seq):
assert_equal(len(seq), 0)
assert_equal(len(seq._offsets), 0)
assert_equal(len(seq._lengths), 0)
assert len(seq) == 0
assert len(seq._offsets) == 0
assert len(seq._lengths) == 0
# assert_equal(seq._data.ndim, 0)
assert_equal(seq._data.ndim, 1)
assert_true(seq.common_shape == ())
assert seq._data.ndim == 1

# TODO: Check assert_true
# assert_true(seq.common_shape == ())


def check_arr_seq(seq, arrays):
lengths = list(map(len, arrays))
assert_true(is_array_sequence(seq))
assert_equal(len(seq), len(arrays))
assert_equal(len(seq._offsets), len(arrays))
assert_equal(len(seq._lengths), len(arrays))
assert_equal(seq._data.shape[1:], arrays[0].shape[1:])
assert_equal(seq.common_shape, arrays[0].shape[1:])
assert is_array_sequence(seq)
assert len(seq) == len(arrays)
assert len(seq._offsets) == len(arrays)
assert len(seq._lengths) == len(arrays)
assert seq._data.shape[1:] == arrays[0].shape[1:]
assert seq.common_shape == arrays[0].shape[1:]

assert_arrays_equal(seq, arrays)

# If seq is a view, then order of internal data is not guaranteed.
Expand All @@ -54,18 +57,20 @@ def check_arr_seq(seq, arrays):
assert_array_equal(sorted(seq._lengths), sorted(lengths))
else:
seq.shrink_data()
assert_equal(seq._data.shape[0], sum(lengths))

assert seq._data.shape[0] == sum(lengths)

assert_array_equal(seq._data, np.concatenate(arrays, axis=0))
assert_array_equal(seq._offsets, np.r_[0, np.cumsum(lengths)[:-1]])
assert_array_equal(seq._lengths, lengths)


def check_arr_seq_view(seq_view, seq):
assert_true(seq_view._is_view)
assert_true(seq_view is not seq)
assert_true(np.may_share_memory(seq_view._data, seq._data))
assert_true(seq_view._offsets is not seq._offsets)
assert_true(seq_view._lengths is not seq._lengths)
assert seq_view._is_view is True
assert seq_view is not seq
assert (np.may_share_memory(seq_view._data, seq._data)) is True
assert seq_view._offsets is not seq._offsets
assert seq_view._lengths is not seq._lengths


class TestArraySequence(unittest.TestCase):
Expand Down Expand Up @@ -97,8 +102,8 @@ def test_creating_arraysequence_from_generator(self):
seq_with_buffer = ArraySequence(gen_2, buffer_size=256)

# Check buffer size effect
assert_equal(seq_with_buffer.data.shape, seq.data.shape)
assert_true(seq_with_buffer._buffer_size > seq._buffer_size)
assert seq_with_buffer.data.shape == seq.data.shape
assert seq_with_buffer._buffer_size > seq._buffer_size

# Check generator result
check_arr_seq(seq, SEQ_DATA['data'])
Expand All @@ -121,26 +126,27 @@ def test_arraysequence_iter(self):
# Try iterating through a corrupted ArraySequence object.
seq = SEQ_DATA['seq'].copy()
seq._lengths = seq._lengths[::2]
assert_raises(ValueError, list, seq)
with pytest.raises(ValueError):
list(seq)

def test_arraysequence_copy(self):
orig = SEQ_DATA['seq']
seq = orig.copy()
n_rows = seq.total_nb_rows
assert_equal(n_rows, orig.total_nb_rows)
assert n_rows == orig.total_nb_rows
assert_array_equal(seq._data, orig._data[:n_rows])
assert_true(seq._data is not orig._data)
assert seq._data is not orig._data
assert_array_equal(seq._offsets, orig._offsets)
assert_true(seq._offsets is not orig._offsets)
assert seq._offsets is not orig._offsets
assert_array_equal(seq._lengths, orig._lengths)
assert_true(seq._lengths is not orig._lengths)
assert_equal(seq.common_shape, orig.common_shape)
assert seq._lengths is not orig._lengths
assert seq.common_shape == orig.common_shape

# Taking a copy of an `ArraySequence` generated by slicing.
# Only keep needed data.
seq = orig[::2].copy()
check_arr_seq(seq, SEQ_DATA['data'][::2])
assert_true(seq._data is not orig._data)
assert seq._data is not orig._data

def test_arraysequence_append(self):
element = generate_data(nb_arrays=1,
Expand Down Expand Up @@ -171,7 +177,9 @@ def test_arraysequence_append(self):
element = generate_data(nb_arrays=1,
common_shape=SEQ_DATA['seq'].common_shape*2,
rng=SEQ_DATA['rng'])[0]
assert_raises(ValueError, seq.append, element)

with pytest.raises(ValueError):
seq.append(element)

def test_arraysequence_extend(self):
new_data = generate_data(nb_arrays=10,
Expand Down Expand Up @@ -217,7 +225,8 @@ def test_arraysequence_extend(self):
common_shape=SEQ_DATA['seq'].common_shape*2,
rng=SEQ_DATA['rng'])
seq = SEQ_DATA['seq'].copy() # Copy because of in-place modification.
assert_raises(ValueError, seq.extend, data)
with pytest.raises(ValueError):
seq.extend(data)

# Extend after extracting some slice
working_slice = seq[:2]
Expand Down Expand Up @@ -262,7 +271,9 @@ def test_arraysequence_getitem(self):
for i, keep in enumerate(selection) if keep])

# Test invalid indexing
assert_raises(TypeError, SEQ_DATA['seq'].__getitem__, 'abc')
with pytest.raises(TypeError):
SEQ_DATA['seq']['abc']
#SEQ_DATA['seq'].abc

# Get specific columns.
seq_view = SEQ_DATA['seq'][:, 2]
Expand All @@ -285,7 +296,7 @@ def test_arraysequence_setitem(self):
# Setitem with a scalar.
seq = SEQ_DATA['seq'].copy()
seq[:] = 0
assert_true(seq._data.sum() == 0)
assert seq._data.sum() == 0

# Setitem with a list of ndarray.
seq = SEQ_DATA['seq'] * 0
Expand All @@ -295,12 +306,12 @@ def test_arraysequence_setitem(self):
# Setitem using tuple indexing.
seq = ArraySequence(np.arange(900).reshape((50,6,3)))
seq[:, 0] = 0
assert_true(seq._data[:, 0].sum() == 0)
assert seq._data[:, 0].sum() == 0

# Setitem using tuple indexing.
seq = ArraySequence(np.arange(900).reshape((50,6,3)))
seq[range(len(seq))] = 0
assert_true(seq._data.sum() == 0)
assert seq._data.sum() == 0

# Setitem of a slice using another slice.
seq = ArraySequence(np.arange(900).reshape((50,6,3)))
Expand All @@ -309,20 +320,26 @@ def test_arraysequence_setitem(self):

# Setitem between array sequences with different number of sequences.
seq = ArraySequence(np.arange(900).reshape((50,6,3)))
assert_raises(ValueError, seq.__setitem__, slice(0, 4), seq[5:10])
with pytest.raises(ValueError):
seq.__setitem__(slice(0, 4), seq[5:10])


# Setitem between array sequences with different amount of points.
seq1 = ArraySequence(np.arange(10).reshape(5, 2))
seq2 = ArraySequence(np.arange(15).reshape(5, 3))
assert_raises(ValueError, seq1.__setitem__, slice(0, 5), seq2)
with pytest.raises(ValueError):
seq1[0:5] = seq2

# Setitem between array sequences with different common shape.
seq1 = ArraySequence(np.arange(12).reshape(2, 2, 3))
seq2 = ArraySequence(np.arange(8).reshape(2, 2, 2))
assert_raises(ValueError, seq1.__setitem__, slice(0, 2), seq2)

with pytest.raises(ValueError):
seq1[0:2] = seq2

# Invalid index.
assert_raises(TypeError, seq.__setitem__, object(), None)
with pytest.raises(TypeError):
seq.__setitem__(object(), None)

def test_arraysequence_operators(self):
# Disable division per zero warnings.
Expand All @@ -341,36 +358,42 @@ def test_arraysequence_operators(self):
def _test_unary(op, arrseq):
orig = arrseq.copy()
seq = getattr(orig, op)()
assert_true(seq is not orig)
assert seq is not orig
check_arr_seq(seq, [getattr(d, op)() for d in orig])

def _test_binary(op, arrseq, scalars, seqs, inplace=False):
for scalar in scalars:
orig = arrseq.copy()
seq = getattr(orig, op)(scalar)
assert_true((seq is orig) if inplace else (seq is not orig))

assert (seq is orig) == inplace

check_arr_seq(seq, [getattr(e, op)(scalar) for e in arrseq])

# Test math operators with another ArraySequence.
for other in seqs:
orig = arrseq.copy()
seq = getattr(orig, op)(other)
assert_true(seq is not SEQ_DATA['seq'])
assert seq is not SEQ_DATA['seq']
check_arr_seq(seq, [getattr(e1, op)(e2) for e1, e2 in zip(arrseq, other)])

# Operations between array sequences of different lengths.
orig = arrseq.copy()
assert_raises(ValueError, getattr(orig, op), orig[::2])
with pytest.raises(ValueError):
getattr(orig, op)(orig[::2])

# Operations between array sequences with different amount of data.
seq1 = ArraySequence(np.arange(10).reshape(5, 2))
seq2 = ArraySequence(np.arange(15).reshape(5, 3))
assert_raises(ValueError, getattr(seq1, op), seq2)
with pytest.raises(ValueError):
getattr(seq1, op)(seq2)

# Operations between array sequences with different common shape.
seq1 = ArraySequence(np.arange(12).reshape(2, 2, 3))
seq2 = ArraySequence(np.arange(8).reshape(2, 2, 2))
assert_raises(ValueError, getattr(seq1, op), seq2)
with pytest.raises(ValueError):
getattr(seq1, op)(seq2)



for op in ["__add__", "__sub__", "__mul__", "__mod__",
Expand All @@ -392,24 +415,36 @@ def _test_binary(op, arrseq, scalars, seqs, inplace=False):
continue # Going to deal with it separately.

_test_binary(op, seq_int, [42, -3, True, 0], [seq_int, seq_bool, -seq_int], inplace=True) # int <-- int
assert_raises(TypeError, _test_binary, op, seq_int, [0.5], [], inplace=True) # int <-- float
assert_raises(TypeError, _test_binary, op, seq_int, [], [seq], inplace=True) # int <-- float

with pytest.raises(TypeError):
_test_binary(op, seq_int, [0.5], [], inplace=True) # int <-- float
with pytest.raises(TypeError):
_test_binary(op, seq_int, [], [seq], inplace=True) # int <-- float


# __pow__ : Integers to negative integer powers are not allowed.
_test_binary("__pow__", seq, [42, -3, True, 0], [seq_int, seq_bool, -seq_int])
_test_binary("__ipow__", seq, [42, -3, True, 0], [seq_int, seq_bool, -seq_int], inplace=True)
assert_raises(ValueError, _test_binary, "__pow__", seq_int, [-3], [])
assert_raises(ValueError, _test_binary, "__ipow__", seq_int, [-3], [], inplace=True)


with pytest.raises(ValueError):
_test_binary("__pow__", seq_int, [-3], [])
with pytest.raises(ValueError):
_test_binary("__ipow__", seq_int, [-3], [], inplace=True)

# __itruediv__ is only valid with float arrseq.
for scalar in SCALARS + ARRSEQS:
assert_raises(TypeError, getattr(seq_int.copy(), "__itruediv__"), scalar)
seq_int_cp = seq_int.copy()
with pytest.raises(TypeError):
seq_int_cp /= scalar

# Bitwise operators
for op in ("__lshift__", "__rshift__", "__or__", "__and__", "__xor__"):
_test_binary(op, seq_bool, [42, -3, True, 0], [seq_int, seq_bool, -seq_int])
assert_raises(TypeError, _test_binary, op, seq_bool, [0.5], [])
assert_raises(TypeError, _test_binary, op, seq, [], [seq])

with pytest.raises(TypeError):
_test_binary(op, seq_bool, [0.5], [])
with pytest.raises(TypeError):
_test_binary(op, seq, [], [seq])

# Unary operators
for op in ["__neg__", "__abs__"]:
Expand All @@ -420,7 +455,8 @@ def _test_binary(op, arrseq, scalars, seqs, inplace=False):

_test_unary("__abs__", seq_bool)
_test_unary("__invert__", seq_bool)
assert_raises(TypeError, _test_unary, "__invert__", seq)
with pytest.raises(TypeError):
_test_unary("__invert__", seq)

# Restore flags.
np.seterr(**flags)
Expand All @@ -440,7 +476,7 @@ def test_arraysequence_repr(self):
txt1 = repr(seq)
np.set_printoptions(threshold=nb_arrays//2)
txt2 = repr(seq)
assert_true(len(txt2) < len(txt1))
assert len(txt2) < len(txt1)
np.set_printoptions(threshold=bkp_threshold)

def test_save_and_load_arraysequence(self):
Expand Down Expand Up @@ -483,10 +519,10 @@ def test_concatenate():
new_seq = concatenate(seqs, axis=1)
seq._data += 100 # Modifying the 'seq' shouldn't change 'new_seq'.
check_arr_seq(new_seq, SEQ_DATA['data'])
assert_true(not new_seq._is_view)
assert new_seq._is_view is not True

seq = SEQ_DATA['seq']
seqs = [seq[:, [i]] for i in range(seq.common_shape[0])]
new_seq = concatenate(seqs, axis=0)
assert_true(len(new_seq), seq.common_shape[0] * len(seq))
assert len(new_seq) == seq.common_shape[0] * len(seq)
assert_array_equal(new_seq._data, seq._data.T.reshape((-1, 1)))
Loading