|
2 | 2 | import unittest
|
3 | 3 | import numpy as np
|
4 | 4 | import warnings
|
| 5 | +import operator |
5 | 6 |
|
6 | 7 | from nibabel.testing import assert_arrays_equal
|
7 | 8 | from nibabel.testing import clear_and_catch_warnings
|
@@ -130,6 +131,11 @@ def assert_tractogram_equal(t1, t2):
|
130 | 131 | t2.data_per_streamline, t2.data_per_point)
|
131 | 132 |
|
132 | 133 |
|
| 134 | +def extender(a, b): |
| 135 | + a.extend(b) |
| 136 | + return a |
| 137 | + |
| 138 | + |
133 | 139 | class TestPerArrayDict(unittest.TestCase):
|
134 | 140 |
|
135 | 141 | def test_per_array_dict_creation(self):
|
@@ -184,18 +190,40 @@ def test_getitem(self):
|
184 | 190 | def test_extend(self):
|
185 | 191 | sdict = PerArrayDict(len(DATA['tractogram']),
|
186 | 192 | DATA['data_per_streamline'])
|
| 193 | + |
| 194 | + new_data = {'mean_curvature': 2 * np.array(DATA['mean_curvature']), |
| 195 | + 'mean_torsion': 3 * np.array(DATA['mean_torsion']), |
| 196 | + 'mean_colors': 4 * np.array(DATA['mean_colors'])} |
187 | 197 | sdict2 = PerArrayDict(len(DATA['tractogram']),
|
188 |
| - DATA['data_per_streamline']) |
| 198 | + new_data) |
189 | 199 |
|
190 | 200 | sdict.extend(sdict2)
|
191 | 201 | assert_equal(len(sdict), len(sdict2))
|
192 |
| - for k, v in DATA['tractogram'].data_per_streamline.items(): |
193 |
| - assert_arrays_equal(sdict[k][:len(DATA['tractogram'])], v) |
194 |
| - assert_arrays_equal(sdict[k][len(DATA['tractogram']):], v) |
| 202 | + for k in DATA['tractogram'].data_per_streamline: |
| 203 | + assert_arrays_equal(sdict[k][:len(DATA['tractogram'])], |
| 204 | + DATA['tractogram'].data_per_streamline[k]) |
| 205 | + assert_arrays_equal(sdict[k][len(DATA['tractogram']):], |
| 206 | + new_data[k]) |
195 | 207 |
|
196 | 208 | # Test incompatible PerArrayDicts.
|
| 209 | + # Other dict is missing entries. |
197 | 210 | assert_raises(ValueError, sdict.extend, PerArrayDict())
|
198 | 211 |
|
| 212 | + # Other dict has more entries. |
| 213 | + new_data = {'mean_curvature': 2 * np.array(DATA['mean_curvature']), |
| 214 | + 'mean_torsion': 3 * np.array(DATA['mean_torsion']), |
| 215 | + 'mean_colors': 4 * np.array(DATA['mean_colors']), |
| 216 | + 'other': 5 * np.array(DATA['mean_colors'])} |
| 217 | + sdict2 = PerArrayDict(len(DATA['tractogram']), new_data) |
| 218 | + assert_raises(ValueError, sdict.extend, sdict2) |
| 219 | + |
| 220 | + # Other dict has the right number of entries but wrong shape. |
| 221 | + new_data = {'mean_curvature': 2 * np.array(DATA['mean_curvature']), |
| 222 | + 'mean_torsion': 3 * np.array(DATA['mean_torsion']), |
| 223 | + 'other': 4 * np.array(DATA['mean_torsion'])} |
| 224 | + sdict2 = PerArrayDict(len(DATA['tractogram']), new_data) |
| 225 | + assert_raises(ValueError, sdict.extend, sdict2) |
| 226 | + |
199 | 227 |
|
200 | 228 | class TestPerArraySequenceDict(unittest.TestCase):
|
201 | 229 |
|
@@ -251,17 +279,36 @@ def test_getitem(self):
|
251 | 279 | def test_extend(self):
|
252 | 280 | total_nb_rows = DATA['tractogram'].streamlines.total_nb_rows
|
253 | 281 | sdict = PerArraySequenceDict(total_nb_rows, DATA['data_per_point'])
|
254 |
| - sdict2 = PerArraySequenceDict(total_nb_rows, DATA['data_per_point']) |
| 282 | + |
| 283 | + new_data = {'colors': 2 * np.array(DATA['colors']), |
| 284 | + 'fa': 3 * np.array(DATA['fa'])} |
| 285 | + sdict2 = PerArraySequenceDict(total_nb_rows, new_data) |
255 | 286 |
|
256 | 287 | sdict.extend(sdict2)
|
257 | 288 | assert_equal(len(sdict), len(sdict2))
|
258 |
| - for k, v in DATA['tractogram'].data_per_point.items(): |
259 |
| - assert_arrays_equal(sdict[k][:len(DATA['tractogram'])], v) |
260 |
| - assert_arrays_equal(sdict[k][len(DATA['tractogram']):], v) |
| 289 | + for k in DATA['tractogram'].data_per_point: |
| 290 | + assert_arrays_equal(sdict[k][:len(DATA['tractogram'])], |
| 291 | + DATA['tractogram'].data_per_point[k]) |
| 292 | + assert_arrays_equal(sdict[k][len(DATA['tractogram']):], |
| 293 | + new_data[k]) |
261 | 294 |
|
262 | 295 | # Test incompatible PerArrayDicts.
|
| 296 | + # Other dict is missing entries. |
263 | 297 | assert_raises(ValueError, sdict.extend, PerArraySequenceDict())
|
264 | 298 |
|
| 299 | + # Other dict has more entries. |
| 300 | + new_data = {'colors': 2 * np.array(DATA['colors']), |
| 301 | + 'fa': 3 * np.array(DATA['fa']), |
| 302 | + 'other': 4 * np.array(DATA['fa'])} |
| 303 | + sdict2 = PerArraySequenceDict(total_nb_rows, new_data) |
| 304 | + assert_raises(ValueError, sdict.extend, sdict2) |
| 305 | + |
| 306 | + # Other dict has the right number of entries but wrong shape. |
| 307 | + new_data = {'colors': 2 * np.array(DATA['colors']), |
| 308 | + 'other': 2 * np.array(DATA['colors']),} |
| 309 | + sdict2 = PerArraySequenceDict(total_nb_rows, new_data) |
| 310 | + assert_raises(ValueError, sdict.extend, sdict2) |
| 311 | + |
265 | 312 |
|
266 | 313 | class TestLazyDict(unittest.TestCase):
|
267 | 314 |
|
@@ -603,19 +650,12 @@ def test_tractogram_extend(self):
|
603 | 650 | # Load tractogram that contains some metadata.
|
604 | 651 | t = DATA['tractogram'].copy()
|
605 | 652 |
|
606 |
| - # Double the tractogram. |
607 |
| - new_t = t + t |
608 |
| - assert_equal(len(new_t), 2*len(t)) |
609 |
| - assert_tractogram_equal(new_t[:len(t)], DATA['tractogram']) |
610 |
| - assert_tractogram_equal(new_t[len(t):], DATA['tractogram']) |
611 |
| - |
612 |
| - # Double the tractogram inplace. |
613 |
| - new_t = DATA['tractogram'].copy() |
614 |
| - new_t += t |
615 |
| - assert_equal(len(new_t), 2*len(t)) |
616 |
| - assert_tractogram_equal(new_t[:len(t)], DATA['tractogram']) |
617 |
| - assert_tractogram_equal(new_t[len(t):], DATA['tractogram']) |
618 |
| - |
| 653 | + for op, in_place in ((operator.add, False), (operator.iadd, True), (extender, True)): |
| 654 | + first_arg = t.copy() |
| 655 | + new_t = op(first_arg, t) |
| 656 | + assert_equal(new_t is first_arg, in_place) |
| 657 | + assert_tractogram_equal(new_t[:len(t)], DATA['tractogram']) |
| 658 | + assert_tractogram_equal(new_t[len(t):], DATA['tractogram']) |
619 | 659 |
|
620 | 660 | class TestLazyTractogram(unittest.TestCase):
|
621 | 661 |
|
@@ -690,7 +730,9 @@ def test_lazy_tractogram_getitem(self):
|
690 | 730 | def test_lazy_tractogram_extend(self):
|
691 | 731 | t = DATA['lazy_tractogram'].copy()
|
692 | 732 | new_t = DATA['lazy_tractogram'].copy()
|
693 |
| - assert_raises(NotImplementedError, new_t.__iadd__, t) |
| 733 | + |
| 734 | + for op in (operator.add, operator.iadd, extender): |
| 735 | + assert_raises(NotImplementedError, op, new_t, t) |
694 | 736 |
|
695 | 737 | def test_lazy_tractogram_len(self):
|
696 | 738 | modules = [module_tractogram] # Modules for which to catch warnings.
|
|
0 commit comments