Skip to content

BF: #596 First point of LazyTractogram is not saved #588

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 12, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions nibabel/streamlines/tck.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .tractogram_file import HeaderError, DataError
from .tractogram import TractogramItem, Tractogram, LazyTractogram
from .header import Field
from .utils import peek_next

MEGABYTE = 1024 * 1024

Expand Down Expand Up @@ -191,8 +192,16 @@ def save(self, fileobj):
# Write temporary header that we will update at the end
self._write_header(f, header)

# Make sure streamlines are in rasmm.
tractogram = self.tractogram.to_world(lazy=True)
# Assume looping over the streamlines can be done only once.
tractogram = iter(tractogram)

try:
first_item = next(iter(self.tractogram))
# Use the first element to check
# 1) the tractogram is not empty;
# 2) quantity of information saved along each streamline.
first_item, tractogram = peek_next(tractogram)
except StopIteration:
# Empty tractogram
header[Field.NB_STREAMLINES] = 0
Expand All @@ -216,9 +225,6 @@ def save(self, fileobj):
" alongside points. Dropping: {}".format(keys))
warnings.warn(msg, DataWarning)

# Make sure streamlines are in rasmm.
tractogram = self.tractogram.to_world(lazy=True)

for t in tractogram:
data = np.r_[t.streamline, self.FIBER_DELIMITER]
f.write(data.astype(dtype).tostring())
Expand Down
5 changes: 5 additions & 0 deletions nibabel/streamlines/tests/test_array_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,11 @@ def test_arraysequence_append(self):
seq.append(element)
check_arr_seq(seq, [element])

# Append an empty array.
seq = SEQ_DATA['seq'].copy() # Copy because of in-place modification.
seq.append([])
check_arr_seq(seq, SEQ_DATA['seq'])

# Append an element with different shape.
element = generate_data(nb_arrays=1,
common_shape=SEQ_DATA['seq'].common_shape*2,
Expand Down
16 changes: 16 additions & 0 deletions nibabel/streamlines/tests/test_streamlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,3 +272,19 @@ def test_load_unknown_format(self):

def test_save_unknown_format(self):
assert_raises(ValueError, nib.streamlines.save, Tractogram(), "")

def test_save_from_generator(self):
tractogram = Tractogram(DATA['streamlines'],
affine_to_rasmm=np.eye(4))

# Just to create a generator
for ext, _ in FORMATS.items():
filtered = (s for s in tractogram.streamlines if True)
lazy_tractogram = LazyTractogram(lambda: filtered,
affine_to_rasmm=np.eye(4))

with InTemporaryDirectory():
filename = 'streamlines' + ext
nib.streamlines.save(lazy_tractogram, filename)
tfile = nib.streamlines.load(filename, lazy_load=False)
assert_tractogram_equal(tfile.tractogram, tractogram)
44 changes: 37 additions & 7 deletions nibabel/streamlines/tests/test_tractogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from six.moves import zip

from .. import tractogram as module_tractogram
from ..tractogram import is_data_dict, is_lazy_dict
from ..tractogram import TractogramItem, Tractogram, LazyTractogram
from ..tractogram import PerArrayDict, PerArraySequenceDict, LazyDict

Expand Down Expand Up @@ -406,14 +407,21 @@ def test_extend(self):
class TestLazyDict(unittest.TestCase):

def test_lazydict_creation(self):
data_dict = LazyDict(DATA['data_per_streamline_func'])
assert_equal(data_dict.keys(), DATA['data_per_streamline_func'].keys())
for k in data_dict.keys():
assert_array_equal(list(data_dict[k]),
list(DATA['data_per_streamline'][k]))
# Different ways of creating LazyDict
lazy_dicts = []
lazy_dicts += [LazyDict(DATA['data_per_streamline_func'])]
lazy_dicts += [LazyDict(**DATA['data_per_streamline_func'])]

assert_equal(len(data_dict),
len(DATA['data_per_streamline_func']))
expected_keys = DATA['data_per_streamline_func'].keys()
for data_dict in lazy_dicts:
assert_true(is_lazy_dict(data_dict))
assert_equal(data_dict.keys(), expected_keys)
for k in data_dict.keys():
assert_array_equal(list(data_dict[k]),
list(DATA['data_per_streamline'][k]))

assert_equal(len(data_dict),
len(DATA['data_per_streamline_func']))


class TestTractogramItem(unittest.TestCase):
Expand Down Expand Up @@ -470,6 +478,9 @@ def test_tractogram_creation(self):
DATA['data_per_streamline'],
DATA['data_per_point'])

assert_true(is_data_dict(tractogram.data_per_streamline))
assert_true(is_data_dict(tractogram.data_per_point))

# Create a tractogram from another tractogram attributes.
tractogram2 = Tractogram(tractogram.streamlines,
tractogram.data_per_streamline,
Expand Down Expand Up @@ -795,6 +806,9 @@ def test_lazy_tractogram_creation(self):
DATA['data_per_streamline_func'],
DATA['data_per_point_func'])

assert_true(is_lazy_dict(tractogram.data_per_streamline))
assert_true(is_lazy_dict(tractogram.data_per_point))

[t for t in tractogram] # Force iteration through tractogram.
assert_equal(len(tractogram), len(DATA['streamlines']))

Expand Down Expand Up @@ -910,6 +924,22 @@ def test_lazy_tractogram_apply_affine(self):
tractogram.affine_to_rasmm = None
assert_raises(ValueError, tractogram.to_world)

# But calling apply_affine when affine_to_rasmm is None should work.
tractogram = DATA['lazy_tractogram'].copy()
tractogram.affine_to_rasmm = None
transformed_tractogram = tractogram.apply_affine(affine)
assert_array_equal(transformed_tractogram._affine_to_apply, affine)
assert_true(transformed_tractogram.affine_to_rasmm is None)
check_tractogram(transformed_tractogram,
streamlines=[s*scaling for s in DATA['streamlines']],
data_per_streamline=DATA['data_per_streamline'],
data_per_point=DATA['data_per_point'])

# Calling apply_affine with lazy=False should fail for LazyTractogram.
tractogram = DATA['lazy_tractogram'].copy()
assert_raises(ValueError, tractogram.apply_affine,
affine=np.eye(4), lazy=False)

def test_tractogram_to_world(self):
tractogram = DATA['lazy_tractogram'].copy()
affine = np.random.RandomState(1234).randn(4, 4)
Expand Down
7 changes: 2 additions & 5 deletions nibabel/streamlines/tractogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def is_data_dict(obj):

def is_lazy_dict(obj):
""" True if `obj` seems to implement the :class:`LazyDict` API """
return is_data_dict(obj) and callable(obj.store.values()[0])
return is_data_dict(obj) and callable(list(obj.store.values())[0])


class SliceableDataDict(collections.MutableMapping):
Expand Down Expand Up @@ -77,7 +77,7 @@ class PerArrayDict(SliceableDataDict):

This container behaves like a standard dictionary but extends key access to
allow keys for key access to be indices slicing into the contained ndarray
values. The elements must also be ndarrays.
values. The elements must also be ndarrays.

In addition, it makes sure the amount of data contained in those ndarrays
matches the number of streamlines given at the instantiation of this
Expand Down Expand Up @@ -200,9 +200,6 @@ def __init__(self, *args, **kwargs):
self.update(**args[0].store) # Copy the generator functions.
return

if isinstance(args[0], SliceableDataDict):
self.update(**args[0])

self.update(dict(*args, **kwargs))

def __getitem__(self, key):
Expand Down
25 changes: 18 additions & 7 deletions nibabel/streamlines/trk.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

import os
import struct
import warnings
import string
import warnings

import numpy as np
import nibabel as nib
Expand All @@ -21,6 +21,7 @@
from .tractogram_file import DataError, HeaderError, HeaderWarning
from .tractogram import TractogramItem, Tractogram, LazyTractogram
from .header import Field
from .utils import peek_next


MAX_NB_NAMED_SCALARS_PER_POINT = 10
Expand Down Expand Up @@ -423,8 +424,23 @@ def save(self, fileobj):
i4_dtype = np.dtype("<i4") # Always save in little-endian.
f4_dtype = np.dtype("<f4") # Always save in little-endian.

# Since the TRK format requires the streamlines to be saved in
# voxmm, we first transform them accordingly. The transformation
# is performed lazily since `self.tractogram` might be a
# LazyTractogram object, which means we might be able to loop
# over the streamlines only once.
tractogram = self.tractogram.to_world(lazy=True)
affine_to_trackvis = get_affine_rasmm_to_trackvis(header)
tractogram = tractogram.apply_affine(affine_to_trackvis, lazy=True)

# Create the iterator we'll be using for the rest of the funciton.
tractogram = iter(tractogram)

try:
first_item = next(iter(self.tractogram))
# Use the first element to check
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you could refactor this out into:

def peek_next(iterable):
    next_item = next(iterable)
    return next_item, itertools.chain([next_item], iterable)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. Will do.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

# 1) the tractogram is not empty;
# 2) quantity of information saved along each streamline.
first_item, tractogram = peek_next(tractogram)
except StopIteration:
# Empty tractogram
header[Field.NB_STREAMLINES] = 0
Expand Down Expand Up @@ -470,11 +486,6 @@ def save(self, fileobj):
scalar_name[i] = encode_value_in_name(nb_values, name)
header['scalar_name'][:] = scalar_name

# Make sure streamlines are in rasmm then send them to voxmm.
tractogram = self.tractogram.to_world(lazy=True)
affine_to_trackvis = get_affine_rasmm_to_trackvis(header)
tractogram = tractogram.apply_affine(affine_to_trackvis, lazy=True)

for t in tractogram:
if any((len(d) != len(t.streamline)
for d in t.data_for_points.values())):
Expand Down
21 changes: 21 additions & 0 deletions nibabel/streamlines/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import itertools

import nibabel


Expand Down Expand Up @@ -29,3 +31,22 @@ def get_affine_from_reference(ref):

# Assume `ref` is the name of a neuroimaging file.
return nibabel.load(ref).affine


def peek_next(iterable):
""" Peek next element of iterable.

Parameters
----------
iterable
Iterable to peek the next element from.

Returns
-------
next_item
Element peeked from `iterable`.
new_iterable
Iterable behaving like if the original `iterable` was untouched.
"""
next_item = next(iterable)
return next_item, itertools.chain([next_item], iterable)