1
1
import sys
2
+ import copy
2
3
import unittest
3
4
import numpy as np
4
5
import warnings
6
+ import operator
7
+ from collections import defaultdict
5
8
6
9
from nibabel .testing import assert_arrays_equal
7
10
from nibabel .testing import clear_and_catch_warnings
16
19
DATA = {}
17
20
18
21
22
+ def make_fake_streamline (nb_points , data_per_point_shapes = {},
23
+ data_for_streamline_shapes = {}, rng = None ):
24
+ """ Make a single streamline according to provided requirements. """
25
+ if rng is None :
26
+ rng = np .random .RandomState ()
27
+
28
+ streamline = rng .randn (nb_points , 3 ).astype ("f4" )
29
+
30
+ data_per_point = {}
31
+ for k , shape in data_per_point_shapes .items ():
32
+ data_per_point [k ] = rng .randn (* ((nb_points ,) + shape )).astype ("f4" )
33
+
34
+ data_for_streamline = {}
35
+ for k , shape in data_for_streamline .items ():
36
+ data_for_streamline [k ] = rng .randn (* shape ).astype ("f4" )
37
+
38
+ return streamline , data_per_point , data_for_streamline
39
+
40
+
41
+ def make_fake_tractogram (list_nb_points , data_per_point_shapes = {},
42
+ data_for_streamline_shapes = {}, rng = None ):
43
+ """ Make multiple streamlines according to provided requirements. """
44
+ all_streamlines = []
45
+ all_data_per_point = defaultdict (lambda : [])
46
+ all_data_per_streamline = defaultdict (lambda : [])
47
+ for nb_points in list_nb_points :
48
+ data = make_fake_streamline (nb_points , data_per_point_shapes ,
49
+ data_for_streamline_shapes , rng )
50
+ streamline , data_per_point , data_for_streamline = data
51
+
52
+ all_streamlines .append (streamline )
53
+ for k , v in data_per_point .items ():
54
+ all_data_per_point [k ].append (v )
55
+
56
+ for k , v in data_for_streamline .items ():
57
+ all_data_per_streamline [k ].append (v )
58
+
59
+ return all_streamlines , all_data_per_point , all_data_per_streamline
60
+
61
+
62
+ def make_dummy_streamline (nb_points ):
63
+ """ Make the streamlines that have been used to create test data files."""
64
+ if nb_points == 1 :
65
+ streamline = np .arange (1 * 3 , dtype = "f4" ).reshape ((1 , 3 ))
66
+ data_per_point = {"fa" : np .array ([[0.2 ]], dtype = "f4" ),
67
+ "colors" : np .array ([(1 , 0 , 0 )]* 1 , dtype = "f4" )}
68
+ data_for_streamline = {"mean_curvature" : np .array ([1.11 ], dtype = "f4" ),
69
+ "mean_torsion" : np .array ([1.22 ], dtype = "f4" ),
70
+ "mean_colors" : np .array ([1 , 0 , 0 ], dtype = "f4" )}
71
+
72
+ elif nb_points == 2 :
73
+ streamline = np .arange (2 * 3 , dtype = "f4" ).reshape ((2 , 3 ))
74
+ data_per_point = {"fa" : np .array ([[0.3 ],
75
+ [0.4 ]], dtype = "f4" ),
76
+ "colors" : np .array ([(0 , 1 , 0 )]* 2 , dtype = "f4" )}
77
+ data_for_streamline = {"mean_curvature" : np .array ([2.11 ], dtype = "f4" ),
78
+ "mean_torsion" : np .array ([2.22 ], dtype = "f4" ),
79
+ "mean_colors" : np .array ([0 , 1 , 0 ], dtype = "f4" )}
80
+
81
+ elif nb_points == 5 :
82
+ streamline = np .arange (5 * 3 , dtype = "f4" ).reshape ((5 , 3 ))
83
+ data_per_point = {"fa" : np .array ([[0.5 ],
84
+ [0.6 ],
85
+ [0.6 ],
86
+ [0.7 ],
87
+ [0.8 ]], dtype = "f4" ),
88
+ "colors" : np .array ([(0 , 0 , 1 )]* 5 , dtype = "f4" )}
89
+ data_for_streamline = {"mean_curvature" : np .array ([3.11 ], dtype = "f4" ),
90
+ "mean_torsion" : np .array ([3.22 ], dtype = "f4" ),
91
+ "mean_colors" : np .array ([0 , 0 , 1 ], dtype = "f4" )}
92
+
93
+ return streamline , data_per_point , data_for_streamline
94
+
95
+
19
96
def setup ():
20
97
global DATA
21
98
DATA ['rng' ] = np .random .RandomState (1234 )
22
- DATA ['streamlines' ] = [np .arange (1 * 3 , dtype = "f4" ).reshape ((1 , 3 )),
23
- np .arange (2 * 3 , dtype = "f4" ).reshape ((2 , 3 )),
24
- np .arange (5 * 3 , dtype = "f4" ).reshape ((5 , 3 ))]
25
-
26
- DATA ['fa' ] = [np .array ([[0.2 ]], dtype = "f4" ),
27
- np .array ([[0.3 ],
28
- [0.4 ]], dtype = "f4" ),
29
- np .array ([[0.5 ],
30
- [0.6 ],
31
- [0.6 ],
32
- [0.7 ],
33
- [0.8 ]], dtype = "f4" )]
34
-
35
- DATA ['colors' ] = [np .array ([(1 , 0 , 0 )]* 1 , dtype = "f4" ),
36
- np .array ([(0 , 1 , 0 )]* 2 , dtype = "f4" ),
37
- np .array ([(0 , 0 , 1 )]* 5 , dtype = "f4" )]
38
-
39
- DATA ['mean_curvature' ] = [np .array ([1.11 ], dtype = "f4" ),
40
- np .array ([2.11 ], dtype = "f4" ),
41
- np .array ([3.11 ], dtype = "f4" )]
42
-
43
- DATA ['mean_torsion' ] = [np .array ([1.22 ], dtype = "f4" ),
44
- np .array ([2.22 ], dtype = "f4" ),
45
- np .array ([3.22 ], dtype = "f4" )]
46
-
47
- DATA ['mean_colors' ] = [np .array ([1 , 0 , 0 ], dtype = "f4" ),
48
- np .array ([0 , 1 , 0 ], dtype = "f4" ),
49
- np .array ([0 , 0 , 1 ], dtype = "f4" )]
99
+
100
+ DATA ['streamlines' ] = []
101
+ DATA ['fa' ] = []
102
+ DATA ['colors' ] = []
103
+ DATA ['mean_curvature' ] = []
104
+ DATA ['mean_torsion' ] = []
105
+ DATA ['mean_colors' ] = []
106
+ for nb_points in [1 , 2 , 5 ]:
107
+ data = make_dummy_streamline (nb_points )
108
+ streamline , data_per_point , data_for_streamline = data
109
+ DATA ['streamlines' ].append (streamline )
110
+ DATA ['fa' ].append (data_per_point ['fa' ])
111
+ DATA ['colors' ].append (data_per_point ['colors' ])
112
+ DATA ['mean_curvature' ].append (data_for_streamline ['mean_curvature' ])
113
+ DATA ['mean_torsion' ].append (data_for_streamline ['mean_torsion' ])
114
+ DATA ['mean_colors' ].append (data_for_streamline ['mean_colors' ])
50
115
51
116
DATA ['data_per_point' ] = {'colors' : DATA ['colors' ],
52
117
'fa' : DATA ['fa' ]}
@@ -63,17 +128,13 @@ def setup():
63
128
affine_to_rasmm = np .eye (4 ))
64
129
65
130
DATA ['streamlines_func' ] = lambda : (e for e in DATA ['streamlines' ])
66
- fa_func = lambda : (e for e in DATA ['fa' ])
67
- colors_func = lambda : (e for e in DATA ['colors' ])
68
- mean_curvature_func = lambda : (e for e in DATA ['mean_curvature' ])
69
- mean_torsion_func = lambda : (e for e in DATA ['mean_torsion' ])
70
- mean_colors_func = lambda : (e for e in DATA ['mean_colors' ])
71
-
72
- DATA ['data_per_point_func' ] = {'colors' : colors_func ,
73
- 'fa' : fa_func }
74
- DATA ['data_per_streamline_func' ] = {'mean_curvature' : mean_curvature_func ,
75
- 'mean_torsion' : mean_torsion_func ,
76
- 'mean_colors' : mean_colors_func }
131
+ DATA ['data_per_point_func' ] = {
132
+ 'colors' : lambda : (e for e in DATA ['colors' ]),
133
+ 'fa' : lambda : (e for e in DATA ['fa' ])}
134
+ DATA ['data_per_streamline_func' ] = {
135
+ 'mean_curvature' : lambda : (e for e in DATA ['mean_curvature' ]),
136
+ 'mean_torsion' : lambda : (e for e in DATA ['mean_torsion' ]),
137
+ 'mean_colors' : lambda : (e for e in DATA ['mean_colors' ])}
77
138
78
139
DATA ['lazy_tractogram' ] = LazyTractogram (DATA ['streamlines_func' ],
79
140
DATA ['data_per_streamline_func' ],
@@ -130,6 +191,11 @@ def assert_tractogram_equal(t1, t2):
130
191
t2 .data_per_streamline , t2 .data_per_point )
131
192
132
193
194
+ def extender (a , b ):
195
+ a .extend (b )
196
+ return a
197
+
198
+
133
199
class TestPerArrayDict (unittest .TestCase ):
134
200
135
201
def test_per_array_dict_creation (self ):
@@ -181,6 +247,53 @@ def test_getitem(self):
181
247
assert_arrays_equal (sdict [- 1 ][k ], v [- 1 ])
182
248
assert_arrays_equal (sdict [[0 , - 1 ]][k ], v [[0 , - 1 ]])
183
249
250
+ def test_extend (self ):
251
+ sdict = PerArrayDict (len (DATA ['tractogram' ]),
252
+ DATA ['data_per_streamline' ])
253
+
254
+ new_data = {'mean_curvature' : 2 * np .array (DATA ['mean_curvature' ]),
255
+ 'mean_torsion' : 3 * np .array (DATA ['mean_torsion' ]),
256
+ 'mean_colors' : 4 * np .array (DATA ['mean_colors' ])}
257
+ sdict2 = PerArrayDict (len (DATA ['tractogram' ]),
258
+ new_data )
259
+
260
+ sdict .extend (sdict2 )
261
+ assert_equal (len (sdict ), len (sdict2 ))
262
+ for k in DATA ['tractogram' ].data_per_streamline :
263
+ assert_arrays_equal (sdict [k ][:len (DATA ['tractogram' ])],
264
+ DATA ['tractogram' ].data_per_streamline [k ])
265
+ assert_arrays_equal (sdict [k ][len (DATA ['tractogram' ]):],
266
+ new_data [k ])
267
+
268
+ # Extending with an empty PerArrayDict should change nothing.
269
+ sdict_orig = copy .deepcopy (sdict )
270
+ sdict .extend (PerArrayDict ())
271
+ for k in sdict_orig .keys ():
272
+ assert_arrays_equal (sdict [k ], sdict_orig [k ])
273
+
274
+ # Test incompatible PerArrayDicts.
275
+ # Other dict has more entries.
276
+ new_data = {'mean_curvature' : 2 * np .array (DATA ['mean_curvature' ]),
277
+ 'mean_torsion' : 3 * np .array (DATA ['mean_torsion' ]),
278
+ 'mean_colors' : 4 * np .array (DATA ['mean_colors' ]),
279
+ 'other' : 5 * np .array (DATA ['mean_colors' ])}
280
+ sdict2 = PerArrayDict (len (DATA ['tractogram' ]), new_data )
281
+ assert_raises (ValueError , sdict .extend , sdict2 )
282
+
283
+ # Other dict has not the same entries (key mistmached).
284
+ new_data = {'mean_curvature' : 2 * np .array (DATA ['mean_curvature' ]),
285
+ 'mean_torsion' : 3 * np .array (DATA ['mean_torsion' ]),
286
+ 'other' : 4 * np .array (DATA ['mean_colors' ])}
287
+ sdict2 = PerArrayDict (len (DATA ['tractogram' ]), new_data )
288
+ assert_raises (ValueError , sdict .extend , sdict2 )
289
+
290
+ # Other dict has the right number of entries but wrong shape.
291
+ new_data = {'mean_curvature' : 2 * np .array (DATA ['mean_curvature' ]),
292
+ 'mean_torsion' : 3 * np .array (DATA ['mean_torsion' ]),
293
+ 'mean_colors' : 4 * np .array (DATA ['mean_torsion' ])}
294
+ sdict2 = PerArrayDict (len (DATA ['tractogram' ]), new_data )
295
+ assert_raises (ValueError , sdict .extend , sdict2 )
296
+
184
297
185
298
class TestPerArraySequenceDict (unittest .TestCase ):
186
299
@@ -233,6 +346,62 @@ def test_getitem(self):
233
346
assert_arrays_equal (sdict [- 1 ][k ], v [- 1 ])
234
347
assert_arrays_equal (sdict [[0 , - 1 ]][k ], v [[0 , - 1 ]])
235
348
349
+ def test_extend (self ):
350
+ total_nb_rows = DATA ['tractogram' ].streamlines .total_nb_rows
351
+ sdict = PerArraySequenceDict (total_nb_rows , DATA ['data_per_point' ])
352
+
353
+ # Test compatible PerArraySequenceDicts.
354
+ list_nb_points = [2 , 7 , 4 ]
355
+ data_per_point_shapes = {"colors" : DATA ['colors' ][0 ].shape [1 :],
356
+ "fa" : DATA ['fa' ][0 ].shape [1 :]}
357
+ _ , new_data , _ = make_fake_tractogram (list_nb_points ,
358
+ data_per_point_shapes ,
359
+ rng = DATA ['rng' ])
360
+ sdict2 = PerArraySequenceDict (np .sum (list_nb_points ), new_data )
361
+
362
+ sdict .extend (sdict2 )
363
+ assert_equal (len (sdict ), len (sdict2 ))
364
+ for k in DATA ['tractogram' ].data_per_point :
365
+ assert_arrays_equal (sdict [k ][:len (DATA ['tractogram' ])],
366
+ DATA ['tractogram' ].data_per_point [k ])
367
+ assert_arrays_equal (sdict [k ][len (DATA ['tractogram' ]):],
368
+ new_data [k ])
369
+
370
+ # Extending with an empty PerArraySequenceDicts should change nothing.
371
+ sdict_orig = copy .deepcopy (sdict )
372
+ sdict .extend (PerArraySequenceDict ())
373
+ for k in sdict_orig .keys ():
374
+ assert_arrays_equal (sdict [k ], sdict_orig [k ])
375
+
376
+ # Test incompatible PerArraySequenceDicts.
377
+ # Other dict has more entries.
378
+ data_per_point_shapes = {"colors" : DATA ['colors' ][0 ].shape [1 :],
379
+ "fa" : DATA ['fa' ][0 ].shape [1 :],
380
+ "other" : (7 ,)}
381
+ _ , new_data , _ = make_fake_tractogram (list_nb_points ,
382
+ data_per_point_shapes ,
383
+ rng = DATA ['rng' ])
384
+ sdict2 = PerArraySequenceDict (np .sum (list_nb_points ), new_data )
385
+ assert_raises (ValueError , sdict .extend , sdict2 )
386
+
387
+ # Other dict has not the same entries (key mistmached).
388
+ data_per_point_shapes = {"colors" : DATA ['colors' ][0 ].shape [1 :],
389
+ "other" : DATA ['fa' ][0 ].shape [1 :]}
390
+ _ , new_data , _ = make_fake_tractogram (list_nb_points ,
391
+ data_per_point_shapes ,
392
+ rng = DATA ['rng' ])
393
+ sdict2 = PerArraySequenceDict (np .sum (list_nb_points ), new_data )
394
+ assert_raises (ValueError , sdict .extend , sdict2 )
395
+
396
+ # Other dict has the right number of entries but wrong shape.
397
+ data_per_point_shapes = {"colors" : DATA ['colors' ][0 ].shape [1 :],
398
+ "fa" : DATA ['fa' ][0 ].shape [1 :] + (3 ,)}
399
+ _ , new_data , _ = make_fake_tractogram (list_nb_points ,
400
+ data_per_point_shapes ,
401
+ rng = DATA ['rng' ])
402
+ sdict2 = PerArraySequenceDict (np .sum (list_nb_points ), new_data )
403
+ assert_raises (ValueError , sdict .extend , sdict2 )
404
+
236
405
237
406
class TestLazyDict (unittest .TestCase ):
238
407
@@ -570,6 +739,28 @@ def test_tractogram_to_world(self):
570
739
tractogram .affine_to_rasmm = None
571
740
assert_raises (ValueError , tractogram .to_world )
572
741
742
+ def test_tractogram_extend (self ):
743
+ # Load tractogram that contains some metadata.
744
+ t = DATA ['tractogram' ].copy ()
745
+
746
+ for op , in_place in ((operator .add , False ), (operator .iadd , True ),
747
+ (extender , True )):
748
+ first_arg = t .copy ()
749
+ new_t = op (first_arg , t )
750
+ assert_equal (new_t is first_arg , in_place )
751
+ assert_tractogram_equal (new_t [:len (t )], DATA ['tractogram' ])
752
+ assert_tractogram_equal (new_t [len (t ):], DATA ['tractogram' ])
753
+
754
+ # Test extending an empty Tractogram.
755
+ t = Tractogram ()
756
+ t += DATA ['tractogram' ]
757
+ assert_tractogram_equal (t , DATA ['tractogram' ])
758
+
759
+ # and the other way around.
760
+ t = DATA ['tractogram' ].copy ()
761
+ t += Tractogram ()
762
+ assert_tractogram_equal (t , DATA ['tractogram' ])
763
+
573
764
574
765
class TestLazyTractogram (unittest .TestCase ):
575
766
@@ -580,11 +771,12 @@ def test_lazy_tractogram_creation(self):
580
771
# Streamlines and other data as generators
581
772
streamlines = (x for x in DATA ['streamlines' ])
582
773
data_per_point = {"colors" : (x for x in DATA ['colors' ])}
583
- data_per_streamline = {'mean_torsion ' : (x for x in DATA ['mean_torsion' ]),
584
- 'mean_colors ' : (x for x in DATA ['mean_colors' ])}
774
+ data_per_streamline = {'torsion ' : (x for x in DATA ['mean_torsion' ]),
775
+ 'colors ' : (x for x in DATA ['mean_colors' ])}
585
776
586
777
# Creating LazyTractogram with generators is not allowed as
587
- # generators get exhausted and are not reusable unlike generator function.
778
+ # generators get exhausted and are not reusable unlike generator
779
+ # function.
588
780
assert_raises (TypeError , LazyTractogram , streamlines )
589
781
assert_raises (TypeError , LazyTractogram ,
590
782
data_per_streamline = data_per_streamline )
@@ -610,12 +802,11 @@ def test_lazy_tractogram_creation(self):
610
802
611
803
def test_lazy_tractogram_from_data_func (self ):
612
804
# Create an empty `LazyTractogram` yielding nothing.
613
- _empty_data_gen = lambda : iter ([])
614
-
615
- tractogram = LazyTractogram .from_data_func (_empty_data_gen )
805
+ tractogram = LazyTractogram .from_data_func (lambda : iter ([]))
616
806
check_tractogram (tractogram )
617
807
618
- # Create `LazyTractogram` from a generator function yielding TractogramItem.
808
+ # Create `LazyTractogram` from a generator function yielding
809
+ # TractogramItem.
619
810
data = [DATA ['streamlines' ], DATA ['fa' ], DATA ['colors' ],
620
811
DATA ['mean_curvature' ], DATA ['mean_torsion' ],
621
812
DATA ['mean_colors' ]]
@@ -641,6 +832,13 @@ def test_lazy_tractogram_getitem(self):
641
832
assert_raises (NotImplementedError ,
642
833
DATA ['lazy_tractogram' ].__getitem__ , 0 )
643
834
835
+ def test_lazy_tractogram_extend (self ):
836
+ t = DATA ['lazy_tractogram' ].copy ()
837
+ new_t = DATA ['lazy_tractogram' ].copy ()
838
+
839
+ for op in (operator .add , operator .iadd , extender ):
840
+ assert_raises (NotImplementedError , op , new_t , t )
841
+
644
842
def test_lazy_tractogram_len (self ):
645
843
modules = [module_tractogram ] # Modules for which to catch warnings.
646
844
with clear_and_catch_warnings (record = True , modules = modules ) as w :
@@ -746,8 +944,8 @@ def test_lazy_tractogram_copy(self):
746
944
# Check we copied the data and not simply created new references.
747
945
assert_true (tractogram is not DATA ['lazy_tractogram' ])
748
946
749
- # When copying LazyTractogram, the generator function yielding streamlines
750
- # should stay the same.
947
+ # When copying LazyTractogram, the generator function yielding
948
+ # streamlines should stay the same.
751
949
assert_true (tractogram ._streamlines
752
950
is DATA ['lazy_tractogram' ]._streamlines )
753
951
@@ -759,12 +957,14 @@ def test_lazy_tractogram_copy(self):
759
957
is not DATA ['lazy_tractogram' ]._data_per_point )
760
958
761
959
for key in tractogram .data_per_streamline :
762
- assert_true (tractogram .data_per_streamline .store [key ]
763
- is DATA ['lazy_tractogram' ].data_per_streamline .store [key ])
960
+ data = tractogram .data_per_streamline .store [key ]
961
+ expected = DATA ['lazy_tractogram' ].data_per_streamline .store [key ]
962
+ assert_true (data is expected )
764
963
765
964
for key in tractogram .data_per_point :
766
- assert_true (tractogram .data_per_point .store [key ]
767
- is DATA ['lazy_tractogram' ].data_per_point .store [key ])
965
+ data = tractogram .data_per_point .store [key ]
966
+ expected = DATA ['lazy_tractogram' ].data_per_point .store [key ]
967
+ assert_true (data is expected )
768
968
769
969
# The affine should be a copy.
770
970
assert_true (tractogram ._affine_to_apply
0 commit comments