Skip to content

Commit e7f23f6

Browse files
Merge pull request #495 from MarcCote/enh_tractogram_operators
MRG: Support + and += operators for Tractogram This PR adds the functionality of concatenating two Tractogram objects using either `tractogram += other_tractogram` or `tractogram = tractogram1 + tractogram2`.
2 parents 1d27aef + 8f8fd5a commit e7f23f6

File tree

2 files changed

+338
-56
lines changed

2 files changed

+338
-56
lines changed

nibabel/streamlines/tests/test_tractogram.py

Lines changed: 252 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import sys
2+
import copy
23
import unittest
34
import numpy as np
45
import warnings
6+
import operator
7+
from collections import defaultdict
58

69
from nibabel.testing import assert_arrays_equal
710
from nibabel.testing import clear_and_catch_warnings
@@ -16,37 +19,99 @@
1619
DATA = {}
1720

1821

22+
def make_fake_streamline(nb_points, data_per_point_shapes={},
23+
data_for_streamline_shapes={}, rng=None):
24+
""" Make a single streamline according to provided requirements. """
25+
if rng is None:
26+
rng = np.random.RandomState()
27+
28+
streamline = rng.randn(nb_points, 3).astype("f4")
29+
30+
data_per_point = {}
31+
for k, shape in data_per_point_shapes.items():
32+
data_per_point[k] = rng.randn(*((nb_points,) + shape)).astype("f4")
33+
34+
data_for_streamline = {}
35+
for k, shape in data_for_streamline.items():
36+
data_for_streamline[k] = rng.randn(*shape).astype("f4")
37+
38+
return streamline, data_per_point, data_for_streamline
39+
40+
41+
def make_fake_tractogram(list_nb_points, data_per_point_shapes={},
42+
data_for_streamline_shapes={}, rng=None):
43+
""" Make multiple streamlines according to provided requirements. """
44+
all_streamlines = []
45+
all_data_per_point = defaultdict(lambda: [])
46+
all_data_per_streamline = defaultdict(lambda: [])
47+
for nb_points in list_nb_points:
48+
data = make_fake_streamline(nb_points, data_per_point_shapes,
49+
data_for_streamline_shapes, rng)
50+
streamline, data_per_point, data_for_streamline = data
51+
52+
all_streamlines.append(streamline)
53+
for k, v in data_per_point.items():
54+
all_data_per_point[k].append(v)
55+
56+
for k, v in data_for_streamline.items():
57+
all_data_per_streamline[k].append(v)
58+
59+
return all_streamlines, all_data_per_point, all_data_per_streamline
60+
61+
62+
def make_dummy_streamline(nb_points):
63+
""" Make the streamlines that have been used to create test data files."""
64+
if nb_points == 1:
65+
streamline = np.arange(1*3, dtype="f4").reshape((1, 3))
66+
data_per_point = {"fa": np.array([[0.2]], dtype="f4"),
67+
"colors": np.array([(1, 0, 0)]*1, dtype="f4")}
68+
data_for_streamline = {"mean_curvature": np.array([1.11], dtype="f4"),
69+
"mean_torsion": np.array([1.22], dtype="f4"),
70+
"mean_colors": np.array([1, 0, 0], dtype="f4")}
71+
72+
elif nb_points == 2:
73+
streamline = np.arange(2*3, dtype="f4").reshape((2, 3))
74+
data_per_point = {"fa": np.array([[0.3],
75+
[0.4]], dtype="f4"),
76+
"colors": np.array([(0, 1, 0)]*2, dtype="f4")}
77+
data_for_streamline = {"mean_curvature": np.array([2.11], dtype="f4"),
78+
"mean_torsion": np.array([2.22], dtype="f4"),
79+
"mean_colors": np.array([0, 1, 0], dtype="f4")}
80+
81+
elif nb_points == 5:
82+
streamline = np.arange(5*3, dtype="f4").reshape((5, 3))
83+
data_per_point = {"fa": np.array([[0.5],
84+
[0.6],
85+
[0.6],
86+
[0.7],
87+
[0.8]], dtype="f4"),
88+
"colors": np.array([(0, 0, 1)]*5, dtype="f4")}
89+
data_for_streamline = {"mean_curvature": np.array([3.11], dtype="f4"),
90+
"mean_torsion": np.array([3.22], dtype="f4"),
91+
"mean_colors": np.array([0, 0, 1], dtype="f4")}
92+
93+
return streamline, data_per_point, data_for_streamline
94+
95+
1996
def setup():
2097
global DATA
2198
DATA['rng'] = np.random.RandomState(1234)
22-
DATA['streamlines'] = [np.arange(1*3, dtype="f4").reshape((1, 3)),
23-
np.arange(2*3, dtype="f4").reshape((2, 3)),
24-
np.arange(5*3, dtype="f4").reshape((5, 3))]
25-
26-
DATA['fa'] = [np.array([[0.2]], dtype="f4"),
27-
np.array([[0.3],
28-
[0.4]], dtype="f4"),
29-
np.array([[0.5],
30-
[0.6],
31-
[0.6],
32-
[0.7],
33-
[0.8]], dtype="f4")]
34-
35-
DATA['colors'] = [np.array([(1, 0, 0)]*1, dtype="f4"),
36-
np.array([(0, 1, 0)]*2, dtype="f4"),
37-
np.array([(0, 0, 1)]*5, dtype="f4")]
38-
39-
DATA['mean_curvature'] = [np.array([1.11], dtype="f4"),
40-
np.array([2.11], dtype="f4"),
41-
np.array([3.11], dtype="f4")]
42-
43-
DATA['mean_torsion'] = [np.array([1.22], dtype="f4"),
44-
np.array([2.22], dtype="f4"),
45-
np.array([3.22], dtype="f4")]
46-
47-
DATA['mean_colors'] = [np.array([1, 0, 0], dtype="f4"),
48-
np.array([0, 1, 0], dtype="f4"),
49-
np.array([0, 0, 1], dtype="f4")]
99+
100+
DATA['streamlines'] = []
101+
DATA['fa'] = []
102+
DATA['colors'] = []
103+
DATA['mean_curvature'] = []
104+
DATA['mean_torsion'] = []
105+
DATA['mean_colors'] = []
106+
for nb_points in [1, 2, 5]:
107+
data = make_dummy_streamline(nb_points)
108+
streamline, data_per_point, data_for_streamline = data
109+
DATA['streamlines'].append(streamline)
110+
DATA['fa'].append(data_per_point['fa'])
111+
DATA['colors'].append(data_per_point['colors'])
112+
DATA['mean_curvature'].append(data_for_streamline['mean_curvature'])
113+
DATA['mean_torsion'].append(data_for_streamline['mean_torsion'])
114+
DATA['mean_colors'].append(data_for_streamline['mean_colors'])
50115

51116
DATA['data_per_point'] = {'colors': DATA['colors'],
52117
'fa': DATA['fa']}
@@ -63,17 +128,13 @@ def setup():
63128
affine_to_rasmm=np.eye(4))
64129

65130
DATA['streamlines_func'] = lambda: (e for e in DATA['streamlines'])
66-
fa_func = lambda: (e for e in DATA['fa'])
67-
colors_func = lambda: (e for e in DATA['colors'])
68-
mean_curvature_func = lambda: (e for e in DATA['mean_curvature'])
69-
mean_torsion_func = lambda: (e for e in DATA['mean_torsion'])
70-
mean_colors_func = lambda: (e for e in DATA['mean_colors'])
71-
72-
DATA['data_per_point_func'] = {'colors': colors_func,
73-
'fa': fa_func}
74-
DATA['data_per_streamline_func'] = {'mean_curvature': mean_curvature_func,
75-
'mean_torsion': mean_torsion_func,
76-
'mean_colors': mean_colors_func}
131+
DATA['data_per_point_func'] = {
132+
'colors': lambda: (e for e in DATA['colors']),
133+
'fa': lambda: (e for e in DATA['fa'])}
134+
DATA['data_per_streamline_func'] = {
135+
'mean_curvature': lambda: (e for e in DATA['mean_curvature']),
136+
'mean_torsion': lambda: (e for e in DATA['mean_torsion']),
137+
'mean_colors': lambda: (e for e in DATA['mean_colors'])}
77138

78139
DATA['lazy_tractogram'] = LazyTractogram(DATA['streamlines_func'],
79140
DATA['data_per_streamline_func'],
@@ -130,6 +191,11 @@ def assert_tractogram_equal(t1, t2):
130191
t2.data_per_streamline, t2.data_per_point)
131192

132193

194+
def extender(a, b):
195+
a.extend(b)
196+
return a
197+
198+
133199
class TestPerArrayDict(unittest.TestCase):
134200

135201
def test_per_array_dict_creation(self):
@@ -181,6 +247,53 @@ def test_getitem(self):
181247
assert_arrays_equal(sdict[-1][k], v[-1])
182248
assert_arrays_equal(sdict[[0, -1]][k], v[[0, -1]])
183249

250+
def test_extend(self):
251+
sdict = PerArrayDict(len(DATA['tractogram']),
252+
DATA['data_per_streamline'])
253+
254+
new_data = {'mean_curvature': 2 * np.array(DATA['mean_curvature']),
255+
'mean_torsion': 3 * np.array(DATA['mean_torsion']),
256+
'mean_colors': 4 * np.array(DATA['mean_colors'])}
257+
sdict2 = PerArrayDict(len(DATA['tractogram']),
258+
new_data)
259+
260+
sdict.extend(sdict2)
261+
assert_equal(len(sdict), len(sdict2))
262+
for k in DATA['tractogram'].data_per_streamline:
263+
assert_arrays_equal(sdict[k][:len(DATA['tractogram'])],
264+
DATA['tractogram'].data_per_streamline[k])
265+
assert_arrays_equal(sdict[k][len(DATA['tractogram']):],
266+
new_data[k])
267+
268+
# Extending with an empty PerArrayDict should change nothing.
269+
sdict_orig = copy.deepcopy(sdict)
270+
sdict.extend(PerArrayDict())
271+
for k in sdict_orig.keys():
272+
assert_arrays_equal(sdict[k], sdict_orig[k])
273+
274+
# Test incompatible PerArrayDicts.
275+
# Other dict has more entries.
276+
new_data = {'mean_curvature': 2 * np.array(DATA['mean_curvature']),
277+
'mean_torsion': 3 * np.array(DATA['mean_torsion']),
278+
'mean_colors': 4 * np.array(DATA['mean_colors']),
279+
'other': 5 * np.array(DATA['mean_colors'])}
280+
sdict2 = PerArrayDict(len(DATA['tractogram']), new_data)
281+
assert_raises(ValueError, sdict.extend, sdict2)
282+
283+
# Other dict has not the same entries (key mistmached).
284+
new_data = {'mean_curvature': 2 * np.array(DATA['mean_curvature']),
285+
'mean_torsion': 3 * np.array(DATA['mean_torsion']),
286+
'other': 4 * np.array(DATA['mean_colors'])}
287+
sdict2 = PerArrayDict(len(DATA['tractogram']), new_data)
288+
assert_raises(ValueError, sdict.extend, sdict2)
289+
290+
# Other dict has the right number of entries but wrong shape.
291+
new_data = {'mean_curvature': 2 * np.array(DATA['mean_curvature']),
292+
'mean_torsion': 3 * np.array(DATA['mean_torsion']),
293+
'mean_colors': 4 * np.array(DATA['mean_torsion'])}
294+
sdict2 = PerArrayDict(len(DATA['tractogram']), new_data)
295+
assert_raises(ValueError, sdict.extend, sdict2)
296+
184297

185298
class TestPerArraySequenceDict(unittest.TestCase):
186299

@@ -233,6 +346,62 @@ def test_getitem(self):
233346
assert_arrays_equal(sdict[-1][k], v[-1])
234347
assert_arrays_equal(sdict[[0, -1]][k], v[[0, -1]])
235348

349+
def test_extend(self):
350+
total_nb_rows = DATA['tractogram'].streamlines.total_nb_rows
351+
sdict = PerArraySequenceDict(total_nb_rows, DATA['data_per_point'])
352+
353+
# Test compatible PerArraySequenceDicts.
354+
list_nb_points = [2, 7, 4]
355+
data_per_point_shapes = {"colors": DATA['colors'][0].shape[1:],
356+
"fa": DATA['fa'][0].shape[1:]}
357+
_, new_data, _ = make_fake_tractogram(list_nb_points,
358+
data_per_point_shapes,
359+
rng=DATA['rng'])
360+
sdict2 = PerArraySequenceDict(np.sum(list_nb_points), new_data)
361+
362+
sdict.extend(sdict2)
363+
assert_equal(len(sdict), len(sdict2))
364+
for k in DATA['tractogram'].data_per_point:
365+
assert_arrays_equal(sdict[k][:len(DATA['tractogram'])],
366+
DATA['tractogram'].data_per_point[k])
367+
assert_arrays_equal(sdict[k][len(DATA['tractogram']):],
368+
new_data[k])
369+
370+
# Extending with an empty PerArraySequenceDicts should change nothing.
371+
sdict_orig = copy.deepcopy(sdict)
372+
sdict.extend(PerArraySequenceDict())
373+
for k in sdict_orig.keys():
374+
assert_arrays_equal(sdict[k], sdict_orig[k])
375+
376+
# Test incompatible PerArraySequenceDicts.
377+
# Other dict has more entries.
378+
data_per_point_shapes = {"colors": DATA['colors'][0].shape[1:],
379+
"fa": DATA['fa'][0].shape[1:],
380+
"other": (7,)}
381+
_, new_data, _ = make_fake_tractogram(list_nb_points,
382+
data_per_point_shapes,
383+
rng=DATA['rng'])
384+
sdict2 = PerArraySequenceDict(np.sum(list_nb_points), new_data)
385+
assert_raises(ValueError, sdict.extend, sdict2)
386+
387+
# Other dict has not the same entries (key mistmached).
388+
data_per_point_shapes = {"colors": DATA['colors'][0].shape[1:],
389+
"other": DATA['fa'][0].shape[1:]}
390+
_, new_data, _ = make_fake_tractogram(list_nb_points,
391+
data_per_point_shapes,
392+
rng=DATA['rng'])
393+
sdict2 = PerArraySequenceDict(np.sum(list_nb_points), new_data)
394+
assert_raises(ValueError, sdict.extend, sdict2)
395+
396+
# Other dict has the right number of entries but wrong shape.
397+
data_per_point_shapes = {"colors": DATA['colors'][0].shape[1:],
398+
"fa": DATA['fa'][0].shape[1:] + (3,)}
399+
_, new_data, _ = make_fake_tractogram(list_nb_points,
400+
data_per_point_shapes,
401+
rng=DATA['rng'])
402+
sdict2 = PerArraySequenceDict(np.sum(list_nb_points), new_data)
403+
assert_raises(ValueError, sdict.extend, sdict2)
404+
236405

237406
class TestLazyDict(unittest.TestCase):
238407

@@ -570,6 +739,28 @@ def test_tractogram_to_world(self):
570739
tractogram.affine_to_rasmm = None
571740
assert_raises(ValueError, tractogram.to_world)
572741

742+
def test_tractogram_extend(self):
743+
# Load tractogram that contains some metadata.
744+
t = DATA['tractogram'].copy()
745+
746+
for op, in_place in ((operator.add, False), (operator.iadd, True),
747+
(extender, True)):
748+
first_arg = t.copy()
749+
new_t = op(first_arg, t)
750+
assert_equal(new_t is first_arg, in_place)
751+
assert_tractogram_equal(new_t[:len(t)], DATA['tractogram'])
752+
assert_tractogram_equal(new_t[len(t):], DATA['tractogram'])
753+
754+
# Test extending an empty Tractogram.
755+
t = Tractogram()
756+
t += DATA['tractogram']
757+
assert_tractogram_equal(t, DATA['tractogram'])
758+
759+
# and the other way around.
760+
t = DATA['tractogram'].copy()
761+
t += Tractogram()
762+
assert_tractogram_equal(t, DATA['tractogram'])
763+
573764

574765
class TestLazyTractogram(unittest.TestCase):
575766

@@ -580,11 +771,12 @@ def test_lazy_tractogram_creation(self):
580771
# Streamlines and other data as generators
581772
streamlines = (x for x in DATA['streamlines'])
582773
data_per_point = {"colors": (x for x in DATA['colors'])}
583-
data_per_streamline = {'mean_torsion': (x for x in DATA['mean_torsion']),
584-
'mean_colors': (x for x in DATA['mean_colors'])}
774+
data_per_streamline = {'torsion': (x for x in DATA['mean_torsion']),
775+
'colors': (x for x in DATA['mean_colors'])}
585776

586777
# Creating LazyTractogram with generators is not allowed as
587-
# generators get exhausted and are not reusable unlike generator function.
778+
# generators get exhausted and are not reusable unlike generator
779+
# function.
588780
assert_raises(TypeError, LazyTractogram, streamlines)
589781
assert_raises(TypeError, LazyTractogram,
590782
data_per_streamline=data_per_streamline)
@@ -610,12 +802,11 @@ def test_lazy_tractogram_creation(self):
610802

611803
def test_lazy_tractogram_from_data_func(self):
612804
# Create an empty `LazyTractogram` yielding nothing.
613-
_empty_data_gen = lambda: iter([])
614-
615-
tractogram = LazyTractogram.from_data_func(_empty_data_gen)
805+
tractogram = LazyTractogram.from_data_func(lambda: iter([]))
616806
check_tractogram(tractogram)
617807

618-
# Create `LazyTractogram` from a generator function yielding TractogramItem.
808+
# Create `LazyTractogram` from a generator function yielding
809+
# TractogramItem.
619810
data = [DATA['streamlines'], DATA['fa'], DATA['colors'],
620811
DATA['mean_curvature'], DATA['mean_torsion'],
621812
DATA['mean_colors']]
@@ -641,6 +832,13 @@ def test_lazy_tractogram_getitem(self):
641832
assert_raises(NotImplementedError,
642833
DATA['lazy_tractogram'].__getitem__, 0)
643834

835+
def test_lazy_tractogram_extend(self):
836+
t = DATA['lazy_tractogram'].copy()
837+
new_t = DATA['lazy_tractogram'].copy()
838+
839+
for op in (operator.add, operator.iadd, extender):
840+
assert_raises(NotImplementedError, op, new_t, t)
841+
644842
def test_lazy_tractogram_len(self):
645843
modules = [module_tractogram] # Modules for which to catch warnings.
646844
with clear_and_catch_warnings(record=True, modules=modules) as w:
@@ -746,8 +944,8 @@ def test_lazy_tractogram_copy(self):
746944
# Check we copied the data and not simply created new references.
747945
assert_true(tractogram is not DATA['lazy_tractogram'])
748946

749-
# When copying LazyTractogram, the generator function yielding streamlines
750-
# should stay the same.
947+
# When copying LazyTractogram, the generator function yielding
948+
# streamlines should stay the same.
751949
assert_true(tractogram._streamlines
752950
is DATA['lazy_tractogram']._streamlines)
753951

@@ -759,12 +957,14 @@ def test_lazy_tractogram_copy(self):
759957
is not DATA['lazy_tractogram']._data_per_point)
760958

761959
for key in tractogram.data_per_streamline:
762-
assert_true(tractogram.data_per_streamline.store[key]
763-
is DATA['lazy_tractogram'].data_per_streamline.store[key])
960+
data = tractogram.data_per_streamline.store[key]
961+
expected = DATA['lazy_tractogram'].data_per_streamline.store[key]
962+
assert_true(data is expected)
764963

765964
for key in tractogram.data_per_point:
766-
assert_true(tractogram.data_per_point.store[key]
767-
is DATA['lazy_tractogram'].data_per_point.store[key])
965+
data = tractogram.data_per_point.store[key]
966+
expected = DATA['lazy_tractogram'].data_per_point.store[key]
967+
assert_true(data is expected)
768968

769969
# The affine should be a copy.
770970
assert_true(tractogram._affine_to_apply

0 commit comments

Comments
 (0)