Skip to content

Commit aec94cf

Browse files
committed
Addressed @matthew-brett's comments
1 parent bd3f921 commit aec94cf

File tree

2 files changed

+81
-26
lines changed

2 files changed

+81
-26
lines changed

nibabel/streamlines/tests/test_tractogram.py

Lines changed: 64 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import unittest
33
import numpy as np
44
import warnings
5+
import operator
56

67
from nibabel.testing import assert_arrays_equal
78
from nibabel.testing import clear_and_catch_warnings
@@ -130,6 +131,11 @@ def assert_tractogram_equal(t1, t2):
130131
t2.data_per_streamline, t2.data_per_point)
131132

132133

134+
def extender(a, b):
135+
a.extend(b)
136+
return a
137+
138+
133139
class TestPerArrayDict(unittest.TestCase):
134140

135141
def test_per_array_dict_creation(self):
@@ -184,18 +190,40 @@ def test_getitem(self):
184190
def test_extend(self):
185191
sdict = PerArrayDict(len(DATA['tractogram']),
186192
DATA['data_per_streamline'])
193+
194+
new_data = {'mean_curvature': 2 * np.array(DATA['mean_curvature']),
195+
'mean_torsion': 3 * np.array(DATA['mean_torsion']),
196+
'mean_colors': 4 * np.array(DATA['mean_colors'])}
187197
sdict2 = PerArrayDict(len(DATA['tractogram']),
188-
DATA['data_per_streamline'])
198+
new_data)
189199

190200
sdict.extend(sdict2)
191201
assert_equal(len(sdict), len(sdict2))
192-
for k, v in DATA['tractogram'].data_per_streamline.items():
193-
assert_arrays_equal(sdict[k][:len(DATA['tractogram'])], v)
194-
assert_arrays_equal(sdict[k][len(DATA['tractogram']):], v)
202+
for k in DATA['tractogram'].data_per_streamline:
203+
assert_arrays_equal(sdict[k][:len(DATA['tractogram'])],
204+
DATA['tractogram'].data_per_streamline[k])
205+
assert_arrays_equal(sdict[k][len(DATA['tractogram']):],
206+
new_data[k])
195207

196208
# Test incompatible PerArrayDicts.
209+
# Other dict is missing entries.
197210
assert_raises(ValueError, sdict.extend, PerArrayDict())
198211

212+
# Other dict has more entries.
213+
new_data = {'mean_curvature': 2 * np.array(DATA['mean_curvature']),
214+
'mean_torsion': 3 * np.array(DATA['mean_torsion']),
215+
'mean_colors': 4 * np.array(DATA['mean_colors']),
216+
'other': 5 * np.array(DATA['mean_colors'])}
217+
sdict2 = PerArrayDict(len(DATA['tractogram']), new_data)
218+
assert_raises(ValueError, sdict.extend, sdict2)
219+
220+
# Other dict has the right number of entries but wrong shape.
221+
new_data = {'mean_curvature': 2 * np.array(DATA['mean_curvature']),
222+
'mean_torsion': 3 * np.array(DATA['mean_torsion']),
223+
'other': 4 * np.array(DATA['mean_torsion'])}
224+
sdict2 = PerArrayDict(len(DATA['tractogram']), new_data)
225+
assert_raises(ValueError, sdict.extend, sdict2)
226+
199227

200228
class TestPerArraySequenceDict(unittest.TestCase):
201229

@@ -251,17 +279,36 @@ def test_getitem(self):
251279
def test_extend(self):
252280
total_nb_rows = DATA['tractogram'].streamlines.total_nb_rows
253281
sdict = PerArraySequenceDict(total_nb_rows, DATA['data_per_point'])
254-
sdict2 = PerArraySequenceDict(total_nb_rows, DATA['data_per_point'])
282+
283+
new_data = {'colors': 2 * np.array(DATA['colors']),
284+
'fa': 3 * np.array(DATA['fa'])}
285+
sdict2 = PerArraySequenceDict(total_nb_rows, new_data)
255286

256287
sdict.extend(sdict2)
257288
assert_equal(len(sdict), len(sdict2))
258-
for k, v in DATA['tractogram'].data_per_point.items():
259-
assert_arrays_equal(sdict[k][:len(DATA['tractogram'])], v)
260-
assert_arrays_equal(sdict[k][len(DATA['tractogram']):], v)
289+
for k in DATA['tractogram'].data_per_point:
290+
assert_arrays_equal(sdict[k][:len(DATA['tractogram'])],
291+
DATA['tractogram'].data_per_point[k])
292+
assert_arrays_equal(sdict[k][len(DATA['tractogram']):],
293+
new_data[k])
261294

262295
# Test incompatible PerArrayDicts.
296+
# Other dict is missing entries.
263297
assert_raises(ValueError, sdict.extend, PerArraySequenceDict())
264298

299+
# Other dict has more entries.
300+
new_data = {'colors': 2 * np.array(DATA['colors']),
301+
'fa': 3 * np.array(DATA['fa']),
302+
'other': 4 * np.array(DATA['fa'])}
303+
sdict2 = PerArraySequenceDict(total_nb_rows, new_data)
304+
assert_raises(ValueError, sdict.extend, sdict2)
305+
306+
# Other dict has the right number of entries but wrong shape.
307+
new_data = {'colors': 2 * np.array(DATA['colors']),
308+
'other': 2 * np.array(DATA['colors']),}
309+
sdict2 = PerArraySequenceDict(total_nb_rows, new_data)
310+
assert_raises(ValueError, sdict.extend, sdict2)
311+
265312

266313
class TestLazyDict(unittest.TestCase):
267314

@@ -603,19 +650,12 @@ def test_tractogram_extend(self):
603650
# Load tractogram that contains some metadata.
604651
t = DATA['tractogram'].copy()
605652

606-
# Double the tractogram.
607-
new_t = t + t
608-
assert_equal(len(new_t), 2*len(t))
609-
assert_tractogram_equal(new_t[:len(t)], DATA['tractogram'])
610-
assert_tractogram_equal(new_t[len(t):], DATA['tractogram'])
611-
612-
# Double the tractogram inplace.
613-
new_t = DATA['tractogram'].copy()
614-
new_t += t
615-
assert_equal(len(new_t), 2*len(t))
616-
assert_tractogram_equal(new_t[:len(t)], DATA['tractogram'])
617-
assert_tractogram_equal(new_t[len(t):], DATA['tractogram'])
618-
653+
for op, in_place in ((operator.add, False), (operator.iadd, True), (extender, True)):
654+
first_arg = t.copy()
655+
new_t = op(first_arg, t)
656+
assert_equal(new_t is first_arg, in_place)
657+
assert_tractogram_equal(new_t[:len(t)], DATA['tractogram'])
658+
assert_tractogram_equal(new_t[len(t):], DATA['tractogram'])
619659

620660
class TestLazyTractogram(unittest.TestCase):
621661

@@ -690,7 +730,9 @@ def test_lazy_tractogram_getitem(self):
690730
def test_lazy_tractogram_extend(self):
691731
t = DATA['lazy_tractogram'].copy()
692732
new_t = DATA['lazy_tractogram'].copy()
693-
assert_raises(NotImplementedError, new_t.__iadd__, t)
733+
734+
for op in (operator.add, operator.iadd, extender):
735+
assert_raises(NotImplementedError, op, new_t, t)
694736

695737
def test_lazy_tractogram_len(self):
696738
modules = [module_tractogram] # Modules for which to catch warnings.

nibabel/streamlines/tractogram.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,13 @@ def extend(self, other):
123123
other : :class:`PerArrayDict` object
124124
Its data will be appended to the data of this dictionary.
125125
126+
Returns
127+
-------
128+
None
129+
126130
Notes
127131
-----
128-
The entries in both dictionaries must match.
132+
The keys in both dictionaries must be the same.
129133
"""
130134
if sorted(self.keys()) != sorted(other.keys()):
131135
msg = ("Entry mismatched between the two PerArrayDict objects."
@@ -173,9 +177,13 @@ def extend(self, other):
173177
other : :class:`PerArraySequenceDict` object
174178
Its data will be appended to the data of this dictionary.
175179
180+
Returns
181+
-------
182+
None
183+
176184
Notes
177185
-----
178-
The entries in both dictionaries must match.
186+
The keys in both dictionaries must be the same.
179187
"""
180188
if sorted(self.keys()) != sorted(other.keys()):
181189
msg = ("Key mismatched between the two PerArrayDict objects."
@@ -481,10 +489,15 @@ def extend(self, other):
481489
other : :class:`Tractogram` object
482490
Its data will be appended to the data of this tractogram.
483491
492+
Returns
493+
-------
494+
None
495+
484496
Notes
485497
-----
486-
The entries of `self.data_per_streamline` and `self.data_per_point`
487-
must match those contained in the other tractogram.
498+
The entries in both dictionaries `self.data_per_streamline` and
499+
`self.data_per_point` must match respectively those contained in the .
500+
the other tractogram.
488501
"""
489502
self.streamlines.extend(other.streamlines)
490503
self.data_per_streamline.extend(other.data_per_streamline)

0 commit comments

Comments
 (0)