Skip to content

Commit f824222

Browse files
Merge pull request #588 from MarcCote/bf_586
MRG: fix #596 First point of `LazyTractogram` is not saved With this PR, we now assume tractogram objects can be iterated over only once at saving time. This PR fixes issue #586. Thanks to @nilgoyette for providing the missing unit test and part of the solution
2 parents 6d54618 + 94c5390 commit f824222

File tree

7 files changed

+109
-23
lines changed

7 files changed

+109
-23
lines changed

nibabel/streamlines/tck.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .tractogram_file import HeaderError, DataError
2020
from .tractogram import TractogramItem, Tractogram, LazyTractogram
2121
from .header import Field
22+
from .utils import peek_next
2223

2324
MEGABYTE = 1024 * 1024
2425

@@ -191,8 +192,16 @@ def save(self, fileobj):
191192
# Write temporary header that we will update at the end
192193
self._write_header(f, header)
193194

195+
# Make sure streamlines are in rasmm.
196+
tractogram = self.tractogram.to_world(lazy=True)
197+
# Assume looping over the streamlines can be done only once.
198+
tractogram = iter(tractogram)
199+
194200
try:
195-
first_item = next(iter(self.tractogram))
201+
# Use the first element to check
202+
# 1) the tractogram is not empty;
203+
# 2) quantity of information saved along each streamline.
204+
first_item, tractogram = peek_next(tractogram)
196205
except StopIteration:
197206
# Empty tractogram
198207
header[Field.NB_STREAMLINES] = 0
@@ -216,9 +225,6 @@ def save(self, fileobj):
216225
" alongside points. Dropping: {}".format(keys))
217226
warnings.warn(msg, DataWarning)
218227

219-
# Make sure streamlines are in rasmm.
220-
tractogram = self.tractogram.to_world(lazy=True)
221-
222228
for t in tractogram:
223229
data = np.r_[t.streamline, self.FIBER_DELIMITER]
224230
f.write(data.astype(dtype).tostring())

nibabel/streamlines/tests/test_array_sequence.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,11 @@ def test_arraysequence_append(self):
152152
seq.append(element)
153153
check_arr_seq(seq, [element])
154154

155+
# Append an empty array.
156+
seq = SEQ_DATA['seq'].copy() # Copy because of in-place modification.
157+
seq.append([])
158+
check_arr_seq(seq, SEQ_DATA['seq'])
159+
155160
# Append an element with different shape.
156161
element = generate_data(nb_arrays=1,
157162
common_shape=SEQ_DATA['seq'].common_shape*2,

nibabel/streamlines/tests/test_streamlines.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,3 +272,19 @@ def test_load_unknown_format(self):
272272

273273
def test_save_unknown_format(self):
274274
assert_raises(ValueError, nib.streamlines.save, Tractogram(), "")
275+
276+
def test_save_from_generator(self):
277+
tractogram = Tractogram(DATA['streamlines'],
278+
affine_to_rasmm=np.eye(4))
279+
280+
# Just to create a generator
281+
for ext, _ in FORMATS.items():
282+
filtered = (s for s in tractogram.streamlines if True)
283+
lazy_tractogram = LazyTractogram(lambda: filtered,
284+
affine_to_rasmm=np.eye(4))
285+
286+
with InTemporaryDirectory():
287+
filename = 'streamlines' + ext
288+
nib.streamlines.save(lazy_tractogram, filename)
289+
tfile = nib.streamlines.load(filename, lazy_load=False)
290+
assert_tractogram_equal(tfile.tractogram, tractogram)

nibabel/streamlines/tests/test_tractogram.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from six.moves import zip
1414

1515
from .. import tractogram as module_tractogram
16+
from ..tractogram import is_data_dict, is_lazy_dict
1617
from ..tractogram import TractogramItem, Tractogram, LazyTractogram
1718
from ..tractogram import PerArrayDict, PerArraySequenceDict, LazyDict
1819

@@ -406,14 +407,21 @@ def test_extend(self):
406407
class TestLazyDict(unittest.TestCase):
407408

408409
def test_lazydict_creation(self):
409-
data_dict = LazyDict(DATA['data_per_streamline_func'])
410-
assert_equal(data_dict.keys(), DATA['data_per_streamline_func'].keys())
411-
for k in data_dict.keys():
412-
assert_array_equal(list(data_dict[k]),
413-
list(DATA['data_per_streamline'][k]))
410+
# Different ways of creating LazyDict
411+
lazy_dicts = []
412+
lazy_dicts += [LazyDict(DATA['data_per_streamline_func'])]
413+
lazy_dicts += [LazyDict(**DATA['data_per_streamline_func'])]
414414

415-
assert_equal(len(data_dict),
416-
len(DATA['data_per_streamline_func']))
415+
expected_keys = DATA['data_per_streamline_func'].keys()
416+
for data_dict in lazy_dicts:
417+
assert_true(is_lazy_dict(data_dict))
418+
assert_equal(data_dict.keys(), expected_keys)
419+
for k in data_dict.keys():
420+
assert_array_equal(list(data_dict[k]),
421+
list(DATA['data_per_streamline'][k]))
422+
423+
assert_equal(len(data_dict),
424+
len(DATA['data_per_streamline_func']))
417425

418426

419427
class TestTractogramItem(unittest.TestCase):
@@ -470,6 +478,9 @@ def test_tractogram_creation(self):
470478
DATA['data_per_streamline'],
471479
DATA['data_per_point'])
472480

481+
assert_true(is_data_dict(tractogram.data_per_streamline))
482+
assert_true(is_data_dict(tractogram.data_per_point))
483+
473484
# Create a tractogram from another tractogram attributes.
474485
tractogram2 = Tractogram(tractogram.streamlines,
475486
tractogram.data_per_streamline,
@@ -795,6 +806,9 @@ def test_lazy_tractogram_creation(self):
795806
DATA['data_per_streamline_func'],
796807
DATA['data_per_point_func'])
797808

809+
assert_true(is_lazy_dict(tractogram.data_per_streamline))
810+
assert_true(is_lazy_dict(tractogram.data_per_point))
811+
798812
[t for t in tractogram] # Force iteration through tractogram.
799813
assert_equal(len(tractogram), len(DATA['streamlines']))
800814

@@ -910,6 +924,22 @@ def test_lazy_tractogram_apply_affine(self):
910924
tractogram.affine_to_rasmm = None
911925
assert_raises(ValueError, tractogram.to_world)
912926

927+
# But calling apply_affine when affine_to_rasmm is None should work.
928+
tractogram = DATA['lazy_tractogram'].copy()
929+
tractogram.affine_to_rasmm = None
930+
transformed_tractogram = tractogram.apply_affine(affine)
931+
assert_array_equal(transformed_tractogram._affine_to_apply, affine)
932+
assert_true(transformed_tractogram.affine_to_rasmm is None)
933+
check_tractogram(transformed_tractogram,
934+
streamlines=[s*scaling for s in DATA['streamlines']],
935+
data_per_streamline=DATA['data_per_streamline'],
936+
data_per_point=DATA['data_per_point'])
937+
938+
# Calling apply_affine with lazy=False should fail for LazyTractogram.
939+
tractogram = DATA['lazy_tractogram'].copy()
940+
assert_raises(ValueError, tractogram.apply_affine,
941+
affine=np.eye(4), lazy=False)
942+
913943
def test_tractogram_to_world(self):
914944
tractogram = DATA['lazy_tractogram'].copy()
915945
affine = np.random.RandomState(1234).randn(4, 4)

nibabel/streamlines/tractogram.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def is_data_dict(obj):
1616

1717
def is_lazy_dict(obj):
1818
""" True if `obj` seems to implement the :class:`LazyDict` API """
19-
return is_data_dict(obj) and callable(obj.store.values()[0])
19+
return is_data_dict(obj) and callable(list(obj.store.values())[0])
2020

2121

2222
class SliceableDataDict(collections.MutableMapping):
@@ -77,7 +77,7 @@ class PerArrayDict(SliceableDataDict):
7777
7878
This container behaves like a standard dictionary but extends key access to
7979
allow keys for key access to be indices slicing into the contained ndarray
80-
values. The elements must also be ndarrays.
80+
values. The elements must also be ndarrays.
8181
8282
In addition, it makes sure the amount of data contained in those ndarrays
8383
matches the number of streamlines given at the instantiation of this
@@ -200,9 +200,6 @@ def __init__(self, *args, **kwargs):
200200
self.update(**args[0].store) # Copy the generator functions.
201201
return
202202

203-
if isinstance(args[0], SliceableDataDict):
204-
self.update(**args[0])
205-
206203
self.update(dict(*args, **kwargs))
207204

208205
def __getitem__(self, key):

nibabel/streamlines/trk.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
import os
77
import struct
8-
import warnings
98
import string
9+
import warnings
1010

1111
import numpy as np
1212
import nibabel as nib
@@ -21,6 +21,7 @@
2121
from .tractogram_file import DataError, HeaderError, HeaderWarning
2222
from .tractogram import TractogramItem, Tractogram, LazyTractogram
2323
from .header import Field
24+
from .utils import peek_next
2425

2526

2627
MAX_NB_NAMED_SCALARS_PER_POINT = 10
@@ -423,8 +424,23 @@ def save(self, fileobj):
423424
i4_dtype = np.dtype("<i4") # Always save in little-endian.
424425
f4_dtype = np.dtype("<f4") # Always save in little-endian.
425426

427+
# Since the TRK format requires the streamlines to be saved in
428+
# voxmm, we first transform them accordingly. The transformation
429+
# is performed lazily since `self.tractogram` might be a
430+
# LazyTractogram object, which means we might be able to loop
431+
# over the streamlines only once.
432+
tractogram = self.tractogram.to_world(lazy=True)
433+
affine_to_trackvis = get_affine_rasmm_to_trackvis(header)
434+
tractogram = tractogram.apply_affine(affine_to_trackvis, lazy=True)
435+
436+
# Create the iterator we'll be using for the rest of the funciton.
437+
tractogram = iter(tractogram)
438+
426439
try:
427-
first_item = next(iter(self.tractogram))
440+
# Use the first element to check
441+
# 1) the tractogram is not empty;
442+
# 2) quantity of information saved along each streamline.
443+
first_item, tractogram = peek_next(tractogram)
428444
except StopIteration:
429445
# Empty tractogram
430446
header[Field.NB_STREAMLINES] = 0
@@ -470,11 +486,6 @@ def save(self, fileobj):
470486
scalar_name[i] = encode_value_in_name(nb_values, name)
471487
header['scalar_name'][:] = scalar_name
472488

473-
# Make sure streamlines are in rasmm then send them to voxmm.
474-
tractogram = self.tractogram.to_world(lazy=True)
475-
affine_to_trackvis = get_affine_rasmm_to_trackvis(header)
476-
tractogram = tractogram.apply_affine(affine_to_trackvis, lazy=True)
477-
478489
for t in tractogram:
479490
if any((len(d) != len(t.streamline)
480491
for d in t.data_for_points.values())):

nibabel/streamlines/utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import itertools
2+
13
import nibabel
24

35

@@ -29,3 +31,22 @@ def get_affine_from_reference(ref):
2931

3032
# Assume `ref` is the name of a neuroimaging file.
3133
return nibabel.load(ref).affine
34+
35+
36+
def peek_next(iterable):
37+
""" Peek next element of iterable.
38+
39+
Parameters
40+
----------
41+
iterable
42+
Iterable to peek the next element from.
43+
44+
Returns
45+
-------
46+
next_item
47+
Element peeked from `iterable`.
48+
new_iterable
49+
Iterable behaving like if the original `iterable` was untouched.
50+
"""
51+
next_item = next(iterable)
52+
return next_item, itertools.chain([next_item], iterable)

0 commit comments

Comments
 (0)