diff --git a/Changelog b/Changelog index 8a7f93b687..a1e6bdf299 100644 --- a/Changelog +++ b/Changelog @@ -36,6 +36,9 @@ References like "pr/298" refer to github pull request numbers. are raising a DataError if the track is truncated when ``strict=True`` (the default), rather than a TypeError when trying to create the points array. + * New API for managing streamlines and their different file formats. This + adds a new module ``nibabel.streamlines`` that will eventually deprecate + the current trackvis reader found in ``nibabel.trackvis``. * 2.0.2 (Monday 23 November 2015) @@ -251,7 +254,7 @@ References like "pr/298" refer to github pull request numbers. the ability to transform to the image with data closest to the cononical image orientation (first axis left-to-right, second back-to-front, third down-to-up) (MB, Jonathan Taylor) - * Gifti format read and write support (preliminary) (Stephen Gerhard) + * Gifti format read and write support (preliminary) (Stephen Gerhard) * Added utilities to use nipy-style data packages, by rip then edit of nipy data package code (MB) * Some improvements to release support (Jarrod Millman, MB, Fernando Perez) @@ -469,7 +472,7 @@ visiting the URL:: * Removed functionality for "NiftiImage.save() raises an IOError exception when writing the image file fails." (Yaroslav Halchenko) - * Added ability to force a filetype when setting the filename or saving + * Added ability to force a filetype when setting the filename or saving a file. * Reverse the order of the 'header' and 'load' argument in the NiftiImage constructor. 'header' is now first as it seems to be used more often. @@ -481,7 +484,7 @@ visiting the URL:: * 0.20070301.2 (Thu, 1 Mar 2007) - * Fixed wrong link to the source tarball in README.html. + * Fixed wrong link to the source tarball in README.html. * 0.20070301.1 (Thu, 1 Mar 2007) diff --git a/nibabel/__init__.py b/nibabel/__init__.py index 4d8791d7d9..91ccace44b 100644 --- a/nibabel/__init__.py +++ b/nibabel/__init__.py @@ -64,6 +64,7 @@ from .imageclasses import class_map, ext_map, all_image_classes from . import trackvis from . import mriutils +from . import streamlines from . import viewers # be friendly on systems with ancient numpy -- no tests, but at least diff --git a/nibabel/benchmarks/bench_streamlines.py b/nibabel/benchmarks/bench_streamlines.py new file mode 100644 index 0000000000..95ba79fb61 --- /dev/null +++ b/nibabel/benchmarks/bench_streamlines.py @@ -0,0 +1,95 @@ +""" Benchmarks for load and save of streamlines + +Run benchmarks with:: + + import nibabel as nib + nib.bench() + +If you have doctests enabled by default in nose (with a noserc file or +environment variable), and you have a numpy version <= 1.6.1, this will also run +the doctests, let's hope they pass. + +Run this benchmark with: + + nosetests -s --match '(?:^|[\\b_\\.//-])[Bb]ench' /path/to/bench_streamlines.py +""" +from __future__ import division, print_function + +import numpy as np + +from nibabel.externals.six.moves import zip +from nibabel.tmpdirs import InTemporaryDirectory + +from numpy.testing import assert_array_equal +from nibabel.streamlines import Tractogram +from nibabel.streamlines import TrkFile + +import nibabel as nib +import nibabel.trackvis as tv + +from numpy.testing import measure + + +def bench_load_trk(): + rng = np.random.RandomState(42) + dtype = 'float32' + NB_STREAMLINES = 5000 + NB_POINTS = 1000 + points = [rng.rand(NB_POINTS, 3).astype(dtype) + for i in range(NB_STREAMLINES)] + scalars = [rng.rand(NB_POINTS, 10).astype(dtype) + for i in range(NB_STREAMLINES)] + + repeat = 10 + + with InTemporaryDirectory(): + trk_file = "tmp.trk" + tractogram = Tractogram(points, affine_to_rasmm=np.eye(4)) + TrkFile(tractogram).save(trk_file) + + streamlines_old = [d[0] - 0.5 + for d in tv.read(trk_file, points_space="rasmm")[0]] + mtime_old = measure('tv.read(trk_file, points_space="rasmm")', repeat) + print("Old: Loaded {:,} streamlines in {:6.2f}".format(NB_STREAMLINES, + mtime_old)) + + trk = nib.streamlines.load(trk_file, lazy_load=False) + streamlines_new = trk.streamlines + mtime_new = measure('nib.streamlines.load(trk_file, lazy_load=False)', + repeat) + print("\nNew: Loaded {:,} streamlines in {:6.2}".format(NB_STREAMLINES, + mtime_new)) + print("Speedup of {:.2f}".format(mtime_old / mtime_new)) + for s1, s2 in zip(streamlines_new, streamlines_old): + assert_array_equal(s1, s2) + + # Points and scalars + with InTemporaryDirectory(): + + trk_file = "tmp.trk" + tractogram = Tractogram(points, + data_per_point={'scalars': scalars}, + affine_to_rasmm=np.eye(4)) + TrkFile(tractogram).save(trk_file) + + streamlines_old = [d[0] - 0.5 + for d in tv.read(trk_file, points_space="rasmm")[0]] + + scalars_old = [d[1] + for d in tv.read(trk_file, points_space="rasmm")[0]] + mtime_old = measure('tv.read(trk_file, points_space="rasmm")', repeat) + msg = "Old: Loaded {:,} streamlines with scalars in {:6.2f}" + print(msg.format(NB_STREAMLINES, mtime_old)) + + trk = nib.streamlines.load(trk_file, lazy_load=False) + scalars_new = trk.tractogram.data_per_point['scalars'] + mtime_new = measure('nib.streamlines.load(trk_file, lazy_load=False)', + repeat) + msg = "New: Loaded {:,} streamlines with scalars in {:6.2f}" + print(msg.format(NB_STREAMLINES, mtime_new)) + print("Speedup of {:2f}".format(mtime_old / mtime_new)) + for s1, s2 in zip(scalars_new, scalars_old): + assert_array_equal(s1, s2) + +if __name__ == '__main__': + bench_load_trk() diff --git a/nibabel/externals/six.py b/nibabel/externals/six.py index eae31454ae..b23166effd 100644 --- a/nibabel/externals/six.py +++ b/nibabel/externals/six.py @@ -143,6 +143,7 @@ class _MovedItems(types.ModuleType): MovedAttribute("StringIO", "StringIO", "io"), MovedAttribute("xrange", "__builtin__", "builtins", "xrange", "range"), MovedAttribute("zip", "itertools", "builtins", "izip", "zip"), + MovedAttribute("zip_longest", "itertools", "itertools", "izip_longest", "zip_longest"), MovedModule("builtins", "__builtin__"), MovedModule("configparser", "ConfigParser"), diff --git a/nibabel/streamlines/__init__.py b/nibabel/streamlines/__init__.py new file mode 100644 index 0000000000..124bdf8f50 --- /dev/null +++ b/nibabel/streamlines/__init__.py @@ -0,0 +1,133 @@ +import os +import warnings +from ..externals.six import string_types + +from .header import Field +from .array_sequence import ArraySequence +from .tractogram import Tractogram, LazyTractogram +from .tractogram_file import ExtensionWarning + +from .trk import TrkFile + +# List of all supported formats +FORMATS = {".trk": TrkFile} + + +def is_supported(fileobj): + """ Checks if the file-like object if supported by NiBabel. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object pointing + to a streamlines file (and ready to read from the beginning of the + header) + + Returns + ------- + is_supported : boolean + """ + return detect_format(fileobj) is not None + + +def detect_format(fileobj): + """ Returns the StreamlinesFile object guessed from the file-like object. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object pointing + to a tractogram file (and ready to read from the beginning of the + header) + + Returns + ------- + tractogram_file : :class:`TractogramFile` class + The class type guessed from the content of `fileobj`. + """ + for format in FORMATS.values(): + try: + if format.is_correct_format(fileobj): + return format + except IOError: + pass + + if isinstance(fileobj, string_types): + _, ext = os.path.splitext(fileobj) + return FORMATS.get(ext.lower()) + + return None + + +def load(fileobj, lazy_load=False): + """ Loads streamlines in *RAS+* and *mm* space from a file-like object. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to a streamlines file (and ready to read from the beginning + of the streamlines file's header). + lazy_load : {False, True}, optional + If True, load streamlines in a lazy manner i.e. they will not be kept + in memory and only be loaded when needed. + Otherwise, load all streamlines in memory. + + Returns + ------- + tractogram_file : :class:`TractogramFile` object + Returns an instance of a :class:`TractogramFile` containing data and + metadata of the tractogram loaded from `fileobj`. + + Notes + ----- + The streamline coordinate (0,0,0) refers to the center of the voxel. + """ + tractogram_file = detect_format(fileobj) + + if tractogram_file is None: + raise ValueError("Unknown format for 'fileobj': {}".format(fileobj)) + + return tractogram_file.load(fileobj, lazy_load=lazy_load) + + +def save(tractogram, filename, **kwargs): + """ Saves a tractogram to a file. + + Parameters + ---------- + tractogram : :class:`Tractogram` object or :class:`TractogramFile` object + If :class:`Tractogram` object, the file format will be guessed from + `filename` and a :class:`TractogramFile` object will be created using + provided keyword arguments. + If :class:`TractogramFile` object, the file format is known and will + be used to save its content to `filename`. + filename : str + Name of the file where the tractogram will be saved. + \*\*kwargs : keyword arguments + Keyword arguments passed to :class:`TractogramFile` constructor. + Should not be specified if `tractogram` is already an instance of + :class:`TractogramFile`. + """ + tractogram_file_class = detect_format(filename) + if isinstance(tractogram, Tractogram): + if tractogram_file_class is None: + msg = "Unknown tractogram file format: '{}'".format(filename) + raise ValueError(msg) + + tractogram_file = tractogram_file_class(tractogram, **kwargs) + + else: # Assume it's a TractogramFile object. + tractogram_file = tractogram + if (tractogram_file_class is None or + not isinstance(tractogram_file, tractogram_file_class)): + msg = ("The extension you specified is unusual for the provided" + " 'TractogramFile' object.") + warnings.warn(msg, ExtensionWarning) + + if len(kwargs) > 0: + msg = ("A 'TractogramFile' object was provided, no need for" + " keyword arguments.") + raise ValueError(msg) + + tractogram_file.save(filename) diff --git a/nibabel/streamlines/array_sequence.py b/nibabel/streamlines/array_sequence.py new file mode 100644 index 0000000000..b41ceb0f90 --- /dev/null +++ b/nibabel/streamlines/array_sequence.py @@ -0,0 +1,382 @@ +from __future__ import division + +import numbers +from operator import mul +from functools import reduce + +import numpy as np + +MEGABYTE = 1024 * 1024 + + +def is_array_sequence(obj): + """ Return True if `obj` is an array sequence. """ + try: + return obj.is_array_sequence + except AttributeError: + return False + + +def is_ndarray_of_int_or_bool(obj): + return (isinstance(obj, np.ndarray) and + (np.issubdtype(obj.dtype, np.integer) or + np.issubdtype(obj.dtype, np.bool))) + + +class _BuildCache(object): + def __init__(self, arr_seq, common_shape, dtype): + self.offsets = list(arr_seq._offsets) + self.lengths = list(arr_seq._lengths) + self.next_offset = arr_seq._get_next_offset() + self.bytes_per_buf = arr_seq._buffer_size * MEGABYTE + # Use the passed dtype only if null data array + self.dtype = dtype if arr_seq._data.size == 0 else arr_seq._data.dtype + if arr_seq.common_shape != () and common_shape != arr_seq.common_shape: + raise ValueError( + "All dimensions, except the first one, must match exactly") + self.common_shape = common_shape + n_in_row = reduce(mul, common_shape, 1) + bytes_per_row = n_in_row * dtype.itemsize + self.rows_per_buf = bytes_per_row / self.bytes_per_buf + + def update_seq(self, arr_seq): + arr_seq._offsets = np.array(self.offsets) + arr_seq._lengths = np.array(self.lengths) + + +class ArraySequence(object): + """ Sequence of ndarrays having variable first dimension sizes. + + This is a container that can store multiple ndarrays where each ndarray + might have a different first dimension size but a *common* size for the + remaining dimensions. + + More generally, an instance of :class:`ArraySequence` of length $N$ is + composed of $N$ ndarrays of shape $(d_1, d_2, ... d_D)$ where $d_1$ + can vary in length between arrays but $(d_2, ..., d_D)$ have to be the + same for every ndarray. + """ + + def __init__(self, iterable=None, buffer_size=4): + """ Initialize array sequence instance + + Parameters + ---------- + iterable : None or iterable or :class:`ArraySequence`, optional + If None, create an empty :class:`ArraySequence` object. + If iterable, create a :class:`ArraySequence` object initialized + from array-like objects yielded by the iterable. + If :class:`ArraySequence`, create a view (no memory is allocated). + For an actual copy use :meth:`.copy` instead. + buffer_size : float, optional + Size (in Mb) for memory allocation when `iterable` is a generator. + """ + # Create new empty `ArraySequence` object. + self._is_view = False + self._data = np.array([]) + self._offsets = np.array([], dtype=np.intp) + self._lengths = np.array([], dtype=np.intp) + self._buffer_size = buffer_size + self._build_cache = None + + if iterable is None: + return + + if is_array_sequence(iterable): + # Create a view. + self._data = iterable._data + self._offsets = iterable._offsets + self._lengths = iterable._lengths + self._is_view = True + return + + self.extend(iterable) + + @property + def is_array_sequence(self): + return True + + @property + def common_shape(self): + """ Matching shape of the elements in this array sequence. """ + return self._data.shape[1:] + + @property + def total_nb_rows(self): + """ Total number of rows in this array sequence. """ + return np.sum(self._lengths) + + @property + def data(self): + """ Elements in this array sequence. """ + return self._data + + def _get_next_offset(self): + """ Offset in ``self._data`` at which to write next rowelement """ + if len(self._offsets) == 0: + return 0 + imax = np.argmax(self._offsets) + return self._offsets[imax] + self._lengths[imax] + + def append(self, element, cache_build=False): + """ Appends `element` to this array sequence. + + Append can be a lot faster if it knows that it is appending several + elements instead of a single element. In that case it can cache the + parameters it uses between append operations, in a "build cache". To + tell append to do this, use ``cache_build=True``. If you use + ``cache_build=True``, you need to finalize the append operations with + :method:`finalize_append`. + + Parameters + ---------- + element : ndarray + Element to append. The shape must match already inserted elements + shape except for the first dimension. + cache_build : {False, True} + Whether to save the build cache from this append routine. If True, + append can assume it is the only player updating `self`, and the + caller must finalize `self` after all append operations, with + ``self.finalize_append()``. + + Returns + ------- + None + + Notes + ----- + If you need to add multiple elements you should consider + `ArraySequence.extend`. + """ + element = np.asarray(element) + if element.size == 0: + return + el_shape = element.shape + n_items, common_shape = el_shape[0], el_shape[1:] + build_cache = self._build_cache + in_cached_build = build_cache is not None + if not in_cached_build: # One shot append, not part of sequence + build_cache = _BuildCache(self, common_shape, element.dtype) + next_offset = build_cache.next_offset + req_rows = next_offset + n_items + if self._data.shape[0] < req_rows: + self._resize_data_to(req_rows, build_cache) + self._data[next_offset:req_rows] = element + build_cache.offsets.append(next_offset) + build_cache.lengths.append(n_items) + build_cache.next_offset = req_rows + if in_cached_build: + return + if cache_build: + self._build_cache = build_cache + else: + build_cache.update_seq(self) + + def finalize_append(self): + """ Finalize process of appending several elements to `self` + + :method:`append` can be a lot faster if it knows that it is appending + several elements instead of a single element. To tell the append + method this is the case, use ``cache_build=True``. This method + finalizes the series of append operations after a call to + :method:`append` with ``cache_build=True``. + """ + if self._build_cache is None: + return + self._build_cache.update_seq(self) + self._build_cache = None + + def _resize_data_to(self, n_rows, build_cache): + """ Resize data array if required """ + # Calculate new data shape, rounding up to nearest buffer size + n_bufs = np.ceil(n_rows / build_cache.rows_per_buf) + extended_n_rows = int(n_bufs * build_cache.rows_per_buf) + new_shape = (extended_n_rows,) + build_cache.common_shape + if self._data.size == 0: + self._data = np.empty(new_shape, dtype=build_cache.dtype) + else: + self._data.resize(new_shape) + + def shrink_data(self): + self._data.resize((self._get_next_offset(),) + self.common_shape) + + def extend(self, elements): + """ Appends all `elements` to this array sequence. + + Parameters + ---------- + elements : iterable of ndarrays or :class:`ArraySequence` object + If iterable of ndarrays, each ndarray will be concatenated along + the first dimension then appended to the data of this + ArraySequence. + If :class:`ArraySequence` object, its data are simply appended to + the data of this ArraySequence. + + Returns + ------- + None + + Notes + ----- + The shape of the elements to be added must match the one of the data of + this :class:`ArraySequence` except for the first dimension. + """ + # If possible try pre-allocating memory. + try: + iter_len = len(elements) + except TypeError: + pass + else: # We do know the iterable length + if iter_len == 0: + return + e0 = np.asarray(elements[0]) + n_elements = np.sum([len(e) for e in elements]) + self._build_cache = _BuildCache(self, e0.shape[1:], e0.dtype) + self._resize_data_to(self._get_next_offset() + n_elements, + self._build_cache) + + for e in elements: + self.append(e, cache_build=True) + + self.finalize_append() + + def copy(self): + """ Creates a copy of this :class:`ArraySequence` object. + + Returns + ------- + seq_copy : :class:`ArraySequence` instance + Copy of `self`. + + Notes + ----- + We do not simply deepcopy this object because we have a chance to use + less memory. For example, if the array sequence being copied is the + result of a slicing operation on an array sequence. + """ + seq = self.__class__() + total_lengths = np.sum(self._lengths) + seq._data = np.empty((total_lengths,) + self._data.shape[1:], + dtype=self._data.dtype) + + next_offset = 0 + offsets = [] + for offset, length in zip(self._offsets, self._lengths): + offsets.append(next_offset) + chunk = self._data[offset:offset + length] + seq._data[next_offset:next_offset + length] = chunk + next_offset += length + + seq._offsets = np.asarray(offsets) + seq._lengths = self._lengths.copy() + + return seq + + def __getitem__(self, idx): + """ Get sequence(s) through standard or advanced numpy indexing. + + Parameters + ---------- + idx : int or slice or list or ndarray + If int, index of the element to retrieve. + If slice, use slicing to retrieve elements. + If list, indices of the elements to retrieve. + If ndarray with dtype int, indices of the elements to retrieve. + If ndarray with dtype bool, only retrieve selected elements. + + Returns + ------- + ndarray or :class:`ArraySequence` + If `idx` is an int, returns the selected sequence. + Otherwise, returns a :class:`ArraySequence` object which is a view + of the selected sequences. + """ + if isinstance(idx, (numbers.Integral, np.integer)): + start = self._offsets[idx] + return self._data[start:start + self._lengths[idx]] + + seq = self.__class__() + seq._is_view = True + if isinstance(idx, tuple): + off_idx = idx[0] + seq._data = self._data.__getitem__((slice(None),) + idx[1:]) + else: + off_idx = idx + seq._data = self._data + + if isinstance(off_idx, slice): # Standard list slicing + seq._offsets = self._offsets[off_idx] + seq._lengths = self._lengths[off_idx] + return seq + + if isinstance(off_idx, list) or is_ndarray_of_int_or_bool(off_idx): + # Fancy indexing + seq._offsets = self._offsets[off_idx] + seq._lengths = self._lengths[off_idx] + return seq + + raise TypeError("Index must be either an int, a slice, a list of int" + " or a ndarray of bool! Not " + str(type(idx))) + + def __iter__(self): + if len(self._lengths) != len(self._offsets): + raise ValueError("ArraySequence object corrupted:" + " len(self._lengths) != len(self._offsets)") + + for offset, lengths in zip(self._offsets, self._lengths): + yield self._data[offset: offset + lengths] + + def __len__(self): + return len(self._offsets) + + def __repr__(self): + if len(self) > np.get_printoptions()['threshold']: + # Show only the first and last edgeitems. + edgeitems = np.get_printoptions()['edgeitems'] + data = str(list(self[:edgeitems]))[:-1] + data += ", ..., " + data += str(list(self[-edgeitems:]))[1:] + else: + data = str(list(self)) + + return "{name}({data})".format(name=self.__class__.__name__, + data=data) + + def save(self, filename): + """ Saves this :class:`ArraySequence` object to a .npz file. """ + np.savez(filename, + data=self._data, + offsets=self._offsets, + lengths=self._lengths) + + @classmethod + def load(cls, filename): + """ Loads a :class:`ArraySequence` object from a .npz file. """ + content = np.load(filename) + seq = cls() + seq._data = content["data"] + seq._offsets = content["offsets"] + seq._lengths = content["lengths"] + return seq + + +def create_arraysequences_from_generator(gen, n): + """ Creates :class:`ArraySequence` objects from a generator yielding tuples + + Parameters + ---------- + gen : generator + Generator yielding a size `n` tuple containing the values to put in the + array sequences. + n : int + Number of :class:`ArraySequences` object to create. + """ + seqs = [ArraySequence() for _ in range(n)] + for data in gen: + for i, seq in enumerate(seqs): + if data[i].nbytes > 0: + seq.append(data[i], cache_build=True) + + for seq in seqs: + seq.finalize_append() + return seqs diff --git a/nibabel/streamlines/header.py b/nibabel/streamlines/header.py new file mode 100644 index 0000000000..c654b1234f --- /dev/null +++ b/nibabel/streamlines/header.py @@ -0,0 +1,19 @@ + +class Field: + """ Header fields common to multiple streamlines file formats. + + In IPython, use `nibabel.streamlines.Field??` to list them. + """ + NB_STREAMLINES = "nb_streamlines" + STEP_SIZE = "step_size" + METHOD = "method" + NB_SCALARS_PER_POINT = "nb_scalars_per_point" + NB_PROPERTIES_PER_STREAMLINE = "nb_properties_per_streamline" + NB_POINTS = "nb_points" + VOXEL_SIZES = "voxel_sizes" + DIMENSIONS = "dimensions" + MAGIC_NUMBER = "magic_number" + ORIGIN = "origin" + VOXEL_TO_RASMM = "voxel_to_rasmm" + VOXEL_ORDER = "voxel_order" + ENDIANNESS = "endianness" diff --git a/nibabel/streamlines/tests/__init__.py b/nibabel/streamlines/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/nibabel/streamlines/tests/test_array_sequence.py b/nibabel/streamlines/tests/test_array_sequence.py new file mode 100644 index 0000000000..a2ebd3a22e --- /dev/null +++ b/nibabel/streamlines/tests/test_array_sequence.py @@ -0,0 +1,301 @@ +import os +import sys +import unittest +import tempfile +import numpy as np + +from nose.tools import assert_equal, assert_raises, assert_true +from nibabel.testing import assert_arrays_equal +from numpy.testing import assert_array_equal + +from ..array_sequence import ArraySequence, is_array_sequence + + +SEQ_DATA = {} + + +def setup(): + global SEQ_DATA + rng = np.random.RandomState(42) + SEQ_DATA['rng'] = rng + SEQ_DATA['data'] = generate_data(nb_arrays=5, common_shape=(3,), rng=rng) + SEQ_DATA['seq'] = ArraySequence(SEQ_DATA['data']) + + +def generate_data(nb_arrays, common_shape, rng): + data = [rng.rand(*(rng.randint(3, 20),) + common_shape) + for _ in range(nb_arrays)] + return data + + +def check_empty_arr_seq(seq): + assert_equal(len(seq), 0) + assert_equal(len(seq._offsets), 0) + assert_equal(len(seq._lengths), 0) + # assert_equal(seq._data.ndim, 0) + assert_equal(seq._data.ndim, 1) + assert_true(seq.common_shape == ()) + + +def check_arr_seq(seq, arrays): + lengths = list(map(len, arrays)) + assert_true(is_array_sequence(seq)) + assert_equal(len(seq), len(arrays)) + assert_equal(len(seq._offsets), len(arrays)) + assert_equal(len(seq._lengths), len(arrays)) + assert_equal(seq._data.shape[1:], arrays[0].shape[1:]) + assert_equal(seq.common_shape, arrays[0].shape[1:]) + assert_arrays_equal(seq, arrays) + + # If seq is a view, then order of internal data is not guaranteed. + if seq._is_view: + # The only thing we can check is the _lengths. + assert_array_equal(sorted(seq._lengths), sorted(lengths)) + else: + seq.shrink_data() + assert_equal(seq._data.shape[0], sum(lengths)) + assert_array_equal(seq._data, np.concatenate(arrays, axis=0)) + assert_array_equal(seq._offsets, np.r_[0, np.cumsum(lengths)[:-1]]) + assert_array_equal(seq._lengths, lengths) + + +def check_arr_seq_view(seq_view, seq): + assert_true(seq_view._is_view) + assert_true(seq_view is not seq) + assert_true(np.may_share_memory(seq_view._data, seq._data)) + assert_true(seq_view._offsets is not seq._offsets) + assert_true(seq_view._lengths is not seq._lengths) + + +class TestArraySequence(unittest.TestCase): + + def test_creating_empty_arraysequence(self): + check_empty_arr_seq(ArraySequence()) + + def test_creating_arraysequence_from_list(self): + # Empty list + check_empty_arr_seq(ArraySequence([])) + + # List of ndarrays. + N = 5 + for ndim in range(1, N+1): + common_shape = tuple([SEQ_DATA['rng'].randint(1, 10) + for _ in range(ndim-1)]) + data = generate_data(nb_arrays=5, common_shape=common_shape, + rng=SEQ_DATA['rng']) + check_arr_seq(ArraySequence(data), data) + + # Force ArraySequence constructor to use buffering. + buffer_size = 1. / 1024**2 # 1 bytes + check_arr_seq(ArraySequence(iter(SEQ_DATA['data']), buffer_size), + SEQ_DATA['data']) + + def test_creating_arraysequence_from_generator(self): + gen = (e for e in SEQ_DATA['data']) + check_arr_seq(ArraySequence(gen), SEQ_DATA['data']) + + # Already consumed generator + check_empty_arr_seq(ArraySequence(gen)) + + def test_creating_arraysequence_from_arraysequence(self): + seq = ArraySequence(SEQ_DATA['data']) + check_arr_seq(ArraySequence(seq), SEQ_DATA['data']) + + # From an empty ArraySequence + seq = ArraySequence() + check_empty_arr_seq(ArraySequence(seq)) + + def test_arraysequence_iter(self): + assert_arrays_equal(SEQ_DATA['seq'], SEQ_DATA['data']) + + # Try iterating through a corrupted ArraySequence object. + seq = SEQ_DATA['seq'].copy() + seq._lengths = seq._lengths[::2] + assert_raises(ValueError, list, seq) + + def test_arraysequence_copy(self): + orig = SEQ_DATA['seq'] + seq = orig.copy() + n_rows = seq.total_nb_rows + assert_equal(n_rows, orig.total_nb_rows) + assert_array_equal(seq._data, orig._data[:n_rows]) + assert_true(seq._data is not orig._data) + assert_array_equal(seq._offsets, orig._offsets) + assert_true(seq._offsets is not orig._offsets) + assert_array_equal(seq._lengths, orig._lengths) + assert_true(seq._lengths is not orig._lengths) + assert_equal(seq.common_shape, orig.common_shape) + + # Taking a copy of an `ArraySequence` generated by slicing. + # Only keep needed data. + seq = orig[::2].copy() + check_arr_seq(seq, SEQ_DATA['data'][::2]) + assert_true(seq._data is not orig._data) + + def test_arraysequence_append(self): + element = generate_data(nb_arrays=1, + common_shape=SEQ_DATA['seq'].common_shape, + rng=SEQ_DATA['rng'])[0] + + # Append a new element. + seq = SEQ_DATA['seq'].copy() # Copy because of in-place modification. + seq.append(element) + check_arr_seq(seq, SEQ_DATA['data'] + [element]) + + # Append a list of list. + seq = SEQ_DATA['seq'].copy() # Copy because of in-place modification. + seq.append(element.tolist()) + check_arr_seq(seq, SEQ_DATA['data'] + [element]) + + # Append to an empty ArraySequence. + seq = ArraySequence() + seq.append(element) + check_arr_seq(seq, [element]) + + # Append an element with different shape. + element = generate_data(nb_arrays=1, + common_shape=SEQ_DATA['seq'].common_shape*2, + rng=SEQ_DATA['rng'])[0] + assert_raises(ValueError, seq.append, element) + + def test_arraysequence_extend(self): + new_data = generate_data(nb_arrays=10, + common_shape=SEQ_DATA['seq'].common_shape, + rng=SEQ_DATA['rng']) + + # Extend with an empty list. + seq = SEQ_DATA['seq'].copy() # Copy because of in-place modification. + seq.extend([]) + check_arr_seq(seq, SEQ_DATA['data']) + + # Extend with a list of ndarrays. + seq = SEQ_DATA['seq'].copy() # Copy because of in-place modification. + seq.extend(new_data) + check_arr_seq(seq, SEQ_DATA['data'] + new_data) + + # Extend with a generator. + seq = SEQ_DATA['seq'].copy() # Copy because of in-place modification. + seq.extend((d for d in new_data)) + check_arr_seq(seq, SEQ_DATA['data'] + new_data) + + # Extend with another `ArraySequence` object. + seq = SEQ_DATA['seq'].copy() # Copy because of in-place modification. + seq.extend(ArraySequence(new_data)) + check_arr_seq(seq, SEQ_DATA['data'] + new_data) + + # Extend with an `ArraySequence` view (e.g. been sliced). + # Need to make sure we extend only the data we need. + seq = SEQ_DATA['seq'].copy() # Copy because of in-place modification. + seq.extend(ArraySequence(new_data)[::2]) + check_arr_seq(seq, SEQ_DATA['data'] + new_data[::2]) + + # Test extending an empty ArraySequence + seq = ArraySequence() + seq.extend(ArraySequence()) + check_empty_arr_seq(seq) + + seq.extend(SEQ_DATA['seq']) + check_arr_seq(seq, SEQ_DATA['data']) + + # Extend with elements of different shape. + data = generate_data(nb_arrays=10, + common_shape=SEQ_DATA['seq'].common_shape*2, + rng=SEQ_DATA['rng']) + seq = SEQ_DATA['seq'].copy() # Copy because of in-place modification. + assert_raises(ValueError, seq.extend, data) + + def test_arraysequence_getitem(self): + # Get one item + for i, e in enumerate(SEQ_DATA['seq']): + assert_array_equal(SEQ_DATA['seq'][i], e) + + if sys.version_info < (3,): + assert_array_equal(SEQ_DATA['seq'][long(i)], e) + + # Get all items using indexing (creates a view). + indices = list(range(len(SEQ_DATA['seq']))) + seq_view = SEQ_DATA['seq'][indices] + check_arr_seq_view(seq_view, SEQ_DATA['seq']) + # We took all elements so the view should match the original. + check_arr_seq(seq_view, SEQ_DATA['seq']) + + # Get multiple items using ndarray of dtype integer. + for dtype in [np.int8, np.int16, np.int32, np.int64]: + seq_view = SEQ_DATA['seq'][np.array(indices, dtype=dtype)] + check_arr_seq_view(seq_view, SEQ_DATA['seq']) + # We took all elements so the view should match the original. + check_arr_seq(seq_view, SEQ_DATA['seq']) + + # Get multiple items out of order (creates a view). + SEQ_DATA['rng'].shuffle(indices) + seq_view = SEQ_DATA['seq'][indices] + check_arr_seq_view(seq_view, SEQ_DATA['seq']) + check_arr_seq(seq_view, [SEQ_DATA['data'][i] for i in indices]) + + # Get slice (this will create a view). + seq_view = SEQ_DATA['seq'][::2] + check_arr_seq_view(seq_view, SEQ_DATA['seq']) + check_arr_seq(seq_view, SEQ_DATA['data'][::2]) + + # Use advanced indexing with ndarray of data type bool. + selection = np.array([False, True, True, False, True]) + seq_view = SEQ_DATA['seq'][selection] + check_arr_seq_view(seq_view, SEQ_DATA['seq']) + check_arr_seq(seq_view, + [SEQ_DATA['data'][i] + for i, keep in enumerate(selection) if keep]) + + # Test invalid indexing + assert_raises(TypeError, SEQ_DATA['seq'].__getitem__, 'abc') + + # Get specific columns. + seq_view = SEQ_DATA['seq'][:, 2] + check_arr_seq_view(seq_view, SEQ_DATA['seq']) + check_arr_seq(seq_view, [d[:, 2] for d in SEQ_DATA['data']]) + + # Combining multiple slicing and indexing operations. + seq_view = SEQ_DATA['seq'][::-2][:, 2] + check_arr_seq_view(seq_view, SEQ_DATA['seq']) + check_arr_seq(seq_view, [d[:, 2] for d in SEQ_DATA['data'][::-2]]) + + def test_arraysequence_repr(self): + # Test that calling repr on a ArraySequence object is not falling. + repr(SEQ_DATA['seq']) + + # Test calling repr when the number of arrays is bigger dans Numpy's + # print option threshold. + nb_arrays = 50 + seq = ArraySequence(generate_data(nb_arrays, common_shape=(1,), + rng=SEQ_DATA['rng'])) + + bkp_threshold = np.get_printoptions()['threshold'] + np.set_printoptions(threshold=nb_arrays*2) + txt1 = repr(seq) + np.set_printoptions(threshold=nb_arrays//2) + txt2 = repr(seq) + assert_true(len(txt2) < len(txt1)) + np.set_printoptions(threshold=bkp_threshold) + + def test_save_and_load_arraysequence(self): + # Test saving and loading an empty ArraySequence. + with tempfile.TemporaryFile(mode="w+b", suffix=".npz") as f: + seq = ArraySequence() + seq.save(f) + f.seek(0, os.SEEK_SET) + loaded_seq = ArraySequence.load(f) + assert_array_equal(loaded_seq._data, seq._data) + assert_array_equal(loaded_seq._offsets, seq._offsets) + assert_array_equal(loaded_seq._lengths, seq._lengths) + + # Test saving and loading a ArraySequence. + with tempfile.TemporaryFile(mode="w+b", suffix=".npz") as f: + seq = SEQ_DATA['seq'] + seq.save(f) + f.seek(0, os.SEEK_SET) + loaded_seq = ArraySequence.load(f) + assert_array_equal(loaded_seq._data, seq._data) + assert_array_equal(loaded_seq._offsets, seq._offsets) + assert_array_equal(loaded_seq._lengths, seq._lengths) + + # Make sure we can add new elements to it. + loaded_seq.append(SEQ_DATA['data'][0]) diff --git a/nibabel/streamlines/tests/test_streamlines.py b/nibabel/streamlines/tests/test_streamlines.py new file mode 100644 index 0000000000..c2f1c066d3 --- /dev/null +++ b/nibabel/streamlines/tests/test_streamlines.py @@ -0,0 +1,267 @@ +import os +import unittest +import tempfile +import numpy as np + +from os.path import join as pjoin + +import nibabel as nib +from nibabel.externals.six import BytesIO +from nibabel.tmpdirs import InTemporaryDirectory + +from nibabel.testing import data_path +from nibabel.testing import clear_and_catch_warnings +from nose.tools import assert_equal, assert_raises, assert_true, assert_false + +from .test_tractogram import assert_tractogram_equal +from ..tractogram import Tractogram, LazyTractogram +from ..tractogram_file import TractogramFile, ExtensionWarning +from .. import trk + +DATA = {} + + +def setup(): + global DATA + DATA['empty_filenames'] = [pjoin(data_path, "empty" + ext) + for ext in nib.streamlines.FORMATS.keys()] + DATA['simple_filenames'] = [pjoin(data_path, "simple" + ext) + for ext in nib.streamlines.FORMATS.keys()] + DATA['complex_filenames'] = [pjoin(data_path, "complex" + ext) + for ext in nib.streamlines.FORMATS.keys()] + + DATA['streamlines'] = [np.arange(1*3, dtype="f4").reshape((1, 3)), + np.arange(2*3, dtype="f4").reshape((2, 3)), + np.arange(5*3, dtype="f4").reshape((5, 3))] + + fa = [np.array([[0.2]], dtype="f4"), + np.array([[0.3], + [0.4]], dtype="f4"), + np.array([[0.5], + [0.6], + [0.6], + [0.7], + [0.8]], dtype="f4")] + + colors = [np.array([(1, 0, 0)]*1, dtype="f4"), + np.array([(0, 1, 0)]*2, dtype="f4"), + np.array([(0, 0, 1)]*5, dtype="f4")] + + mean_curvature = [np.array([1.11], dtype="f4"), + np.array([2.11], dtype="f4"), + np.array([3.11], dtype="f4")] + + mean_torsion = [np.array([1.22], dtype="f4"), + np.array([2.22], dtype="f4"), + np.array([3.22], dtype="f4")] + + mean_colors = [np.array([1, 0, 0], dtype="f4"), + np.array([0, 1, 0], dtype="f4"), + np.array([0, 0, 1], dtype="f4")] + + DATA['data_per_point'] = {'colors': colors, + 'fa': fa} + DATA['data_per_streamline'] = {'mean_curvature': mean_curvature, + 'mean_torsion': mean_torsion, + 'mean_colors': mean_colors} + + DATA['empty_tractogram'] = Tractogram(affine_to_rasmm=np.eye(4)) + DATA['simple_tractogram'] = Tractogram(DATA['streamlines'], + affine_to_rasmm=np.eye(4)) + DATA['complex_tractogram'] = Tractogram(DATA['streamlines'], + DATA['data_per_streamline'], + DATA['data_per_point'], + affine_to_rasmm=np.eye(4)) + + +def test_is_supported_detect_format(): + # Test is_supported and detect_format functions + # Empty file/string + f = BytesIO() + assert_false(nib.streamlines.is_supported(f)) + assert_false(nib.streamlines.is_supported("")) + assert_true(nib.streamlines.detect_format(f) is None) + assert_true(nib.streamlines.detect_format("") is None) + + # Valid file without extension + for tfile_cls in nib.streamlines.FORMATS.values(): + f = BytesIO() + f.write(tfile_cls.MAGIC_NUMBER) + f.seek(0, os.SEEK_SET) + assert_true(nib.streamlines.is_supported(f)) + assert_true(nib.streamlines.detect_format(f) is tfile_cls) + + # Wrong extension but right magic number + for tfile_cls in nib.streamlines.FORMATS.values(): + with tempfile.TemporaryFile(mode="w+b", suffix=".txt") as f: + f.write(tfile_cls.MAGIC_NUMBER) + f.seek(0, os.SEEK_SET) + assert_true(nib.streamlines.is_supported(f)) + assert_true(nib.streamlines.detect_format(f) is tfile_cls) + + # Good extension but wrong magic number + for ext, tfile_cls in nib.streamlines.FORMATS.items(): + with tempfile.TemporaryFile(mode="w+b", suffix=ext) as f: + f.write(b"pass") + f.seek(0, os.SEEK_SET) + assert_false(nib.streamlines.is_supported(f)) + assert_true(nib.streamlines.detect_format(f) is None) + + # Wrong extension, string only + f = "my_tractogram.asd" + assert_false(nib.streamlines.is_supported(f)) + assert_true(nib.streamlines.detect_format(f) is None) + + # Good extension, string only + for ext, tfile_cls in nib.streamlines.FORMATS.items(): + f = "my_tractogram" + ext + assert_true(nib.streamlines.is_supported(f)) + assert_equal(nib.streamlines.detect_format(f), tfile_cls) + + # Extension should not be case-sensitive. + for ext, tfile_cls in nib.streamlines.FORMATS.items(): + f = "my_tractogram" + ext.upper() + assert_true(nib.streamlines.detect_format(f) is tfile_cls) + + +class TestLoadSave(unittest.TestCase): + + def test_load_empty_file(self): + for lazy_load in [False, True]: + for empty_filename in DATA['empty_filenames']: + tfile = nib.streamlines.load(empty_filename, + lazy_load=lazy_load) + assert_true(isinstance(tfile, TractogramFile)) + + if lazy_load: + assert_true(type(tfile.tractogram), Tractogram) + else: + assert_true(type(tfile.tractogram), LazyTractogram) + + assert_tractogram_equal(tfile.tractogram, + DATA['empty_tractogram']) + + def test_load_simple_file(self): + for lazy_load in [False, True]: + for simple_filename in DATA['simple_filenames']: + tfile = nib.streamlines.load(simple_filename, + lazy_load=lazy_load) + assert_true(isinstance(tfile, TractogramFile)) + + if lazy_load: + assert_true(type(tfile.tractogram), Tractogram) + else: + assert_true(type(tfile.tractogram), LazyTractogram) + + assert_tractogram_equal(tfile.tractogram, + DATA['simple_tractogram']) + + def test_load_complex_file(self): + for lazy_load in [False, True]: + for complex_filename in DATA['complex_filenames']: + tfile = nib.streamlines.load(complex_filename, + lazy_load=lazy_load) + assert_true(isinstance(tfile, TractogramFile)) + + if lazy_load: + assert_true(type(tfile.tractogram), Tractogram) + else: + assert_true(type(tfile.tractogram), LazyTractogram) + + tractogram = Tractogram(DATA['streamlines'], + affine_to_rasmm=np.eye(4)) + + if tfile.SUPPORTS_DATA_PER_POINT: + tractogram.data_per_point = DATA['data_per_point'] + + if tfile.SUPPORTS_DATA_PER_STREAMLINE: + data = DATA['data_per_streamline'] + tractogram.data_per_streamline = data + + assert_tractogram_equal(tfile.tractogram, + tractogram) + + def test_save_tractogram_file(self): + tractogram = Tractogram(DATA['streamlines'], + affine_to_rasmm=np.eye(4)) + trk_file = trk.TrkFile(tractogram) + + # No need for keyword arguments. + assert_raises(ValueError, nib.streamlines.save, + trk_file, "dummy.trk", header={}) + + # Wrong extension. + with clear_and_catch_warnings(record=True, + modules=[nib.streamlines]) as w: + trk_file = trk.TrkFile(tractogram) + assert_raises(ValueError, nib.streamlines.save, + trk_file, "dummy.tck", header={}) + + assert_equal(len(w), 1) + assert_true(issubclass(w[0].category, ExtensionWarning)) + assert_true("extension" in str(w[0].message)) + + with InTemporaryDirectory(): + nib.streamlines.save(trk_file, "dummy.trk") + tfile = nib.streamlines.load("dummy.trk", lazy_load=False) + assert_tractogram_equal(tfile.tractogram, tractogram) + + def test_save_empty_file(self): + tractogram = Tractogram(affine_to_rasmm=np.eye(4)) + for ext, cls in nib.streamlines.FORMATS.items(): + with InTemporaryDirectory(): + filename = 'streamlines' + ext + nib.streamlines.save(tractogram, filename) + tfile = nib.streamlines.load(filename, lazy_load=False) + assert_tractogram_equal(tfile.tractogram, tractogram) + + def test_save_simple_file(self): + tractogram = Tractogram(DATA['streamlines'], + affine_to_rasmm=np.eye(4)) + for ext, cls in nib.streamlines.FORMATS.items(): + with InTemporaryDirectory(): + filename = 'streamlines' + ext + nib.streamlines.save(tractogram, filename) + tfile = nib.streamlines.load(filename, lazy_load=False) + assert_tractogram_equal(tfile.tractogram, tractogram) + + def test_save_complex_file(self): + complex_tractogram = Tractogram(DATA['streamlines'], + DATA['data_per_streamline'], + DATA['data_per_point'], + affine_to_rasmm=np.eye(4)) + + for ext, cls in nib.streamlines.FORMATS.items(): + with InTemporaryDirectory(): + filename = 'streamlines' + ext + + with clear_and_catch_warnings(record=True, + modules=[trk]) as w: + nib.streamlines.save(complex_tractogram, filename) + + # If streamlines format does not support saving data + # per point or data per streamline, a warning message + # should be issued. + if not (cls.SUPPORTS_DATA_PER_POINT and + cls.SUPPORTS_DATA_PER_STREAMLINE): + assert_equal(len(w), 1) + assert_true(issubclass(w[0].category, Warning)) + + tractogram = Tractogram(DATA['streamlines'], + affine_to_rasmm=np.eye(4)) + + if cls.SUPPORTS_DATA_PER_POINT: + tractogram.data_per_point = DATA['data_per_point'] + + if cls.SUPPORTS_DATA_PER_STREAMLINE: + data = DATA['data_per_streamline'] + tractogram.data_per_streamline = data + + tfile = nib.streamlines.load(filename, lazy_load=False) + assert_tractogram_equal(tfile.tractogram, tractogram) + + def test_load_unknown_format(self): + assert_raises(ValueError, nib.streamlines.load, "") + + def test_save_unknown_format(self): + assert_raises(ValueError, nib.streamlines.save, Tractogram(), "") diff --git a/nibabel/streamlines/tests/test_tractogram.py b/nibabel/streamlines/tests/test_tractogram.py new file mode 100644 index 0000000000..76f06dff0e --- /dev/null +++ b/nibabel/streamlines/tests/test_tractogram.py @@ -0,0 +1,768 @@ +import sys +import unittest +import numpy as np +import warnings + +from nibabel.testing import assert_arrays_equal +from nibabel.testing import clear_and_catch_warnings +from nose.tools import assert_equal, assert_raises, assert_true +from numpy.testing import assert_array_equal, assert_array_almost_equal +from nibabel.externals.six.moves import zip + +from .. import tractogram as module_tractogram +from ..tractogram import TractogramItem, Tractogram, LazyTractogram +from ..tractogram import PerArrayDict, PerArraySequenceDict, LazyDict + +DATA = {} + + +def setup(): + global DATA + DATA['streamlines'] = [np.arange(1*3, dtype="f4").reshape((1, 3)), + np.arange(2*3, dtype="f4").reshape((2, 3)), + np.arange(5*3, dtype="f4").reshape((5, 3))] + + DATA['fa'] = [np.array([[0.2]], dtype="f4"), + np.array([[0.3], + [0.4]], dtype="f4"), + np.array([[0.5], + [0.6], + [0.6], + [0.7], + [0.8]], dtype="f4")] + + DATA['colors'] = [np.array([(1, 0, 0)]*1, dtype="f4"), + np.array([(0, 1, 0)]*2, dtype="f4"), + np.array([(0, 0, 1)]*5, dtype="f4")] + + DATA['mean_curvature'] = [np.array([1.11], dtype="f4"), + np.array([2.11], dtype="f4"), + np.array([3.11], dtype="f4")] + + DATA['mean_torsion'] = [np.array([1.22], dtype="f4"), + np.array([2.22], dtype="f4"), + np.array([3.22], dtype="f4")] + + DATA['mean_colors'] = [np.array([1, 0, 0], dtype="f4"), + np.array([0, 1, 0], dtype="f4"), + np.array([0, 0, 1], dtype="f4")] + + DATA['data_per_point'] = {'colors': DATA['colors'], + 'fa': DATA['fa']} + DATA['data_per_streamline'] = {'mean_curvature': DATA['mean_curvature'], + 'mean_torsion': DATA['mean_torsion'], + 'mean_colors': DATA['mean_colors']} + + DATA['empty_tractogram'] = Tractogram(affine_to_rasmm=np.eye(4)) + DATA['simple_tractogram'] = Tractogram(DATA['streamlines'], + affine_to_rasmm=np.eye(4)) + DATA['tractogram'] = Tractogram(DATA['streamlines'], + DATA['data_per_streamline'], + DATA['data_per_point'], + affine_to_rasmm=np.eye(4)) + + DATA['streamlines_func'] = lambda: (e for e in DATA['streamlines']) + fa_func = lambda: (e for e in DATA['fa']) + colors_func = lambda: (e for e in DATA['colors']) + mean_curvature_func = lambda: (e for e in DATA['mean_curvature']) + mean_torsion_func = lambda: (e for e in DATA['mean_torsion']) + mean_colors_func = lambda: (e for e in DATA['mean_colors']) + + DATA['data_per_point_func'] = {'colors': colors_func, + 'fa': fa_func} + DATA['data_per_streamline_func'] = {'mean_curvature': mean_curvature_func, + 'mean_torsion': mean_torsion_func, + 'mean_colors': mean_colors_func} + + DATA['lazy_tractogram'] = LazyTractogram(DATA['streamlines_func'], + DATA['data_per_streamline_func'], + DATA['data_per_point_func'], + affine_to_rasmm=np.eye(4)) + + +def check_tractogram_item(tractogram_item, + streamline, + data_for_streamline={}, + data_for_points={}): + + assert_array_equal(tractogram_item.streamline, streamline) + + assert_equal(len(tractogram_item.data_for_streamline), + len(data_for_streamline)) + for key in data_for_streamline.keys(): + assert_array_equal(tractogram_item.data_for_streamline[key], + data_for_streamline[key]) + + assert_equal(len(tractogram_item.data_for_points), len(data_for_points)) + for key in data_for_points.keys(): + assert_arrays_equal(tractogram_item.data_for_points[key], + data_for_points[key]) + + +def assert_tractogram_item_equal(t1, t2): + check_tractogram_item(t1, t2.streamline, + t2.data_for_streamline, t2.data_for_points) + + +def check_tractogram(tractogram, + streamlines=[], + data_per_streamline={}, + data_per_point={}): + streamlines = list(streamlines) + assert_equal(len(tractogram), len(streamlines)) + assert_arrays_equal(tractogram.streamlines, streamlines) + [t for t in tractogram] # Force iteration through tractogram. + + assert_equal(len(tractogram.data_per_streamline), len(data_per_streamline)) + for key in data_per_streamline.keys(): + assert_arrays_equal(tractogram.data_per_streamline[key], + data_per_streamline[key]) + + assert_equal(len(tractogram.data_per_point), len(data_per_point)) + for key in data_per_point.keys(): + assert_arrays_equal(tractogram.data_per_point[key], + data_per_point[key]) + + +def assert_tractogram_equal(t1, t2): + check_tractogram(t1, t2.streamlines, + t2.data_per_streamline, t2.data_per_point) + + +class TestPerArrayDict(unittest.TestCase): + + def test_per_array_dict_creation(self): + # Create a PerArrayDict object using another + # PerArrayDict object. + nb_streamlines = len(DATA['tractogram']) + data_per_streamline = DATA['tractogram'].data_per_streamline + data_dict = PerArrayDict(nb_streamlines, data_per_streamline) + assert_equal(data_dict.keys(), data_per_streamline.keys()) + for k in data_dict.keys(): + assert_array_equal(data_dict[k], data_per_streamline[k]) + + del data_dict['mean_curvature'] + assert_equal(len(data_dict), + len(data_per_streamline)-1) + + # Create a PerArrayDict object using an existing dict object. + data_per_streamline = DATA['data_per_streamline'] + data_dict = PerArrayDict(nb_streamlines, data_per_streamline) + assert_equal(data_dict.keys(), data_per_streamline.keys()) + for k in data_dict.keys(): + assert_array_equal(data_dict[k], data_per_streamline[k]) + + del data_dict['mean_curvature'] + assert_equal(len(data_dict), len(data_per_streamline)-1) + + # Create a PerArrayDict object using keyword arguments. + data_per_streamline = DATA['data_per_streamline'] + data_dict = PerArrayDict(nb_streamlines, **data_per_streamline) + assert_equal(data_dict.keys(), data_per_streamline.keys()) + for k in data_dict.keys(): + assert_array_equal(data_dict[k], data_per_streamline[k]) + + del data_dict['mean_curvature'] + assert_equal(len(data_dict), len(data_per_streamline)-1) + + def test_getitem(self): + sdict = PerArrayDict(len(DATA['tractogram']), + DATA['data_per_streamline']) + + assert_raises(KeyError, sdict.__getitem__, 'invalid') + + # Test slicing and advanced indexing. + for k, v in DATA['tractogram'].data_per_streamline.items(): + assert_true(k in sdict) + assert_arrays_equal(sdict[k], v) + assert_arrays_equal(sdict[::2][k], v[::2]) + assert_arrays_equal(sdict[::-1][k], v[::-1]) + assert_arrays_equal(sdict[-1][k], v[-1]) + assert_arrays_equal(sdict[[0, -1]][k], v[[0, -1]]) + + +class TestPerArraySequenceDict(unittest.TestCase): + + def test_per_array_sequence_dict_creation(self): + # Create a PerArraySequenceDict object using another + # PerArraySequenceDict object. + total_nb_rows = DATA['tractogram'].streamlines.total_nb_rows + data_per_point = DATA['tractogram'].data_per_point + data_dict = PerArraySequenceDict(total_nb_rows, data_per_point) + assert_equal(data_dict.keys(), data_per_point.keys()) + for k in data_dict.keys(): + assert_arrays_equal(data_dict[k], data_per_point[k]) + + del data_dict['fa'] + assert_equal(len(data_dict), + len(data_per_point)-1) + + # Create a PerArraySequenceDict object using an existing dict object. + data_per_point = DATA['data_per_point'] + data_dict = PerArraySequenceDict(total_nb_rows, data_per_point) + assert_equal(data_dict.keys(), data_per_point.keys()) + for k in data_dict.keys(): + assert_arrays_equal(data_dict[k], data_per_point[k]) + + del data_dict['fa'] + assert_equal(len(data_dict), len(data_per_point)-1) + + # Create a PerArraySequenceDict object using keyword arguments. + data_per_point = DATA['data_per_point'] + data_dict = PerArraySequenceDict(total_nb_rows, **data_per_point) + assert_equal(data_dict.keys(), data_per_point.keys()) + for k in data_dict.keys(): + assert_arrays_equal(data_dict[k], data_per_point[k]) + + del data_dict['fa'] + assert_equal(len(data_dict), len(data_per_point)-1) + + def test_getitem(self): + total_nb_rows = DATA['tractogram'].streamlines.total_nb_rows + sdict = PerArraySequenceDict(total_nb_rows, DATA['data_per_point']) + + assert_raises(KeyError, sdict.__getitem__, 'invalid') + + # Test slicing and advanced indexing. + for k, v in DATA['tractogram'].data_per_point.items(): + assert_true(k in sdict) + assert_arrays_equal(sdict[k], v) + assert_arrays_equal(sdict[::2][k], v[::2]) + assert_arrays_equal(sdict[::-1][k], v[::-1]) + assert_arrays_equal(sdict[-1][k], v[-1]) + assert_arrays_equal(sdict[[0, -1]][k], v[[0, -1]]) + + +class TestLazyDict(unittest.TestCase): + + def test_lazydict_creation(self): + data_dict = LazyDict(DATA['data_per_streamline_func']) + assert_equal(data_dict.keys(), DATA['data_per_streamline_func'].keys()) + for k in data_dict.keys(): + assert_array_equal(list(data_dict[k]), + list(DATA['data_per_streamline'][k])) + + assert_equal(len(data_dict), + len(DATA['data_per_streamline_func'])) + + +class TestTractogramItem(unittest.TestCase): + + def test_creating_tractogram_item(self): + rng = np.random.RandomState(42) + streamline = rng.rand(rng.randint(10, 50), 3) + colors = rng.rand(len(streamline), 3) + mean_curvature = 1.11 + mean_color = np.array([0, 1, 0], dtype="f4") + + data_for_streamline = {"mean_curvature": mean_curvature, + "mean_color": mean_color} + + data_for_points = {"colors": colors} + + # Create a tractogram item with a streamline, data. + t = TractogramItem(streamline, data_for_streamline, data_for_points) + assert_equal(len(t), len(streamline)) + assert_array_equal(t.streamline, streamline) + assert_array_equal(list(t), streamline) + assert_array_equal(t.data_for_streamline['mean_curvature'], + mean_curvature) + assert_array_equal(t.data_for_streamline['mean_color'], + mean_color) + assert_array_equal(t.data_for_points['colors'], + colors) + + +class TestTractogram(unittest.TestCase): + + def test_tractogram_creation(self): + # Create an empty tractogram. + tractogram = Tractogram() + check_tractogram(tractogram) + assert_true(tractogram.affine_to_rasmm is None) + + # Create a tractogram with only streamlines + tractogram = Tractogram(streamlines=DATA['streamlines']) + check_tractogram(tractogram, DATA['streamlines']) + + # Create a tractogram with a given affine_to_rasmm. + affine = np.diag([1, 2, 3, 1]) + tractogram = Tractogram(affine_to_rasmm=affine) + assert_array_equal(tractogram.affine_to_rasmm, affine) + + # Create a tractogram with streamlines and other data. + tractogram = Tractogram(DATA['streamlines'], + DATA['data_per_streamline'], + DATA['data_per_point']) + + check_tractogram(tractogram, + DATA['streamlines'], + DATA['data_per_streamline'], + DATA['data_per_point']) + + # Create a tractogram from another tractogram attributes. + tractogram2 = Tractogram(tractogram.streamlines, + tractogram.data_per_streamline, + tractogram.data_per_point) + + assert_tractogram_equal(tractogram2, tractogram) + + # Create a tractogram from a LazyTractogram object. + tractogram = LazyTractogram(DATA['streamlines_func'], + DATA['data_per_streamline_func'], + DATA['data_per_point_func']) + + tractogram2 = Tractogram(tractogram.streamlines, + tractogram.data_per_streamline, + tractogram.data_per_point) + + # Inconsistent number of scalars between streamlines + wrong_data = [[(1, 0, 0)]*1, + [(0, 1, 0), (0, 1)], + [(0, 0, 1)]*5] + + data_per_point = {'wrong_data': wrong_data} + assert_raises(ValueError, Tractogram, DATA['streamlines'], + data_per_point=data_per_point) + + # Inconsistent number of scalars between streamlines + wrong_data = [[(1, 0, 0)]*1, + [(0, 1)]*2, + [(0, 0, 1)]*5] + + data_per_point = {'wrong_data': wrong_data} + assert_raises(ValueError, Tractogram, DATA['streamlines'], + data_per_point=data_per_point) + + def test_setting_affine_to_rasmm(self): + tractogram = DATA['tractogram'].copy() + affine = np.diag(range(4)) + + # Test assigning None. + tractogram.affine_to_rasmm = None + assert_true(tractogram.affine_to_rasmm is None) + + # Test assigning a valid ndarray (should make a copy). + tractogram.affine_to_rasmm = affine + assert_true(tractogram.affine_to_rasmm is not affine) + + # Test assigning a list of lists. + tractogram.affine_to_rasmm = affine.tolist() + assert_array_equal(tractogram.affine_to_rasmm, affine) + + # Test assigning a ndarray with wrong shape. + assert_raises(ValueError, setattr, tractogram, + "affine_to_rasmm", affine[::2]) + + def test_tractogram_getitem(self): + # Retrieve TractogramItem by their index. + for i, t in enumerate(DATA['tractogram']): + assert_tractogram_item_equal(DATA['tractogram'][i], t) + + if sys.version_info < (3,): + assert_tractogram_item_equal(DATA['tractogram'][long(i)], t) + + # Get one TractogramItem out of two. + tractogram_view = DATA['simple_tractogram'][::2] + check_tractogram(tractogram_view, DATA['streamlines'][::2]) + + # Use slicing. + r_tractogram = DATA['tractogram'][::-1] + check_tractogram(r_tractogram, + DATA['streamlines'][::-1], + DATA['tractogram'].data_per_streamline[::-1], + DATA['tractogram'].data_per_point[::-1]) + + def test_tractogram_add_new_data(self): + # Tractogram with only streamlines + t = DATA['simple_tractogram'].copy() + t.data_per_point['fa'] = DATA['fa'] + t.data_per_point['colors'] = DATA['colors'] + t.data_per_streamline['mean_curvature'] = DATA['mean_curvature'] + t.data_per_streamline['mean_torsion'] = DATA['mean_torsion'] + t.data_per_streamline['mean_colors'] = DATA['mean_colors'] + assert_tractogram_equal(t, DATA['tractogram']) + + # Retrieve tractogram by their index. + for i, item in enumerate(t): + assert_tractogram_item_equal(t[i], item) + + # Use slicing. + r_tractogram = t[::-1] + check_tractogram(r_tractogram, + t.streamlines[::-1], + t.data_per_streamline[::-1], + t.data_per_point[::-1]) + + # Add new data to a tractogram for which its `streamlines` is a view. + t = Tractogram(DATA['streamlines']*2, affine_to_rasmm=np.eye(4)) + t = t[:len(DATA['streamlines'])] # Create a view of `streamlines` + t.data_per_point['fa'] = DATA['fa'] + t.data_per_point['colors'] = DATA['colors'] + t.data_per_streamline['mean_curvature'] = DATA['mean_curvature'] + t.data_per_streamline['mean_torsion'] = DATA['mean_torsion'] + t.data_per_streamline['mean_colors'] = DATA['mean_colors'] + assert_tractogram_equal(t, DATA['tractogram']) + + def test_tractogram_copy(self): + # Create a copy of a tractogram. + tractogram = DATA['tractogram'].copy() + + # Check we copied the data and not simply created new references. + assert_true(tractogram is not DATA['tractogram']) + assert_true(tractogram.streamlines + is not DATA['tractogram'].streamlines) + assert_true(tractogram.data_per_streamline + is not DATA['tractogram'].data_per_streamline) + assert_true(tractogram.data_per_point + is not DATA['tractogram'].data_per_point) + + for key in tractogram.data_per_streamline: + assert_true(tractogram.data_per_streamline[key] + is not DATA['tractogram'].data_per_streamline[key]) + + for key in tractogram.data_per_point: + assert_true(tractogram.data_per_point[key] + is not DATA['tractogram'].data_per_point[key]) + + # Check the values of the data are the same. + assert_tractogram_equal(tractogram, DATA['tractogram']) + + def test_creating_invalid_tractogram(self): + # Not enough data_per_point for all the points of all streamlines. + scalars = [[(1, 0, 0)]*1, + [(0, 1, 0)]*2, + [(0, 0, 1)]*3] # Last streamlines has 5 points. + + assert_raises(ValueError, Tractogram, DATA['streamlines'], + data_per_point={'scalars': scalars}) + + # Not enough data_per_streamline for all streamlines. + properties = [np.array([1.11, 1.22], dtype="f4"), + np.array([3.11, 3.22], dtype="f4")] + + assert_raises(ValueError, Tractogram, DATA['streamlines'], + data_per_streamline={'properties': properties}) + + # Inconsistent dimension for a data_per_point. + scalars = [[(1, 0, 0)]*1, + [(0, 1)]*2, + [(0, 0, 1)]*5] + + assert_raises(ValueError, Tractogram, DATA['streamlines'], + data_per_point={'scalars': scalars}) + + # Inconsistent dimension for a data_per_streamline. + properties = [[1.11, 1.22], + [2.11], + [3.11, 3.22]] + + assert_raises(ValueError, Tractogram, DATA['streamlines'], + data_per_streamline={'properties': properties}) + + # Too many dimension for a data_per_streamline. + properties = [np.array([[1.11], [1.22]], dtype="f4"), + np.array([[2.11], [2.22]], dtype="f4"), + np.array([[3.11], [3.22]], dtype="f4")] + + assert_raises(ValueError, Tractogram, DATA['streamlines'], + data_per_streamline={'properties': properties}) + + def test_tractogram_apply_affine(self): + tractogram = DATA['tractogram'].copy() + affine = np.eye(4) + scaling = np.array((1, 2, 3), dtype=float) + affine[range(3), range(3)] = scaling + + # Apply the affine to the streamline in a lazy manner. + transformed_tractogram = tractogram.apply_affine(affine, lazy=True) + assert_true(type(transformed_tractogram) is LazyTractogram) + check_tractogram(transformed_tractogram, + streamlines=[s*scaling for s in DATA['streamlines']], + data_per_streamline=DATA['data_per_streamline'], + data_per_point=DATA['data_per_point']) + assert_array_equal(transformed_tractogram.affine_to_rasmm, + np.dot(np.eye(4), np.linalg.inv(affine))) + # Make sure streamlines of the original tractogram have not been + # modified. + assert_arrays_equal(tractogram.streamlines, DATA['streamlines']) + + # Apply the affine to the streamlines in-place. + transformed_tractogram = tractogram.apply_affine(affine) + assert_true(transformed_tractogram is tractogram) + check_tractogram(tractogram, + streamlines=[s*scaling for s in DATA['streamlines']], + data_per_streamline=DATA['data_per_streamline'], + data_per_point=DATA['data_per_point']) + + # Apply affine again and check the affine_to_rasmm. + transformed_tractogram = tractogram.apply_affine(affine) + assert_array_equal(transformed_tractogram.affine_to_rasmm, + np.dot(np.eye(4), np.dot(np.linalg.inv(affine), + np.linalg.inv(affine)))) + + # Check that applying an affine and its inverse give us back the + # original streamlines. + tractogram = DATA['tractogram'].copy() + affine = np.random.RandomState(1234).randn(4, 4) + affine[-1] = [0, 0, 0, 1] # Remove perspective projection. + + tractogram.apply_affine(affine) + tractogram.apply_affine(np.linalg.inv(affine)) + assert_array_almost_equal(tractogram.affine_to_rasmm, np.eye(4)) + for s1, s2 in zip(tractogram.streamlines, DATA['streamlines']): + assert_array_almost_equal(s1, s2) + + # Test applying the identity transformation. + tractogram = DATA['tractogram'].copy() + tractogram.apply_affine(np.eye(4)) + for s1, s2 in zip(tractogram.streamlines, DATA['streamlines']): + assert_array_almost_equal(s1, s2) + + # Test removing affine_to_rasmm + tractogram = DATA['tractogram'].copy() + tractogram.affine_to_rasmm = None + tractogram.apply_affine(affine) + assert_true(tractogram.affine_to_rasmm is None) + + def test_tractogram_to_world(self): + tractogram = DATA['tractogram'].copy() + affine = np.random.RandomState(1234).randn(4, 4) + affine[-1] = [0, 0, 0, 1] # Remove perspective projection. + + # Apply the affine to the streamlines, then bring them back + # to world space in a lazy manner. + transformed_tractogram = tractogram.apply_affine(affine) + assert_array_equal(transformed_tractogram.affine_to_rasmm, + np.linalg.inv(affine)) + + tractogram_world = transformed_tractogram.to_world(lazy=True) + assert_true(type(tractogram_world) is LazyTractogram) + assert_array_almost_equal(tractogram_world.affine_to_rasmm, + np.eye(4)) + for s1, s2 in zip(tractogram_world.streamlines, DATA['streamlines']): + assert_array_almost_equal(s1, s2) + + # Bring them back streamlines to world space in a in-place manner. + tractogram_world = transformed_tractogram.to_world() + assert_true(tractogram_world is tractogram) + assert_array_almost_equal(tractogram.affine_to_rasmm, np.eye(4)) + for s1, s2 in zip(tractogram.streamlines, DATA['streamlines']): + assert_array_almost_equal(s1, s2) + + # Calling to_world twice should do nothing. + tractogram_world2 = transformed_tractogram.to_world() + assert_true(tractogram_world2 is tractogram) + assert_array_almost_equal(tractogram.affine_to_rasmm, np.eye(4)) + for s1, s2 in zip(tractogram.streamlines, DATA['streamlines']): + assert_array_almost_equal(s1, s2) + + # Calling to_world when affine_to_rasmm is None should fail. + tractogram = DATA['tractogram'].copy() + tractogram.affine_to_rasmm = None + assert_raises(ValueError, tractogram.to_world) + + +class TestLazyTractogram(unittest.TestCase): + + def test_lazy_tractogram_creation(self): + # To create tractogram from arrays use `Tractogram`. + assert_raises(TypeError, LazyTractogram, DATA['streamlines']) + + # Streamlines and other data as generators + streamlines = (x for x in DATA['streamlines']) + data_per_point = {"colors": (x for x in DATA['colors'])} + data_per_streamline = {'mean_torsion': (x for x in DATA['mean_torsion']), + 'mean_colors': (x for x in DATA['mean_colors'])} + + # Creating LazyTractogram with generators is not allowed as + # generators get exhausted and are not reusable unlike generator function. + assert_raises(TypeError, LazyTractogram, streamlines) + assert_raises(TypeError, LazyTractogram, + data_per_streamline=data_per_streamline) + assert_raises(TypeError, LazyTractogram, DATA['streamlines'], + data_per_point=data_per_point) + + # Empty `LazyTractogram` + tractogram = LazyTractogram() + check_tractogram(tractogram) + assert_true(tractogram.affine_to_rasmm is None) + + # Create tractogram with streamlines and other data + tractogram = LazyTractogram(DATA['streamlines_func'], + DATA['data_per_streamline_func'], + DATA['data_per_point_func']) + + [t for t in tractogram] # Force iteration through tractogram. + assert_equal(len(tractogram), len(DATA['streamlines'])) + + # Generator functions get re-called and creates new iterators. + for i in range(2): + assert_tractogram_equal(tractogram, DATA['tractogram']) + + def test_lazy_tractogram_from_data_func(self): + # Create an empty `LazyTractogram` yielding nothing. + _empty_data_gen = lambda: iter([]) + + tractogram = LazyTractogram.from_data_func(_empty_data_gen) + check_tractogram(tractogram) + + # Create `LazyTractogram` from a generator function yielding TractogramItem. + data = [DATA['streamlines'], DATA['fa'], DATA['colors'], + DATA['mean_curvature'], DATA['mean_torsion'], + DATA['mean_colors']] + + def _data_gen(): + for d in zip(*data): + data_for_points = {'fa': d[1], + 'colors': d[2]} + data_for_streamline = {'mean_curvature': d[3], + 'mean_torsion': d[4], + 'mean_colors': d[5]} + yield TractogramItem(d[0], + data_for_streamline, + data_for_points) + + tractogram = LazyTractogram.from_data_func(_data_gen) + assert_tractogram_equal(tractogram, DATA['tractogram']) + + # Creating a LazyTractogram from not a corouting should raise an error. + assert_raises(TypeError, LazyTractogram.from_data_func, _data_gen()) + + def test_lazy_tractogram_getitem(self): + assert_raises(NotImplementedError, + DATA['lazy_tractogram'].__getitem__, 0) + + def test_lazy_tractogram_len(self): + modules = [module_tractogram] # Modules for which to catch warnings. + with clear_and_catch_warnings(record=True, modules=modules) as w: + warnings.simplefilter("always") # Always trigger warnings. + + # Calling `len` will create new generators each time. + tractogram = LazyTractogram(DATA['streamlines_func']) + assert_true(tractogram._nb_streamlines is None) + + # This should produce a warning message. + assert_equal(len(tractogram), len(DATA['streamlines'])) + assert_equal(tractogram._nb_streamlines, len(DATA['streamlines'])) + assert_equal(len(w), 1) + + tractogram = LazyTractogram(DATA['streamlines_func']) + + # New instances should still produce a warning message. + assert_equal(len(tractogram), len(DATA['streamlines'])) + assert_equal(len(w), 2) + assert_true(issubclass(w[-1].category, Warning)) + + # Calling again 'len' again should *not* produce a warning. + assert_equal(len(tractogram), len(DATA['streamlines'])) + assert_equal(len(w), 2) + + with clear_and_catch_warnings(record=True, modules=modules) as w: + # Once we iterated through the tractogram, we know the length. + + tractogram = LazyTractogram(DATA['streamlines_func']) + + assert_true(tractogram._nb_streamlines is None) + [t for t in tractogram] # Force iteration through tractogram. + assert_equal(tractogram._nb_streamlines, len(DATA['streamlines'])) + # This should *not* produce a warning. + assert_equal(len(tractogram), len(DATA['streamlines'])) + assert_equal(len(w), 0) + + def test_lazy_tractogram_apply_affine(self): + affine = np.eye(4) + scaling = np.array((1, 2, 3), dtype=float) + affine[range(3), range(3)] = scaling + + tractogram = DATA['lazy_tractogram'].copy() + + transformed_tractogram = tractogram.apply_affine(affine) + assert_true(transformed_tractogram is not tractogram) + assert_array_equal(tractogram._affine_to_apply, np.eye(4)) + assert_array_equal(tractogram.affine_to_rasmm, np.eye(4)) + assert_array_equal(transformed_tractogram._affine_to_apply, affine) + assert_array_equal(transformed_tractogram.affine_to_rasmm, + np.dot(np.eye(4), np.linalg.inv(affine))) + check_tractogram(transformed_tractogram, + streamlines=[s*scaling for s in DATA['streamlines']], + data_per_streamline=DATA['data_per_streamline'], + data_per_point=DATA['data_per_point']) + + # Apply affine again and check the affine_to_rasmm. + transformed_tractogram = transformed_tractogram.apply_affine(affine) + assert_array_equal(transformed_tractogram._affine_to_apply, + np.dot(affine, affine)) + assert_array_equal(transformed_tractogram.affine_to_rasmm, + np.dot(np.eye(4), np.dot(np.linalg.inv(affine), + np.linalg.inv(affine)))) + + # Calling to_world when affine_to_rasmm is None should fail. + tractogram = DATA['lazy_tractogram'].copy() + tractogram.affine_to_rasmm = None + assert_raises(ValueError, tractogram.to_world) + + def test_tractogram_to_world(self): + tractogram = DATA['lazy_tractogram'].copy() + affine = np.random.RandomState(1234).randn(4, 4) + affine[-1] = [0, 0, 0, 1] # Remove perspective projection. + + # Apply the affine to the streamlines, then bring them back + # to world space in a lazy manner. + transformed_tractogram = tractogram.apply_affine(affine) + assert_array_equal(transformed_tractogram.affine_to_rasmm, + np.linalg.inv(affine)) + + tractogram_world = transformed_tractogram.to_world() + assert_true(tractogram_world is not transformed_tractogram) + assert_array_almost_equal(tractogram_world.affine_to_rasmm, + np.eye(4)) + for s1, s2 in zip(tractogram_world.streamlines, DATA['streamlines']): + assert_array_almost_equal(s1, s2) + + # Calling to_world twice should do nothing. + tractogram_world = tractogram_world.to_world() + assert_array_almost_equal(tractogram_world.affine_to_rasmm, np.eye(4)) + for s1, s2 in zip(tractogram_world.streamlines, DATA['streamlines']): + assert_array_almost_equal(s1, s2) + + # Calling to_world when affine_to_rasmm is None should fail. + tractogram = DATA['lazy_tractogram'].copy() + tractogram.affine_to_rasmm = None + assert_raises(ValueError, tractogram.to_world) + + def test_lazy_tractogram_copy(self): + # Create a copy of the lazy tractogram. + tractogram = DATA['lazy_tractogram'].copy() + + # Check we copied the data and not simply created new references. + assert_true(tractogram is not DATA['lazy_tractogram']) + + # When copying LazyTractogram, the generator function yielding streamlines + # should stay the same. + assert_true(tractogram._streamlines + is DATA['lazy_tractogram']._streamlines) + + # Copying LazyTractogram, creates new internal LazyDict objects, + # but generator functions contained in it should stay the same. + assert_true(tractogram._data_per_streamline + is not DATA['lazy_tractogram']._data_per_streamline) + assert_true(tractogram._data_per_point + is not DATA['lazy_tractogram']._data_per_point) + + for key in tractogram.data_per_streamline: + assert_true(tractogram.data_per_streamline.store[key] + is DATA['lazy_tractogram'].data_per_streamline.store[key]) + + for key in tractogram.data_per_point: + assert_true(tractogram.data_per_point.store[key] + is DATA['lazy_tractogram'].data_per_point.store[key]) + + # The affine should be a copy. + assert_true(tractogram._affine_to_apply + is not DATA['lazy_tractogram']._affine_to_apply) + assert_array_equal(tractogram._affine_to_apply, + DATA['lazy_tractogram']._affine_to_apply) + + # Check the data are the equivalent. + assert_tractogram_equal(tractogram, DATA['tractogram']) diff --git a/nibabel/streamlines/tests/test_tractogram_file.py b/nibabel/streamlines/tests/test_tractogram_file.py new file mode 100644 index 0000000000..b2995a124a --- /dev/null +++ b/nibabel/streamlines/tests/test_tractogram_file.py @@ -0,0 +1,53 @@ +from nose.tools import assert_raises + +from ..tractogram import Tractogram +from ..tractogram_file import TractogramFile + + +def test_subclassing_tractogram_file(): + + # Missing 'save' method + class DummyTractogramFile(TractogramFile): + @classmethod + def is_correct_format(cls, fileobj): + return False + + @classmethod + def load(cls, fileobj, lazy_load=True): + return None + + assert_raises(TypeError, DummyTractogramFile, Tractogram()) + + # Missing 'load' method + class DummyTractogramFile(TractogramFile): + @classmethod + def is_correct_format(cls, fileobj): + return False + + def save(self, fileobj): + pass + + assert_raises(TypeError, DummyTractogramFile, Tractogram()) + + +def test_tractogram_file(): + assert_raises(NotImplementedError, TractogramFile.is_correct_format, "") + assert_raises(NotImplementedError, TractogramFile.load, "") + + # Testing calling the 'save' method of `TractogramFile` object. + class DummyTractogramFile(TractogramFile): + @classmethod + def is_correct_format(cls, fileobj): + return False + + @classmethod + def load(cls, fileobj, lazy_load=True): + return None + + @classmethod + def save(self, fileobj): + pass + + assert_raises(NotImplementedError, + super(DummyTractogramFile, + DummyTractogramFile(Tractogram)).save, "") diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py new file mode 100644 index 0000000000..f890021689 --- /dev/null +++ b/nibabel/streamlines/tests/test_trk.py @@ -0,0 +1,495 @@ +import os +import copy +import unittest +import numpy as np +from os.path import join as pjoin + +from nibabel.externals.six import BytesIO + +from nibabel.testing import data_path +from nibabel.testing import clear_and_catch_warnings, assert_arr_dict_equal +from nose.tools import assert_equal, assert_raises, assert_true +from numpy.testing import assert_array_equal + +from .test_tractogram import assert_tractogram_equal +from ..tractogram import Tractogram +from ..tractogram_file import HeaderError, HeaderWarning + +from .. import trk as trk_module +from ..trk import TrkFile, encode_value_in_name, decode_value_from_name +from ..header import Field + +DATA = {} + + +def setup(): + global DATA + + DATA['empty_trk_fname'] = pjoin(data_path, "empty.trk") + # simple.trk contains only streamlines + DATA['simple_trk_fname'] = pjoin(data_path, "simple.trk") + # standard.trk contains only streamlines + DATA['standard_trk_fname'] = pjoin(data_path, "standard.trk") + # standard.LPS.trk contains only streamlines + DATA['standard_LPS_trk_fname'] = pjoin(data_path, "standard.LPS.trk") + + # complex.trk contains streamlines, scalars and properties + DATA['complex_trk_fname'] = pjoin(data_path, "complex.trk") + DATA['complex_trk_big_endian_fname'] = pjoin(data_path, + "complex_big_endian.trk") + + DATA['streamlines'] = [np.arange(1*3, dtype="f4").reshape((1, 3)), + np.arange(2*3, dtype="f4").reshape((2, 3)), + np.arange(5*3, dtype="f4").reshape((5, 3))] + + DATA['fa'] = [np.array([[0.2]], dtype="f4"), + np.array([[0.3], + [0.4]], dtype="f4"), + np.array([[0.5], + [0.6], + [0.6], + [0.7], + [0.8]], dtype="f4")] + + DATA['colors'] = [np.array([(1, 0, 0)]*1, dtype="f4"), + np.array([(0, 1, 0)]*2, dtype="f4"), + np.array([(0, 0, 1)]*5, dtype="f4")] + + DATA['mean_curvature'] = [np.array([1.11], dtype="f4"), + np.array([2.11], dtype="f4"), + np.array([3.11], dtype="f4")] + + DATA['mean_torsion'] = [np.array([1.22], dtype="f4"), + np.array([2.22], dtype="f4"), + np.array([3.22], dtype="f4")] + + DATA['mean_colors'] = [np.array([1, 0, 0], dtype="f4"), + np.array([0, 1, 0], dtype="f4"), + np.array([0, 0, 1], dtype="f4")] + + DATA['data_per_point'] = {'colors': DATA['colors'], + 'fa': DATA['fa']} + DATA['data_per_streamline'] = {'mean_curvature': DATA['mean_curvature'], + 'mean_torsion': DATA['mean_torsion'], + 'mean_colors': DATA['mean_colors']} + + DATA['empty_tractogram'] = Tractogram(affine_to_rasmm=np.eye(4)) + DATA['simple_tractogram'] = Tractogram(DATA['streamlines'], + affine_to_rasmm=np.eye(4)) + DATA['complex_tractogram'] = Tractogram(DATA['streamlines'], + DATA['data_per_streamline'], + DATA['data_per_point'], + affine_to_rasmm=np.eye(4)) + + +class TestTRK(unittest.TestCase): + + def test_load_empty_file(self): + for lazy_load in [False, True]: + trk = TrkFile.load(DATA['empty_trk_fname'], lazy_load=lazy_load) + assert_tractogram_equal(trk.tractogram, DATA['empty_tractogram']) + + def test_load_simple_file(self): + for lazy_load in [False, True]: + trk = TrkFile.load(DATA['simple_trk_fname'], lazy_load=lazy_load) + assert_tractogram_equal(trk.tractogram, DATA['simple_tractogram']) + + def test_load_complex_file(self): + for lazy_load in [False, True]: + trk = TrkFile.load(DATA['complex_trk_fname'], lazy_load=lazy_load) + assert_tractogram_equal(trk.tractogram, DATA['complex_tractogram']) + + def test_load_file_with_wrong_information(self): + trk_file = open(DATA['simple_trk_fname'], 'rb').read() + + # Simulate a TRK file where `count` was not provided. + count = np.array(0, dtype="int32").tostring() + new_trk_file = trk_file[:1000-12] + count + trk_file[1000-8:] + trk = TrkFile.load(BytesIO(new_trk_file), lazy_load=False) + assert_tractogram_equal(trk.tractogram, DATA['simple_tractogram']) + + # Simulate a TRK where `vox_to_ras` is not recorded (i.e. all zeros). + vox_to_ras = np.zeros((4, 4), dtype=np.float32).tostring() + new_trk_file = trk_file[:440] + vox_to_ras + trk_file[440+64:] + with clear_and_catch_warnings(record=True, modules=[trk_module]) as w: + trk = TrkFile.load(BytesIO(new_trk_file)) + assert_equal(len(w), 1) + assert_true(issubclass(w[0].category, HeaderWarning)) + assert_true("identity" in str(w[0].message)) + assert_array_equal(trk.affine, np.eye(4)) + + # Simulate a TRK where `vox_to_ras` is invalid. + vox_to_ras = np.zeros((4, 4), dtype=np.float32) + vox_to_ras[3, 3] = 1 + vox_to_ras = vox_to_ras.tostring() + new_trk_file = trk_file[:440] + vox_to_ras + trk_file[440+64:] + with clear_and_catch_warnings(record=True, modules=[trk_module]) as w: + assert_raises(HeaderError, TrkFile.load, BytesIO(new_trk_file)) + + # Simulate a TRK file where `voxel_order` was not provided. + voxel_order = np.zeros(1, dtype="|S3").tostring() + new_trk_file = trk_file[:948] + voxel_order + trk_file[948+3:] + with clear_and_catch_warnings(record=True, modules=[trk_module]) as w: + TrkFile.load(BytesIO(new_trk_file)) + assert_equal(len(w), 1) + assert_true(issubclass(w[0].category, HeaderWarning)) + assert_true("LPS" in str(w[0].message)) + + # Simulate a TRK file with an unsupported version. + version = np.int32(123).tostring() + new_trk_file = trk_file[:992] + version + trk_file[992+4:] + assert_raises(HeaderError, TrkFile.load, BytesIO(new_trk_file)) + + # Simulate a TRK file with a wrong hdr_size. + hdr_size = np.int32(1234).tostring() + new_trk_file = trk_file[:996] + hdr_size + trk_file[996+4:] + assert_raises(HeaderError, TrkFile.load, BytesIO(new_trk_file)) + + # Simulate a TRK file with a wrong scalar_name. + trk_file = open(DATA['complex_trk_fname'], 'rb').read() + noise = np.int32(42).tostring() + new_trk_file = trk_file[:47] + noise + trk_file[47+4:] + assert_raises(HeaderError, TrkFile.load, BytesIO(new_trk_file)) + + # Simulate a TRK file with a wrong property_name. + noise = np.int32(42).tostring() + new_trk_file = trk_file[:254] + noise + trk_file[254+4:] + assert_raises(HeaderError, TrkFile.load, BytesIO(new_trk_file)) + + def test_load_complex_file_in_big_endian(self): + trk_file = open(DATA['complex_trk_big_endian_fname'], 'rb').read() + # We use hdr_size as an indicator of little vs big endian. + hdr_size_big_endian = np.array(1000, dtype=">i4").tostring() + assert_equal(trk_file[996:996+4], hdr_size_big_endian) + + for lazy_load in [False, True]: + trk = TrkFile.load(DATA['complex_trk_big_endian_fname'], + lazy_load=lazy_load) + assert_tractogram_equal(trk.tractogram, DATA['complex_tractogram']) + + def test_tractogram_file_properties(self): + trk = TrkFile.load(DATA['simple_trk_fname']) + assert_equal(trk.streamlines, trk.tractogram.streamlines) + assert_array_equal(trk.affine, trk.header[Field.VOXEL_TO_RASMM]) + + def test_write_empty_file(self): + tractogram = Tractogram(affine_to_rasmm=np.eye(4)) + + trk_file = BytesIO() + trk = TrkFile(tractogram) + trk.save(trk_file) + trk_file.seek(0, os.SEEK_SET) + + new_trk = TrkFile.load(trk_file) + assert_tractogram_equal(new_trk.tractogram, tractogram) + + new_trk_orig = TrkFile.load(DATA['empty_trk_fname']) + assert_tractogram_equal(new_trk.tractogram, new_trk_orig.tractogram) + + trk_file.seek(0, os.SEEK_SET) + assert_equal(trk_file.read(), + open(DATA['empty_trk_fname'], 'rb').read()) + + def test_write_simple_file(self): + tractogram = Tractogram(DATA['streamlines'], + affine_to_rasmm=np.eye(4)) + + trk_file = BytesIO() + trk = TrkFile(tractogram) + trk.save(trk_file) + trk_file.seek(0, os.SEEK_SET) + + new_trk = TrkFile.load(trk_file) + assert_tractogram_equal(new_trk.tractogram, tractogram) + + new_trk_orig = TrkFile.load(DATA['simple_trk_fname']) + assert_tractogram_equal(new_trk.tractogram, new_trk_orig.tractogram) + + trk_file.seek(0, os.SEEK_SET) + assert_equal(trk_file.read(), + open(DATA['simple_trk_fname'], 'rb').read()) + + def test_write_complex_file(self): + # With scalars + tractogram = Tractogram(DATA['streamlines'], + data_per_point=DATA['data_per_point'], + affine_to_rasmm=np.eye(4)) + + trk_file = BytesIO() + trk = TrkFile(tractogram) + trk.save(trk_file) + trk_file.seek(0, os.SEEK_SET) + + new_trk = TrkFile.load(trk_file, lazy_load=False) + assert_tractogram_equal(new_trk.tractogram, tractogram) + + # With properties + data_per_streamline = DATA['data_per_streamline'] + tractogram = Tractogram(DATA['streamlines'], + data_per_streamline=data_per_streamline, + affine_to_rasmm=np.eye(4)) + + trk = TrkFile(tractogram) + trk_file = BytesIO() + trk.save(trk_file) + trk_file.seek(0, os.SEEK_SET) + + new_trk = TrkFile.load(trk_file, lazy_load=False) + assert_tractogram_equal(new_trk.tractogram, tractogram) + + # With scalars and properties + data_per_streamline = DATA['data_per_streamline'] + tractogram = Tractogram(DATA['streamlines'], + data_per_point=DATA['data_per_point'], + data_per_streamline=data_per_streamline, + affine_to_rasmm=np.eye(4)) + + trk_file = BytesIO() + trk = TrkFile(tractogram) + trk.save(trk_file) + trk_file.seek(0, os.SEEK_SET) + + new_trk = TrkFile.load(trk_file, lazy_load=False) + assert_tractogram_equal(new_trk.tractogram, tractogram) + + new_trk_orig = TrkFile.load(DATA['complex_trk_fname']) + assert_tractogram_equal(new_trk.tractogram, new_trk_orig.tractogram) + + trk_file.seek(0, os.SEEK_SET) + assert_equal(trk_file.read(), + open(DATA['complex_trk_fname'], 'rb').read()) + + def test_load_write_file(self): + for fname in [DATA['empty_trk_fname'], + DATA['simple_trk_fname'], + DATA['complex_trk_fname']]: + for lazy_load in [False, True]: + trk = TrkFile.load(fname, lazy_load=lazy_load) + trk_file = BytesIO() + trk.save(trk_file) + + new_trk = TrkFile.load(fname, lazy_load=False) + assert_tractogram_equal(new_trk.tractogram, trk.tractogram) + + trk_file.seek(0, os.SEEK_SET) + + def test_load_write_LPS_file(self): + # Load the RAS and LPS version of the standard. + trk_RAS = TrkFile.load(DATA['standard_trk_fname'], lazy_load=False) + trk_LPS = TrkFile.load(DATA['standard_LPS_trk_fname'], lazy_load=False) + assert_tractogram_equal(trk_LPS.tractogram, trk_RAS.tractogram) + + # Write back the standard. + trk_file = BytesIO() + trk = TrkFile(trk_LPS.tractogram, trk_LPS.header) + trk.save(trk_file) + trk_file.seek(0, os.SEEK_SET) + + new_trk = TrkFile.load(trk_file) + + assert_arr_dict_equal(new_trk.header, trk.header) + assert_tractogram_equal(new_trk.tractogram, trk.tractogram) + + new_trk_orig = TrkFile.load(DATA['standard_LPS_trk_fname']) + assert_tractogram_equal(new_trk.tractogram, new_trk_orig.tractogram) + + trk_file.seek(0, os.SEEK_SET) + assert_equal(trk_file.read(), + open(DATA['standard_LPS_trk_fname'], 'rb').read()) + + # Test writing a file where the header is missing the + # Field.VOXEL_ORDER. + trk_file = BytesIO() + + # For TRK file format, the default voxel order is LPS. + header = copy.deepcopy(trk_LPS.header) + header[Field.VOXEL_ORDER] = b"" + + trk = TrkFile(trk_LPS.tractogram, header) + trk.save(trk_file) + trk_file.seek(0, os.SEEK_SET) + + new_trk = TrkFile.load(trk_file) + + assert_arr_dict_equal(new_trk.header, trk_LPS.header) + assert_tractogram_equal(new_trk.tractogram, trk.tractogram) + + new_trk_orig = TrkFile.load(DATA['standard_LPS_trk_fname']) + assert_tractogram_equal(new_trk.tractogram, new_trk_orig.tractogram) + + trk_file.seek(0, os.SEEK_SET) + assert_equal(trk_file.read(), + open(DATA['standard_LPS_trk_fname'], 'rb').read()) + + def test_write_optional_header_fields(self): + # The TRK file format doesn't support additional header fields. + # If provided, they will be ignored. + tractogram = Tractogram(affine_to_rasmm=np.eye(4)) + + trk_file = BytesIO() + header = {'extra': 1234} + trk = TrkFile(tractogram, header) + trk.save(trk_file) + trk_file.seek(0, os.SEEK_SET) + + new_trk = TrkFile.load(trk_file) + assert_true("extra" not in new_trk.header) + + def test_write_too_many_scalars_and_properties(self): + # TRK supports up to 10 data_per_point. + data_per_point = {} + for i in range(10): + data_per_point['#{0}'.format(i)] = DATA['fa'] + + tractogram = Tractogram(DATA['streamlines'], + data_per_point=data_per_point, + affine_to_rasmm=np.eye(4)) + + trk_file = BytesIO() + trk = TrkFile(tractogram) + trk.save(trk_file) + trk_file.seek(0, os.SEEK_SET) + + new_trk = TrkFile.load(trk_file, lazy_load=False) + assert_tractogram_equal(new_trk.tractogram, tractogram) + + # More than 10 data_per_point should raise an error. + data_per_point['#{0}'.format(i+1)] = DATA['fa'] + + tractogram = Tractogram(DATA['streamlines'], + data_per_point=data_per_point, + affine_to_rasmm=np.eye(4)) + + trk = TrkFile(tractogram) + assert_raises(ValueError, trk.save, BytesIO()) + + # TRK supports up to 10 data_per_streamline. + data_per_streamline = {} + for i in range(10): + data_per_streamline['#{0}'.format(i)] = DATA['mean_torsion'] + + tractogram = Tractogram(DATA['streamlines'], + data_per_streamline=data_per_streamline, + affine_to_rasmm=np.eye(4)) + + trk_file = BytesIO() + trk = TrkFile(tractogram) + trk.save(trk_file) + trk_file.seek(0, os.SEEK_SET) + + new_trk = TrkFile.load(trk_file, lazy_load=False) + assert_tractogram_equal(new_trk.tractogram, tractogram) + + # More than 10 data_per_streamline should raise an error. + data_per_streamline['#{0}'.format(i+1)] = DATA['mean_torsion'] + + tractogram = Tractogram(DATA['streamlines'], + data_per_streamline=data_per_streamline) + + trk = TrkFile(tractogram) + assert_raises(ValueError, trk.save, BytesIO()) + + def test_write_scalars_and_properties_name_too_long(self): + # TRK supports data_per_point name up to 20 characters. + # However, we reserve the last two characters to store + # the number of values associated to each data_per_point. + # So in reality we allow name of 18 characters, otherwise + # the name is truncated and warning is issue. + for nb_chars in range(22): + data_per_point = {'A'*nb_chars: DATA['colors']} + tractogram = Tractogram(DATA['streamlines'], + data_per_point=data_per_point, + affine_to_rasmm=np.eye(4)) + + trk = TrkFile(tractogram) + if nb_chars > 18: + assert_raises(ValueError, trk.save, BytesIO()) + else: + trk.save(BytesIO()) + + data_per_point = {'A'*nb_chars: DATA['fa']} + tractogram = Tractogram(DATA['streamlines'], + data_per_point=data_per_point, + affine_to_rasmm=np.eye(4)) + + trk = TrkFile(tractogram) + if nb_chars > 20: + assert_raises(ValueError, trk.save, BytesIO()) + else: + trk.save(BytesIO()) + + # TRK supports data_per_streamline name up to 20 characters. + # However, we reserve the last two characters to store + # the number of values associated to each data_per_streamline. + # So in reality we allow name of 18 characters, otherwise + # the name is truncated and warning is issue. + for nb_chars in range(22): + data_per_streamline = {'A'*nb_chars: DATA['mean_colors']} + tractogram = Tractogram(DATA['streamlines'], + data_per_streamline=data_per_streamline, + affine_to_rasmm=np.eye(4)) + + trk = TrkFile(tractogram) + if nb_chars > 18: + assert_raises(ValueError, trk.save, BytesIO()) + else: + trk.save(BytesIO()) + + data_per_streamline = {'A'*nb_chars: DATA['mean_torsion']} + tractogram = Tractogram(DATA['streamlines'], + data_per_streamline=data_per_streamline, + affine_to_rasmm=np.eye(4)) + + trk = TrkFile(tractogram) + if nb_chars > 20: + assert_raises(ValueError, trk.save, BytesIO()) + else: + trk.save(BytesIO()) + + def test_str(self): + trk = TrkFile.load(DATA['complex_trk_fname']) + str(trk) # Simply test it's not failing when called. + + def test_header_read_restore(self): + # Test that reading a header restores the file position + trk_fname = DATA['simple_trk_fname'] + bio = BytesIO() + bio.write(b'Along my very merry way') + hdr_pos = bio.tell() + hdr_from_fname = TrkFile._read_header(trk_fname) + with open(trk_fname, 'rb') as fobj: + bio.write(fobj.read()) + bio.seek(hdr_pos) + # Check header is as expected + hdr_from_fname['_offset_data'] += hdr_pos # Correct for start position + assert_arr_dict_equal(TrkFile._read_header(bio), hdr_from_fname) + # Check fileobject file position has not changed + assert_equal(bio.tell(), hdr_pos) + + +def test_encode_names(): + # Test function for encoding numbers into property names + b0 = b'\x00' + assert_equal(encode_value_in_name(0, 'foo', 10), + b'foo' + b0 * 7) + assert_equal(encode_value_in_name(1, 'foo', 10), + b'foo' + b0 * 7) + assert_equal(encode_value_in_name(8, 'foo', 10), + b'foo' + b0 + b'8' + b0 * 5) + assert_equal(encode_value_in_name(40, 'foobar', 10), + b'foobar' + b0 + b'40' + b0) + assert_equal(encode_value_in_name(1, 'foobarbazz', 10), b'foobarbazz') + assert_raises(ValueError, encode_value_in_name, 1, 'foobarbazzz', 10) + assert_raises(ValueError, encode_value_in_name, 2, 'foobarbaz', 10) + assert_equal(encode_value_in_name(2, 'foobarba', 10), b'foobarba\x002') + + +def test_decode_names(): + # Test function for decoding name string into name, number + b0 = b'\x00' + assert_equal(decode_value_from_name(b''), ('', 0)) + assert_equal(decode_value_from_name(b'foo' + b0 * 7), ('foo', 1)) + assert_equal(decode_value_from_name(b'foo\x008' + b0 * 5), ('foo', 8)) + assert_equal(decode_value_from_name(b'foobar\x0010\x00'), ('foobar', 10)) + assert_raises(ValueError, decode_value_from_name, b'foobar\x0010\x01') + assert_raises(HeaderError, decode_value_from_name, b'foo\x0010\x00111') diff --git a/nibabel/streamlines/tests/test_utils.py b/nibabel/streamlines/tests/test_utils.py new file mode 100644 index 0000000000..939ee9bb9e --- /dev/null +++ b/nibabel/streamlines/tests/test_utils.py @@ -0,0 +1,26 @@ +import os +import numpy as np +import nibabel as nib + +from nibabel.testing import data_path +from numpy.testing import assert_array_equal +from nose.tools import assert_raises + +from ..utils import get_affine_from_reference + + +def test_get_affine_from_reference(): + filename = os.path.join(data_path, 'example_nifti2.nii.gz') + img = nib.load(filename) + affine = img.affine + + # Get affine from an numpy array. + assert_array_equal(get_affine_from_reference(affine), affine) + wrong_ref = np.array([[1, 2, 3], [4, 5, 6]]) + assert_raises(ValueError, get_affine_from_reference, wrong_ref) + + # Get affine from a `SpatialImage`. + assert_array_equal(get_affine_from_reference(img), affine) + + # Get affine from a `SpatialImage` using by its filename. + assert_array_equal(get_affine_from_reference(filename), affine) diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py new file mode 100644 index 0000000000..c33f707d1c --- /dev/null +++ b/nibabel/streamlines/tractogram.py @@ -0,0 +1,745 @@ +import copy +import numbers +import numpy as np +import collections +from warnings import warn + +from nibabel.affines import apply_affine + +from .array_sequence import ArraySequence + + +def is_data_dict(obj): + """ True if `obj` seems to implement the :class:`DataDict` API """ + return hasattr(obj, 'store') + + +def is_lazy_dict(obj): + """ True if `obj` seems to implement the :class:`LazyDict` API """ + return is_data_dict(obj) and callable(obj.store.values()[0]) + + +class SliceableDataDict(collections.MutableMapping): + """ Dictionary for which key access can do slicing on the values. + + This container behaves like a standard dictionary but extends key access to + allow keys for key access to be indices slicing into the contained ndarray + values. + + Parameters + ---------- + \*args : + \*\*kwargs : + Positional and keyword arguments, passed straight through the ``dict`` + constructor. + """ + def __init__(self, *args, **kwargs): + self.store = dict() + self.update(dict(*args, **kwargs)) + + def __getitem__(self, key): + try: + return self.store[key] + except (KeyError, TypeError, IndexError): + pass # Maybe it is an integer or a slicing object + + # Try to interpret key as an index/slice for every data element, in + # which case we perform (maybe advanced) indexing on every element of + # the dictionary. + idx = key + new_dict = type(self)() + try: + for k, v in self.items(): + new_dict[k] = v[idx] + except (TypeError, ValueError, IndexError): + pass + else: + return new_dict + + # Key was not a valid index/slice after all. + return self.store[key] # Will raise the proper error. + + def __delitem__(self, key): + del self.store[key] + + def __iter__(self): + return iter(self.store) + + def __len__(self): + return len(self.store) + + +class PerArrayDict(SliceableDataDict): + """ Dictionary for which key access can do slicing on the values. + + This container behaves like a standard dictionary but extends key access to + allow keys for key access to be indices slicing into the contained ndarray + values. The elements must also be ndarrays. + + In addition, it makes sure the amount of data contained in those ndarrays + matches the number of streamlines given at the instantiation of this + instance. + + Parameters + ---------- + n_rows : None or int, optional + Number of rows per value in each key, value pair or None for not + specified. + \*args : + \*\*kwargs : + Positional and keyword arguments, passed straight through the ``dict`` + constructor. + """ + def __init__(self, n_rows=None, *args, **kwargs): + self.n_rows = n_rows + super(PerArrayDict, self).__init__(*args, **kwargs) + + def __setitem__(self, key, value): + value = np.asarray(list(value)) + + if value.ndim == 1 and value.dtype != object: + # Reshape without copy + value.shape = ((len(value), 1)) + + if value.ndim != 2: + raise ValueError("data_per_streamline must be a 2D array.") + + # We make sure there is the right amount of values + if self.n_rows is not None and len(value) != self.n_rows: + msg = ("The number of values ({0}) should match n_elements " + "({1}).").format(len(value), self.n_rows) + raise ValueError(msg) + + self.store[key] = value + + +class PerArraySequenceDict(PerArrayDict): + """ Dictionary for which key access can do slicing on the values. + + This container behaves like a standard dictionary but extends key access to + allow keys for key access to be indices slicing into the contained ndarray + values. The elements must also be :class:`ArraySequence`. + + In addition, it makes sure the amount of data contained in those array + sequences matches the number of elements given at the instantiation + of the instance. + """ + def __setitem__(self, key, value): + value = ArraySequence(value) + + # We make sure there is the right amount of data. + if (self.n_rows is not None and + value.total_nb_rows != self.n_rows): + msg = ("The number of values ({0}) should match " + "({1}).").format(value.total_nb_rows, self.n_rows) + raise ValueError(msg) + + self.store[key] = value + + +class LazyDict(collections.MutableMapping): + """ Dictionary of generator functions. + + This container behaves like a dictionary but it makes sure its elements are + callable objects that it assumes are generator functions yielding values. + When getting the element associated with a given key, the element (i.e. a + generator function) is first called before being returned. + """ + def __init__(self, *args, **kwargs): + self.store = dict() + # Use the 'update' method to set the keys. + if len(args) == 1: + if args[0] is None: + return + + if isinstance(args[0], LazyDict): + self.update(**args[0].store) # Copy the generator functions. + return + + if isinstance(args[0], SliceableDataDict): + self.update(**args[0]) + + self.update(dict(*args, **kwargs)) + + def __getitem__(self, key): + return self.store[key]() + + def __setitem__(self, key, value): + if value is not None and not callable(value): # TODO: why None? + raise TypeError("`value` must be a generator function or None.") + self.store[key] = value + + def __delitem__(self, key): + del self.store[key] + + def __iter__(self): + return iter(self.store) + + def __len__(self): + return len(self.store) + + +class TractogramItem(object): + """ Class containing information about one streamline. + + :class:`TractogramItem` objects have three public attributes: `streamline`, + `data_for_streamline`, and `data_for_points`. + + Parameters + ---------- + streamline : ndarray shape (N, 3) + Points of this streamline represented as an ndarray of shape (N, 3) + where N is the number of points. + data_for_streamline : dict + Dictionary containing some data associated with this particular + streamline. Each key ``k`` is mapped to a ndarray of shape (Pt,), where + ``Pt`` is the dimension of the data associated with key ``k``. + data_for_points : dict + Dictionary containing some data associated to each point of this + particular streamline. Each key ``k`` is mapped to a ndarray of shape + (Nt, Mk), where ``Nt`` is the number of points of this streamline and + ``Mk`` is the dimension of the data associated with key ``k``. + """ + def __init__(self, streamline, data_for_streamline, data_for_points): + self.streamline = np.asarray(streamline) + self.data_for_streamline = data_for_streamline + self.data_for_points = data_for_points + + def __iter__(self): + return iter(self.streamline) + + def __len__(self): + return len(self.streamline) + + +class Tractogram(object): + """ Container for streamlines and their data information. + + Streamlines of a tractogram can be in any coordinate system of your + choice as long as you provide the correct `affine_to_rasmm` matrix, at + construction time, that brings the streamlines back to *RAS+*, *mm* space, + where the coordinates (0,0,0) corresponds to the center of the voxel + (as opposed to the corner of the voxel). + + Attributes + ---------- + streamlines : :class:`ArraySequence` object + Sequence of $T$ streamlines. Each streamline is an ndarray of + shape ($N_t$, 3) where $N_t$ is the number of points of + streamline $t$. + data_per_streamline : :class:`PerArrayDict` object + Dictionary where the items are (str, 2D array). Each key represents a + piece of information $i$ to be kept alongside every streamline, and its + associated value is a 2D array of shape ($T$, $P_i$) where $T$ is the + number of streamlines and $P_i$ is the number of values to store for + that particular piece of information $i$. + data_per_point : :class:`PerArraySequenceDict` object + Dictionary where the items are (str, :class:`ArraySequence`). Each key + represents a piece of information $i$ to be kept alongside every point + of every streamline, and its associated value is an iterable of + ndarrays of shape ($N_t$, $M_i$) where $N_t$ is the number of points + for a particular streamline $t$ and $M_i$ is the number values to store + for that particular piece of information $i$. + """ + def __init__(self, streamlines=None, + data_per_streamline=None, + data_per_point=None, + affine_to_rasmm=None): + """ + Parameters + ---------- + streamlines : iterable of ndarrays or :class:`ArraySequence`, optional + Sequence of $T$ streamlines. Each streamline is an ndarray of + shape ($N_t$, 3) where $N_t$ is the number of points of + streamline $t$. + data_per_streamline : dict of iterable of ndarrays, optional + Dictionary where the items are (str, iterable). + Each key represents an information $i$ to be kept alongside every + streamline, and its associated value is an iterable of ndarrays of + shape ($P_i$,) where $P_i$ is the number of scalar values to store + for that particular information $i$. + data_per_point : dict of iterable of ndarrays, optional + Dictionary where the items are (str, iterable). + Each key represents an information $i$ to be kept alongside every + point of every streamline, and its associated value is an iterable + of ndarrays of shape ($N_t$, $M_i$) where $N_t$ is the number of + points for a particular streamline $t$ and $M_i$ is the number + scalar values to store for that particular information $i$. + affine_to_rasmm : ndarray of shape (4, 4) or None, optional + Transformation matrix that brings the streamlines contained in + this tractogram to *RAS+* and *mm* space where coordinate (0,0,0) + refers to the center of the voxel. By default, the streamlines + are in an unknown space, i.e. affine_to_rasmm is None. + """ + self._set_streamlines(streamlines) + self.data_per_streamline = data_per_streamline + self.data_per_point = data_per_point + self.affine_to_rasmm = affine_to_rasmm + + @property + def streamlines(self): + return self._streamlines + + def _set_streamlines(self, value): + self._streamlines = ArraySequence(value) + + @property + def data_per_streamline(self): + return self._data_per_streamline + + @data_per_streamline.setter + def data_per_streamline(self, value): + self._data_per_streamline = PerArrayDict( + len(self.streamlines), {} if value is None else value) + + @property + def data_per_point(self): + return self._data_per_point + + @data_per_point.setter + def data_per_point(self, value): + self._data_per_point = PerArraySequenceDict( + self.streamlines.total_nb_rows, {} if value is None else value) + + @property + def affine_to_rasmm(self): + """ Affine bringing streamlines in this tractogram to RAS+mm. """ + return copy.deepcopy(self._affine_to_rasmm) + + @affine_to_rasmm.setter + def affine_to_rasmm(self, value): + if value is not None: + value = np.array(value) + if value.shape != (4, 4): + msg = ("Affine matrix has a shape of (4, 4) but a ndarray with" + "shape {} was provided instead.").format(value.shape) + raise ValueError(msg) + + self._affine_to_rasmm = value + + def __iter__(self): + for i in range(len(self.streamlines)): + yield self[i] + + def __getitem__(self, idx): + pts = self.streamlines[idx] + + data_per_streamline = {} + for key in self.data_per_streamline: + data_per_streamline[key] = self.data_per_streamline[key][idx] + + data_per_point = {} + for key in self.data_per_point: + data_per_point[key] = self.data_per_point[key][idx] + + if isinstance(idx, (numbers.Integral, np.integer)): + return TractogramItem(pts, data_per_streamline, data_per_point) + + return Tractogram(pts, data_per_streamline, data_per_point) + + def __len__(self): + return len(self.streamlines) + + def copy(self): + """ Returns a copy of this :class:`Tractogram` object. """ + return copy.deepcopy(self) + + def apply_affine(self, affine, lazy=False): + """ Applies an affine transformation on the points of each streamline. + + If `lazy` is not specified, this is performed *in-place*. + + Parameters + ---------- + affine : ndarray of shape (4, 4) + Transformation that will be applied to every streamline. + lazy : {False, True}, optional + If True, streamlines are *not* transformed in-place and a + :class:`LazyTractogram` object is returned. Otherwise, streamlines + are modified in-place. + + Returns + ------- + tractogram : :class:`Tractogram` or :class:`LazyTractogram` object + Tractogram where the streamlines have been transformed according + to the given affine transformation. If the `lazy` option is true, + it returns a :class:`LazyTractogram` object, otherwise it returns a + reference to this :class:`Tractogram` object with updated + streamlines. + """ + if lazy: + lazy_tractogram = LazyTractogram.from_tractogram(self) + return lazy_tractogram.apply_affine(affine) + + if len(self.streamlines) == 0: + return self + + if np.all(affine == np.eye(4)): + return self # No transformation. + + BUFFER_SIZE = 10000000 # About 128 Mb since pts shape is 3. + for start in range(0, len(self.streamlines.data), BUFFER_SIZE): + end = start + BUFFER_SIZE + pts = self.streamlines._data[start:end] + self.streamlines.data[start:end] = apply_affine(affine, pts) + + if self.affine_to_rasmm is not None: + # Update the affine that brings back the streamlines to RASmm. + self.affine_to_rasmm = np.dot(self.affine_to_rasmm, + np.linalg.inv(affine)) + + return self + + def to_world(self, lazy=False): + """ Brings the streamlines to world space (i.e. RAS+ and mm). + + If `lazy` is not specified, this is performed *in-place*. + + Parameters + ---------- + lazy : {False, True}, optional + If True, streamlines are *not* transformed in-place and a + :class:`LazyTractogram` object is returned. Otherwise, streamlines + are modified in-place. + + Returns + ------- + tractogram : :class:`Tractogram` or :class:`LazyTractogram` object + Tractogram where the streamlines have been sent to world space. + If the `lazy` option is true, it returns a :class:`LazyTractogram` + object, otherwise it returns a reference to this + :class:`Tractogram` object with updated streamlines. + """ + if self.affine_to_rasmm is None: + msg = ("Streamlines are in a unknown space. This error can be" + " avoided by setting the 'affine_to_rasmm' property.") + raise ValueError(msg) + + return self.apply_affine(self.affine_to_rasmm, lazy=lazy) + + +class LazyTractogram(Tractogram): + """ Lazy container for streamlines and their data information. + + This container behaves lazily as it uses generator functions to manage + streamlines and their data information. This container is thus memory + friendly since it doesn't require having all this data loaded in memory. + + Streamlines of a lazy tractogram can be in any coordinate system of your + choice as long as you provide the correct `affine_to_rasmm` matrix, at + construction time, that brings the streamlines back to *RAS+*, *mm* space, + where the coordinates (0,0,0) corresponds to the center of the voxel + (as opposed to the corner of the voxel). + + Attributes + ---------- + streamlines : generator function + Generator function yielding streamlines. Each streamline is an + ndarray of shape ($N_t$, 3) where $N_t$ is the number of points of + streamline $t$. + data_per_streamline : instance of :class:`LazyDict` + Dictionary where the items are (str, instantiated generator). + Each key represents a piece of information $i$ to be kept alongside + every streamline, and its associated value is a generator function + yielding that information via ndarrays of shape ($P_i$,) where $P_i$ is + the number of values to store for that particular piece of information + $i$. + data_per_point : :class:`LazyDict` object + Dictionary where the items are (str, instantiated generator). Each key + represents a piece of information $i$ to be kept alongside every point + of every streamline, and its associated value is a generator function + yielding that information via ndarrays of shape ($N_t$, $M_i$) where + $N_t$ is the number of points for a particular streamline $t$ and $M_i$ + is the number of values to store for that particular piece of + information $i$. + + Notes + ----- + LazyTractogram objects do not support indexing currently. + LazyTractogram objects are suited for operations that can be linearized + such as applying an affine transformation or converting streamlines from + one file format to another. + """ + def __init__(self, streamlines=None, + data_per_streamline=None, + data_per_point=None, + affine_to_rasmm=None): + """ + Parameters + ---------- + streamlines : generator function, optional + Generator function yielding streamlines. Each streamline is an + ndarray of shape ($N_t$, 3) where $N_t$ is the number of points of + streamline $t$. + data_per_streamline : dict of generator functions, optional + Dictionary where the items are (str, generator function). + Each key represents an information $i$ to be kept alongside every + streamline, and its associated value is a generator function + yielding that information via ndarrays of shape ($P_i$,) where + $P_i$ is the number of values to store for that particular + information $i$. + data_per_point : dict of generator functions, optional + Dictionary where the items are (str, generator function). + Each key represents an information $i$ to be kept alongside every + point of every streamline, and its associated value is a generator + function yielding that information via ndarrays of shape + ($N_t$, $M_i$) where $N_t$ is the number of points for a particular + streamline $t$ and $M_i$ is the number of values to store for + that particular information $i$. + affine_to_rasmm : ndarray of shape (4, 4) or None, optional + Transformation matrix that brings the streamlines contained in + this tractogram to *RAS+* and *mm* space where coordinate (0,0,0) + refers to the center of the voxel. By default, the streamlines + are in an unknown space, i.e. affine_to_rasmm is None. + """ + super(LazyTractogram, self).__init__(streamlines, + data_per_streamline, + data_per_point, + affine_to_rasmm) + self._nb_streamlines = None + self._data = None + self._affine_to_apply = np.eye(4) + + @classmethod + def from_tractogram(cls, tractogram): + """ Creates a :class:`LazyTractogram` object from a :class:`Tractogram` object. + + Parameters + ---------- + tractogram : :class:`Tractgogram` object + Tractogram from which to create a :class:`LazyTractogram` object. + + Returns + ------- + lazy_tractogram : :class:`LazyTractogram` object + New lazy tractogram. + """ + lazy_tractogram = cls(lambda: tractogram.streamlines.copy()) + + # Set data_per_streamline using data_func + def _gen(key): + return lambda: iter(tractogram.data_per_streamline[key]) + + for k in tractogram.data_per_streamline: + lazy_tractogram._data_per_streamline[k] = _gen(k) + + # Set data_per_point using data_func + def _gen(key): + return lambda: iter(tractogram.data_per_point[key]) + + for k in tractogram.data_per_point: + lazy_tractogram._data_per_point[k] = _gen(k) + + lazy_tractogram._nb_streamlines = len(tractogram) + lazy_tractogram.affine_to_rasmm = tractogram.affine_to_rasmm + return lazy_tractogram + + @classmethod + def from_data_func(cls, data_func): + """ Creates an instance from a generator function. + + The generator function must yield :class:`TractogramItem` objects. + + Parameters + ---------- + data_func : generator function yielding :class:`TractogramItem` objects + Generator function that whenever is called starts yielding + :class:`TractogramItem` objects that will be used to instantiate a + :class:`LazyTractogram`. + + Returns + ------- + lazy_tractogram : :class:`LazyTractogram` object + New lazy tractogram. + """ + if not callable(data_func): + raise TypeError("`data_func` must be a generator function.") + + lazy_tractogram = cls() + lazy_tractogram._data = data_func + + try: + first_item = next(data_func()) + + # Set data_per_streamline using data_func + def _gen(key): + return lambda: (t.data_for_streamline[key] + for t in data_func()) + + data_per_streamline_keys = first_item.data_for_streamline.keys() + for k in data_per_streamline_keys: + lazy_tractogram._data_per_streamline[k] = _gen(k) + + # Set data_per_point using data_func + def _gen(key): + return lambda: (t.data_for_points[key] for t in data_func()) + + data_per_point_keys = first_item.data_for_points.keys() + for k in data_per_point_keys: + lazy_tractogram._data_per_point[k] = _gen(k) + + except StopIteration: + pass + + return lazy_tractogram + + @property + def streamlines(self): + streamlines_gen = iter([]) + if self._streamlines is not None: + streamlines_gen = self._streamlines() + elif self._data is not None: + streamlines_gen = (t.streamline for t in self._data()) + + # Check if we need to apply an affine. + if not np.allclose(self._affine_to_apply, np.eye(4)): + def _apply_affine(): + for s in streamlines_gen: + yield apply_affine(self._affine_to_apply, s) + + return _apply_affine() + + return streamlines_gen + + def _set_streamlines(self, value): + if value is not None and not callable(value): + raise TypeError("`streamlines` must be a generator function.") + self._streamlines = value + + @property + def data_per_streamline(self): + return self._data_per_streamline + + @data_per_streamline.setter + def data_per_streamline(self, value): + self._data_per_streamline = LazyDict(value) + + @property + def data_per_point(self): + return self._data_per_point + + @data_per_point.setter + def data_per_point(self, value): + self._data_per_point = LazyDict(value) + + @property + def data(self): + if self._data is not None: + return self._data() + + def _gen_data(): + data_per_streamline_generators = {} + for k, v in self.data_per_streamline.items(): + data_per_streamline_generators[k] = iter(v) + + data_per_point_generators = {} + for k, v in self.data_per_point.items(): + data_per_point_generators[k] = iter(v) + + for s in self.streamlines: + data_for_streamline = {} + for k, v in data_per_streamline_generators.items(): + data_for_streamline[k] = next(v) + + data_for_points = {} + for k, v in data_per_point_generators.items(): + data_for_points[k] = next(v) + + yield TractogramItem(s, data_for_streamline, data_for_points) + + return _gen_data() + + def __getitem__(self, idx): + raise NotImplementedError('LazyTractogram does not support indexing.') + + def __iter__(self): + count = 0 + for tractogram_item in self.data: + yield tractogram_item + count += 1 + + # Keep how many streamlines there are in this tractogram. + self._nb_streamlines = count + + def __len__(self): + # Check if we know how many streamlines there are. + if self._nb_streamlines is None: + warn("Number of streamlines will be determined manually by looping" + " through the streamlines. If you know the actual number of" + " streamlines, you might want to set it beforehand via" + " `self.header.nb_streamlines`.", Warning) + # Count the number of streamlines. + self._nb_streamlines = sum(1 for _ in self.streamlines) + + return self._nb_streamlines + + def copy(self): + """ Returns a copy of this :class:`LazyTractogram` object. """ + tractogram = LazyTractogram(self._streamlines, + self._data_per_streamline, + self._data_per_point, + self.affine_to_rasmm) + tractogram._nb_streamlines = self._nb_streamlines + tractogram._data = self._data + tractogram._affine_to_apply = self._affine_to_apply.copy() + return tractogram + + def apply_affine(self, affine, lazy=True): + """ Applies an affine transformation to the streamlines. + + The transformation given by the `affine` matrix is applied after any + other pending transformations to the streamline points. + + Parameters + ---------- + affine : 2D array (4,4) + Transformation matrix that will be applied on each streamline. + lazy : True, optional + Should always be True for :class:`LazyTractogram` object. Doing + otherwise will raise a ValueError. + + Returns + ------- + lazy_tractogram : :class:`LazyTractogram` object + A copy of this :class:`LazyTractogram` instance but with a + transformation to be applied on the streamlines. + """ + if not lazy: + msg = "LazyTractogram only supports lazy transformations." + raise ValueError(msg) + + tractogram = self.copy() # New instance. + + # Update the affine that will be applied when returning streamlines. + tractogram._affine_to_apply = np.dot(affine, self._affine_to_apply) + + if tractogram.affine_to_rasmm is not None: + # Update the affine that brings back the streamlines to RASmm. + tractogram.affine_to_rasmm = np.dot(self.affine_to_rasmm, + np.linalg.inv(affine)) + return tractogram + + def to_world(self, lazy=True): + """ Brings the streamlines to world space (i.e. RAS+ and mm). + + The transformation is applied after any other pending transformations + to the streamline points. + + Parameters + ---------- + lazy : True, optional + Should always be True for :class:`LazyTractogram` object. Doing + otherwise will raise a ValueError. + + Returns + ------- + lazy_tractogram : :class:`LazyTractogram` object + A copy of this :class:`LazyTractogram` instance but with a + transformation to be applied on the streamlines. + """ + if self.affine_to_rasmm is None: + msg = ("Streamlines are in a unknown space. This error can be" + " avoided by setting the 'affine_to_rasmm' property.") + raise ValueError(msg) + + return self.apply_affine(self.affine_to_rasmm, lazy=lazy) diff --git a/nibabel/streamlines/tractogram_file.py b/nibabel/streamlines/tractogram_file.py new file mode 100644 index 0000000000..a1dc4e83fb --- /dev/null +++ b/nibabel/streamlines/tractogram_file.py @@ -0,0 +1,108 @@ +""" Define abstract interface for Tractogram file classes +""" +from abc import ABCMeta, abstractmethod +from nibabel.externals.six import with_metaclass + +from .header import Field + + +class ExtensionWarning(Warning): + """ Base class for warnings about tractogram file extension. """ + + +class HeaderWarning(Warning): + """ Base class for warnings about tractogram file header. """ + + +class HeaderError(Exception): + """ Raised when a tractogram file header contains invalid information. """ + + +class DataError(Exception): + """ Raised when data is missing or inconsistent in a tractogram file. """ + + +class abstractclassmethod(classmethod): + __isabstractmethod__ = True + + def __init__(self, callable): + callable.__isabstractmethod__ = True + super(abstractclassmethod, self).__init__(callable) + + +class TractogramFile(with_metaclass(ABCMeta)): + """ Convenience class to encapsulate tractogram file format. """ + + def __init__(self, tractogram, header=None): + self._tractogram = tractogram + self._header = {} if header is None else header + + @property + def tractogram(self): + return self._tractogram + + @property + def streamlines(self): + return self.tractogram.streamlines + + @property + def header(self): + return self._header + + @property + def affine(self): + """ voxmm -> rasmm affine. """ + return self.header.get(Field.VOXEL_TO_RASMM) + + @abstractclassmethod + def is_correct_format(cls, fileobj): + """ Checks if the file has the right streamlines file format. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to a streamlines file (and ready to read from the + beginning of the header). + + Returns + ------- + is_correct_format : {True, False} + Returns True if `fileobj` is in the right streamlines file format, + otherwise returns False. + """ + raise NotImplementedError() + + @abstractclassmethod + def load(cls, fileobj, lazy_load=True): + """ Loads streamlines from a filename or file-like object. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to a streamlines file (and ready to read from the + beginning of the header). + lazy_load : {False, True}, optional + If True, load streamlines in a lazy manner i.e. they will not be + kept in memory. Otherwise, load all streamlines in memory. + + Returns + ------- + tractogram_file : :class:`TractogramFile` object + Returns an object containing tractogram data and header + information. + """ + raise NotImplementedError() + + @abstractmethod + def save(self, fileobj): + """ Saves streamlines to a filename or file-like object. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + opened and ready to write. + """ + raise NotImplementedError() diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py new file mode 100644 index 0000000000..9eeef0f1cd --- /dev/null +++ b/nibabel/streamlines/trk.py @@ -0,0 +1,753 @@ +from __future__ import division + +# Definition of trackvis header structure: +# http://www.trackvis.org/docs/?subsect=fileformat + +import os +import struct +import warnings +import string + +import numpy as np +import nibabel as nib + +from nibabel.openers import Opener +from nibabel.py3k import asstr +from nibabel.volumeutils import (native_code, swapped_code) +from nibabel.orientations import (aff2axcodes, axcodes2ornt) + +from .array_sequence import create_arraysequences_from_generator +from .tractogram_file import TractogramFile +from .tractogram_file import DataError, HeaderError, HeaderWarning +from .tractogram import TractogramItem, Tractogram, LazyTractogram +from .header import Field + + +MAX_NB_NAMED_SCALARS_PER_POINT = 10 +MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE = 10 + +# See http://docs.scipy.org/doc/numpy/reference/arrays.dtypes.html +header_1_dtd = [(Field.MAGIC_NUMBER, 'S6'), + (Field.DIMENSIONS, 'h', 3), + (Field.VOXEL_SIZES, 'f4', 3), + (Field.ORIGIN, 'f4', 3), + (Field.NB_SCALARS_PER_POINT, 'h'), + ('scalar_name', 'S20', MAX_NB_NAMED_SCALARS_PER_POINT), + (Field.NB_PROPERTIES_PER_STREAMLINE, 'h'), + ('property_name', 'S20', + MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE), + ('reserved', 'S508'), + (Field.VOXEL_ORDER, 'S4'), + ('pad2', 'S4'), + ('image_orientation_patient', 'f4', 6), + ('pad1', 'S2'), + ('invert_x', 'S1'), + ('invert_y', 'S1'), + ('invert_z', 'S1'), + ('swap_xy', 'S1'), + ('swap_yz', 'S1'), + ('swap_zx', 'S1'), + (Field.NB_STREAMLINES, 'i4'), + ('version', 'i4'), + ('hdr_size', 'i4'), + ] + +# Version 2 adds a 4x4 matrix giving the affine transformtation going +# from voxel coordinates in the referenced 3D voxel matrix, to xyz +# coordinates (axes L->R, P->A, I->S). If (0 based) value [3, 3] from +# this matrix is 0, this means the matrix is not recorded. +header_2_dtd = [(Field.MAGIC_NUMBER, 'S6'), + (Field.DIMENSIONS, 'h', 3), + (Field.VOXEL_SIZES, 'f4', 3), + (Field.ORIGIN, 'f4', 3), + (Field.NB_SCALARS_PER_POINT, 'h'), + ('scalar_name', 'S20', MAX_NB_NAMED_SCALARS_PER_POINT), + (Field.NB_PROPERTIES_PER_STREAMLINE, 'h'), + ('property_name', 'S20', + MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE), + (Field.VOXEL_TO_RASMM, 'f4', (4, 4)), # New in version 2. + ('reserved', 'S444'), + (Field.VOXEL_ORDER, 'S4'), + ('pad2', 'S4'), + ('image_orientation_patient', 'f4', 6), + ('pad1', 'S2'), + ('invert_x', 'S1'), + ('invert_y', 'S1'), + ('invert_z', 'S1'), + ('swap_xy', 'S1'), + ('swap_yz', 'S1'), + ('swap_zx', 'S1'), + (Field.NB_STREAMLINES, 'i4'), + ('version', 'i4'), + ('hdr_size', 'i4'), + ] + +# Full header numpy dtypes +header_1_dtype = np.dtype(header_1_dtd) +header_2_dtype = np.dtype(header_2_dtd) + + +def get_affine_trackvis_to_rasmm(header): + """ Get affine mapping trackvis voxelmm space to RAS+ mm space + + The streamlines in a trackvis file are in 'voxelmm' space, where the + coordinates refer to the corner of the voxel. + + Compute the affine matrix that will bring them back to RAS+ mm space, where + the coordinates refer to the center of the voxel. + + Parameters + ---------- + header : dict + Dict containing trackvis header. + + Returns + ------- + aff_tv2ras : shape (4, 4) array + Affine array mapping coordinates in 'voxelmm' space to RAS+ mm space. + """ + # TRK's streamlines are in 'voxelmm' space, we will compute the + # affine matrix that will bring them back to RAS+ and mm space. + affine = np.eye(4) + + # The affine matrix found in the TRK header requires the points to + # be in the voxel space. + # voxelmm -> voxel + scale = np.eye(4) + scale[range(3), range(3)] /= header[Field.VOXEL_SIZES] + affine = np.dot(scale, affine) + + # TrackVis considers coordinate (0,0,0) to be the corner of the + # voxel whereas streamlines returned assumes (0,0,0) to be the + # center of the voxel. Thus, streamlines are shifted by half a voxel. + offset = np.eye(4) + offset[:-1, -1] -= 0.5 + affine = np.dot(offset, affine) + + # If the voxel order implied by the affine does not match the voxel + # order in the TRK header, change the orientation. + # voxel (header) -> voxel (affine) + header_ornt = asstr(header[Field.VOXEL_ORDER]) + affine_ornt = "".join(aff2axcodes(header[Field.VOXEL_TO_RASMM])) + header_ornt = axcodes2ornt(header_ornt) + affine_ornt = axcodes2ornt(affine_ornt) + ornt = nib.orientations.ornt_transform(header_ornt, affine_ornt) + M = nib.orientations.inv_ornt_aff(ornt, header[Field.DIMENSIONS]) + affine = np.dot(M, affine) + + # Applied the affine found in the TRK header. + # voxel -> rasmm + voxel_to_rasmm = header[Field.VOXEL_TO_RASMM] + affine_voxmm_to_rasmm = np.dot(voxel_to_rasmm, affine) + return affine_voxmm_to_rasmm.astype(np.float32) + + +def get_affine_rasmm_to_trackvis(header): + return np.linalg.inv(get_affine_trackvis_to_rasmm(header)) + + +def encode_value_in_name(value, name, max_name_len=20): + """ Return `name` as fixed-length string, appending `value` as string. + + Form output from `name` if `value <= 1` else `name` + ``\x00`` + + str(value). + + Return output as fixed length string length `max_name_len`, padded with + ``\x00``. + + This function also verifies that the modified length of name is less than + `max_name_len`. + + Parameters + ---------- + value : int + Integer value to encode. + name : str + Name to which we may append an ascii / latin-1 representation of + `value`. + max_name_len : int, optional + Maximum length of byte string that output can have. + + Returns + ------- + encoded_name : bytes + Name maybe followed by ``\x00`` and ascii / latin-1 representation of + `value`, padded with ``\x00`` bytes. + """ + if len(name) > max_name_len: + msg = ("Data information named '{0}' is too long" + " (max {1} characters.)").format(name, max_name_len) + raise ValueError(msg) + encoded_name = name if value <= 1 else name + '\x00' + str(value) + if len(encoded_name) > max_name_len: + msg = ("Data information named '{0}' is too long (need to be less" + " than {1} characters when storing more than one value" + " for a given data information." + ).format(name, max_name_len - (len(str(value)) + 1)) + raise ValueError(msg) + # Fill to the end with zeros + return encoded_name.ljust(max_name_len, '\x00').encode('latin1') + + +def decode_value_from_name(encoded_name): + """ Decodes a value that has been encoded in the last bytes of a string. + + Check :func:`encode_value_in_name` to see how the value has been encoded. + + Parameters + ---------- + encoded_name : bytes + Name in which a value has been encoded or not. + + Returns + ------- + name : bytes + Name without the encoded value. + value : int + Value decoded from the name. + """ + encoded_name = asstr(encoded_name) + if len(encoded_name) == 0: + return encoded_name, 0 + + splits = encoded_name.rstrip('\x00').split('\x00') + name = splits[0] + value = 1 + + if len(splits) == 2: + value = int(splits[1]) # Decode value. + elif len(splits) > 2: + # The remaining bytes are not \x00, raising. + msg = ("Wrong scalar_name or property_name: '{0}'." + " Unused characters should be \\x00.").format(encoded_name) + raise HeaderError(msg) + + return name, value + + +def create_empty_header(): + """ Return an empty compliant TRK header. """ + header = np.zeros(1, dtype=header_2_dtype) + + # Default values + header[Field.MAGIC_NUMBER] = TrkFile.MAGIC_NUMBER + header[Field.VOXEL_SIZES] = np.array((1, 1, 1), dtype="f4") + header[Field.DIMENSIONS] = np.array((1, 1, 1), dtype="h") + header[Field.VOXEL_TO_RASMM] = np.eye(4, dtype="f4") + header[Field.VOXEL_ORDER] = b"RAS" + header['version'] = 2 + header['hdr_size'] = TrkFile.HEADER_SIZE + + return header + + +class TrkFile(TractogramFile): + """ Convenience class to encapsulate TRK file format. + + Note + ---- + TrackVis (so its file format: TRK) considers the streamline coordinate + (0,0,0) to be in the corner of the voxel whereas NiBabel's streamlines + internal representation (Voxel space) assumes (0,0,0) to be in the + center of the voxel. + + Thus, streamlines are shifted by half a voxel on load and are shifted + back on save. + """ + + # Constants + MAGIC_NUMBER = b"TRACK" + HEADER_SIZE = 1000 + SUPPORTS_DATA_PER_POINT = True + SUPPORTS_DATA_PER_STREAMLINE = True + + def __init__(self, tractogram, header=None): + """ + Parameters + ---------- + tractogram : :class:`Tractogram` object + Tractogram that will be contained in this :class:`TrkFile`. + + header : dict, optional + Metadata associated to this tractogram file. + + Notes + ----- + Streamlines of the tractogram are assumed to be in *RAS+* + and *mm* space where coordinate (0,0,0) refers to the center + of the voxel. + """ + if header is None: + header_rec = create_empty_header() + header = dict(zip(header_rec.dtype.names, header_rec[0])) + + super(TrkFile, self).__init__(tractogram, header) + + @classmethod + def is_correct_format(cls, fileobj): + """ Check if the file is in TRK format. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to TRK file (and ready to read from the beginning + of the TRK header data). Note that calling this function + does not change the file position. + + Returns + ------- + is_correct_format : {True, False} + Returns True if `fileobj` is compatible with TRK format, + otherwise returns False. + """ + with Opener(fileobj) as f: + magic_len = len(cls.MAGIC_NUMBER) + magic_number = f.read(magic_len) + f.seek(-magic_len, os.SEEK_CUR) + return magic_number == cls.MAGIC_NUMBER + + @classmethod + def load(cls, fileobj, lazy_load=False): + """ Loads streamlines from a filename or file-like object. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to TRK file (and ready to read from the beginning + of the TRK header). Note that calling this function + does not change the file position. + lazy_load : {False, True}, optional + If True, load streamlines in a lazy manner i.e. they will not be + kept in memory. Otherwise, load all streamlines in memory. + + Returns + ------- + trk_file : :class:`TrkFile` object + Returns an object containing tractogram data and header + information. + + Notes + ----- + Streamlines of the returned tractogram are assumed to be in *RAS* + and *mm* space where coordinate (0,0,0) refers to the center of the + voxel. + """ + hdr = cls._read_header(fileobj) + + # Find scalars and properties name + data_per_point_slice = {} + if hdr[Field.NB_SCALARS_PER_POINT] > 0: + cpt = 0 + for scalar_name in hdr['scalar_name']: + scalar_name, nb_scalars = decode_value_from_name(scalar_name) + + if nb_scalars == 0: + continue + + slice_obj = slice(cpt, cpt + nb_scalars) + data_per_point_slice[scalar_name] = slice_obj + cpt += nb_scalars + + if cpt < hdr[Field.NB_SCALARS_PER_POINT]: + slice_obj = slice(cpt, hdr[Field.NB_SCALARS_PER_POINT]) + data_per_point_slice['scalars'] = slice_obj + + data_per_streamline_slice = {} + if hdr[Field.NB_PROPERTIES_PER_STREAMLINE] > 0: + cpt = 0 + for property_name in hdr['property_name']: + results = decode_value_from_name(property_name) + property_name, nb_properties = results + + if nb_properties == 0: + continue + + slice_obj = slice(cpt, cpt + nb_properties) + data_per_streamline_slice[property_name] = slice_obj + cpt += nb_properties + + if cpt < hdr[Field.NB_PROPERTIES_PER_STREAMLINE]: + slice_obj = slice(cpt, hdr[Field.NB_PROPERTIES_PER_STREAMLINE]) + data_per_streamline_slice['properties'] = slice_obj + + if lazy_load: + def _read(): + for pts, scals, props in cls._read(fileobj, hdr): + items = data_per_point_slice.items() + data_for_points = dict((k, scals[:, v]) for k, v in items) + items = data_per_streamline_slice.items() + data_for_streamline = dict((k, props[v]) for k, v in items) + yield TractogramItem(pts, + data_for_streamline, + data_for_points) + + tractogram = LazyTractogram.from_data_func(_read) + + else: + trk_reader = cls._read(fileobj, hdr) + arr_seqs = create_arraysequences_from_generator(trk_reader, n=3) + streamlines, scalars, properties = arr_seqs + properties = np.asarray(properties) # Actually a 2d array. + tractogram = Tractogram(streamlines) + + for name, slice_ in data_per_point_slice.items(): + tractogram.data_per_point[name] = scalars[:, slice_] + + for name, slice_ in data_per_streamline_slice.items(): + tractogram.data_per_streamline[name] = properties[:, slice_] + + tractogram.affine_to_rasmm = get_affine_trackvis_to_rasmm(hdr) + tractogram = tractogram.to_world() + + return cls(tractogram, header=hdr) + + def save(self, fileobj): + """ Save tractogram to a filename or file-like object using TRK format. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to TRK file (and ready to write from the beginning + of the TRK header data). + """ + header = create_empty_header() + + # Override hdr's fields by those contained in `header`. + for k, v in self.header.items(): + if k in header_2_dtype.fields.keys(): + header[k] = v + + # By default, the voxel order is LPS. + # http://trackvis.org/blog/forum/diffusion-toolkit-usage/interpretation-of-track-point-coordinates + if header[Field.VOXEL_ORDER] == b"": + header[Field.VOXEL_ORDER] = b"LPS" + + # Keep counts for correcting incoherent fields or warn. + nb_streamlines = 0 + nb_points = 0 + nb_scalars = 0 + nb_properties = 0 + + header = header[0] + with Opener(fileobj, mode="wb") as f: + # Keep track of the beginning of the header. + beginning = f.tell() + + # Write temporary header that we will update at the end + f.write(header.tostring()) + + i4_dtype = np.dtype(" MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE: + msg = ("Can only store {0} named data_per_streamline (also" + " known as 'properties' in the TRK format)." + ).format(MAX_NB_NAMED_SCALARS_PER_POINT) + raise ValueError(msg) + + data_for_streamline_keys = sorted(data_for_streamline.keys()) + property_name = np.zeros(MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE, + dtype='S20') + for i, name in enumerate(data_for_streamline_keys): + # Append number of values as ascii to zero-terminated name + # to encode number of values into trackvis name. + nb_values = data_for_streamline[name].shape[-1] + property_name[i] = encode_value_in_name(nb_values, name) + header['property_name'][:] = property_name + + # Update field 'scalar_name' using 'tractogram.data_per_point'. + data_for_points = first_item.data_for_points + if len(data_for_points) > MAX_NB_NAMED_SCALARS_PER_POINT: + msg = ("Can only store {0} named data_per_point (also known" + " as 'scalars' in the TRK format)." + ).format(MAX_NB_NAMED_SCALARS_PER_POINT) + raise ValueError(msg) + + data_for_points_keys = sorted(data_for_points.keys()) + scalar_name = np.zeros(MAX_NB_NAMED_SCALARS_PER_POINT, dtype='S20') + for i, name in enumerate(data_for_points_keys): + # Append number of values as ascii to zero-terminated name + # to encode number of values into trackvis name. + nb_values = data_for_points[name].shape[-1] + scalar_name[i] = encode_value_in_name(nb_values, name) + header['scalar_name'][:] = scalar_name + + # Make sure streamlines are in rasmm then send them to voxmm. + tractogram = self.tractogram.to_world(lazy=True) + affine_to_trackvis = get_affine_rasmm_to_trackvis(header) + tractogram = tractogram.apply_affine(affine_to_trackvis, lazy=True) + + for t in tractogram: + if any((len(d) != len(t.streamline) + for d in t.data_for_points.values())): + raise DataError("Missing scalars for some points!") + + points = np.asarray(t.streamline, dtype=f4_dtype) + scalars = [np.asarray(t.data_for_points[k], dtype=f4_dtype) + for k in data_for_points_keys] + scalars = np.concatenate([np.ndarray((len(points), 0), + dtype=f4_dtype) + ] + scalars, axis=1) + properties = [np.asarray(t.data_for_streamline[k], + dtype=f4_dtype) + for k in data_for_streamline_keys] + properties = np.concatenate([np.array([], dtype=f4_dtype) + ] + properties) + + data = struct.pack(i4_dtype.str[:-1], len(points)) + data += np.concatenate([points, scalars], axis=1).tostring() + data += properties.tostring() + f.write(data) + + nb_streamlines += 1 + nb_points += len(points) + nb_scalars += scalars.size + nb_properties += len(properties) + + # Use those values to update the header. + nb_scalars_per_point = nb_scalars / nb_points + nb_properties_per_streamline = nb_properties / nb_streamlines + + # Check for errors + if nb_scalars_per_point != int(nb_scalars_per_point): + msg = "Nb. of scalars differs from one point to another!" + raise DataError(msg) + + if nb_properties_per_streamline != int(nb_properties_per_streamline): + msg = ("Nb. of properties differs from one streamline to" + " another!") + raise DataError(msg) + + header[Field.NB_STREAMLINES] = nb_streamlines + header[Field.NB_SCALARS_PER_POINT] = nb_scalars_per_point + header[Field.NB_PROPERTIES_PER_STREAMLINE] = nb_properties_per_streamline + + # Overwrite header with updated one. + f.seek(beginning, os.SEEK_SET) + f.write(header.tostring()) + + @staticmethod + def _read_header(fileobj): + """ Reads a TRK header from a file. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to TRK file (and ready to read from the beginning + of the TRK header). Note that calling this function + does not change the file position. + + Returns + ------- + header : dict + Metadata associated with this tractogram file. + """ + # Record start position if this is a file-like object + start_position = fileobj.tell() if hasattr(fileobj, 'tell') else None + + with Opener(fileobj) as f: + + # Read the header in one block. + header_str = f.read(header_2_dtype.itemsize) + header_rec = np.fromstring(string=header_str, dtype=header_2_dtype) + + # Check endianness + endianness = native_code + if header_rec['hdr_size'] != TrkFile.HEADER_SIZE: + endianness = swapped_code + + # Swap byte order + header_rec = header_rec.newbyteorder() + if header_rec['hdr_size'] != TrkFile.HEADER_SIZE: + msg = "Invalid hdr_size: {0} instead of {1}" + raise HeaderError(msg.format(header_rec['hdr_size'], + TrkFile.HEADER_SIZE)) + + if header_rec['version'] == 1: + header_rec = np.fromstring(string=header_str, + dtype=header_1_dtype) + elif header_rec['version'] == 2: + pass # Nothing more to do. + else: + raise HeaderError('NiBabel only supports versions 1 and 2 of ' + 'the Trackvis file format') + + # Convert the first record of `header_rec` into a dictionnary + header = dict(zip(header_rec.dtype.names, header_rec[0])) + header[Field.ENDIANNESS] = endianness + + # If vox_to_ras[3][3] is 0, it means the matrix is not recorded. + if header[Field.VOXEL_TO_RASMM][3][3] == 0: + header[Field.VOXEL_TO_RASMM] = np.eye(4, dtype=np.float32) + warnings.warn(("Field 'vox_to_ras' in the TRK's header was" + " not recorded. Will continue assuming it's" + " the identity."), HeaderWarning) + + # Check that the 'vox_to_ras' affine is valid, i.e. should be + # able to determine the axis directions. + axcodes = aff2axcodes(header[Field.VOXEL_TO_RASMM]) + if None in axcodes: + msg = ("The 'vox_to_ras' affine is invalid! Could not" + " determine the axis directions from it.\n{0}" + ).format(header[Field.VOXEL_TO_RASMM]) + raise HeaderError(msg) + + # By default, the voxel order is LPS. + # http://trackvis.org/blog/forum/diffusion-toolkit-usage/interpretation-of-track-point-coordinates + if header[Field.VOXEL_ORDER] == b"": + msg = ("Voxel order is not specified, will assume 'LPS' since" + "it is Trackvis software's default.") + warnings.warn(msg, HeaderWarning) + header[Field.VOXEL_ORDER] = b"LPS" + + # Keep the file position where the data begin. + header['_offset_data'] = f.tell() + + # Set the file position where it was, if it was previously open + if start_position is not None: + fileobj.seek(start_position, os.SEEK_SET) + + return header + + @staticmethod + def _read(fileobj, header): + """ Return generator that reads TRK data from `fileobj` given `header` + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to TRK file (and ready to read from the beginning + of the TRK header). Note that calling this function + does not change the file position. + header : dict + Metadata associated with this tractogram file. + + Yields + ------ + data : tuple of ndarrays + Length 3 tuple of streamline data of form (points, scalars, + properties), where: + + * points: ndarray of shape (n_pts, 3) + * scalars: ndarray of shape (n_pts, nb_scalars_per_point) + * properties: ndarray of shape (nb_properties_per_point,) + """ + i4_dtype = np.dtype(header[Field.ENDIANNESS] + "i4") + f4_dtype = np.dtype(header[Field.ENDIANNESS] + "f4") + + with Opener(fileobj) as f: + start_position = f.tell() + + nb_pts_and_scalars = int(3 + + header[Field.NB_SCALARS_PER_POINT]) + pts_and_scalars_size = int(nb_pts_and_scalars * f4_dtype.itemsize) + nb_properties = header[Field.NB_PROPERTIES_PER_STREAMLINE] + properties_size = int(nb_properties * f4_dtype.itemsize) + + # Set the file position at the beginning of the data. + f.seek(header["_offset_data"], os.SEEK_SET) + + # If 'count' field is 0, i.e. not provided, we have to loop + # until the EOF. + nb_streamlines = header[Field.NB_STREAMLINES] + if nb_streamlines == 0: + nb_streamlines = np.inf + + count = 0 + nb_pts_dtype = i4_dtype.str[:-1] + while count < nb_streamlines: + nb_pts_str = f.read(i4_dtype.itemsize) + + # Check if we reached EOF + if len(nb_pts_str) == 0: + break + + # Read number of points of the next streamline. + nb_pts = struct.unpack(nb_pts_dtype, nb_pts_str)[0] + + # Read streamline's data + points_and_scalars = np.ndarray( + shape=(nb_pts, nb_pts_and_scalars), + dtype=f4_dtype, + buffer=f.read(nb_pts * pts_and_scalars_size)) + + points = points_and_scalars[:, :3] + scalars = points_and_scalars[:, 3:] + + # Read properties + properties = np.ndarray( + shape=(nb_properties,), + dtype=f4_dtype, + buffer=f.read(properties_size)) + + yield points, scalars, properties + count += 1 + + # In case the 'count' field was not provided. + header[Field.NB_STREAMLINES] = count + + # Set the file position where it was (in case it was already open). + f.seek(start_position, os.SEEK_CUR) + + def __str__(self): + """ Gets a formatted string of the header of a TRK file. + + Returns + ------- + info : string + Header information relevant to the TRK format. + """ + vars = self.header.copy() + for attr in dir(Field): + if attr[0] in string.ascii_uppercase: + hdr_field = getattr(Field, attr) + if hdr_field in vars: + vars[attr] = vars[hdr_field] + vars['scalar_names'] = '\n '.join([asstr(s) + for s in vars['scalar_name'] + if len(s) > 0]) + vars['property_names'] = "\n ".join([asstr(s) + for s in vars['property_name'] + if len(s) > 0]) + return """\ +MAGIC NUMBER: {MAGIC_NUMBER} +v.{version} +dim: {DIMENSIONS} +voxel_sizes: {VOXEL_SIZES} +orgin: {ORIGIN} +nb_scalars: {NB_SCALARS_PER_POINT} +scalar_name:\n {scalar_names} +nb_properties: {NB_PROPERTIES_PER_STREAMLINE} +property_name:\n {property_names} +vox_to_world:\n{VOXEL_TO_RASMM} +voxel_order: {VOXEL_ORDER} +image_orientation_patient: {image_orientation_patient} +pad1: {pad1} +pad2: {pad2} +invert_x: {invert_x} +invert_y: {invert_y} +invert_z: {invert_z} +swap_xy: {swap_xy} +swap_yz: {swap_yz} +swap_zx: {swap_zx} +n_count: {NB_STREAMLINES} +hdr_size: {hdr_size}""".format(**vars) diff --git a/nibabel/streamlines/utils.py b/nibabel/streamlines/utils.py new file mode 100644 index 0000000000..0ef5b740ac --- /dev/null +++ b/nibabel/streamlines/utils.py @@ -0,0 +1,31 @@ +import nibabel + + +def get_affine_from_reference(ref): + """ Returns the affine defining the reference space. + + Parameter + --------- + ref : str or :class:`Nifti1Image` object or ndarray shape (4, 4) + If str then it's the filename of reference file that will be loaded + using :func:`nibabel.load` in order to obtain the affine. + If :class:`Nifti1Image` object then the affine is obtained from it. + If ndarray shape (4, 4) then it's the affine. + + Returns + ------- + affine : ndarray (4, 4) + Transformation matrix mapping voxel space to RAS+mm space. + """ + if hasattr(ref, 'affine'): + return ref.affine + + if hasattr(ref, 'shape'): + if ref.shape != (4, 4): + msg = "`ref` needs to be a numpy array with shape (4, 4)!" + raise ValueError(msg) + + return ref + + # Assume `ref` is the name of a neuroimaging file. + return nibabel.load(ref).affine diff --git a/nibabel/testing/__init__.py b/nibabel/testing/__init__.py index c6b5ebd66b..2200b25182 100644 --- a/nibabel/testing/__init__.py +++ b/nibabel/testing/__init__.py @@ -16,6 +16,7 @@ from os.path import dirname, abspath, join as pjoin import numpy as np +from numpy.testing import assert_array_equal from numpy.testing.decorators import skipif # Allow failed import of nose if not now running tests @@ -25,6 +26,8 @@ except ImportError: pass +from nibabel.externals.six.moves import zip_longest + # set path to example data data_path = abspath(pjoin(dirname(__file__), '..', 'tests', 'data')) @@ -62,6 +65,13 @@ def assert_allclose_safely(a, b, match_nans=True, rtol=1e-5, atol=1e-8): assert_true(np.allclose(a, b, rtol=rtol, atol=atol)) +def assert_arrays_equal(arrays1, arrays2): + """ Check two iterables yield the same sequence of arrays. """ + for arr1, arr2 in zip_longest(arrays1, arrays2, fillvalue=None): + assert_false(arr1 is None or arr2 is None) + assert_array_equal(arr1, arr2) + + def assert_re_in(regex, c, flags=0): """Assert that container (list, str, etc) contains entry matching the regex """ @@ -190,3 +200,12 @@ def runif_extra_has(test_str): """Decorator checks to see if NIPY_EXTRA_TESTS env var contains test_str""" return skipif(test_str not in EXTRA_SET, "Skip {0} tests.".format(test_str)) + + +def assert_arr_dict_equal(dict1, dict2): + """ Assert that two dicts are equal, where dicts contain arrays + """ + assert_equal(set(dict1), set(dict2)) + for key, value1 in dict1.items(): + value2 = dict2[key] + assert_array_equal(value1, value2) diff --git a/nibabel/tests/data/complex.trk b/nibabel/tests/data/complex.trk new file mode 100644 index 0000000000..e2860ee95a Binary files /dev/null and b/nibabel/tests/data/complex.trk differ diff --git a/nibabel/tests/data/complex_big_endian.trk b/nibabel/tests/data/complex_big_endian.trk new file mode 100644 index 0000000000..0f5b9e71ba Binary files /dev/null and b/nibabel/tests/data/complex_big_endian.trk differ diff --git a/nibabel/tests/data/empty.trk b/nibabel/tests/data/empty.trk new file mode 100644 index 0000000000..fbe0871807 Binary files /dev/null and b/nibabel/tests/data/empty.trk differ diff --git a/nibabel/tests/data/gen_standard.py b/nibabel/tests/data/gen_standard.py new file mode 100644 index 0000000000..b97da8ff2f --- /dev/null +++ b/nibabel/tests/data/gen_standard.py @@ -0,0 +1,86 @@ +""" Generate mask and testing tractogram in known formats: + +* mask: standard.nii.gz +* tractogram: + + * standard.trk +""" +import numpy as np +import nibabel as nib + +from nibabel.streamlines import FORMATS +from nibabel.streamlines.header import Field + + +def mark_the_spot(mask): + """ Marks every nonzero voxel using streamlines to form a 3D 'X' inside. + + Generates streamlines forming a 3D 'X' inside every nonzero voxel. + + Parameters + ---------- + mask : ndarray + Mask containing the spots to be marked. + + Returns + ------- + list of ndarrays + All streamlines needed to mark every nonzero voxel in the `mask`. + """ + def _gen_straight_streamline(start, end, steps=3): + coords = [] + for s, e in zip(start, end): + coords.append(np.linspace(s, e, steps)) + + return np.array(coords).T + + # Generate a 3D 'X' template fitting inside the voxel centered at (0,0,0). + X = [] + X.append(_gen_straight_streamline((-0.5, -0.5, -0.5), (0.5, 0.5, 0.5))) + X.append(_gen_straight_streamline((-0.5, 0.5, -0.5), (0.5, -0.5, 0.5))) + X.append(_gen_straight_streamline((-0.5, 0.5, 0.5), (0.5, -0.5, -0.5))) + X.append(_gen_straight_streamline((-0.5, -0.5, 0.5), (0.5, 0.5, -0.5))) + + # Get the coordinates of voxels 'on' in the mask. + coords = np.array(zip(*np.where(mask))) + + streamlines = [] + for c in coords: + for line in X: + streamlines.append((line + c) * voxel_size) + + return streamlines + + +rng = np.random.RandomState(42) + +width = 4 # Coronal +height = 5 # Sagittal +depth = 7 # Axial + +voxel_size = np.array((1., 3., 2.)) + +# Generate a random mask with voxel order RAS+. +mask = rng.rand(width, height, depth) > 0.8 +mask = (255*mask).astype(np.uint8) + +# Build tractogram +streamlines = mark_the_spot(mask) +tractogram = nib.streamlines.Tractogram(streamlines) + +# Build header +affine = np.eye(4) +affine[range(3), range(3)] = voxel_size +header = {Field.DIMENSIONS: (width, height, depth), + Field.VOXEL_SIZES: voxel_size, + Field.VOXEL_TO_RASMM: affine, + Field.VOXEL_ORDER: 'RAS'} + +# Save the standard mask. +nii = nib.Nifti1Image(mask, affine=affine) +nib.save(nii, "standard.nii.gz") + +# Save the standard tractogram in every available file format. +for ext, cls in FORMATS.items(): + tfile = cls(tractogram, header) + nib.streamlines.save(tfile, "standard" + ext) diff --git a/nibabel/tests/data/simple.trk b/nibabel/tests/data/simple.trk new file mode 100644 index 0000000000..df601e29a7 Binary files /dev/null and b/nibabel/tests/data/simple.trk differ diff --git a/nibabel/tests/data/standard.LPS.trk b/nibabel/tests/data/standard.LPS.trk new file mode 100644 index 0000000000..ebda71bdb8 Binary files /dev/null and b/nibabel/tests/data/standard.LPS.trk differ diff --git a/nibabel/tests/data/standard.nii.gz b/nibabel/tests/data/standard.nii.gz new file mode 100644 index 0000000000..98bb31a778 Binary files /dev/null and b/nibabel/tests/data/standard.nii.gz differ diff --git a/nibabel/tests/data/standard.trk b/nibabel/tests/data/standard.trk new file mode 100644 index 0000000000..01ea01744a Binary files /dev/null and b/nibabel/tests/data/standard.trk differ diff --git a/nibabel/tests/test_parrec.py b/nibabel/tests/test_parrec.py index 63e96c4938..ed50150706 100644 --- a/nibabel/tests/test_parrec.py +++ b/nibabel/tests/test_parrec.py @@ -23,7 +23,8 @@ from nose.tools import (assert_true, assert_false, assert_raises, assert_equal) -from ..testing import clear_and_catch_warnings, suppress_warnings +from ..testing import (clear_and_catch_warnings, suppress_warnings, + assert_arr_dict_equal) from .test_arrayproxy import check_mmap from . import test_spatialimages as tsi @@ -618,13 +619,6 @@ def test_copy_on_init(): assert_array_equal(HDR_DEFS['image pixel size'], 16) -def assert_arr_dict_equal(dict1, dict2): - assert_equal(set(dict1), set(dict2)) - for key, value1 in dict1.items(): - value2 = dict2[key] - assert_array_equal(value1, value2) - - def assert_structarr_equal(star1, star2): # Compare structured arrays (array_equal does not work for np 1.5) assert_equal(star1.dtype, star2.dtype) diff --git a/setup.py b/setup.py index d5160b4a0c..5e9bf51c29 100755 --- a/setup.py +++ b/setup.py @@ -90,6 +90,8 @@ def main(**extra_args): 'nibabel.testing', 'nibabel.tests', 'nibabel.benchmarks', + 'nibabel.streamlines', + 'nibabel.streamlines.tests', # install nisext as its own package 'nisext', 'nisext.tests'], @@ -104,6 +106,7 @@ def main(**extra_args): pjoin('externals', 'tests', 'data', '*'), pjoin('nicom', 'tests', 'data', '*'), pjoin('gifti', 'tests', 'data', '*'), + pjoin('streamlines', 'tests', 'data', '*'), ]}, scripts = [pjoin('bin', 'parrec2nii'), pjoin('bin', 'nib-ls'),